Source code for patato.io.hdf.fileimporter

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

import copy
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from typing import Union, Tuple

import dask.array as da
import h5py
import numpy as np

from ...core.image_structures.pa_time_data import PATimeSeries
from ...utils.mask_operations import interpolate_rois


[docs] def slice_1d(data, test_data, slices, dim=-1): if data is None: return None def slice_wl(slice_data, item, wl_axis): if wl_axis == 0 and type(item) is not tuple: r = slice_data[item] elif type(item) is not tuple: r = slice_data else: if len(item) > wl_axis: r = slice_data[item[wl_axis]] else: r = slice_data return r t = data for s in slices: test_data = test_data[s] wl_axis = test_data.ndim + dim t = slice_wl(t, s, wl_axis) return t
[docs] class ReaderInterface(metaclass=ABCMeta): def close(self): pass def is_clinical(self): return np.all(np.isnan(self.get_scanner_z_position()[:])) or np.all( 0.0 == self.get_scanner_z_position()[:] ) def save_to_hdf5(self, filename): from ..hdf.hdf5_interface import HDF5Writer file = h5py.File(filename, "a") writer = HDF5Writer(file) return writer.save_file(self) def __getitem__(self, item): s = copy.copy(self) s.slices = copy.deepcopy(s.slices) # check valid try: s._get_run_numbers()[item] except KeyError: raise KeyError("Invalid slice") s.slices.append(item) return s
[docs] def __init__(self): self._raw_time_series_data = None self._sampling_frequency = None self._scan_geometry = None self.slices = []
@abstractmethod def _get_rois(self): pass def get_rois(self, interpolate=False): output = self._get_rois() if interpolate: groups = defaultdict(list) for name, number in output.keys(): groups[name].append(output[(name, number)]) for roi_name in groups: if len(groups[roi_name]) > 0: interpolated_rois = interpolate_rois( groups[roi_name], self.get_scanner_z_position() ) for i, roi in enumerate(interpolated_rois): output[(roi.roi_class + "_" + roi.position, str(i))] = roi return output @abstractmethod def _get_segmentation(self): pass def get_segmentation(self): test_data = self._get_run_numbers() t = self._get_segmentation() t = slice_1d(t, test_data, self.slices, 0) return t @abstractmethod def get_scan_datetime(self): pass @property def raw_data(self): return self.get_pa_data() @raw_data.setter def raw_data(self, x): self._raw_time_series_data = x @property def sampling_frequency(self): return self.get_sampling_frequency() @sampling_frequency.setter def sampling_frequency(self, x): self._sampling_frequency = x @property def scan_geometry(self): return self.get_sensor_geometry() @scan_geometry.setter def scan_geometry(self, x): self._scan_geometry = x def get_pa_data(self): dataset, attributes = self._get_pa_data() if self._raw_time_series_data is not None: dataset = self._raw_time_series_data wavelengths = self._get_wavelengths() cls = PATimeSeries dims = ["frames", cls.get_ax1_label_meaning(), "detectors", "timeseries"] dim_coords = [ np.arange(dataset.shape[0]), wavelengths, np.arange(dataset.shape[2]), np.arange(dataset.shape[3]), ] coordinates = {a: b for a, b in zip(dims, dim_coords)} new_cls = cls( da.from_array(dataset, chunks=(1,) + dataset.shape[1:]), dims, coordinates, attributes, ) if not cls.is_single_instance(): new_cls.hdf5_sub_name = dataset.name.split("/")[-2] for s in self.slices: new_cls = new_cls[s] return new_cls @abstractmethod def _get_pa_data(self): pass @abstractmethod def get_scan_name(self): pass @abstractmethod def _get_temperature(self): pass def get_temperature(self): t = self._get_temperature() for s in self.slices: t = t[s] return t @abstractmethod def _get_correction_factor(self): pass def get_correction_factor(self): t = self._get_correction_factor() for s in self.slices: t = t[s] return t @abstractmethod def _get_scanner_z_position(self): pass def get_scanner_z_position(self): t = self._get_scanner_z_position() for s in self.slices: t = t[s] return t @abstractmethod def _get_run_numbers(self): pass def get_run_numbers(self): t = self._get_run_numbers() for s in self.slices: t = t[s] return t @abstractmethod def _get_repetition_numbers(self): pass def get_repetition_numbers(self): t = self._get_repetition_numbers() for s in self.slices: t = t[s] return t @abstractmethod def _get_scan_times(self): pass def get_scan_times(self): t = self._get_scan_times() for s in self.slices: t = t[s] return t @abstractmethod def _get_sensor_geometry(self): pass def get_sensor_geometry(self): if self._scan_geometry is not None: return self._scan_geometry else: return self._get_sensor_geometry() @abstractmethod def get_impulse_response(self): pass @abstractmethod def _get_wavelengths(self): pass def get_wavelengths(self): test_data = self._get_run_numbers() t = self._get_wavelengths() t = slice_1d(t, test_data, self.slices, -1) return t @abstractmethod def _get_water_absorption(self): pass def get_water_absorption(self): t, pl = self._get_water_absorption() test_data = self._get_run_numbers() t = slice_1d(t, test_data, self.slices, -1) return t, pl @abstractmethod def _get_datasets(self): # Make this return an image sequence type pass def get_datasets(self): all_datasets = self._get_datasets() if all_datasets is not None: for s in self.slices: for dataset_type in all_datasets: for reconstruction_type in all_datasets[dataset_type]: if all_datasets[dataset_type][reconstruction_type]: all_datasets[dataset_type][reconstruction_type] = ( all_datasets[dataset_type][reconstruction_type][s] ) return all_datasets @abstractmethod def get_scan_comment(self): pass @abstractmethod def _get_sampling_frequency(self): pass def get_sampling_frequency(self): if self._sampling_frequency is not None: return self._sampling_frequency else: return self._get_sampling_frequency() @abstractmethod def get_speed_of_sound(self): pass
[docs] class WriterInterface(metaclass=ABCMeta): def close(self): pass def save_file(self, reader: ReaderInterface): # TODO: implement updating. if reader.get_segmentation() is not None: self.set_segmentation(reader.get_segmentation()) self.set_scan_datetime(reader.get_scan_datetime()) self.set_pa_data(reader.get_pa_data()) self.set_scan_name(reader.get_scan_name()) self.set_temperature(reader.get_temperature()) self.set_correction_factor(reader.get_correction_factor()) self.set_scanner_z_position(reader.get_scanner_z_position()) self.set_run_numbers(reader.get_run_numbers()) self.set_repetition_numbers(reader.get_repetition_numbers()) self.set_scan_times(reader.get_scan_times()) self.set_sensor_geometry(reader.get_sensor_geometry()) self.set_impulse_response(reader.get_impulse_response()) self.set_wavelengths(reader.get_wavelengths()) self.set_water_absorption(*reader.get_water_absorption()) if reader.get_datasets() is not None: for _, image_group in reader.get_datasets().items(): for key in sorted(image_group, key=lambda x: int(x[1])): recon = image_group[key] self.add_image(recon) self.set_scan_comment(reader.get_scan_comment()) self.set_sampling_frequency(reader.get_sampling_frequency()) @abstractmethod def set_segmentation(self, seg): pass @abstractmethod def set_scan_datetime(self, datetime): pass @abstractmethod def set_pa_data(self, pa_data: "PATimeSeries"): pass @abstractmethod def set_scan_name(self, scan_name: str): pass @abstractmethod def set_temperature(self, temperature: "np.ndarray"): pass @abstractmethod def set_correction_factor(self, correction_factor): pass @abstractmethod def set_scanner_z_position(self, z_position): pass @abstractmethod def set_run_numbers(self, run_numbers): pass @abstractmethod def set_repetition_numbers(self, repetition_numbers): pass @abstractmethod def set_scan_times(self, scan_times): pass @abstractmethod def set_sensor_geometry(self, sensor_geometry): pass @abstractmethod def set_impulse_response(self, impulse_response): pass @abstractmethod def set_wavelengths(self, wavelengths): pass @abstractmethod def set_water_absorption(self, water_absorption, pathlength): pass @abstractmethod def add_image(self, image): pass @abstractmethod def delete_images(self, image): pass @abstractmethod def set_scan_comment(self, comment: str): pass @abstractmethod def set_sampling_frequency(self, frequency: float): pass @abstractmethod def set_speed_of_sound(self, c: float): pass @abstractmethod def add_roi(self, roi_data, generated: bool = False): pass
[docs] @abstractmethod def rename_roi( self, old_name: Union[str, Tuple], new_name: str, new_position: str ) -> None: """ Rename a region of interest. Parameters ---------- old_name : str or tuple Old roi name e.g. "tumour_left/0" or ("tumour_left", "0") new_name : str New roi name e.g. "brain" new_position : str New roi position e.g. "left" """ pass
@abstractmethod def delete_rois(self, name_position=None, number=None): pass @abstractmethod def delete_recons(self, name, recon_groups): pass @abstractmethod def delete_dso2s(self) -> None: pass