# Copyright (c) Thomas Else 2023.
# License: MIT
import logging
from abc import ABC, abstractmethod
from typing import Sequence, Tuple, Optional, List
import numpy as np
from ..core.image_structures.pa_time_data import PATimeSeries
from ..core.image_structures.reconstruction_image import Reconstruction
from ..io.attribute_tags import ReconAttributeTags
from ..io.msot_data import PAData, HDF5Tags
from ..processing.processing_algorithm import TimeSeriesProcessingAlgorithm, ProcessingResult
[docs]
class ReconstructionAlgorithm(TimeSeriesProcessingAlgorithm, ABC):
def pre_prepare_data(self, x):
return x
[docs]
def __init__(self,
n_pixels: Sequence[int],
field_of_view: Sequence[float],
**kwargs):
super().__init__()
self.n_pixels = n_pixels
self.field_of_view = field_of_view
self.custom_params = kwargs
self._batch = True
self.attributes = {}
@abstractmethod
def reconstruct(self, raw_data: np.ndarray,
fs: float,
geometry: np.ndarray,
n_pixels: Sequence[int],
field_of_view: Sequence[float],
speed_of_sound,
**kwargs) -> np.ndarray:
pass
@staticmethod
def get_algorithm_name() -> str:
pass
def run(self, time_series: PATimeSeries,
pa_data: PAData,
speed_of_sound=None,
geometry=None,
**kwargs) -> Tuple[Reconstruction, dict, Optional[List[ProcessingResult]]]:
from .. import PAT_MAXIMUM_BATCH_SIZE
if speed_of_sound is None and pa_data is not None:
speed_of_sound = pa_data.get_speed_of_sound()
if pa_data is not None and geometry is None:
geometry = pa_data.get_scan_geometry()
logging.debug(f"{time_series.attributes}, {geometry.shape if geometry is not None else None}, "
f"{time_series.raw_data.shape}, {speed_of_sound}")
irf = pa_data.get_impulse_response() if pa_data is not None else None
wavelengths = time_series.da.coords.get("wavelengths") if pa_data is None else pa_data.get_wavelengths()
if type(self.field_of_view[0]) in [tuple, list, np.ndarray]:
# Off centre field_of_view
fov = [abs(x1 - x0) for (x1, x0) in self.field_of_view]
geometry = np.array(geometry)
for i in range(len(self.field_of_view)):
geometry[:, i] -= np.mean(self.field_of_view[i])
else:
fov = self.field_of_view
# Process in batches to avoid GPU running out of memory.
if time_series.shape[0] * time_series.shape[1] > PAT_MAXIMUM_BATCH_SIZE != -1 and self._batch:
new_recons = []
ts_raw = time_series.raw_data
shape = ts_raw.shape
ts_raw = ts_raw.reshape((-1,) + shape[-2:])
for i in range(0, ts_raw.shape[0], PAT_MAXIMUM_BATCH_SIZE):
raw = self.reconstruct(ts_raw[i:i + PAT_MAXIMUM_BATCH_SIZE], time_series.attributes["fs"],
geometry, self.n_pixels, fov, speed_of_sound,
irf=irf, **kwargs, **self.custom_params)
new_recons.append(np.asarray(raw))
raw_data = np.concatenate(new_recons, axis=0).reshape(shape[:2] + new_recons[0].shape[1:])
else:
raw_data = self.reconstruct(time_series.raw_data,
time_series.attributes["fs"],
geometry,
self.n_pixels,
fov,
speed_of_sound,
irf=irf,
**kwargs, **self.custom_params)
output_data = Reconstruction(raw_data, wavelengths,
hdf5_sub_name=self.get_algorithm_name(),
field_of_view=self.field_of_view)
output_data.attributes[HDF5Tags.SPEED_OF_SOUND] = speed_of_sound
output_data.attributes[ReconAttributeTags.RECONSTRUCTION_ALGORITHM] = self.get_algorithm_name()
output_data.attributes[ReconAttributeTags.X_NUMBER_OF_PIXELS] = self.n_pixels[0]
output_data.attributes[ReconAttributeTags.Y_NUMBER_OF_PIXELS] = self.n_pixels[1]
output_data.attributes[ReconAttributeTags.Z_NUMBER_OF_PIXELS] = self.n_pixels[2]
output_data.attributes[ReconAttributeTags.X_FIELD_OF_VIEW] = self.field_of_view[0]
output_data.attributes[ReconAttributeTags.Y_FIELD_OF_VIEW] = self.field_of_view[1]
output_data.attributes[ReconAttributeTags.Z_FIELD_OF_VIEW] = self.field_of_view[2]
output_data.attributes[ReconAttributeTags.ADDITIONAL_PARAMETERS] = kwargs
for a in self.attributes:
output_data.attributes[a] = self.attributes[a]
output_data.attributes[HDF5Tags.WAVELENGTH] = wavelengths
for a in time_series.attributes:
output_data.attributes[a] = time_series.attributes[a]
return output_data, {}, None