# Copyright (c) 2025, Maxime Paschoud.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# (http://opensource.org/licenses/BSD-3-Clause)
#
# __author__ = "Maxime Paschoud, ETHZ: CMBM"
#
from .image_helper.image import Image
from .image_helper.image_pair import ImagePair
from math import sqrt, ceil, pi
from operator import attrgetter
import os
import copy
import numpy as np
import matplotlib.pyplot as plt
import cv2
import warnings
from pipeline import DAGNode
[docs]
class ImageCollection(DAGNode):
"""
Represents a collection of images, with support for downsampling, cropping, rotating, segmentation,
and pairing based on image metadata.
Attributes
----------
metadata_type : str
Type of metadata ('none', 'fei', or 'custom').
metadata : list of dict
Metadata associated with each image. Each dict may contain:
- 'StageTa': Tilt angle in radians.
- 'tilt_axis': Tilt axis as a 3D vector (e.g., [1., 0., 0.]).
- 'StageR': Rotation angle of the microscope stage in the (x, y) plane, in radians.
downsample_factor : float
Factor by which to downsample all images.
crop_dim : list
Crop dimensions per face. Format: [[[width, height], [x, y]], ...]
two_by_two : bool
If True, images are grouped into fixed pairs: (1,2), (3,4), etc.
If False, grouped as overlapping pairs: (1,2), (2,3), etc.
use_segmentation : bool
Whether to use segmentation masks.
max_dim : tuple of int
Maximum image dimensions as (width, height).
verbose : bool
If True, prints warnings and logs.
paths_to_images : list of str
File paths to the images in the collection.
paths_to_masks : list of str
File paths to the corresponding segmentation masks.
images : list of Image
List of loaded Image objects.
masks : dict
Mapping from Image objects to their segmentation masks.
pairs : list of ImagePair
List of image pairs used for multi-view geometry.
faces : list of tuple
Grouped image pairs, organized by face.
"""
def __init__(
self, metadata_type: str = 'none',
metadata: [] = None,
downsample = 1.,
crop_dim = [],
two_by_two = False,
use_segmentation=False,
max_dim=(800, 800),
verbose=True):
"""
Initializes the ImageCollection with optional parameters.
Parameters
----------
metadata_type : str, optional
Type of metadata to use ('none', 'fei', 'custom'). Default is 'none'.
metadata : list of dict, optional
Metadata per image (default is None).
downsample : float, optional
Initial downsampling factor (default is 1.0).
crop_dim : list, optional
Crop dimensions (default is empty list).
two_by_two : bool, optional
Whether to group images into fixed pairs (default is False).
use_segmentation : bool, optional
Whether to use segmentation masks (default is False).
max_dim : tuple, optional
Maximum image dimensions (default is (800, 800)).
verbose : bool, optional
Verbosity flag (default is True).
"""
super().__init__()
self.metadata_type = metadata_type
self.metadata = metadata
self.downsample_factor = downsample
self.crop_dim = crop_dim
self.two_by_two = two_by_two
self.use_segmentation=use_segmentation
self.verbose = verbose
self.max_dim = max_dim
[docs]
def eval(self, paths_to_images, paths_to_masks = None):
"""
Loads and preprocesses the image collection: reading, cropping, downsampling, rotating,
and handling segmentation.
Parameters
----------
paths_to_images : list of str
Paths to the image files.
paths_to_masks : list of str, optional
Paths to the segmentation mask directories.
Returns
-------
ImageCollection
The processed image collection.
"""
self.paths_to_images = paths_to_images
self.paths_to_masks = paths_to_masks
if type(self.paths_to_images) is not list:
self.paths_to_images = [self.paths_to_images]
if type(self.paths_to_masks) is not list:
self.paths_to_masks = [self.paths_to_masks]
# read and sort image collection
self.images: [Image] = self._read_images(metadata=self.metadata)
self._sort_collection()
self.pairs, self.faces = self._get_pairs(self.two_by_two)
# downsample images such that they all have the same dimensions
self._all_to_same_dim()
# normalize self.crop_dim format
try:
if len(self.crop_dim) != 0 and type(self.crop_dim[0][0]) is not list:
self.crop_dim = [self.crop_dim] * len(self.faces)
except:
raise ValueError("crop_dim has not the right shape. It should be [[dim_x, dim_y], [top_left_x, top_left_y]] or [[[dim_x_face1, dim_y_face1], [top_left_x_face1, top_left_y_face1]], [[..., ...],[..., ...]], ...]")
# post process image collection (rotate, downsample and crop)
if self.metadata_type != 'none':
self.rotate()
# copy images to plot in crop_utility mode (see show method)
self.original_images = copy.deepcopy(self.images)
if len(self.crop_dim) != 0:
self.crop(*self.crop_dim)
down_fac = self._get_downsample_factor()
self.downsample(down_fac, down_fac)
self._validate_num_images_per_face()
if self.use_segmentation:
self.masks = self._read_masks(self.paths_to_masks)
return self
def __getitem__(self, index):
"""
Accesses an image by index or by face and image pair.
Parameters
----------
index : int or tuple
If int: index into flat image list.
If tuple: (face index, image index within face pair).
Returns
-------
Image
The requested Image object.
"""
if type(index) is tuple:
pair_idx = (index[1]-1) if index[1] > 0 else 0
img_idx = int(index[1] > 0)
return self.faces[index[0]][pair_idx][img_idx]
return self.images[index]
def _validate_num_images_per_face(self):
"""
Ensures each face has at least one pair of images.
Raises
------
ValueError
If any face has fewer than two images.
"""
for i,face in enumerate(self.faces):
if len(face) == 0:
raise ValueError(f"Face {i+1} has no image pair. Be sure to have at least two images per face.")
def _get_downsample_factor(self):
"""
Computes the appropriate downsampling factor based on max_dim and current image sizes.
Returns
-------
float
The downsampling factor.
"""
user_factor = self.downsample_factor
max_shape = self._get_max_shape()
fac_x = self.max_dim[0] / max_shape[0]
fac_y = self.max_dim[1] / max_shape[1]
#retain smallest factor
min_fac = min([fac_x, fac_y, user_factor])
return min_fac
def _all_to_same_dim(self):
"""
Resizes all images in the collection to have the same dimensions using downsampling.
"""
min_shape_wh = self._get_min_shape()
for img in self.images:
# compute downsampling factor
factor_w = min_shape_wh[0] / img.shape[1]
factor_h = min_shape_wh[1] / img.shape[0]
img.downsample(factor_w, factor_h)
def _get_min_shape(self):
""" gets the minimum shape in the collection """
shapes = np.asarray([list(img.shape) for img in self.images])
min_shape = np.min(shapes[:, :2], axis=0)
return [min_shape[1], min_shape[0]]
def _get_max_shape(self):
""" gets the maximum shape in the collection """
shapes = np.asarray([list(img.shape) for img in self.images])
max_shape = np.max(shapes[:, :2], axis=0)
return [max_shape[1], max_shape[0]]
def _read_masks(self, paths):
"""
Reads grayscale segmentation masks from the given directories and maps them to images.
Parameters
----------
paths : list of str
Paths to directories containing segmentation masks in the same order as self.images.
Raises
------
AttributeError
If number of masks does not match number of images.
ValueError
If a mask file cannot be read.
NotImplementedError
If there are more than one segmentation mask per image
"""
if paths is None:
self.masks = None
return
masks = {}
if len(paths) != len(self.images):
raise AttributeError('Number of images and number of segmentation masks are not equal')
for i,directory_path in enumerate(paths):
filename_list = sorted(os.listdir(directory_path))
if len(filename_list) > 1:
raise NotImplementedError("Currently only supports one segmentation mask per image.")
filename = filename_list[0]
file_path = os.path.join(directory_path, filename)
# Normalize the image to 0., 1.
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) / 255.
if mask is not None:
masks[self.images[i]] = mask
else:
raise ValueError(f'{file_path} could not be read')
self.masks = masks
def _read_images(self, metadata=None):
"""
Reads image files and wraps them in Image objects with metadata.
Parameters
----------
metadata : list of dict, optional
Metadata for each image.
Returns
-------
list of Image
Loaded images.
"""
images = []
for i,path in enumerate(self.paths_to_images):
image_name = os.path.splitext(os.path.basename(path))[0]
images.append(Image(path, image_name, self.metadata_type, metadata=metadata[i] if metadata is not None else None))
return images
def _sort_collection(self):
"""
Sorts images by (rotation_z, tilt), using metadata.
"""
# sort the image collection by lexicographical order (rotation_z, tilt_angle)
idx_img_sorted = sorted(enumerate(self.images), key=lambda x: attrgetter('rotation_z', 'tilt')(x[1]))
self.sorted_idx, self.images = list(zip(*idx_img_sorted))
self.images = list(self.images)
def _get_pairs(self, pairs_bool):
"""
Pairs images into fixed or overlapping pairs and groups them into faces.
Parameters
----------
pairs_bool : bool
Whether to use fixed (two-by-two) pairing.
Returns
-------
tuple
- pairs : list of ImagePair
- faces : list of tuple of ImagePair
"""
pairs = []
faces = [[]]
if pairs_bool and (len(self.paths_to_images) % 2 != 0):
raise ValueError('pairs_bool is True but number of image paths is odd')
increment = 2 if pairs_bool else 1
for i in range(0,len(self.images)-1,increment):
img1 = self.images[i]
img2 = self.images[i+1]
pair = ImagePair(img1,img2,np.array([1.,0,0]))
# could unroll first iteration of the loop to avoid this if statement
if i == 0: # reference pair
self.pair_ref = pair
if not np.isclose(img1.rotation_z, img2.rotation_z, atol=1) and pairs_bool:
raise ValueError(f'pairs is True but images {i} and {i+1} do not have same rotation_z value. Probably the rotation {img1.rotation_z} has odd number of images')
# if images don't have same rotation but pairs_bool is False, append to cross_pairs and continue
if np.abs(img1.rotation_z - img2.rotation_z) >= 10*np.pi/180 and not pairs_bool:
faces.append([]) # start new face
continue
pairs.append(pair)
faces[-1].append(pair)
# convert list of images to tuples
# [(face1_pair1, face1_pair2, ...), (face2_pair1, ...)]
faces = list(map(tuple, faces))
return pairs, faces
[docs]
def apply_to_collection(self, func, *args, **kwargs):
"""
Applies a function to each image in the collection.
Parameters
----------
func : callable
Function to apply to each image.
*args
Additional positional arguments to pass to the function.
**kwargs
Additional keyword arguments to pass to the function.
Returns
-------
list
List of results from the function.
"""
result = []
for image in self.images:
result.append(func(image, *args, **kwargs))
return result
[docs]
def apply_to_face(self, face_idx, func, *args, **kwargs):
"""
Applies a function to all images within a face.
Parameters
----------
face_idx : int
Index of the face.
func : callable
Function to apply.
*args
Additional positional arguments.
**kwargs
Additional keyword arguments.
Returns
-------
list
List of results from the function.
"""
result = []
result.append(func(self.faces[face_idx][0][0], *args, **kwargs))
for pair in self.faces[face_idx]:
result.append(func(pair[1], *args, **kwargs))
return result
[docs]
def save_collection(self, path_to_dir):
"""
Saves all images to the specified directory.
Parameters
----------
path_to_dir : str
Directory path where images will be saved.
Returns
-------
list
Results from saving images.
"""
func = Image.save
return self.apply_to_collection(func, path_to_dir)
[docs]
def downsample(self, factor_x: float, factor_y: float):
"""
Downsamples all images in the collection.
Parameters
----------
factor_x : float
Downsampling factor for width.
factor_y : float
Downsampling factor for height.
"""
func = Image.downsample
self.images = self.apply_to_collection(func, factor_x, factor_y)
[docs]
def crop(self, *crop_arg):
"""
Crops images in the collection.
Parameters
----------
crop_arg : list
List of crop parameters. Format per face:
[[width, height], [x_offset, y_offset]]
If only one crop_arg is provided, it is applied to all faces.
"""
func = Image.crop
# not a double list
if type(crop_arg[0][0]) is not list:
self.images = self.apply_to_collection(func, *crop_arg)
# double list and same length
else:
idx_img = 0
for i, crop_tuple in enumerate(crop_arg):
len_face = len(self.faces[i]) + 1
self.images[idx_img:idx_img + len_face] = self.apply_to_face(i, func, *crop_tuple)
idx_img += len_face
[docs]
def rotate(self, angle=None, rad=False):
"""
Rotates all images of the collection according to `angle` or `rotation_z` from the metadata.
Parameters
----------
angle : float or None, optional
Angle to rotate. If None, uses `rotation_z` from metadata.
rad : bool, optional
If True, angle is in radians. Default is False.
"""
if angle is None:
func = Image.rotate_by_transpose
self.apply_to_collection(func)
else:
func = Image.rotate
self.apply_to_collection(func, angle, rad=rad)
[docs]
def register_pair(self, paths_to_images):
"""
Not implemented. Used to manually add image pairs to the collection.
Raises
------
NotImplementedError
"""
raise NotImplementedError
[docs]
def show(self, crop_utility=True, use_segmentation=False):
"""
Displays all images in the collection, optionally overlaying crop regions and segmentation masks.
Parameters
----------
crop_utility : bool, optional
Whether to display the cropping utility overlay (default is True).
use_segmentation : bool, optional
Whether to overlay segmentation masks (default is False).
"""
num_faces = len(self.faces)
max_num_img = 0
for f in self.faces:
if max_num_img < (len(f) + 1):
max_num_img = len(f) + 1
fig = plt.figure(figsize=(5 * max_num_img, 6 * num_faces))
img_idx = 0 # used if crop_utility is True
for i, face in enumerate(self.faces):
num_images = len(face) + 1 # face contains pairs of images
ax = fig.add_subplot(num_faces, max_num_img, i*max_num_img + 1)
if not crop_utility:
ax.axis('off')
img = face[0][0]
else:
img = self.original_images[img_idx]
img_idx += 1
# create patch for the face
left, bottom = self.crop_dim[i][1]
height, width, _ = img.shape
rec_width, rec_height = self.crop_dim[i][0]
patch_1 = plt.Rectangle((0, 0), left, height, alpha=0.7, color='white')
patch_2 = plt.Rectangle((left + rec_width, 0), width - (left + rec_width), height, alpha=0.7, color='white')
patch_3 = plt.Rectangle((left, 0), rec_width, bottom, alpha=0.7, color='white')
patch_4 = plt.Rectangle((left, bottom + rec_height), rec_width, height - (bottom + rec_height), alpha=0.7, color='white')
ax.add_patch(patch_1)
ax.add_patch(patch_2)
ax.add_patch(patch_3)
ax.add_patch(patch_4)
ax.imshow(img.image)
if use_segmentation:
ax.imshow(self.masks[img], alpha=0.32, cmap='magma')
# Show the subsequent images from the pairs
for j,pair in enumerate(face):
ax = fig.add_subplot(num_faces, max_num_img, i*max_num_img + 1 + j + 1)
if not crop_utility:
ax.axis('off')
img = pair[1]
else:
patch_1 = plt.Rectangle((0, 0), left, height, alpha=0.7, color='white')
patch_2 = plt.Rectangle((left + rec_width, 0), width - (left + rec_width), height, alpha=0.7, color='white')
patch_3 = plt.Rectangle((left, 0), rec_width, bottom, alpha=0.7, color='white')
patch_4 = plt.Rectangle((left, bottom + rec_height), rec_width, height - (bottom + rec_height), alpha=0.7, color='white')
ax.add_patch(patch_1)
ax.add_patch(patch_2)
ax.add_patch(patch_3)
ax.add_patch(patch_4)
img = self.original_images[img_idx]
img_idx += 1
ax.imshow(img.image)
if use_segmentation:
ax.imshow(self.masks[pair[1]], alpha=0.32, cmap='magma')
plt.tight_layout(pad=0.5)
plt.show()