# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#
from pipeline import DAGNode
import numpy as np
[docs]
class CorrespondingPointsFromDisparity(DAGNode):
"""
Extract corresponding points from disparity maps.
This class takes disparity maps (representing the horizontal and vertical displacements between images) as input
and computes the corresponding points in the second image based on the disparity information.
Attributes
----------
n_points : int
Number of points to extract from the disparity maps.
"""
def __init__(self, n_points=10000):
"""
Initialize the node.
Parameters
----------
n_points : int, optional
The number of points to extract from the disparity maps. Default is 10000.
"""
super().__init__()
self.n_points = n_points
[docs]
def eval(self, disparity_maps):
"""
Compute updated corresponding points using disparity maps.
This method iterates over the disparity maps and applies the displacement values from the disparity maps
to a grid of points in the first image, removing points that lie outside the image border.
Parameters
----------
disparity_maps : list of list of np.ndarray
A nested list of disparity maps. Each face (outer list) contains a list of disparity maps.
Each disparity map is a (2, H, W) array representing the x and y displacements respectively.
Returns
-------
list of list of np.ndarray
Valid corresponding points across all disparity maps.
"""
print(f"Start of corresponding points from disparity maps, n_points = {self.n_points}")
h, w = disparity_maps[0][0][0].shape
# create grid in first image
y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
points = []
for face in disparity_maps:
points.append([np.vstack([x.ravel(), y.ravel()]).T])
for disp_map in face:
curr = points[-1][-1]
idx = curr.astype(float) # float required for handling NaN
mask_valid = ~np.isnan(idx[:, 0]) & ~np.isnan(idx[:, 1])
idx_valid = np.floor(idx[mask_valid]).astype(int)
delta_x = disp_map[0][idx_valid[:, 1], idx_valid[:, 0]]
delta_y = disp_map[1][idx_valid[:, 1], idx_valid[:, 0]]
update = curr.astype(float)
update[mask_valid] = idx_valid + np.vstack([delta_x, delta_y]).T
update[update < 0] = np.nan
update[update[:, 1] >= h] = np.nan
update[update[:, 0] >= w] = np.nan
points[-1].append(update)
# remove nan
valid_points = []
min_n_points = h*w
for points_face in points:
valid_points.append([])
points_face = np.asarray(points_face)
mask = ~np.isnan(points_face[0][:, 0]) & ~np.isnan(points_face[0][:, 1])
for p in points_face:
mask &= (~np.isnan(p[:, 0]) & ~np.isnan(p[:, 1]))
valid_p = [np.asarray(p[mask]) for p in points_face]
if valid_p[0].shape[0] < min_n_points:
min_n_points = valid_p[0].shape[0]
valid_points[-1] += valid_p
sampled_indices = np.random.choice(min_n_points, self.n_points, replace=False)
for i,points_face in enumerate(valid_points):
for j,p in enumerate(points_face):
valid_points[i][j] = valid_points[i][j][sampled_indices]
self.points = valid_points
print(f"End of corresponding points from disparity maps")
return [self.points]
[docs]
class CorrespondingPointsFromMatches(DAGNode):
"""
Identify corresponding points across an image collection using feature matches.
This class takes a series of matched points from multiple images (typically, the output of a feature matching algorithm) and finds feature points
that have consistent correspondences across image pairs (e.g., a <-> b, b <-> c implies a <-> b <-> c).
Attributes
----------
atol : float
Absolute tolerance for coordinates to be considered equivalent (used for rounding).
"""
def __init__(self, atol):
"""
Initialize the node.
Parameters
----------
atol : float
Absolute tolerance for rounding coordinates in the matching process.
"""
super().__init__()
self.atol = atol
[docs]
def round_to_nearest_multiple(self, arr, atol):
"""
Round coordinates to the nearest multiple of the given absolute tolerance.
Parameters
----------
arr : np.ndarray
Input coordinate array.
atol : float
Absolute tolerance for rounding.
Returns
-------
np.ndarray
Array with coordinates rounded to the nearest multiple of `atol`.
"""
if atol == 0.:
return arr
return np.round(arr/atol) * atol
[docs]
def find_duplicates(self, arr):
"""
Find duplicate (non-unique) 2D points in an array.
Parameters
----------
arr : np.ndarray
A 2D array of shape (n, 2), where n is the number of keypoints.
Returns
-------
np.ndarray
Indices of duplicate points.
"""
arr = np.asarray(arr) # converts to numpy array
unique_matches, idx_unique = np.unique(arr, return_index=True, axis=0)
all_indices = np.arange(arr.shape[0])
non_unique_idx = np.where(np.isin(all_indices, idx_unique, invert=True))[0]
return non_unique_idx
[docs]
def eval(self, matches):
"""
Filter and propagate consistent feature matches across an image sequence.
Parameters
----------
matches : list of tuple of np.ndarray
A list of match pairs. Each element is a tuple of two (N, 2) numpy arrays (`pts_1`, `pts_2`)
representing keypoints matched between two consecutive images.
Returns
-------
list of np.ndarray
A list of keypoints corresponding across all images, one per image in the sequence.
"""
newMatches = []
newMatches.append(matches[0][0]) # first image of first pair
newMatches.append(matches[0][1]) # second image of first pair
for k in range(1, len(matches)): # number of pairs
# We compute the masks on a rounded array (to give some tolerance)
rounded_matches = self.round_to_nearest_multiple(matches[k][0], self.atol)
rounded_newMatches = self.round_to_nearest_multiple(newMatches[-1], self.atol)
# With the rounding, we may have introduced duplicates. Here we find the indices of the duplicates
non_unique_idx = self.find_duplicates(rounded_matches)
non_unique_idx_new = self.find_duplicates(rounded_newMatches)
# only keep matches a<->b, b<->c,when point in b is in both rounded matches
mask_matches = np.array([np.any(np.all(point == rounded_newMatches, axis=1)) for point in rounded_matches])
mask_new = np.array([np.any(np.all(point == rounded_matches, axis=1)) for point in rounded_newMatches])
# we correct now the masks accordingly
mask_matches[non_unique_idx] = False
mask_new[non_unique_idx_new] = False
if np.sum(mask_matches) != np.sum(mask_new):
raise Exception('Masks do not have the same number of inliers.')
#mask matches
masked_matches0 = matches[k][0][mask_matches]
masked_matches1 = matches[k][1][mask_matches]
masked_new = newMatches[-1][mask_new]
# propagates to previous images
for u in range(0, k + 1):
newMatches[u] = newMatches[u][mask_new]
# find new ordering
idx = np.lexsort((masked_new[:,1], masked_new[:,0]))
idx2 = np.lexsort((masked_matches0[:,1], masked_matches0[:,0]))
for i in range(0, k + 1):
newMatches[i] = newMatches[i][idx]
# append new points (pts3)
newMatches.append(masked_matches1[idx2])
self.newMatches = newMatches
return [newMatches]