Source code for patato.io.msot_data

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

from __future__ import annotations

import copy
from typing import Union, Dict, Tuple, Optional, Sequence, TYPE_CHECKING, Type

import h5py
import numpy as np
import xarray

from .hdf.fileimporter import ReaderInterface, WriterInterface
from .hdf.hdf5_interface import HDF5Reader, HDF5Writer
from ..core.image_structures.pa_time_data import PATimeSeries
from ..core.image_structures.single_image import SingleImage
from ..core.image_structures.single_parameter_data import SingleParameterData
from ..utils.roi_operations import get_rim_core_rois
from ..utils.rois.roi_type import ROI

if TYPE_CHECKING:
    try:
        from pyopencl.array import Array
    except ImportError:
        Array = np.ndarray
    from ..core.image_structures.image_sequence import ImageSequence

from functools import lru_cache

from ..io.attribute_tags import HDF5Tags


class PAData:
    """A class that contains the interface to access data from a single scan. Any source of scans (e.g.
    iThera/HDF5/IPASC) can be linked to this.
    """

    @property
    def shape(self) -> Tuple[int]:
        """
        Returns the shape of the dataset, minus the image size.

        Returns
        -------
        tuple of int
            Shape of the dataset.
        """
        return self.get_time_series().shape[:-2]

    def __getitem__(self, item: Union[slice, Tuple, None]) -> "PAData":
        """
        Slice the photoacoustic data. Choose a particular frame/wavelength etc.

        Parameters
        ----------
        item : slice or tuple

        Returns
        -------
        PAData
            Sliced pa data.
        """
        new_data = self.copy()
        new_data.scan_reader = copy.copy(new_data.scan_reader)[item]
        new_data.scan_writer = new_data.scan_writer
        return new_data

[docs] def __init__( self, scan_reader: ReaderInterface, scan_writer: WriterInterface = None ) -> None: """ Parameters ---------- scan_reader scan_writer. """ super().__init__() self.scan_reader = scan_reader self.scan_writer = scan_writer self.default_recon = None self.default_unmixing_type = "" self.external_roi_interface = None
def is_clinical(self): return self.scan_reader.is_clinical()
[docs] def set_default_recon(self, rec_name: Optional[str] = None) -> None: """ Make all returned data be of a particular reconstruction type. It is recommended to run this at the start of analysis scripts. Parameters ---------- rec_name : tuple of str, optional """ if rec_name is None and self.default_recon is None: rec_name = list(self.get_scan_reconstructions().keys())[0] if not (self.default_recon is not None and rec_name is None): self.default_recon = rec_name
[docs] def copy(self, cls: Optional[Type["PAData"]] = None) -> PAData: """ Copy the pa data with changes given. Parameters ---------- cls : Type[PAData] Returns ------- PAData Copy of the dataset. """ if cls is None: cls = self.__class__ from copy import copy c = copy(self) c.__class__ = cls return c
[docs] def get_scan_name(self) -> str: """ Get the scan name. Returns ------- str Scan name. """ return self.scan_reader.get_scan_name()
def get_scan_datetime(self): return self.scan_reader.get_scan_datetime()
[docs] def get_sampling_frequency(self) -> float: """ Get the scan's sampling frequency. Returns ------- float Sampling Frequency """ return self.scan_reader.get_sampling_frequency()
[docs] def get_overall_correction_factor( self, ) -> Union[np.ndarray, Array, xarray.DataArray]: """ Return the energy correction factors for the dataset. Returns ------- np.ndarray Overall correction factor. """ return self.scan_reader.get_correction_factor()
[docs] def get_impulse_response(self) -> Union[np.ndarray, Array]: """ Return the time-domain impulse response function. Returns ------- np.ndarray or pyopencl.array.Array Impulse response function. """ return self.scan_reader.get_impulse_response()
[docs] def get_n_samples(self) -> int: """ Get the number of time samples in the dataset. Returns ------- int Number of samples. """ return self.get_time_series().shape[-1]
[docs] def get_speed_of_sound(self) -> Union[float, None]: """ Get the speed of sound of the data if it has been set. Returns ------- float or None Speed of sound """ return self.scan_reader.get_speed_of_sound()
[docs] def get_scan_geometry(self) -> Union[np.ndarray, Array]: """ Get the scan detector geometry. Returns ------- np.ndarray or pyopencl.array.Array """ return self.scan_reader.get_sensor_geometry()
[docs] def get_wavelengths(self) -> np.ndarray: """ Get the wavelengths used in the scan. Returns ------- np.ndarray Scan Wavelengths. """ return self.scan_reader.get_wavelengths()
[docs] def get_scan_images( self, group: str, ignore_default=False, suffix="" ) -> Union[Dict[Tuple[str, str], ImageSequence], ImageSequence]: """ Get the scan images, e.g. reconstructions or so2 etc. Parameters ---------- group : str Group to get images from. ignore_default : bool Ignore the default reconstruction. suffix : str Suffix to add to the image number (e.g. for ICG unmixing). Returns ------- (dict of {tuple of (str, str) : ImageSequence}) or ImageSequence Images of certain type if default recon has been set, or dict or images for all reconstructions. """ datasets = self.scan_reader.get_datasets() if group not in datasets: return {} if self.default_recon is None or ignore_default: return datasets.get(group, {}) else: image = (self.default_recon[0], self.default_recon[1] + suffix) return datasets.get(group, {}).get(image, {})
def get_scan_reconstructions(self): return self.get_scan_images(HDF5Tags.RECONSTRUCTION) def get_ultrasound(self): us_images = self.get_scan_images(HDF5Tags.ULTRASOUND, ignore_default=True) if len(us_images) == 0: return {} elif len(us_images) > 1: return us_images else: return list(us_images.values())[0] def get_scan_unmixed(self): return self.get_scan_images(HDF5Tags.UNMIXED, suffix=self.default_unmixing_type) def get_scan_so2(self): return self.get_scan_images(HDF5Tags.SO2) def close(self): self.scan_reader.close() if self.scan_writer is not None: self.scan_writer.close()
[docs] def get_scan_mean(self, dataset: ImageSequence, operation=np.mean): """ Parameters ---------- dataset operation. Returns ------- """ if type(dataset) == dict: raise NotImplementedError new_dataset = SingleImage( operation(dataset.raw_data, axis=0)[0], dataset.ax_1_labels, field_of_view=dataset.fov_3d, attributes=dataset.attributes, ) return new_dataset
[docs] def get_scan_so2_time_mean(self): """ Returns ------- """ return self.get_scan_mean(self.get_scan_so2())
[docs] def get_scan_thb_time_mean(self): """ Returns ------- """ return self.get_scan_mean(self.get_scan_thb())
[docs] def get_scan_so2_time_standard_deviation(self): """ Returns ------- """ return self.get_scan_mean(self.get_scan_so2(), np.std)
[docs] def get_scan_thb(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.THB)
[docs] def get_scan_dso2(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.DELTA_SO2)
[docs] def get_scan_dicg(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.DELTA_ICG)
[docs] def get_scan_baseline_icg(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.BASELINE_ICG)
[docs] def get_scan_baseline_standard_deviation_icg(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.BASELINE_ICG_SIGMA)
[docs] def get_responding_pixels(self, nsigma=2): """ Parameters ---------- nsigma. Returns ------- """ responding = None delta_images = None if self.get_scan_dicg(): delta_images = self.get_scan_dicg() sigma_icg = self.get_scan_baseline_standard_deviation_icg() responding = delta_images.raw_data > nsigma * sigma_icg.raw_data elif self.get_scan_dso2(): delta_images = self.get_scan_dso2() sigma_so2 = self.get_scan_baseline_standard_deviation_so2() responding = delta_images.raw_data > nsigma * sigma_so2.raw_data else: return None return SingleImage( responding, ["Responding Pixels"], algorithm_id=delta_images.algorithm_id, attributes=delta_images.attributes, hdf5_sub_name=delta_images.hdf5_sub_name, field_of_view=delta_images.fov_3d, )
[docs] def get_scan_baseline_so2(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.BASELINE_SO2)
[docs] def get_scan_baseline_standard_deviation_so2(self): """ Returns ------- """ return self.get_scan_images(HDF5Tags.BASELINE_SO2_STANDARD_DEVIATION)
# Some cycling hypoxia analysis here.
[docs] @lru_cache def get_scan_so2_frequency_components( self, do_detrend=True, fmin=1e-5, fmax=1000, fnum=1000 ): """ Parameters ---------- do_detrend fmin fmax fnum. Returns ------- """ from scipy.signal import detrend, lombscargle so2 = self.get_scan_images(HDF5Tags.SO2) if type(so2) == dict: if len(so2) == 1: so2 = so2[list(so2.keys())[0]] else: raise NotImplementedError( """Frequency components are only enabled when there is one reconstruction set. Run PAData.set_default_recon() before running this if in doubt. """ ) detrended = detrend( so2.raw_data, axis=0, type="linear" if do_detrend else "constant" ) times = self.get_timestamps()[:, 0].copy() times -= times[0] frequencies = np.linspace(fmin, fmax, fnum) def lomb(a, b, c): return lombscargle(a, b.copy(), c) so2_frequency = SingleParameterData( np.apply_along_axis( lambda x: lomb(times, x, frequencies * 2 * np.pi), 0, detrended ), so2.ax_1_labels, field_of_view=so2.fov_3d, attributes=so2.attributes, ) so2_frequency.da.coords["frames"] = frequencies * 2 * np.pi return so2_frequency
[docs] def get_scan_so2_frequency_peak(self, fnum=1000): """ Parameters ---------- fnum. Returns ------- """ so2 = self.get_scan_so2_frequency_components(fnum=fnum) raw_data = np.max(so2.raw_data, axis=0)[0] so2_frequency = SingleImage( raw_data, so2.ax_1_labels, field_of_view=so2.fov_3d, attributes=so2.attributes, ) return so2_frequency
[docs] def get_scan_so2_frequency_sum(self, fnum=1000): """ Parameters ---------- fnum. Returns ------- """ so2 = self.get_scan_so2_frequency_components(fnum=fnum) raw_data = np.sum(so2.raw_data, axis=0)[0] so2_frequency = SingleImage( raw_data, so2.ax_1_labels, field_of_view=so2.fov_3d, attributes=so2.attributes, ) return so2_frequency
[docs] def get_segmentation(self): """ Returns ------- """ return self.scan_reader.get_segmentation()
[docs] def get_time_series(self) -> PATimeSeries: """ Returns ------- """ dataset = self.scan_reader.get_pa_data() if type(dataset) is not PATimeSeries: raise ValueError( "raw_data attribute must be either TimeSeries or FourierDomain type." ) else: return dataset
[docs] def get_recon_types(self) -> list: """ Get the list of different reconstruction types. Returns ------- list List of different reconstruction types that we have. """ return list(self.get_scan_images(HDF5Tags.RECONSTRUCTION, True).keys())
[docs] def get_z_positions(self) -> np.ndarray: """ Get the z-positions of the sensor. Returns ------- np.ndarray Z-positions array. """ return self.scan_reader.get_scanner_z_position()
[docs] def get_run_number(self) -> np.ndarray: """ Get the run number for each of the frames. Returns ------- np.ndarray Get the run numbers of each of the frames. """ return self.scan_reader.get_run_numbers()
[docs] def get_repetition_numbers(self) -> np.ndarray: """ Get the scan repetition numbers for each frame. Returns ------- np.ndarray Scan repetition numbers. """ return self.scan_reader.get_repetition_numbers()
[docs] def get_timestamps(self) -> np.ndarray: """ Get the scan timestamps in seconds. Returns ------- np.ndarray Timestamps in seconds. """ # in seconds return self.scan_reader.get_scan_times()
[docs] def get_rois( self, filter_rois=None, interpolate: bool = False, get_rim_cores=None, rim_core_distance=None, ) -> Dict[Tuple[str, str], "ROI"]: """ Get the regions of interest from the dataset. Parameters ---------- rim_core_distance get_rim_cores filter_rois : dict or None interpolate : bool Returns ------- dict of {(tuple of (str, str) : ROI} Return all the rois. """ if self.external_roi_interface is not None: reader = self.external_roi_interface.scan_reader else: reader = self.scan_reader output_rois = reader.get_rois(interpolate) if get_rim_cores is not None: rim_core_rois = {} for o in output_rois: if o[0] in get_rim_cores: core, rim = get_rim_core_rois(output_rois[o], rim_core_distance) name, position = o[0].split("_") rim_core_rois[(name + ".core_" + position, o[1])] = core rim_core_rois[(name + ".rim_" + position, o[1])] = rim output_rois.update(rim_core_rois) if filter_rois is not None: new_out_rois = {} for k in output_rois: for f in filter_rois: if f == k[0] or f + "_" == k[0]: new_out_rois[k] = output_rois[k] output_rois = new_out_rois return output_rois
[docs] def delete_recons(self, name=None, recon_groups: Optional[Sequence[str]] = None): """ Delete the reconstructions. Parameters ---------- name : str or None recon_groups : (iterable of str) or None """ if self.scan_writer is None: raise NotImplementedError( "Deletion only possible with a writing interface." ) self.scan_writer.delete_recons(name, recon_groups)
[docs] def set_speed_of_sound(self, c: float) -> None: """ Change the speed of sound for the dataset. Parameters ---------- c : float Speed of sound. """ if self.scan_writer is None: raise NotImplementedError("No writing capability enabled.") self.scan_writer.set_speed_of_sound(c)
[docs] def delete_rois( self, name_position: Optional[str] = None, number: Optional[str] = None ) -> None: """ Delete a roi with name and number. If number is None, will delete all. If name and number is None, delete all. Parameters ---------- name_position str or None number str or none """ self.scan_writer.delete_rois(name_position, number)
[docs] def add_roi(self, roi_data: "ROI", generated: bool = False) -> None: """ Add a region of interest to the hdf5 file. Parameters ---------- roi_data : ROI generated : bool, default False """ self.scan_writer.add_roi(roi_data, generated)
[docs] def rename_roi( self, old_name: Union[str, Tuple], new_name: str, new_position: str ) -> None: """ Rename a region of interest. Parameters ---------- old_name : str or tuple Old roi name e.g. "tumour_left/0" or ("tumour_left", "0") new_name : str New roi name e.g. "brain" new_position : str New roi position e.g. "left" """ self.scan_writer.rename_roi(old_name, new_name, new_position)
def clear_dso2(self) -> None: self.scan_writer.delete_dso2s() @classmethod def from_hdf5(cls, filename: Union[str, h5py.File], mode: str = "r"): try: file = h5py.File(filename, mode) except TypeError: file = filename return cls(HDF5Reader(file), HDF5Writer(file)) def save_hdf5(self, filename: str): file = h5py.File(filename, "a") writer = HDF5Writer(file) return writer.save_file(self.scan_reader) def save_to_hdf5(self, filename): return self.save_hdf5(filename) @property def dataset(self): """ Returns ------- """ return self.get_time_series()
[docs] def summary_measurements( self, metrics=None, include_rois=None, roi_kwargs=None, just_summary=True, return_masks=False, metric_limits=None, ): """ Parameters ---------- metrics include_rois roi_kwargs just_summary. Returns ------- """ import pandas as pd if metrics is None: metrics = ["thb", "so2"] if roi_kwargs is None: roi_kwargs = {} rois = self.get_rois(**roi_kwargs) if include_rois is not None: new_rois = {} for x in rois: if x[0][:-1] in include_rois or x[0] in include_rois: new_rois[x] = rois[x] rois = new_rois measurements = [] for m in metrics: if m == "thb": measurements.append(self.get_scan_thb()) elif m == "so2": measurements.append(self.get_scan_so2()) elif m == "icg": unmixed = self.get_scan_unmixed() n = [i for i, l in enumerate(unmixed.ax_1_labels) if l.lower() == "icg"] if len(n) == 0: print("No icg channel found") continue n = n[0] measurements.append(unmixed[:, n : n + 1]) elif m == "dso2": measurements.append(self.get_scan_dso2()) elif m == "baseline_so2": measurements.append(self.get_scan_baseline_so2()) elif m == "responding": measurements.append(self.get_responding_pixels()) elif m == "recons": measurements.append(self.get_scan_reconstructions()) else: raise NotImplementedError(f"Metric {m} not yet implemented.") if not rois: return pd.DataFrame({}) outputs = [] for name, roi in rois.items(): output_roi = {} for i, m in enumerate(measurements): if m: mask, data_slice = roi.to_mask_slice(m) output_roi[metrics[i]] = data_slice.raw_data.T[mask.T].T if metric_limits is not None and metrics[i] in metric_limits: lower, upper = metric_limits[metrics[i]] output_roi[metrics[i]][output_roi[metrics[i]] < lower] = np.nan output_roi[metrics[i]][output_roi[metrics[i]] > upper] = np.nan else: print(f"Skipping metric {metrics[i]}") mask, _, selection = roi.to_mask_slice( self.get_scan_reconstructions(), return_selection=True ) if return_masks: return_mask, _ = roi.to_mask_slice(self.get_scan_so2()) output_roi["Mask"] = return_mask output_roi["Timings"] = self.get_timestamps()[selection] output_roi["Area"] = np.sum(mask) for a in roi.attributes: output_roi[a] = roi.attributes[a] output_roi["number"] = name[1] output_roi["name"] = name output_roi["Wavelengths"] = self.get_wavelengths() outputs.append(output_roi) output_table = pd.DataFrame(outputs) summary_methods = {"mean": np.nanmean, "median": np.nanmedian, "std": np.nanstd} for name, method in summary_methods.items(): for metric in metrics: if metric in output_table.columns: output_table[metric + "_" + name] = output_table[metric].apply( lambda t: np.squeeze(method(t, axis=-1))[()] ) if just_summary: for metric in metrics: if metric in output_table.columns: del output_table[metric] return output_table