# Copyright (c) Thomas Else 2023.
# License: MIT
from __future__ import annotations
from typing import Union, TYPE_CHECKING, Tuple, Optional
import numpy as np
from .processing_algorithm import TimeSeriesProcessingAlgorithm, ProcessingResult
if TYPE_CHECKING:
from ..io.msot_data import PAData
from ..core.image_structures.pa_time_data import PATimeSeries
from scipy.fft import fft, ifft, fftshift, fftfreq
from ..io.attribute_tags import PreprocessingAttributeTags
from patato.unmixing.spectra import SPECTRA_NAMES
[docs]
class NumpyPreProcessor(TimeSeriesProcessingAlgorithm):
[docs]
@staticmethod
def get_algorithm_name() -> str:
"""
Get the name of the algorithm.
Returns
-------
str or None
"""
return "CPU Standard Preprocessor"
[docs]
@staticmethod
def get_hdf5_group_name() -> Union[str, None]:
"""
Return the name of the group in the HDF5 file
Returns
-------
str or None
"""
return None
[docs]
def __init__(self, time_factor=3, detector_factor=2,
irf=True, hilbert=True, lp_filter=None,
hp_filter=None, filter_window_size=512,
window: Union[str, None] = "hann", absolute: Union[bool, str] = None,
couplant_correction=None, couplant_path_length=0):
super().__init__()
self.time_factor = time_factor
self.detector_factor = detector_factor
absolute = "imag" if absolute is None and hilbert else absolute
self.irf_correct = irf
self.hilbert = hilbert
self.lp_filter = lp_filter
self.hp_filter = hp_filter
self.n_filter = filter_window_size
self.window = window
self.absolute = absolute
if couplant_correction is not None:
self.couplant_correction = SPECTRA_NAMES[couplant_correction]
else:
self.couplant_correction = None
self.couplant_path_length = couplant_path_length
def run(self, time_series: PATimeSeries,
pa_data: PAData, irf=None, detectors=None, **kwargs) -> Tuple["PATimeSeries", dict,
Optional[ProcessingResult]]:
if irf is None:
irf = pa_data.get_impulse_response()
if detectors is None:
detectors = pa_data.get_scan_geometry()
fs = time_series.attributes["fs"]
overall_correction_factor = pa_data.get_overall_correction_factor()
# Generate the filter
ft_filter = self.make_filter(time_series.shape[-1],
fs,
irf if self.irf_correct else None,
self.hilbert, self.lp_filter,
self.hp_filter,
n_filter=self.n_filter,
window=self.window)
new_time_series, new_parameters = self._run(time_series,
ft_filter, detectors, overall_correction_factor, **kwargs)
# Update the results' attributes.
for a in time_series.attributes:
if a not in new_time_series.attributes:
new_time_series.attributes[a] = time_series.attributes[a]
new_time_series.attributes[PreprocessingAttributeTags.IMPULSE_RESPONSE] = self.irf_correct
new_time_series.attributes[PreprocessingAttributeTags.PROCESSING_ALGORITHM] = self.get_algorithm_name()
new_time_series.attributes[PreprocessingAttributeTags.WINDOW_SIZE] = self.window
new_time_series.attributes[PreprocessingAttributeTags.ENVELOPE_DETECTION] = self.absolute == "abs"
new_time_series.attributes[PreprocessingAttributeTags.HILBERT_TRANSFORM] = self.hilbert
new_time_series.attributes[PreprocessingAttributeTags.DETECTOR_INTERPOLATION] = self.detector_factor
new_time_series.attributes[PreprocessingAttributeTags.TIME_INTERPOLATION] = self.time_factor
new_time_series.attributes[PreprocessingAttributeTags.LOW_PASS_FILTER] = self.lp_filter
new_time_series.attributes[PreprocessingAttributeTags.HIGH_PASS_FILTER] = self.hp_filter
new_time_series.attributes["CorrectionFactorApplied"] = overall_correction_factor is not None
return new_time_series, new_parameters, None
def _run(self, time_series: PATimeSeries, ft_filter, detectors, overall_correction_factor, **kwargs) -> Tuple[
PATimeSeries, dict]:
new_parameters = {}
# Subtract mean
time_series = time_series.copy()
extend = (slice(None, None),) * (time_series.raw_data.ndim - 1) + (None,)
# Subtract mean
raw_data = np.array(time_series.raw_data)
time_series.raw_data = raw_data - np.mean(raw_data, axis=-1)[extend]
# Apply a fourier domain filter.
time_series_ft = fft(time_series.raw_data, axis=-1)
time_series_ft = self.apply_filter(time_series_ft, ft_filter=ft_filter)
# Go back to the time domain.
operation = np.real if self.absolute == "real" or self.absolute is None else np.imag if self.absolute == "imag" else np.abs
time_series_td = operation(ifft(time_series_ft, axis=-1))
time_series.raw_data = time_series_td
# Apply interpolation in time and detector domains.
time_series, interp_params = self.interpolate(time_series, detectors)
new_parameters.update(interp_params)
# Apply energy correction factor
extend = (slice(None, None),) * overall_correction_factor.ndim + (None, None)
time_series.raw_data /= overall_correction_factor[extend]
time_series.raw_data = time_series.raw_data.copy()
return time_series, new_parameters
def interpolate(self, time_series: PATimeSeries, detectors, exact_ratios=True) -> Tuple[PATimeSeries, dict]:
# Interpolate the data in the time and detector domains.
# Interpolate in the detector domain
if self.time_factor == 1 and self.detector_factor == 1:
return time_series, {"geometry": detectors}
detector_ind = np.arange(detectors.shape[0])
new_detector_ind = np.arange((detectors.shape[0] - 1) * self.detector_factor + 1) / self.detector_factor
if exact_ratios:
new_detector_ind = np.linspace(0, detectors.shape[0] - 1, self.detector_factor * detectors.shape[0])
signal = time_series.raw_data
signal = np.apply_along_axis(lambda x: np.interp(new_detector_ind, detector_ind, x),
-2, signal)
# Get the new detector locations
detectors = np.apply_along_axis(lambda x: np.interp(new_detector_ind, detector_ind, x),
0, detectors)
# Interpolate in the sample domain
sample_ind = np.arange(signal.shape[-1])
new_samp_ind = np.arange((signal.shape[-1] - 1) * self.time_factor + 1) / self.time_factor
if exact_ratios:
new_samp_ind = np.linspace(0, signal.shape[-1] - 1, self.time_factor * signal.shape[-1])
signal = np.apply_along_axis(lambda x: np.interp(new_samp_ind, sample_ind, x),
-1, signal)
# Update the xarray coordinates of the new dataset.
coords = dict(time_series.da.coords)
coords["detectors"] = new_detector_ind
coords["timeseries"] = new_samp_ind
attributes = dict(time_series.da.attrs)
attributes["fs"] *= self.time_factor
new_data = PATimeSeries(signal.copy(), time_series.da.dims, coords, attributes=attributes)
return new_data, {"geometry": detectors}
@staticmethod
def apply_filter(pa_data: np.ndarray, ft_filter) -> np.ndarray:
extend = (None,) * (pa_data.ndim - 1) + (slice(None, None),)
pa_data *= ft_filter[extend]
return pa_data
@staticmethod
def make_filter(n_samples, fs, irf,
hilbert, lp_filter,
hp_filter, rise=0.2,
n_filter=1024, window=None) -> np.ndarray:
# at the moment, it looks like it is shifting the data a bit??
# Impulse Response Correction
output = np.ones((n_samples,), dtype=np.cdouble)
if irf is not None:
irf_shifted = np.zeros_like(irf)
irf_shifted[:irf.shape[0] // 2] = irf[irf.shape[0] // 2:]
irf_shifted[-irf.shape[0] // 2:] = irf[:irf.shape[0] // 2]
output *= np.conj(fft(irf_shifted)) / np.abs(fft(irf_shifted)) ** 2
from scipy.signal.windows import hann
output *= fftshift(hann(n_samples))
# Hilbert Transform
frequencies = fftfreq(n_samples)
if hilbert:
output *= (1 + np.sign(frequencies)) / 2
frequencies = np.abs(fftfreq(n_filter, 1 / fs))
fir_filter = np.ones_like(frequencies, dtype=np.cdouble)
if hp_filter is not None:
fir_filter[frequencies < hp_filter * (1 - rise)] = 0
in_rise = np.logical_and(frequencies > hp_filter * (1 - rise), frequencies < hp_filter)
fir_filter[in_rise] = (frequencies[in_rise] - hp_filter * (1 - rise)) / (hp_filter * rise)
if lp_filter is not None:
fir_filter[frequencies > lp_filter * (1 + rise)] = 0
in_rise = np.logical_and(frequencies < lp_filter * (1 + rise), frequencies > lp_filter)
fir_filter[in_rise] = 1 - (frequencies[in_rise] - lp_filter) / (lp_filter * rise)
time_series = ifft(fir_filter)
if window == "hann":
from scipy.signal.windows import hann
time_series *= fftshift(hann(n_filter))
filter_time = np.zeros_like(output)
filter_time[:n_filter // 2] = time_series[:n_filter // 2]
filter_time[-n_filter // 2:] = time_series[-n_filter // 2:]
fir_filter = fft(filter_time)
output *= fir_filter
return output