Source code for miop.image_collection

# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#

from .image_helper.image import Image
from .image_helper.image_pair import ImagePair
from math import sqrt, ceil, pi
from operator import attrgetter
import os
import copy
import numpy as np
import matplotlib.pyplot as plt
import cv2

import warnings

from pipeline import DAGNode

[docs] class ImageCollection(DAGNode): """ Represents a collection of images, with support for downsampling, cropping, rotating, segmentation, and pairing based on image metadata. Attributes ---------- metadata_type : str Type of metadata ('none', 'fei', or 'custom'). metadata : list of dict Metadata associated with each image. Each dict may contain: - 'StageTa': Tilt angle in radians. - 'tilt_axis': Tilt axis as a 3D vector (e.g., [1., 0., 0.]). - 'StageR': Rotation angle of the microscope stage in the (x, y) plane, in radians. downsample_factor : float Factor by which to downsample all images. crop_dim : list Crop dimensions per face. Format: [[[width, height], [x, y]], ...] two_by_two : bool If True, images are grouped into fixed pairs: (1,2), (3,4), etc. If False, grouped as overlapping pairs: (1,2), (2,3), etc. use_segmentation : bool Whether to use segmentation masks. max_dim : tuple of int Maximum image dimensions as (width, height). verbose : bool If True, prints warnings and logs. paths_to_images : list of str File paths to the images in the collection. paths_to_masks : list of str File paths to the corresponding segmentation masks. images : list of Image List of loaded Image objects. masks : dict Mapping from Image objects to their segmentation masks. pairs : list of ImagePair List of image pairs used for multi-view geometry. faces : list of tuple Grouped image pairs, organized by face. """ def __init__( self, metadata_type: str = 'none', metadata: [] = None, downsample = 1., crop_dim = [], two_by_two = False, use_segmentation=False, max_dim=(800, 800), verbose=True): """ Initializes the ImageCollection with optional parameters. Parameters ---------- metadata_type : str, optional Type of metadata to use ('none', 'fei', 'custom'). Default is 'none'. metadata : list of dict, optional Metadata per image (default is None). downsample : float, optional Initial downsampling factor (default is 1.0). crop_dim : list, optional Crop dimensions (default is empty list). two_by_two : bool, optional Whether to group images into fixed pairs (default is False). use_segmentation : bool, optional Whether to use segmentation masks (default is False). max_dim : tuple, optional Maximum image dimensions (default is (800, 800)). verbose : bool, optional Verbosity flag (default is True). """ super().__init__() self.metadata_type = metadata_type self.metadata = metadata self.downsample_factor = downsample self.crop_dim = crop_dim self.two_by_two = two_by_two self.use_segmentation=use_segmentation self.verbose = verbose self.max_dim = max_dim
[docs] def eval(self, paths_to_images, paths_to_masks = None): """ Loads and preprocesses the image collection: reading, cropping, downsampling, rotating, and handling segmentation. Parameters ---------- paths_to_images : list of str Paths to the image files. paths_to_masks : list of str, optional Paths to the segmentation mask directories. Returns ------- ImageCollection The processed image collection. """ self.paths_to_images = paths_to_images self.paths_to_masks = paths_to_masks if type(self.paths_to_images) is not list: self.paths_to_images = [self.paths_to_images] if type(self.paths_to_masks) is not list: self.paths_to_masks = [self.paths_to_masks] # read and sort image collection self.images: [Image] = self._read_images(metadata=self.metadata) self._sort_collection() self.pairs, self.faces = self._get_pairs(self.two_by_two) # downsample images such that they all have the same dimensions self._all_to_same_dim() # normalize self.crop_dim format try: if len(self.crop_dim) != 0 and type(self.crop_dim[0][0]) is not list: self.crop_dim = [self.crop_dim] * len(self.faces) except: raise ValueError("crop_dim has not the right shape. It should be [[dim_x, dim_y], [top_left_x, top_left_y]] or [[[dim_x_face1, dim_y_face1], [top_left_x_face1, top_left_y_face1]], [[..., ...],[..., ...]], ...]") # post process image collection (rotate, downsample and crop) if self.metadata_type != 'none': self.rotate() # copy images to plot in crop_utility mode (see show method) self.original_images = copy.deepcopy(self.images) if len(self.crop_dim) != 0: self.crop(*self.crop_dim) down_fac = self._get_downsample_factor() self.downsample(down_fac, down_fac) self._validate_num_images_per_face() if self.use_segmentation: self.masks = self._read_masks(self.paths_to_masks) return self
def __getitem__(self, index): """ Accesses an image by index or by face and image pair. Parameters ---------- index : int or tuple If int: index into flat image list. If tuple: (face index, image index within face pair). Returns ------- Image The requested Image object. """ if type(index) is tuple: pair_idx = (index[1]-1) if index[1] > 0 else 0 img_idx = int(index[1] > 0) return self.faces[index[0]][pair_idx][img_idx] return self.images[index] def _validate_num_images_per_face(self): """ Ensures each face has at least one pair of images. Raises ------ ValueError If any face has fewer than two images. """ for i,face in enumerate(self.faces): if len(face) == 0: raise ValueError(f"Face {i+1} has no image pair. Be sure to have at least two images per face.") def _get_downsample_factor(self): """ Computes the appropriate downsampling factor based on max_dim and current image sizes. Returns ------- float The downsampling factor. """ user_factor = self.downsample_factor max_shape = self._get_max_shape() fac_x = self.max_dim[0] / max_shape[0] fac_y = self.max_dim[1] / max_shape[1] #retain smallest factor min_fac = min([fac_x, fac_y, user_factor]) return min_fac def _all_to_same_dim(self): """ Resizes all images in the collection to have the same dimensions using downsampling. """ min_shape_wh = self._get_min_shape() for img in self.images: # compute downsampling factor factor_w = min_shape_wh[0] / img.shape[1] factor_h = min_shape_wh[1] / img.shape[0] img.downsample(factor_w, factor_h) def _get_min_shape(self): """ gets the minimum shape in the collection """ shapes = np.asarray([list(img.shape) for img in self.images]) min_shape = np.min(shapes[:, :2], axis=0) return [min_shape[1], min_shape[0]] def _get_max_shape(self): """ gets the maximum shape in the collection """ shapes = np.asarray([list(img.shape) for img in self.images]) max_shape = np.max(shapes[:, :2], axis=0) return [max_shape[1], max_shape[0]] def _read_masks(self, paths): """ Reads grayscale segmentation masks from the given directories and maps them to images. Parameters ---------- paths : list of str Paths to directories containing segmentation masks in the same order as self.images. Raises ------ AttributeError If number of masks does not match number of images. ValueError If a mask file cannot be read. NotImplementedError If there are more than one segmentation mask per image """ if paths is None: self.masks = None return masks = {} if len(paths) != len(self.images): raise AttributeError('Number of images and number of segmentation masks are not equal') for i,directory_path in enumerate(paths): filename_list = sorted(os.listdir(directory_path)) if len(filename_list) > 1: raise NotImplementedError("Currently only supports one segmentation mask per image.") filename = filename_list[0] file_path = os.path.join(directory_path, filename) # Normalize the image to 0., 1. mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) / 255. if mask is not None: masks[self.images[i]] = mask else: raise ValueError(f'{file_path} could not be read') self.masks = masks def _read_images(self, metadata=None): """ Reads image files and wraps them in Image objects with metadata. Parameters ---------- metadata : list of dict, optional Metadata for each image. Returns ------- list of Image Loaded images. """ images = [] for i,path in enumerate(self.paths_to_images): image_name = os.path.splitext(os.path.basename(path))[0] images.append(Image(path, image_name, self.metadata_type, metadata=metadata[i] if metadata is not None else None)) return images def _sort_collection(self): """ Sorts images by (rotation_z, tilt), using metadata. """ # sort the image collection by lexicographical order (rotation_z, tilt_angle) idx_img_sorted = sorted(enumerate(self.images), key=lambda x: attrgetter('rotation_z', 'tilt')(x[1])) self.sorted_idx, self.images = list(zip(*idx_img_sorted)) self.images = list(self.images) def _get_pairs(self, pairs_bool): """ Pairs images into fixed or overlapping pairs and groups them into faces. Parameters ---------- pairs_bool : bool Whether to use fixed (two-by-two) pairing. Returns ------- tuple - pairs : list of ImagePair - faces : list of tuple of ImagePair """ pairs = [] faces = [[]] if pairs_bool and (len(self.paths_to_images) % 2 != 0): raise ValueError('pairs_bool is True but number of image paths is odd') increment = 2 if pairs_bool else 1 for i in range(0,len(self.images)-1,increment): img1 = self.images[i] img2 = self.images[i+1] pair = ImagePair(img1,img2,np.array([1.,0,0])) # could unroll first iteration of the loop to avoid this if statement if i == 0: # reference pair self.pair_ref = pair if not np.isclose(img1.rotation_z, img2.rotation_z, atol=1) and pairs_bool: raise ValueError(f'pairs is True but images {i} and {i+1} do not have same rotation_z value. Probably the rotation {img1.rotation_z} has odd number of images') # if images don't have same rotation but pairs_bool is False, append to cross_pairs and continue if np.abs(img1.rotation_z - img2.rotation_z) >= 10*np.pi/180 and not pairs_bool: faces.append([]) # start new face continue pairs.append(pair) faces[-1].append(pair) # convert list of images to tuples # [(face1_pair1, face1_pair2, ...), (face2_pair1, ...)] faces = list(map(tuple, faces)) return pairs, faces
[docs] def apply_to_collection(self, func, *args, **kwargs): """ Applies a function to each image in the collection. Parameters ---------- func : callable Function to apply to each image. *args Additional positional arguments to pass to the function. **kwargs Additional keyword arguments to pass to the function. Returns ------- list List of results from the function. """ result = [] for image in self.images: result.append(func(image, *args, **kwargs)) return result
[docs] def apply_to_face(self, face_idx, func, *args, **kwargs): """ Applies a function to all images within a face. Parameters ---------- face_idx : int Index of the face. func : callable Function to apply. *args Additional positional arguments. **kwargs Additional keyword arguments. Returns ------- list List of results from the function. """ result = [] result.append(func(self.faces[face_idx][0][0], *args, **kwargs)) for pair in self.faces[face_idx]: result.append(func(pair[1], *args, **kwargs)) return result
[docs] def save_collection(self, path_to_dir): """ Saves all images to the specified directory. Parameters ---------- path_to_dir : str Directory path where images will be saved. Returns ------- list Results from saving images. """ func = Image.save return self.apply_to_collection(func, path_to_dir)
[docs] def downsample(self, factor_x: float, factor_y: float): """ Downsamples all images in the collection. Parameters ---------- factor_x : float Downsampling factor for width. factor_y : float Downsampling factor for height. """ func = Image.downsample self.images = self.apply_to_collection(func, factor_x, factor_y)
[docs] def crop(self, *crop_arg): """ Crops images in the collection. Parameters ---------- crop_arg : list List of crop parameters. Format per face: [[width, height], [x_offset, y_offset]] If only one crop_arg is provided, it is applied to all faces. """ func = Image.crop # not a double list if type(crop_arg[0][0]) is not list: self.images = self.apply_to_collection(func, *crop_arg) # double list and same length else: idx_img = 0 for i, crop_tuple in enumerate(crop_arg): len_face = len(self.faces[i]) + 1 self.images[idx_img:idx_img + len_face] = self.apply_to_face(i, func, *crop_tuple) idx_img += len_face
[docs] def rotate(self, angle=None, rad=False): """ Rotates all images of the collection according to `angle` or `rotation_z` from the metadata. Parameters ---------- angle : float or None, optional Angle to rotate. If None, uses `rotation_z` from metadata. rad : bool, optional If True, angle is in radians. Default is False. """ if angle is None: func = Image.rotate_by_transpose self.apply_to_collection(func) else: func = Image.rotate self.apply_to_collection(func, angle, rad=rad)
[docs] def register_pair(self, paths_to_images): """ Not implemented. Used to manually add image pairs to the collection. Raises ------ NotImplementedError """ raise NotImplementedError
[docs] def show(self, crop_utility=True, use_segmentation=False): """ Displays all images in the collection, optionally overlaying crop regions and segmentation masks. Parameters ---------- crop_utility : bool, optional Whether to display the cropping utility overlay (default is True). use_segmentation : bool, optional Whether to overlay segmentation masks (default is False). """ num_faces = len(self.faces) max_num_img = 0 for f in self.faces: if max_num_img < (len(f) + 1): max_num_img = len(f) + 1 fig = plt.figure(figsize=(5 * max_num_img, 6 * num_faces)) img_idx = 0 # used if crop_utility is True for i, face in enumerate(self.faces): num_images = len(face) + 1 # face contains pairs of images ax = fig.add_subplot(num_faces, max_num_img, i*max_num_img + 1) if not crop_utility: ax.axis('off') img = face[0][0] else: img = self.original_images[img_idx] img_idx += 1 # create patch for the face left, bottom = self.crop_dim[i][1] height, width, _ = img.shape rec_width, rec_height = self.crop_dim[i][0] patch_1 = plt.Rectangle((0, 0), left, height, alpha=0.7, color='white') patch_2 = plt.Rectangle((left + rec_width, 0), width - (left + rec_width), height, alpha=0.7, color='white') patch_3 = plt.Rectangle((left, 0), rec_width, bottom, alpha=0.7, color='white') patch_4 = plt.Rectangle((left, bottom + rec_height), rec_width, height - (bottom + rec_height), alpha=0.7, color='white') ax.add_patch(patch_1) ax.add_patch(patch_2) ax.add_patch(patch_3) ax.add_patch(patch_4) ax.imshow(img.image) if use_segmentation: ax.imshow(self.masks[img], alpha=0.32, cmap='magma') # Show the subsequent images from the pairs for j,pair in enumerate(face): ax = fig.add_subplot(num_faces, max_num_img, i*max_num_img + 1 + j + 1) if not crop_utility: ax.axis('off') img = pair[1] else: patch_1 = plt.Rectangle((0, 0), left, height, alpha=0.7, color='white') patch_2 = plt.Rectangle((left + rec_width, 0), width - (left + rec_width), height, alpha=0.7, color='white') patch_3 = plt.Rectangle((left, 0), rec_width, bottom, alpha=0.7, color='white') patch_4 = plt.Rectangle((left, bottom + rec_height), rec_width, height - (bottom + rec_height), alpha=0.7, color='white') ax.add_patch(patch_1) ax.add_patch(patch_2) ax.add_patch(patch_3) ax.add_patch(patch_4) img = self.original_images[img_idx] img_idx += 1 ax.imshow(img.image) if use_segmentation: ax.imshow(self.masks[pair[1]], alpha=0.32, cmap='magma') plt.tight_layout(pad=0.5) plt.show()