Source code for patato.processing.jax_preprocessing_algorithm

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

from functools import partial
from typing import Union, Tuple, Optional, Dict

import numpy as np
from patato.io.attribute_tags import PreprocessingAttributeTags
from scipy.signal.windows import hann

from .processing_algorithm import TimeSeriesProcessingAlgorithm

try:
    import jax.numpy as jnp
    import jax
except ImportError:
    jnp = None
    jax = None

import warnings

from ..core.image_structures.pa_time_data import PATimeSeries

# Specify Array type
Array = np.typing.NDArray


[docs] @jax.jit def subtract_mean(time_series): # A function to subtract the mean from a time series return time_series - jnp.mean(time_series, axis=-1).reshape( time_series.shape[:-1] + (1,) )
[docs] @partial(jax.jit, static_argnums=(1,)) def interpolate_detectors(detectors, ndet): # Interpolate the detectors to the correct number of detectors new_detector_i = jnp.linspace(0, detectors.shape[0] - 1, ndet * detectors.shape[0]) old_detector_i = jnp.arange(detectors.shape[0]) interp_detectors = jax.vmap(jnp.interp, in_axes=(None, None, -1), out_axes=-1) new_detectors = interp_detectors(new_detector_i, old_detector_i, detectors) return new_detectors
[docs] @partial(jax.jit, static_argnums=(1, 2)) def partial_interpolate(time_series: Array, nt: int, ndet: int) -> Array: # Interpolate the time series to the correct number of time points new_detector_i = jnp.linspace( 0, time_series.shape[-2] - 1, ndet * time_series.shape[-2] ) old_detector_i = jnp.arange(time_series.shape[-2]) new_times = jnp.linspace( 0, time_series.shape[-1] - 1, nt * time_series.shape[-1] + 1 )[:-1] old_times = jnp.arange(time_series.shape[-1]) interp_detectors = jax.vmap(jnp.interp, in_axes=(None, None, -1), out_axes=-1) new_time_series = interp_detectors(new_detector_i, old_detector_i, time_series) interp_time = jax.vmap(jnp.interp, in_axes=(None, None, 0), out_axes=0) new_time_series = interp_time(new_times, old_times, new_time_series) return new_time_series
[docs] def make_filter( n_samples: int, fs: float, irf: Array, hilbert: bool, lp_filter: Optional[float], hp_filter: Optional[float], rise: float = 0.2, n_filter: int = 1024, window: Optional[str] = None, ): """ Make the filter for the time series. Parameters ---------- n_samples : int fs : float irf : Array hilbert : bool lp_filter : float or None hp_filter : float or None rise : float n_filter : int window Returns ------- """ output = np.ones((n_samples,), dtype=np.cdouble) # Impulse response correction if irf is not None: irf_shifted = np.fft.fftshift(irf) # Divide by the impulse response to deconvolve output *= ( np.conj(np.fft.fft(irf_shifted)) / np.abs(np.fft.fft(irf_shifted)) ** 2 ) # Suppress high frequencies to avoid amplifying noise - apply a window. output *= np.fft.fftshift(hann(n_samples)) # Hilbert Transform frequencies = np.fft.fftfreq(n_samples) if hilbert: # TODO: check this # Multiply positive frequencies by output *= (1 + np.sign(frequencies)) / 2 frequencies = np.abs(np.fft.fftfreq(n_filter, 1 / fs)) filter_output = np.ones_like(frequencies, dtype=np.cdouble) if hp_filter is not None: filter_output[frequencies < hp_filter * (1 - rise)] = 0 in_rise = np.logical_and( frequencies > hp_filter * (1 - rise), frequencies < hp_filter ) filter_output[in_rise] = (frequencies[in_rise] - hp_filter * (1 - rise)) / ( hp_filter * rise ) if lp_filter is not None: filter_output[frequencies > lp_filter * (1 + rise)] = 0 in_rise = np.logical_and( frequencies < lp_filter * (1 + rise), frequencies > lp_filter ) filter_output[in_rise] = 1 - (frequencies[in_rise] - lp_filter) / ( lp_filter * rise ) time_series = np.fft.ifft(filter_output) if window == "hann": time_series *= np.fft.fftshift(hann(n_filter)) filter_time = np.zeros_like(output) filter_time[: n_filter // 2] = time_series[: n_filter // 2] filter_time[-n_filter // 2 :] = time_series[-n_filter // 2 :] filter_output = np.fft.fft(filter_time) output *= filter_output return output
class PreProcessor(TimeSeriesProcessingAlgorithm): """Preprocesses MSOT time series data. Uses JAX in the background."""
[docs] @staticmethod def get_algorithm_name() -> Union[str, None]: """ Get the name of the algorithm. Returns ------- str or None """ return "Standard Preprocessor"
[docs] @staticmethod def get_hdf5_group_name() -> Union[str, None]: """ Return the name of the group in the HDF5 file. Returns ------- str or None """ # Return the name of the group in the HDF5 file return None
[docs] def __init__( self, time_factor: int = 3, detector_factor: int = 2, irf: bool = True, hilbert: bool = True, lp_filter: Optional[float] = None, hp_filter: Optional[float] = None, filter_window_size: int = 512, window: str = "hann", absolute: Optional[str] = None, universal_backprojection=False, ): # Initialise the preprocessor super().__init__() self.time_factor = time_factor self.detector_factor = detector_factor self.hilbert = hilbert absolute = "imag" if absolute is None and hilbert else absolute self.ubp = universal_backprojection self.irf_correct = irf self.lp_filter = lp_filter self.hp_filter = hp_filter self.n_filter = filter_window_size self.window = window self.absolute = absolute self.filter = None
[docs] def pre_compute_filter(self, n_samples: int, fs: float, irf: Array = None): """ Precompute the filter to be applied. Parameters ---------- n_samples : int fs : float irf : Array """ self.filter = jnp.array( make_filter( n_samples, fs, irf if self.irf_correct else None, self.hilbert, self.lp_filter, self.hp_filter, ) )
def _run( self, time_series: Array, detectors: Array, overall_correction_factor, **kwargs ): """ Run the preprocessing step on a given time series and detectors. This allows batch processing, e.g. if the data doesn't fit into memory. Parameters ---------- time_series : Array detectors : Array overall_correction_factor : Array kwargs : dict Returns ------- tuple of Array, Array """ shape = time_series.shape time_series = jnp.array(time_series.reshape((-1,) + shape[-2:])) detectors = jnp.array(detectors) # Subtract mean time_series = subtract_mean(time_series) # Apply filters time_series_ft = jnp.fft.fft(time_series) time_series_ft = time_series_ft * self.filter.reshape( (1,) * (time_series.ndim - 1) + (-1,) ) if self.absolute == "imag": op = jnp.imag else: op = jnp.real time_series = op(jnp.fft.ifft(time_series_ft)) # Allow for universal backprojection here. if self.ubp: time_series -= jnp.gradient(time_series, axis=-1) * jnp.arange( time_series.shape[-1] ) # Interpolate in time domain and detector domain if not (self.detector_factor == 1 and self.time_factor == 1): full_interpolate = jax.vmap( partial_interpolate, in_axes=(0, None, None), out_axes=0 ) time_series = full_interpolate( time_series, self.time_factor, self.detector_factor ) detectors = interpolate_detectors(detectors, self.detector_factor) time_series = time_series.reshape(shape[:-2] + time_series.shape[1:]) # Apply energy correction factor: if overall_correction_factor is not None: extend = (slice(None, None),) * overall_correction_factor.ndim + ( None, None, ) time_series /= overall_correction_factor[extend] else: warnings.warn("No energy correction factor applied.") return time_series, detectors
[docs] def run( self, time_series, pa_data=None, irf=None, detectors=None, **kwargs ) -> Tuple[PATimeSeries, Dict, Optional[list]]: """ Run the preprocessing step on a given time series and detectors. This allows batch processing, e.g. if the data doesn't fit into memory. Parameters ---------- time_series pa_data irf detectors kwargs Returns ------- tuple of PATimeSeries, dict, list """ from .. import PAT_MAXIMUM_BATCH_SIZE # Impulse response if irf is None and pa_data is not None: irf = pa_data.get_impulse_response() # Photoacoustic transducers if detectors is None and pa_data is not None: detectors = pa_data.get_scan_geometry() if pa_data is not None: overall_correction_factor = pa_data.get_overall_correction_factor() else: overall_correction_factor = None # Sampling frequency fs = time_series.attributes["fs"] if self.filter is None: self.pre_compute_filter(time_series.shape[-1], fs=fs, irf=irf) new_detectors = detectors if time_series.shape[0] * time_series.shape[1] > PAT_MAXIMUM_BATCH_SIZE != -1: new_timeseries = [] ts_raw = time_series.raw_data shape = ts_raw.shape ts_raw = ts_raw.reshape((-1,) + shape[-2:]) for i in range(0, ts_raw.shape[0], PAT_MAXIMUM_BATCH_SIZE): if overall_correction_factor is not None: overall_correction_factor_sliced = ( overall_correction_factor.flatten()[ i : i + PAT_MAXIMUM_BATCH_SIZE ] ) else: overall_correction_factor_sliced = None new_ts, new_detectors = self._run( ts_raw[i : i + PAT_MAXIMUM_BATCH_SIZE], detectors, overall_correction_factor_sliced, ) new_timeseries.append(np.asarray(new_ts)) new_ts = np.concatenate(new_timeseries, axis=0).reshape( shape[:2] + new_timeseries[0].shape[-2:] ) else: new_ts, new_detectors = self._run( time_series.raw_data, detectors, overall_correction_factor ) # Convert timeseries into an xarray attributes = dict(time_series.attributes) attributes["fs"] *= self.time_factor attributes[PreprocessingAttributeTags.IMPULSE_RESPONSE] = self.irf_correct attributes[PreprocessingAttributeTags.PROCESSING_ALGORITHM] = ( self.get_algorithm_name() ) attributes[PreprocessingAttributeTags.WINDOW_SIZE] = self.window attributes[PreprocessingAttributeTags.ENVELOPE_DETECTION] = ( self.absolute == "abs" ) attributes[PreprocessingAttributeTags.HILBERT_TRANSFORM] = self.hilbert attributes[PreprocessingAttributeTags.DETECTOR_INTERPOLATION] = ( self.detector_factor ) attributes[PreprocessingAttributeTags.TIME_INTERPOLATION] = self.time_factor attributes[PreprocessingAttributeTags.LOW_PASS_FILTER] = self.lp_filter attributes[PreprocessingAttributeTags.HIGH_PASS_FILTER] = self.hp_filter attributes["UniversalBackProjection"] = self.ubp attributes["CorrectionFactorApplied"] = overall_correction_factor is not None coords = dict(time_series.da.coords) coords["detectors"] = np.linspace( 0, time_series.shape[-2] - 1, self.detector_factor * time_series.shape[-2] ) coords["timeseries"] = np.linspace( 0, time_series.shape[-1] - 1, self.time_factor * time_series.shape[-1] + 1 )[:-1] new_data = PATimeSeries( new_ts, time_series.da.dims, coords, attributes=attributes ) return new_data, {"geometry": new_detectors}, None