Source code for patato.recon.reconstruction_algorithm

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

import logging
from abc import ABC, abstractmethod
from typing import Sequence, Tuple, Optional, List

import numpy as np

from ..core.image_structures.pa_time_data import PATimeSeries
from ..core.image_structures.reconstruction_image import Reconstruction
from ..io.attribute_tags import ReconAttributeTags
from ..io.msot_data import PAData, HDF5Tags
from ..processing.processing_algorithm import (
    TimeSeriesProcessingAlgorithm,
    ProcessingResult,
)


[docs] class ReconstructionAlgorithm(TimeSeriesProcessingAlgorithm, ABC): def pre_prepare_data(self, x): return x
[docs] def __init__( self, n_pixels: Sequence[int], field_of_view: Sequence[float], **kwargs ): super().__init__() self.n_pixels = n_pixels self.field_of_view = field_of_view self.custom_params = kwargs self._batch = True self.attributes = {}
@abstractmethod def reconstruct( self, raw_data: np.ndarray, fs: float, geometry: np.ndarray, n_pixels: Sequence[int], field_of_view: Sequence[float], speed_of_sound, **kwargs, ) -> np.ndarray: pass @staticmethod def get_algorithm_name() -> str: pass def run( self, time_series: PATimeSeries, pa_data: PAData, speed_of_sound=None, geometry=None, **kwargs, ) -> Tuple[Reconstruction, dict, Optional[List[ProcessingResult]]]: from .. import PAT_MAXIMUM_BATCH_SIZE if speed_of_sound is None and pa_data is not None: speed_of_sound = pa_data.get_speed_of_sound() if pa_data is not None and geometry is None: geometry = pa_data.get_scan_geometry() logging.debug( f"{time_series.attributes}, {geometry.shape if geometry is not None else None}, " f"{time_series.raw_data.shape}, {speed_of_sound}" ) irf = pa_data.get_impulse_response() if pa_data is not None else None wavelengths = ( time_series.da.coords.get("wavelengths") if pa_data is None else pa_data.get_wavelengths() ) if type(self.field_of_view[0]) in [tuple, list, np.ndarray]: # Off centre field_of_view fov = [abs(x1 - x0) for (x1, x0) in self.field_of_view] geometry = np.array(geometry) for i in range(len(self.field_of_view)): geometry[:, i] -= np.mean(self.field_of_view[i]) else: fov = self.field_of_view # Process in batches to avoid GPU running out of memory. if ( time_series.shape[0] * time_series.shape[1] > PAT_MAXIMUM_BATCH_SIZE != -1 and self._batch ): new_recons = [] ts_raw = time_series.raw_data shape = ts_raw.shape ts_raw = ts_raw.reshape((-1,) + shape[-2:]) for i in range(0, ts_raw.shape[0], PAT_MAXIMUM_BATCH_SIZE): raw = self.reconstruct( ts_raw[i : i + PAT_MAXIMUM_BATCH_SIZE], time_series.attributes["fs"], geometry, self.n_pixels, fov, speed_of_sound, irf=irf, **kwargs, **self.custom_params, ) new_recons.append(np.asarray(raw)) raw_data = np.concatenate(new_recons, axis=0).reshape( shape[:2] + new_recons[0].shape[1:] ) else: raw_data = self.reconstruct( time_series.raw_data, time_series.attributes["fs"], geometry, self.n_pixels, fov, speed_of_sound, irf=irf, **kwargs, **self.custom_params, ) output_data = Reconstruction( raw_data, wavelengths, hdf5_sub_name=self.get_algorithm_name(), field_of_view=self.field_of_view, ) output_data.attributes[HDF5Tags.SPEED_OF_SOUND] = speed_of_sound output_data.attributes[ ReconAttributeTags.RECONSTRUCTION_ALGORITHM ] = self.get_algorithm_name() output_data.attributes[ReconAttributeTags.X_NUMBER_OF_PIXELS] = self.n_pixels[0] output_data.attributes[ReconAttributeTags.Y_NUMBER_OF_PIXELS] = self.n_pixels[1] output_data.attributes[ReconAttributeTags.Z_NUMBER_OF_PIXELS] = self.n_pixels[2] output_data.attributes[ReconAttributeTags.X_FIELD_OF_VIEW] = self.field_of_view[ 0 ] output_data.attributes[ReconAttributeTags.Y_FIELD_OF_VIEW] = self.field_of_view[ 1 ] output_data.attributes[ReconAttributeTags.Z_FIELD_OF_VIEW] = self.field_of_view[ 2 ] output_data.attributes[ReconAttributeTags.ADDITIONAL_PARAMETERS] = kwargs for a in self.attributes: output_data.attributes[a] = self.attributes[a] output_data.attributes[HDF5Tags.WAVELENGTH] = wavelengths for a in time_series.attributes: output_data.attributes[a] = time_series.attributes[a] return output_data, {}, None