Source code for patato.processing.gpu_preprocessing_algorithm

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

from __future__ import annotations

from typing import Union, TYPE_CHECKING, Tuple, Optional

from .processing_algorithm import ProcessingResult, TimeSeriesProcessingAlgorithm

if TYPE_CHECKING:
    from ..io.msot_data import PAData

from ..core.image_structures.pa_time_data import PATimeSeries

try:
    import cupy as cp
    from cupy.fft import fft, ifft, fftshift, fftfreq
except ImportError:
    cp = None

from ..io.attribute_tags import PreprocessingAttributeTags
import time
from patato.unmixing.spectra import SPECTRA_NAMES
import logging


[docs] class GPUMSOTPreProcessor(TimeSeriesProcessingAlgorithm): @staticmethod def get_algorithm_name() -> str: return "GPU Standard Preprocessor" @staticmethod def get_hdf5_group_name() -> Union[str, None]: return None
[docs] def __init__( self, time_factor=3, detector_factor=2, irf=True, hilbert=True, lp_filter=None, hp_filter=None, filter_window_size=512, window: Union[str, None] = "hann", absolute: Union[bool, str] = None, couplant_correction=None, couplant_path_length=0, ): super().__init__() self.time_factor = time_factor self.detector_factor = detector_factor absolute = "imag" if absolute is None and hilbert else absolute self.irf_correct = irf self.hilbert = hilbert self.lp_filter = lp_filter self.hp_filter = hp_filter self.n_filter = filter_window_size self.window = window self.absolute = absolute if couplant_correction is not None: self.couplant_correction = SPECTRA_NAMES[couplant_correction] else: self.couplant_correction = None self.couplant_path_length = couplant_path_length
def run( self, time_series: PATimeSeries, pa_data: PAData, irf=None, detectors=None, **kwargs, ) -> Tuple["PATimeSeries", dict, Optional[ProcessingResult]]: if irf is None: irf = pa_data.get_impulse_response() if detectors is None: detectors = pa_data.get_scan_geometry() fs = time_series.attributes["fs"] if pa_data.get_overall_correction_factor() is not None: overall_correction_factor = cp.array( pa_data.get_overall_correction_factor() ) # Generate the filter ft_filter = self.make_filter( time_series.shape[-1], fs, cp.array(irf) if self.irf_correct else None, self.hilbert, self.lp_filter, self.hp_filter, n_filter=self.n_filter, window=self.window, ) new_time_series, new_parameters = self._run( time_series, ft_filter, cp.array(detectors), cp.array(overall_correction_factor), **kwargs, ) # Update the results' attributes. for a in time_series.attributes: if a not in new_time_series.attributes: new_time_series.attributes[a] = time_series.attributes[a] new_time_series.attributes[ PreprocessingAttributeTags.IMPULSE_RESPONSE ] = self.irf_correct new_time_series.attributes[ PreprocessingAttributeTags.PROCESSING_ALGORITHM ] = self.get_algorithm_name() new_time_series.attributes[PreprocessingAttributeTags.WINDOW_SIZE] = self.window new_time_series.attributes[PreprocessingAttributeTags.ENVELOPE_DETECTION] = ( self.absolute == "abs" ) new_time_series.attributes[ PreprocessingAttributeTags.HILBERT_TRANSFORM ] = self.hilbert new_time_series.attributes[ PreprocessingAttributeTags.DETECTOR_INTERPOLATION ] = self.detector_factor new_time_series.attributes[ PreprocessingAttributeTags.TIME_INTERPOLATION ] = self.time_factor new_time_series.attributes[ PreprocessingAttributeTags.LOW_PASS_FILTER ] = self.lp_filter new_time_series.attributes[ PreprocessingAttributeTags.HIGH_PASS_FILTER ] = self.hp_filter new_time_series.attributes["CorrectionFactorApplied"] = ( overall_correction_factor is not None ) # TODO: replace pa_data with new attributes for further processing steps. return new_time_series, new_parameters, None def _run( self, time_series: PATimeSeries, ft_filter, detectors, overall_correction_factor, **kwargs, ) -> Tuple[PATimeSeries, dict]: new_parameters = {} # Subtract mean extend = (slice(None, None),) * (time_series.raw_data.ndim - 1) + (None,) # Subtract mean t = time.time() raw_data = cp.array(time_series.raw_data) time_series.raw_data = raw_data - cp.mean(raw_data, axis=-1)[extend] logging.debug(f"Mean subtraction took {time.time() - t}s") # Apply a fourier domain filter. t = time.time() time_series_ft = fft(time_series.raw_data, axis=-1) time_series_ft = self.apply_filter(time_series_ft, ft_filter=ft_filter) operation = ( cp.real if self.absolute == "real" or self.absolute is None else cp.imag if self.absolute == "imag" else cp.abs ) time_series.raw_data = operation(ifft(time_series_ft, axis=-1)) # Go back to the time domain. logging.debug(f"Filter took {time.time() - t}s") # Apply interpolation in time and detector domains. time_series, interp_params = self.interpolate(time_series, detectors) new_parameters.update(interp_params) # Apply energy correction factor extend = (slice(None, None),) * overall_correction_factor.ndim + (None, None) time_series.raw_data /= overall_correction_factor[extend] # Apply the couplant correction. if self.couplant_correction is not None: # TODO: This needs rewriting. raise NotImplementedError("Couplant correction is not implemented.") # extend = (None,) * (time_series.ndim - 3) + (slice(None, None), None, None) # time_series.raw_data *= \ # np.exp(self.couplant_path_length * self.couplant_correction.get_spectrum(pa_data.get_wavelengths()))[ # extend] time_series.raw_data = time_series.raw_data.get() return time_series, new_parameters def interpolate( self, time_series: PATimeSeries, detectors, exact_ratios=True ) -> Tuple[PATimeSeries, dict]: # Interpolate the data in the time and detector domains. # Interpolate in the detector domain if self.time_factor == 1 and self.detector_factor == 1: return time_series, {"geometry": detectors} signal = time_series.raw_data # Get the new indices on which to interpolate. detector_ind = cp.arange(detectors.shape[0]) new_detector_ind = ( cp.arange((detectors.shape[0] - 1) * self.detector_factor + 1) / self.detector_factor ) if exact_ratios: new_detector_ind = cp.linspace( 0, detectors.shape[0] - 1, self.detector_factor * detectors.shape[0] ) # Interpolate in the sample domain cp.arange(signal.shape[-1]) new_samp_ind = ( cp.arange((signal.shape[-1] - 1) * self.time_factor + 1) / self.time_factor ) if exact_ratios: new_samp_ind = cp.linspace( 0, signal.shape[-1] - 1, self.time_factor * signal.shape[-1] ) new_signal = cp.zeros( signal.shape[:-2] + (new_detector_ind.shape[0], new_samp_ind.shape[0]), dtype=signal.dtype, ) t = time.time() sind, dind = cp.meshgrid(new_samp_ind, new_detector_ind) # print(dind.shape, signal.shape) new_signal += ( signal[..., dind.astype(cp.int32), sind.astype(cp.int32)] * (1 - (dind - cp.floor(dind))) * (1 - (sind - cp.floor(sind))) ) new_signal += ( signal[..., dind.astype(cp.int32) + 1, sind.astype(cp.int32)] * (dind - cp.floor(dind)) * (1 - (sind - cp.floor(sind))) ) new_signal += ( signal[..., dind.astype(cp.int32), sind.astype(cp.int32) + 1] * (1 - (dind - cp.floor(dind))) * (sind - cp.floor(sind)) ) new_signal += ( signal[..., dind.astype(cp.int32) + 1, sind.astype(cp.int32) + 1] * (dind - cp.floor(dind)) * (sind - cp.floor(sind)) ) # # assert 1==2 detectors = cp.apply_along_axis( lambda x: cp.interp(new_detector_ind, detector_ind, x), 0, detectors ) logging.debug(f"Interpolation took {time.time() - t}.") # Update the xarray coordinates of the new dataset. coords = dict(time_series.da.coords) coords["detectors"] = new_detector_ind.get() coords["timeseries"] = new_samp_ind.get() attributes = dict(time_series.da.attrs) attributes["fs"] *= self.time_factor new_data = PATimeSeries( new_signal, time_series.da.dims, coords, attributes=attributes ) return new_data, {"geometry": detectors.get()} @staticmethod def apply_filter(pa_data: cp.ndarray, ft_filter) -> cp.ndarray: extend = (None,) * (pa_data.ndim - 1) + (slice(None, None),) pa_data *= ft_filter[extend] return pa_data @staticmethod def make_filter( n_samples, fs, irf, hilbert, lp_filter, hp_filter, rise=0.2, n_filter=1024, window=None, ) -> cp.ndarray: # at the moment, it looks like it is shifting the data a bit?? # Impulse Response Correction output = cp.ones((n_samples,), dtype=cp.cdouble) if irf is not None: irf_shifted = cp.zeros_like(irf) irf_shifted[: irf.shape[0] // 2] = irf[irf.shape[0] // 2 :] irf_shifted[-irf.shape[0] // 2 :] = irf[: irf.shape[0] // 2] output *= cp.conj(fft(irf_shifted)) / cp.abs(fft(irf_shifted)) ** 2 from scipy.signal.windows import hann output *= fftshift(hann(n_samples)) # Hilbert Transform frequencies = fftfreq(n_samples) if hilbert: output *= (1 + cp.sign(frequencies)) / 2 frequencies = cp.abs(fftfreq(n_filter, 1 / fs)) fir_filter = cp.ones_like(frequencies, dtype=cp.cdouble) if hp_filter is not None: fir_filter[frequencies < hp_filter * (1 - rise)] = 0 in_rise = cp.logical_and( frequencies > hp_filter * (1 - rise), frequencies < hp_filter ) fir_filter[in_rise] = (frequencies[in_rise] - hp_filter * (1 - rise)) / ( hp_filter * rise ) if lp_filter is not None: fir_filter[frequencies > lp_filter * (1 + rise)] = 0 in_rise = cp.logical_and( frequencies < lp_filter * (1 + rise), frequencies > lp_filter ) fir_filter[in_rise] = 1 - (frequencies[in_rise] - lp_filter) / ( lp_filter * rise ) time_series = ifft(fir_filter) if window == "hann": from scipy.signal.windows import hann time_series *= fftshift(hann(n_filter)) filter_time = cp.zeros_like(output) filter_time[: n_filter // 2] = time_series[: n_filter // 2] filter_time[-n_filter // 2 :] = time_series[-n_filter // 2 :] fir_filter = fft(filter_time) output *= fir_filter return output