# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#
import numpy as np
from PIL import Image
from scipy.interpolate import griddata
import matplotlib
import matplotlib.pyplot as plt
from pipeline import DAGNode
[docs]
class DenseRoma(DAGNode):
"""
Perform dense feature matching between image pairs using the RoMa model.
RoMa can be installed from the researcher's repository:
https://github.com/Parskatt/RoMa
"""
def __init__(self, num_matches=5000, tiny=False):
"""
Initialize the DenseRoma node.
Parameters
----------
num_matches : int, optional
Number of matches to sample per image pair. Default is 5000.
tiny : bool, optional
Whether to use the RoMa tiny model variant. Default is False.
"""
import torch
super().__init__()
self.num_matches = num_matches
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.tiny = tiny # use the RoMa tiny model
self.img_collection = None
[docs]
def eval(self, img_collection):
"""
Run RoMa matching on an image collection.
Parameters
----------
img_collection : ImageCollection
A collection of faces, each containing a list of image pairs to match.
Returns
-------
List[List[np.ndarray]]
A nested list of matched keypoints for each face. Each element contains a list of two (N, 2) arrays
representing the matched keypoints in the two images.
"""
from romatch import roma_outdoor, tiny_roma_v1_outdoor
self.img_collection = img_collection
h, w = img_collection[0].shape[0], img_collection[0].shape[1]
print(f"Start of RoMa feature matching..., tiny={self.tiny}")
if self.tiny:
roma_model = tiny_roma_v1_outdoor(device=self.device) # coarse_res=560, upsample_res=(w, h))
else:
roma_model = roma_outdoor(device=self.device, coarse_res=560, upsample_res=(w, h))
#H, W = roma_model.get_output_resolution()
points = []
for face in img_collection.faces:
points.append([])
for pair in face:
im = Image.fromarray(pair[0].image)
if self.tiny:
warp, certainty = roma_model.match(Image.fromarray(pair[0].image), Image.fromarray(pair[1].image))
else:
warp, certainty = roma_model.match(Image.fromarray(pair[0].image), Image.fromarray(pair[1].image), device=self.device)
matches, certainty = roma_model.sample(warp, certainty, num=self.num_matches)
k1, k2 = roma_model.to_pixel_coordinates(matches, h, w, h, w)
points[-1].append(np.asarray([np.asarray(k1.cpu()), np.asarray(k2.cpu())]))
self.points = points
print("End of RoMa feature matching")
return [points]
[docs]
def plot_images(self, img1, img2):
"""
Plot two images side by side.
Parameters
----------
img1 : np.ndarray
First image. Shape (H, W, 3) or (H, W).
img2 : np.ndarray
Second image. Shape (H, W, 3) or (H, W).
"""
fig, ax = plt.subplots(1, 2, figsize=(10,10))
ax[0].imshow(img1, cmap='gray')
ax[1].imshow(img2, cmap='gray')
ax[0].get_yaxis().set_ticks([])
ax[0].get_xaxis().set_ticks([])
ax[0].set_axis_off()
ax[1].get_yaxis().set_ticks([])
ax[1].get_xaxis().set_ticks([])
ax[1].set_axis_off()
fig.tight_layout(pad=0.5)
[docs]
def plot_keypoints(self, kpts, colors="lime"):
"""
Plot keypoints on the currently displayed images.
Parameters
----------
kpts : List[np.ndarray]
List of (N, 2) arrays containing keypoint coordinates.
colors : str or List[str], optional
Color or list of colors for the keypoints. Default is "lime".
"""
if not isinstance(colors, list):
colors = [colors] * len(kpts)
a = [1.0] * len(kpts)
axes = plt.gcf().axes
for ax, k, c, alpha in zip(axes, kpts, colors, a):
ax.scatter(k[:, 0], k[:, 1], c=c, s=4, linewidths=0, alpha=alpha)
[docs]
def plot_matches(self, kpts0, kpts1):
"""
Plot lines between matching keypoints in a pair of images.
Parameters
----------
kpts0 : np.ndarray
(N, 2) array of keypoints in the first image.
kpts1 : np.ndarray
(N, 2) array of keypoints in the second image.
"""
fig = plt.gcf()
a = 1.0
ps = 4
lw = 1.5
ax = fig.axes
ax0, ax1 = ax[0], ax[1]
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
for i in range(len(kpts0)):
line = matplotlib.patches.ConnectionPatch(
xyA=(kpts0[i, 0], kpts0[i, 1]),
xyB=(kpts1[i, 0], kpts1[i, 1]),
coordsA=ax0.transData,
coordsB=ax1.transData,
axesA=ax0,
axesB=ax1,
zorder=1,
color=color[i],
linewidth=lw,
clip_on=True,
alpha=a,
label=None,
picker=5.0,
)
line.set_annotation_clip(True)
fig.add_artist(line)
# freeze the axes to prevent the transform to change
ax0.autoscale(enable=False)
ax1.autoscale(enable=False)
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
[docs]
def show(self, idx, n):
"""
Visualize matching keypoints for a given face and image index.
Parameters
----------
idx : tuple of int
A tuple of (face_index, image_index).
n : int
Number of keypoints/matches to display.
"""
if self.img_collection is None:
raise ValueError("self.img_collection is None. You have to run the component first.")
try:
idx = tuple(idx)
except TypeError:
raise TypeError("Idx must be a sequence type (list, arrays, etc.). The first element is the face index, the second is the image index.")
img0 = self.img_collection[idx].image
img1 = self.img_collection[idx[0], idx[1]+1].image
self.plot_images(img0, img1)
self.plot_keypoints([self.points[idx[0]][idx[1]][0], self.points[idx[0]][idx[1]][1]])
self.plot_matches(self.points[idx[0]][idx[1]][0][:n], self.points[idx[0]][idx[1]][1][:n])
[docs]
class Interpolate(DAGNode):
"""
Densify sparse feature matches into full disparity maps using interpolation.
Interpolation is done with `scipy.interpolate.griddata(method='linear')`.
"""
def __init__(self):
"""
Initialize the Interpolate node.
"""
super().__init__()
[docs]
def eval(self, matches):
"""
Generate dense disparity maps from sparse matches.
Parameters
----------
matches : List[List[np.ndarray]]
A nested list of shape (M, N), where M is the number of faces and N is the number of image pairs.
Each entry is a (2, N, 2) array of matched keypoints, where the first element contains keypoints from image 1
and the second from image 2.
Returns
-------
List[List[np.ndarray]]
A nested list of shape (M, N), where each element contains two (H, W) arrays:
the horizontal and vertical disparity maps.
"""
disp_maps = []
w = np.max([np.max(m[:, :, 0]) for face in matches for m in face])
h = np.max([np.max(m[:, :, 1]) for face in matches for m in face])
# Create a meshgrid of the entire image's pixel locations
y, x = np.meshgrid(np.arange(int(h)), np.arange(int(w)), indexing='ij')
for face in matches:
disp_maps.append([])
for match in face:
k1 = match[0]
k2 = match[1]
disparity = k2 - k1
dx = disparity[:, 0]
dy = disparity[:, 1]
dense_disparity_x = griddata(k1, dx, (x, y), method='linear')
dense_disparity_y = griddata(k1, dy, (x, y), method='linear')
disp_maps[-1].append(np.array([dense_disparity_x, dense_disparity_y]))
self.disp_maps = disp_maps
return [self.disp_maps]