Source code for miop.raft_flow

# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#

import numpy as np
import matplotlib.pyplot as plt
#import torch
import os
#from torchvision.models.optical_flow import Raft_Large_Weights
#from torchvision.models.optical_flow import raft_large
from pipeline import DAGNode

[docs] class RaftFlow(DAGNode): """ Computes dense optical flow (displacement maps) between image pairs using the RAFT model. Attributes ---------- device : torch.device The device (CPU, CUDA, MPS) used for inference. disp_maps : list of list of np.ndarray or None The computed displacement maps for each image pair in each face of the collection. Each displacement map is a 2D flow field of shape (2, H, W), where: - disp[0, :, :] is the x-direction displacement, - disp[1, :, :] is the y-direction displacement. """ def __init__(self, device=None): """ Initializes the RAFT flow computation node. Parameters ---------- device : torch.device, optional The device on which to run the RAFT model (default: CPU). """ import torch from torchvision.models.optical_flow import Raft_Large_Weights from torchvision.models.optical_flow import raft_large if device is None: device = torch.device('cpu') super().__init__() self.device = device self.disp_maps = None
[docs] def eval(self, img_collection): """ Evaluates the optical flow for each image pair in the image collection. Parameters ---------- img_collection : ImageCollection An image collection containing pairs of images organized by faces. Returns ------- list of list of np.ndarray A nested list where each element corresponds to a face, and contains the displacement maps for each image pair in that face. """ print(f"Start of Raft Flow") disp_maps = [] for face in img_collection.faces: disp_maps.append([]) for pair in face: disp_map = self.compute_flow(pair[0].image, pair[1].image) disp_maps[-1].append(disp_map) self.disp_maps = disp_maps print(f"End of Raft Flow") return [self.disp_maps]
[docs] def compute_flow(self, image_1, image_2) -> np.ndarray: """ Computes the dense optical flow (displacement map) between two RGB images using the RAFT model. Parameters ---------- image_1 : np.ndarray The first image of shape (H, W, 3) in uint8 or float format. image_2 : np.ndarray The second image of shape (H, W, 3) in uint8 or float format. Returns ------- np.ndarray A displacement map of shape (2, H, W), where the first channel is x-displacement and the second channel is y-displacement. """ device = self.device weights = Raft_Large_Weights.DEFAULT transforms = weights.transforms() img1_batch, img2_batch = torch.from_numpy(image_1).permute(2,0,1).unsqueeze(0), torch.from_numpy(image_2).permute(2,0,1).unsqueeze(0) img1_batch, img2_batch = transforms(img1_batch, img2_batch) model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).to(device) model = model.eval() list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) predicted_flows = list_of_flows[-1] predicted_flows = predicted_flows.squeeze(0).cpu().detach().numpy() return predicted_flows
[docs] def show(self, face=0, pair=0, cmap="Spectral"): """ Visualizes the displacement map for a specific image pair using matplotlib. Parameters ---------- face : int, optional Index of the face to visualize (default: 0). pair : int, optional Index of the image pair within the face to visualize (default: 0). cmap : str, optional Colormap to use for the visualization (default: "Spectral"). Raises ------ AttributeError If self.disp_maps is not available. """ if not hasattr(self, 'disp_maps'): raise AttributeError("To use the show() method, your class should have a self.disp_maps attribute.") fig, ax = plt.subplots(1,2, figsize=(12,4)) # draw displacement in y direction p0 = ax[0].imshow(self.disp_maps[face][pair][1,:,:], cmap=cmap) ax[0].set_title("Displacement in y direction") fig.colorbar(p0, ax=ax[0]) # draw displacement in x direction p1 = ax[1].imshow(self.disp_maps[face][pair][0,:,:], cmap=cmap) ax[1].set_title("Displacement in x direction") fig.colorbar(p1, ax=ax[1])