Source code for patato.recon.jax_model_based

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT
from typing import Sequence

import numpy as np

from .. import ReconstructionAlgorithm
import numpy.typing as npt
from ... import PAData
from ...processing.preprocessing_algorithm import (
    TimeSeriesProcessingAlgorithm,
    PreprocessingAttributeTags,
)


[docs] class ModelBasedPreProcessor(TimeSeriesProcessingAlgorithm):
[docs] @staticmethod def get_algorithm_name() -> str: """ Get the name of the algorithm. Returns ------- str or None """ return "Model Based Preprocessor"
[docs] @staticmethod def get_hdf5_group_name(): """ Return the name of the group in the HDF5 file. Returns ------- str or None """ return None
[docs] def __init__(self): super().__init__()
def run(self, time_series, pa_data, **kwargs): new_ts = time_series.copy() ts_raw = ( time_series.raw_data - np.mean(time_series.raw_data, axis=-1)[:, :, :, None] ) if pa_data is not None: overall_correction_factor = pa_data.get_overall_correction_factor() ts_raw /= overall_correction_factor[:, :, None, None] else: overall_correction_factor = None new_ts.raw_data = ts_raw # Update the results' attributes. for a in time_series.attributes: if a not in new_ts.attributes: new_ts.attributes[a] = time_series.attributes[a] new_ts.attributes[PreprocessingAttributeTags.PROCESSING_ALGORITHM] = ( self.get_algorithm_name() ) new_ts.attributes["CorrectionFactorApplied"] = ( overall_correction_factor is not None ) return new_ts, {}, None
[docs] class JAXModelBasedReconstruction(ReconstructionAlgorithm): """JAX-based, two-dimensional model based reconstruction algorithm."""
[docs] def __init__(self, n_pixels, field_of_view, pa_example: "PAData" = None, **kwargs): if pa_example is not None: kwargs["model_geometry"] = kwargs.get( "model_geometry", pa_example.get_scan_geometry() ) kwargs["model_fs"] = kwargs.get( "model_fs", pa_example.get_sampling_frequency() ) kwargs["model_nt"] = kwargs.get( "model_nt", pa_example.get_time_series().shape[-1] ) kwargs["model_c"] = kwargs.get("model_c", pa_example.get_speed_of_sound()) kwargs["model_irf"] = kwargs.get( "model_irf", pa_example.get_impulse_response() ) super().__init__(n_pixels, field_of_view, **kwargs) self._batch = False axes = [i for i, x in enumerate(n_pixels) if x != 1] dx = field_of_view[axes[0]] / (n_pixels[axes[0]] - 1) if ( dx != field_of_view[axes[1]] / (n_pixels[axes[1]] - 1) or field_of_view[axes[1]] != field_of_view[axes[0]] ): raise ValueError( "Current model based implementation only supports square reconstruction grids." ) detx = kwargs["model_geometry"][:, axes[0]] dety = kwargs["model_geometry"][:, axes[1]] fs = kwargs["model_fs"] nx = n_pixels[axes[0]] x_0 = -field_of_view[axes[0]] / 2 nt = kwargs["model_nt"] c = kwargs["model_c"] self._nx_model = nx self._model_matrix = self._generate_model(detx, dety, fs, dx, nx, x_0, nt, c) self._model_matrix.eliminate_zeros() self._model_matrix.sort_indices() from jax.experimental.sparse import CSR self._model_regulariser = ( kwargs.get("model_regulariser", None), kwargs.get("model_regulariser_lambda", None), ) # if self._model_regulariser[0] is None: # self._model_regulariser = None self._model_matrix = CSR( ( self._model_matrix.data, self._model_matrix.indices, self._model_matrix.indptr, ), shape=self._model_matrix.shape, ) self._model_max_iter = kwargs.get("model_max_iter", 50) self._model_constraint = kwargs.get("constraint", "positive") self.attributes["model_max_iter"] = self._model_max_iter self.attributes["model_constraint"] = self._model_constraint self.attributes["model_regulariser"] = self._model_regulariser[0] self.attributes["model_regulariser_weighting"] = self._model_regulariser[1] self.attributes["model_c"] = c
def _generate_model( self, detx: npt.ArrayLike, dety: npt.ArrayLike, fs: float, dx: float, nx: int, x_0: float, nt: int, c: float, ): """ Generates the model matrix for given parameters. Parameters ---------- detx dety fs dx nx x_0 nt Returns ------- """ from ..model_based.numpy_implementation import get_model dl = c / fs model_matrix = get_model(detx, dety, dl, dx, nx, x_0, nt, cache=False) return model_matrix def reconstruct( self, raw_data: np.ndarray, fs: float = None, geometry: np.ndarray = None, n_pixels=None, field_of_view=None, speed_of_sound=None, **kwargs, ) -> np.ndarray: import jax import jax.numpy as jnp from jax.scipy.signal import convolve2d import jaxopt from jaxopt.projection import projection_non_negative, projection_box from tqdm.auto import tqdm if self._model_constraint == "positive": projection = projection_non_negative hyperparams = None elif self._model_constraint == "none": projection = projection_box hyperparams = (-np.inf, np.inf) else: raise ValueError("Constraint must either be 'positive' or 'none'.") M = self._model_matrix if self._model_regulariser is not None or all( [x is None for x in self._model_regulariser] ): method, lambda_reg = self._model_regulariser else: method, lambda_reg = None, None if method == "laplacian": conv_mat = jnp.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) / 8 else: conv_mat = None if method not in ["identity", "laplacian", None]: raise ValueError( "Regularisation method must either be identity, laplacian or None." ) @jax.jit def forward(params, y, M, lambda_reg=None, conv_mat=None): x = M @ params.flatten() residuals = x - y.flatten() loss1 = jnp.mean(residuals**2) if lambda_reg is not None: if conv_mat is None: loss2 = lambda_reg * jnp.mean(params**2) else: loss2 = lambda_reg * jnp.mean(convolve2d(params, conv_mat) ** 2) else: loss2 = 0 return loss1 + loss2, {"loss1": loss1, "loss2": loss2} # value, aux def rec(time_series): opt = jaxopt.ProjectedGradient( forward, projection=projection, maxiter=self._model_max_iter, has_aux=True, acceleration=True, ) result = opt.run( jnp.zeros((self._nx_model, self._nx_model)), y=jnp.array(time_series), M=M, lambda_reg=lambda_reg, conv_mat=conv_mat, hyperparams_proj=hyperparams, ) return result.params, result.state output_shape = raw_data.shape[:-2] raw_data = raw_data.reshape((-1,) + raw_data.shape[-2:]) output = np.zeros((raw_data.shape[0], M.shape[1])) for i in tqdm(range(raw_data.shape[0]), leave=False): ts = jnp.array(raw_data[i] - np.mean(raw_data[i], axis=-1)[:, None]) params, state = rec(ts) output[i] = np.array(params.reshape(output[i].shape)) self._prev_state = state return output.reshape(output_shape + self.n_pixels[::-1]) @staticmethod def get_algorithm_name() -> str: return "JAX Model Based Reconstruction"