Source code for patato.recon.backprojection_reference

#  Copyright (c) Thomas Else 2023-25.
#  License: MIT

from typing import Sequence

import numpy as np

from .backprojection_implementation.jax_implementation import full_recon
from .reconstruction_algorithm import ReconstructionAlgorithm

try:
    import jax
except ImportError:
    print(
        """WARNING: JAX must be installed to support the standard backprojection algorithm and filtering.
    Alternatively, try the Numpy implementations.
    """
    )
    jax = None


class ReferenceBackprojection(ReconstructionAlgorithm):
    """Reference backprojection: Uses JAX in the background."""

[docs] def reconstruct( self, time_series: np.ndarray, fs: float, geometry: np.ndarray, n_pixels: Sequence[int], field_of_view: Sequence[float], speed_of_sound, **kwargs, ) -> np.ndarray: """ Parameters ---------- time_series fs geometry n_pixels field_of_view speed_of_sound kwargs. Returns ------- """ # Get parameters: dl = speed_of_sound / fs # Reshape frames so that we can loop through to reconstruct original_shape = time_series.shape[:-2] frames = int(np.prod(original_shape)) signal = time_series.reshape((frames,) + time_series.shape[-2:]) dx = field_of_view[0] / (n_pixels[0] - 1) if n_pixels[0] != 1 else 0 dy = field_of_view[1] / (n_pixels[1] - 1) if n_pixels[1] != 1 else 0 dz = field_of_view[2] / (n_pixels[2] - 1) if n_pixels[2] != 1 else 0 recon_all = jax.vmap(full_recon, in_axes=(0,) + (None,) * 8, out_axes=0) output = recon_all( signal, geometry, dl, n_pixels[0], n_pixels[1], n_pixels[2], dx, dy, dz ) return output.reshape(original_shape + tuple(n_pixels)[::-1])
[docs] @staticmethod def get_algorithm_name() -> str: """ Returns ------- """ return "Reference Backprojection"