# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#
import numpy as np
import matplotlib.pyplot as plt
#import torch
import os
#from torchvision.models.optical_flow import Raft_Large_Weights
#from torchvision.models.optical_flow import raft_large
from pipeline import DAGNode
[docs]
class RaftFlow(DAGNode):
"""
Computes dense optical flow (displacement maps) between image pairs using the RAFT model.
Attributes
----------
device : torch.device
The device (CPU, CUDA, MPS) used for inference.
disp_maps : list of list of np.ndarray or None
The computed displacement maps for each image pair in each face of the collection.
Each displacement map is a 2D flow field of shape (2, H, W), where:
- disp[0, :, :] is the x-direction displacement,
- disp[1, :, :] is the y-direction displacement.
"""
def __init__(self, device=None):
"""
Initializes the RAFT flow computation node.
Parameters
----------
device : torch.device, optional
The device on which to run the RAFT model (default: CPU).
"""
import torch
from torchvision.models.optical_flow import Raft_Large_Weights
from torchvision.models.optical_flow import raft_large
if device is None:
device = torch.device('cpu')
super().__init__()
self.device = device
self.disp_maps = None
[docs]
def eval(self, img_collection):
"""
Evaluates the optical flow for each image pair in the image collection.
Parameters
----------
img_collection : ImageCollection
An image collection containing pairs of images organized by faces.
Returns
-------
list of list of np.ndarray
A nested list where each element corresponds to a face, and contains the
displacement maps for each image pair in that face.
"""
print(f"Start of Raft Flow")
disp_maps = []
for face in img_collection.faces:
disp_maps.append([])
for pair in face:
disp_map = self.compute_flow(pair[0].image, pair[1].image)
disp_maps[-1].append(disp_map)
self.disp_maps = disp_maps
print(f"End of Raft Flow")
return [self.disp_maps]
[docs]
def compute_flow(self, image_1, image_2) -> np.ndarray:
"""
Computes the dense optical flow (displacement map) between two RGB images using the RAFT model.
Parameters
----------
image_1 : np.ndarray
The first image of shape (H, W, 3) in uint8 or float format.
image_2 : np.ndarray
The second image of shape (H, W, 3) in uint8 or float format.
Returns
-------
np.ndarray
A displacement map of shape (2, H, W), where the first channel is x-displacement
and the second channel is y-displacement.
"""
device = self.device
weights = Raft_Large_Weights.DEFAULT
transforms = weights.transforms()
img1_batch, img2_batch = torch.from_numpy(image_1).permute(2,0,1).unsqueeze(0), torch.from_numpy(image_2).permute(2,0,1).unsqueeze(0)
img1_batch, img2_batch = transforms(img1_batch, img2_batch)
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).to(device)
model = model.eval()
list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
predicted_flows = list_of_flows[-1]
predicted_flows = predicted_flows.squeeze(0).cpu().detach().numpy()
return predicted_flows
[docs]
def show(self, face=0, pair=0, cmap="Spectral"):
"""
Visualizes the displacement map for a specific image pair using matplotlib.
Parameters
----------
face : int, optional
Index of the face to visualize (default: 0).
pair : int, optional
Index of the image pair within the face to visualize (default: 0).
cmap : str, optional
Colormap to use for the visualization (default: "Spectral").
Raises
------
AttributeError
If self.disp_maps is not available.
"""
if not hasattr(self, 'disp_maps'):
raise AttributeError("To use the show() method, your class should have a self.disp_maps attribute.")
fig, ax = plt.subplots(1,2, figsize=(12,4))
# draw displacement in y direction
p0 = ax[0].imshow(self.disp_maps[face][pair][1,:,:], cmap=cmap)
ax[0].set_title("Displacement in y direction")
fig.colorbar(p0, ax=ax[0])
# draw displacement in x direction
p1 = ax[1].imshow(self.disp_maps[face][pair][0,:,:], cmap=cmap)
ax[1].set_title("Displacement in x direction")
fig.colorbar(p1, ax=ax[1])