Source code for patato.core.image_structures.image_sequence

"""Image sequence - abstract classes for processing datasets from PA data."""

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Tuple, Iterable

import numpy as np
import xarray

try:
    import cupy as cp
except ImportError:
    cp = None

from xarray import DataArray

try:
    import jax.numpy as jnp

    if jnp.ndarray not in xarray.core.variable.NON_NUMPY_SUPPORTED_ARRAY_TYPES:
        xarray.core.variable.NON_NUMPY_SUPPORTED_ARRAY_TYPES += (jnp.ndarray,)
except ImportError:
    jnp = None

from dask.array.core import Array as DaskArray

from ...io.attribute_tags import ReconAttributeTags
from ...processing.processing_algorithm import ProcessingResult
from ...utils.plotting import type_cmaps
from ...utils.rois.roi_type import ROI


def _get_matplotlib_scalebar_size(scalebar):
    # Mode 1: Auto
    ax = scalebar.axes
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    from matplotlib import rcParams

    def _get_value(attr, default):
        value = getattr(scalebar, attr)
        if value is None:
            value = rcParams.get("scalebar." + attr, default)
        return value

    rotation = _get_value("rotation", "horizontal").lower()
    length_fraction = _get_value("length_fraction", 0.2)
    fixed_value = scalebar.fixed_value
    fixed_units = scalebar.fixed_units or scalebar.units

    if rotation == "vertical":
        xlim, ylim = ylim, xlim

    if scalebar.fixed_value is None:
        length_px = abs(xlim[1] - xlim[0]) * length_fraction
        length_px, value, units = scalebar._calculate_best_length(length_px)
    else:
        value = fixed_value
        units = fixed_units

    return scalebar.scale_formatter(value, scalebar.dimension.to_latex(units))


[docs] class DataSequence(ProcessingResult, ABC): """ Abstract base class for defining a sequence of data, e.g. raw data, reconstructed images, unmixed images. Enables consistent saving and processing for all of these data types. """ n_im_dim = 3 @property def attributes(self): return self.da.attrs
[docs] def __init__( self, data, dimensions, coordinates=None, attributes=None, hdf5_sub_name=None, algorithm_id=None, ): if coordinates is None: coordinates = { x: np.arange(data.shape[i]) for i, x in enumerate(dimensions) } ProcessingResult.__init__(self) self.da = DataArray( data=data, dims=dimensions, coords=coordinates, attrs=attributes ) self._cmap = None self.hdf5_sub_name = hdf5_sub_name self.algorithm_id = algorithm_id
def __getitem__(self, item): c = self.copy() try: c.da = c.da[item] except TypeError: # Fudging the JAX implementation... c.da = c.da.copy() c.da.variable._data = np.asarray(c.da.variable._data) c.da = c.da[item] return c # TODO: Implement a concatenate function def copy(self, cls=None): from copy import copy c = copy(self) if cls is not None: c.__class__ = cls return c @property def shape_2d(self): if np.any(np.array(self.shape[-3:]) == 1): return tuple([x for x in self.shape[-3:] if x != 1]) else: return self.shape[-2:] @property def extent(self): coords = [self.da.coords[x] for x in self.two_dims()] return sum(((np.min(np.array(c)), np.max(np.array(c))) for c in coords), ()) def two_dims(self): if self.da.coords["x"].size == 1: return "y", "z" elif self.da.coords["y"].size == 1: return "x", "z" else: return "x", "y" def to_2d(self): s = (0,) * (len(self.shape) - self.n_im_dim) if self.n_im_dim > 2: slicer = [0] * self.n_im_dim for i in np.argsort(self.shape[-self.n_im_dim :])[-2:]: slicer[i] = slice(None) s += tuple(slicer) return self[s] def imshow( self, ax=None, roi_mask: Tuple["ROI", Iterable["ROI"]] = None, mask_roi=True, cmap=None, scale_kwargs=None, return_scalebar_dimension=False, scalebar=True, transpose=False, log=False, **kwargs, ): if scale_kwargs is None: scale_kwargs = {} import matplotlib.pyplot as plt if ax is None: ax = plt.gca() if roi_mask is not None: if type(roi_mask) is not ROI: try: mask, image_slice = roi_mask[0].to_mask_slice(self) image_slice = image_slice.to_2d() display_image = np.squeeze(image_slice.numpy_array).astype( np.float64 ) overall_mask = np.zeros(display_image.shape, dtype=bool) for roi in roi_mask: mask, _ = roi.to_mask_slice(self) if mask_roi: overall_mask[np.squeeze(mask)] = True if mask_roi: display_image[~overall_mask] = np.nan except TypeError: raise ValueError("roi_mask must be a ROI or a tuple of ROIs") else: mask, image_slice = roi_mask.to_mask_slice(self) image_slice = image_slice.to_2d() display_image = np.squeeze(image_slice.numpy_array).astype(np.float64) if mask_roi: display_image[~np.squeeze(mask)] = np.nan else: display_image = self.to_2d().numpy_array interpolation = "antialiased" if display_image.dtype == np.bool_: interpolation = "nearest" if cmap is None: cmap = self.cmap if np.iscomplexobj(display_image): display_image = np.real(display_image) display_image = np.squeeze(display_image) if "origin" not in kwargs: kwargs["origin"] = "lower" if transpose: extent = self.extent[2:] + self.extent[:2] display_image = display_image.T else: extent = self.extent if log: display_image = np.log(display_image) im = ax.imshow( display_image, extent=extent, cmap=cmap, **kwargs, interpolation=interpolation, ) ax.axis("off") if scalebar: from matplotlib_scalebar.scalebar import ScaleBar scale_kwargs_defaults = dict( length_fraction=0.1, location="lower right", font_properties=dict(size="xx-small"), box_alpha=0.0, color="w", ) scale_kwargs_defaults.update(scale_kwargs) scalebar = ScaleBar(1, "m", **scale_kwargs_defaults) ax.add_artist(scalebar) if return_scalebar_dimension: return im, _get_matplotlib_scalebar_size(scalebar) else: return im @property def cmap(self): if self._cmap is None: return type_cmaps[self.get_hdf5_group_name()] else: return self._cmap @cmap.setter def cmap(self, x): self._cmap = x @staticmethod def get_ax1_label_meaning(): return "" @abstractmethod def get_hdf5_group_name(self): pass @property def values(self): return self.raw_data @property def numpy_array(self): return np.asarray(self.values) @property def raw_data(self): # TODO: check if this has side effects in future. Previously, this said if type(self.da.data) == Array: if type(self.da.variable._data) == DaskArray: return np.array(self.da.data) else: return self.da.variable._data @raw_data.setter def raw_data(self, value): self.da.values = value @property def shape(self): return self.da.shape @property def ndim(self): return self.da.ndim @property def dtype(self): return self.da.dtype @property def ax_1_labels(self): r = np.asarray(self.da.coords.get(self.get_ax1_label_meaning(), [])) if r.ndim == 0 and r.item() is None: return r.item() else: return r @property def ax_0_labels(self): return self.da.coords["frames"] @staticmethod def ax_0_exists(): return True
class ImageSequence(DataSequence): @staticmethod def is_single_instance(): return False def get_hdf5_group_name(self): raise NotImplementedError() def __add__(self, other): # A really lazy implementation of concatenating these datasets. There is 100% a better way to do this.. new_data = xarray.concat([self.da, other.da], dim=other.da.dims[0]) output = ImageSequence( new_data.values, self.ax_1_labels, self.algorithm_id, self.fov_3d, self.attributes, self.hdf5_sub_name, ax1_meaning=self.get_ax1_label_meaning(), ) output.__class__ = self.__class__ return output @property def fov(self): n_pixel_tags = [ ReconAttributeTags.X_NUMBER_OF_PIXELS, ReconAttributeTags.Y_NUMBER_OF_PIXELS, ReconAttributeTags.Z_NUMBER_OF_PIXELS, ] fov_tags = [ ReconAttributeTags.X_FIELD_OF_VIEW, ReconAttributeTags.Y_FIELD_OF_VIEW, ReconAttributeTags.Z_FIELD_OF_VIEW, ] n_pixels = np.array([self.attributes.get(tag, 1) for tag in n_pixel_tags]) if np.all(n_pixels == 1): # Old-style data n_pixels = np.array( [self.attributes.get(ReconAttributeTags.OLD_RECON_NX)] * 2 ) fov_tags = [ReconAttributeTags.OLD_FIELD_OF_VIEW] * 2 axes = np.where(~(n_pixels == 1))[0] fov_x = self.attributes.get(fov_tags[axes[0]], None) fov_y = self.attributes.get(fov_tags[axes[1]], None) return fov_x, fov_y @property def fov_3d(self): fov_tags = [ ReconAttributeTags.X_FIELD_OF_VIEW, ReconAttributeTags.Y_FIELD_OF_VIEW, ReconAttributeTags.Z_FIELD_OF_VIEW, ] fov = np.array([self.attributes.get(tag, None) for tag in fov_tags]) if all([x is None for x in fov_tags]): # Old-style data n_pixels = self.raw_data.shape[-3:] fov = np.array( [ self.attributes.get(ReconAttributeTags.OLD_FIELD_OF_VIEW) if x != 1 else 1 for x in n_pixels ] ) return fov
[docs] def __init__( self, raw_data, ax_1_labels=None, algorithm_id="", field_of_view=None, attributes=None, hdf5_sub_name=None, ax1_meaning=None, ): # Ax1 labels = e.g. Wavelength if ax1_meaning is None: ax1_meaning = self.get_ax1_label_meaning() # Quick bit of validation if ax_1_labels is not None and ax1_meaning is not None: if raw_data.shape[1] != len(ax_1_labels): raise ValueError("Axis 1 labels must match raw data size.") if field_of_view is None: field_of_view = [1, 1, 1] if type(field_of_view[0]) not in [tuple, list, np.ndarray]: field_of_view = [ (-x / 2, x / 2) if x is not None else (0, 0) for x in field_of_view ] xs = [ np.linspace(x, y, N) for (x, y), N in zip(field_of_view, raw_data.shape[-3:][::-1]) ] dims = ["frames", "z", "y", "x"] coords = { "frames": np.arange(raw_data.shape[0]), "x": xs[0], "y": xs[1], "z": xs[2], } if not self.ax_0_exists(): del coords["frames"] dims = dims[1:] if ax1_meaning is not None: dims.insert(1, ax1_meaning) coords[ax1_meaning] = ax_1_labels else: # If there isn't really an axis 1 (e.g. for delta so2). coords[ax1_meaning] = ax_1_labels[0] DataSequence.__init__( self, raw_data, dims, coords, attributes, hdf5_sub_name, algorithm_id )