# Copyright (c) Thomas Else 2023-25.
# License: MIT
import argparse
import glob
from os.path import join, split
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Slider, RangeSlider
from .. import PAData
from ..utils import sort_key
from ..core.image_structures.single_parameter_data import SingleParameterData
[docs]
def init_argparse():
parser = argparse.ArgumentParser(description="View MSOT Recons.")
parser.add_argument("input", type=str, help="Data Folder")
parser.add_argument("-f", "--filter", type=str, help="Choose scan", default=None)
parser.add_argument(
"-r", "--recon", type=int, help="Reconstruction number", default=0
)
parser.add_argument(
"-fn", "--filtername", type=str, help="Choose scan name filter", default=None
)
parser.add_argument("-t", "--thb", type=bool, help="Draw on THb", default=False)
parser.add_argument("-s", "--so2", type=bool, help="Draw on SigSo2", default=False)
return parser
[docs]
def main():
p = init_argparse()
args = p.parse_args()
DATA_FOLDER = args.input
for file in sorted(glob.glob(join(DATA_FOLDER, "*.hdf5")), key=sort_key):
if args.filter is not None:
if split(file)[-1] != "Scan_" + args.filter + ".hdf5":
continue
data = PAData.from_hdf5(file)
if args.filtername is not None:
if str.lower(args.filtername) not in str.lower(data.get_scan_name()):
continue
print(file)
scan_name = data.get_scan_name()
# by default just choose the first reconstruction.
if not args.thb and not args.so2:
reconstructions = data.get_scan_reconstructions()
else:
group_name = "thb" if args.thb else "so2"
reconstructions = data.get_scan_images(group_name, SingleParameterData)
if reconstructions == {}:
print(f"{file} has not been reconstructed. Skipping.")
continue
methods = list(reconstructions.keys())
if len(methods) > 1:
print("Multiple reconstructions available, using", methods[0])
recon_data = reconstructions[methods[0]]
extents = recon_data.extent
recon = recon_data.raw_data
frame_n = 0
wl = 0
iqr = np.nanpercentile(recon[frame_n, wl], 95) - np.nanpercentile(
recon[frame_n, wl], 5
)
median = np.median(recon[frame_n, wl])
range_interest = (median - 3 * iqr, median + 3 * iqr)
fig = plt.figure()
plt.subplots_adjust(bottom=0.3)
ax, ax2 = fig.subplots(1, 2)
p = ax.imshow(
np.squeeze(recon[frame_n, wl]), extent=extents, clim=range_interest
)
ax2.hist(recon[frame_n, wl].flatten())
vlinea = ax2.axvline(np.nanmin(recon[frame_n, wl]))
vlineb = ax2.axvline(np.nanmax(recon[frame_n, wl]))
ax_clims = plt.axes([0.25, 0.15, 0.5, 0.03])
ax_slide = plt.axes([0.25, 0.11, 0.5, 0.03])
ax_frame = plt.axes([0.25, 0.07, 0.5, 0.03])
wavelength = Slider(
ax_slide, "Wavelength", 0, recon.shape[1] - 1, valinit=0, valstep=1
)
frame = Slider(ax_frame, "Frame", 0, recon.shape[0] - 1, valinit=0, valstep=1)
clims = RangeSlider(
ax_clims,
"Clim Range",
range_interest[0],
range_interest[1],
valinit=range_interest,
)
ax_text = plt.axes([0.3, 0.95, 0.4, 0.04])
ax_text.axis("off")
ax_text.annotate(
scan_name,
(0.5, 0),
xycoords="axes fraction",
annotation_clip=False,
horizontalalignment="center",
verticalalignment="bottom",
)
def update(_):
nonlocal frame_n, wl
frame_n = frame.val
wl = wavelength.val
p.set_data(np.squeeze(recon[frame_n, wl]))
fig.canvas.draw()
def update_clim(val):
cmin = val[0]
cmax = val[1]
vlinea.set_data(([cmin, cmin], [0, 1]))
vlineb.set_data(([cmax, cmax], [0, 1]))
p.set_clim([cmin, cmax])
frame.on_changed(update)
wavelength.on_changed(update)
clims.on_changed(update_clim)
plt.show()