Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 224 additions & 4 deletions src/imperandi/qc/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.path import Path as MplPath
from matplotlib.widgets import EllipseSelector, LassoSelector, RectangleSelector
import ipywidgets as widgets
from IPython.display import clear_output, display

Expand Down Expand Up @@ -51,8 +53,8 @@
JUMP_NAV_BUTTON_WIDTH = "120px"


def load_nifti(file_path, orientation="LAS"):
"""Load a NIfTI file and return the image data oriented in RAS+."""
def load_nifti(file_path, orientation="LAS", with_affine=False):
"""Load a NIfTI file and return oriented data and optionally oriented affine."""
img = nib.load(Path(file_path).resolve())
data = img.get_fdata()
affine = img.affine
Expand All @@ -62,7 +64,13 @@ def load_nifti(file_path, orientation="LAS"):
elif orientation == "LAS":
new_ornt = np.array([[0, -1], [1, 1], [2, 1]])
transform = nib.orientations.ornt_transform(current_ornt, new_ornt)
return nib.orientations.apply_orientation(data, transform)
oriented_data = nib.orientations.apply_orientation(data, transform)
if not with_affine:
return oriented_data

inv_affine = nib.orientations.inv_ornt_aff(transform, data.shape)
oriented_affine = affine @ inv_affine
return oriented_data, oriented_affine


def clip_hu_values(ct_scan, min_hu, max_hu):
Expand Down Expand Up @@ -138,6 +146,9 @@ def __init__(
self.canvas_size_px = DISPLAY_CANVAS_PX
self.figure_dpi = FIGURE_DPI
self.image_aspect = "auto"
self.ct_affine = np.eye(4)
self.annotation_selector = None
self.annotations_current_scan = []

if self.exploration_mode == "random":
self.explored_history = [self.current_index]
Expand Down Expand Up @@ -248,6 +259,20 @@ def init_widgets(self):
)

self.info_display = widgets.HTML(value="")
self.annotation_summary = widgets.HTML(value="<i>No annotations</i>")
self.annotation_mode = widgets.Dropdown(
options=[
("Bounding box", "bounding_box"),
("Circle", "circle"),
("Freehand", "freehand"),
],
value="bounding_box",
description="Tool",
layout=widgets.Layout(width="100%", min_width="0px"),
)
self.annotation_mode.observe(self.on_annotation_mode_change, names="value")
self.clear_annotations_button = widgets.Button(description="Clear all")
self.clear_annotations_button.on_click(self.on_clear_annotations)

if self.segmentation_cols:
for seg_name in self.segmentation_cols:
Expand Down Expand Up @@ -435,6 +460,15 @@ def init_widgets(self):
[widgets.HTML("<b>Largest Surface Slice</b>"), center_button_and_dropdown],
layout=group_layout,
)
annotation_group = widgets.VBox(
[
widgets.HTML("<b>Annotations (class: tumor)</b>"),
self.annotation_mode,
self.clear_annotations_button,
self.annotation_summary,
],
layout=group_layout,
)
progress_group = widgets.VBox(
[self.progress_bar],
layout=widgets.Layout(
Expand All @@ -454,6 +488,7 @@ def init_widgets(self):
right_items = [
progress_group,
center_group,
annotation_group,
overlay_group,
window_group,
self.info_container,
Expand Down Expand Up @@ -836,6 +871,181 @@ def on_key_press(self, event):
elif key in {"right", "down"}:
self.on_next_slice_manual(None)

def on_annotation_mode_change(self, change):
self._activate_annotation_selector()

def on_clear_annotations(self, button):
self.annotations_current_scan = []
self._persist_annotations_to_df()
self._refresh_annotation_summary()
self.update_display()

def _activate_annotation_selector(self):
if self.ax is None:
return
if self.annotation_selector is not None:
self.annotation_selector.set_active(False)
self.annotation_selector = None

mode = self.annotation_mode.value
if mode == "bounding_box":
self.annotation_selector = RectangleSelector(
self.ax,
self._on_rectangle_annotation,
useblit=False,
button=[1],
interactive=False,
)
elif mode == "circle":
self.annotation_selector = EllipseSelector(
self.ax,
self._on_ellipse_annotation,
useblit=False,
button=[1],
interactive=False,
)
else:
self.annotation_selector = LassoSelector(
self.ax,
self._on_freehand_annotation,
useblit=False,
button=[1],
)

def _display_to_voxel(self, x_display, y_display, slice_idx=None):
if slice_idx is None:
slice_idx = int(self.slice_slider.value)
xi = int(round(x_display))
yi = int(round(y_display))

x_max, y_max, z_max = self.ct_scan_raw.shape
if self.view_plane == "axial":
voxel = np.array([yi, xi, slice_idx], dtype=float)
elif self.view_plane == "sagittal":
voxel = np.array([slice_idx, yi, xi], dtype=float)
else:
voxel = np.array([yi, slice_idx, xi], dtype=float)

voxel[0] = np.clip(voxel[0], 0, x_max - 1)
voxel[1] = np.clip(voxel[1], 0, y_max - 1)
voxel[2] = np.clip(voxel[2], 0, z_max - 1)
return voxel

def _voxel_to_ct_coordinate(self, voxel_xyz):
return nib.affines.apply_affine(self.ct_affine, voxel_xyz)

def _sample_ellipse(self, x1, y1, x2, y2, num_samples=32):
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
rx = abs(x2 - x1) / 2.0
ry = abs(y2 - y1) / 2.0
theta = np.linspace(0, 2 * np.pi, num_samples, endpoint=False)
return [(cx + rx * np.cos(t), cy + ry * np.sin(t)) for t in theta]

def _record_annotation(self, shape_type, display_points):
if not display_points:
return
slice_idx = int(self.slice_slider.value)
voxel_points = [self._display_to_voxel(x, y, slice_idx) for x, y in display_points]
ct_points = [self._voxel_to_ct_coordinate(p).tolist() for p in voxel_points]

annotation = {
"label": "tumor",
"shape": shape_type,
"plane": self.view_plane,
"slice_idx": slice_idx,
"display_points": [[float(x), float(y)] for x, y in display_points],
"voxel_points": [p.tolist() for p in voxel_points],
"ct_points": ct_points,
}
self.annotations_current_scan.append(annotation)
self._persist_annotations_to_df()
self._refresh_annotation_summary()
self.update_display()

def _on_rectangle_annotation(self, eclick, erelease):
if eclick.xdata is None or eclick.ydata is None:
return
if erelease.xdata is None or erelease.ydata is None:
return
x1, y1 = float(eclick.xdata), float(eclick.ydata)
x2, y2 = float(erelease.xdata), float(erelease.ydata)
self._record_annotation("bounding_box", [(x1, y1), (x2, y2)])

def _on_ellipse_annotation(self, eclick, erelease):
if eclick.xdata is None or eclick.ydata is None:
return
if erelease.xdata is None or erelease.ydata is None:
return
x1, y1 = float(eclick.xdata), float(eclick.ydata)
x2, y2 = float(erelease.xdata), float(erelease.ydata)
self._record_annotation("circle", self._sample_ellipse(x1, y1, x2, y2))

def _on_freehand_annotation(self, verts):
if not verts:
return
self._record_annotation("freehand", [(float(x), float(y)) for x, y in verts])

def _persist_annotations_to_df(self):
if "annotations" not in self.df.columns:
self.df["annotations"] = None
self.df.at[self.current_index, "annotations"] = list(self.annotations_current_scan)

def _refresh_annotation_summary(self):
if not self.annotations_current_scan:
self.annotation_summary.value = "<i>No annotations</i>"
return
counts = {}
for item in self.annotations_current_scan:
shape = item.get("shape", "unknown")
counts[shape] = counts.get(shape, 0) + 1
parts = [f"{shape}: {count}" for shape, count in sorted(counts.items())]
self.annotation_summary.value = (
f"<b>{len(self.annotations_current_scan)}</b> annotation(s)<br>" + ", ".join(parts)
)

def _draw_annotations_overlay(self):
if not self.annotations_current_scan:
return
slice_idx = int(self.slice_slider.value)
for annotation in self.annotations_current_scan:
if annotation.get("plane") != self.view_plane:
continue
if int(annotation.get("slice_idx", -1)) != slice_idx:
continue
pts = annotation.get("display_points", [])
if len(pts) < 2:
continue
if annotation.get("shape") == "bounding_box":
(x1, y1), (x2, y2) = pts[0], pts[1]
rect = mpatches.Rectangle(
(min(x1, x2), min(y1, y2)),
abs(x2 - x1),
abs(y2 - y1),
fill=False,
edgecolor="yellow",
linewidth=1.8,
)
self.ax.add_patch(rect)
elif annotation.get("shape") == "circle":
path = MplPath(np.asarray(pts, dtype=float))
patch = mpatches.PathPatch(
path,
fill=False,
edgecolor="yellow",
linewidth=1.8,
)
self.ax.add_patch(patch)
else:
poly = mpatches.Polygon(
np.asarray(pts, dtype=float),
closed=True,
fill=False,
edgecolor="yellow",
linewidth=1.8,
)
self.ax.add_patch(poly)

def load_data(self):
self.progress_bar.layout.visibility = "visible"
self.progress_bar.value = 0
Expand All @@ -844,7 +1054,15 @@ def load_data(self):

row = self.df.iloc[self.current_index]
self.progress_bar.value = 0.1
self.ct_scan_raw = load_nifti(row[self.ct_scan_col])
self.ct_scan_raw, self.ct_affine = load_nifti(
row[self.ct_scan_col], with_affine=True
)
existing_annotations = row.get("annotations", None)
if isinstance(existing_annotations, list):
self.annotations_current_scan = list(existing_annotations)
else:
self.annotations_current_scan = []
self._refresh_annotation_summary()

self.segmentations = {}
if self.segmentation_cols:
Expand Down Expand Up @@ -895,6 +1113,7 @@ def update_slice_slider(self):
self.slice_slider.value = min(self.slice_idx, self.slice_slider.max)
self.slice_slider.observe(self.on_slice_change, names="value")
self.update_display()
self._activate_annotation_selector()

def update_display(self, *_):
if self.ct_scan_raw is None or self.ax is None:
Expand Down Expand Up @@ -973,6 +1192,7 @@ def update_display(self, *_):
framealpha=0.6,
)

self._draw_annotations_overlay()
self.ax.axis("off")
if self._uses_output_fallback:
self._render_output_figure()
Expand Down
47 changes: 47 additions & 0 deletions tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest
import ipywidgets as widgets
import numpy as np
import pandas as pd

# Ensure src/ is on sys.path for imports
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
Expand Down Expand Up @@ -67,3 +69,48 @@ def test_exam_nav_buttons_disabled_when_single_exam():
assert viewer.next_date_button.disabled is True
assert viewer.prev_patient_button.disabled is True
assert viewer.next_patient_button.disabled is True


def test_display_to_voxel_and_ct_coordinates_axial():
viewer = CTScanViewer.__new__(CTScanViewer)
viewer.view_plane = "axial"
viewer.ct_scan_raw = np.zeros((8, 9, 10))
viewer.slice_slider = widgets.IntSlider(value=4)
viewer.ct_affine = np.array(
[
[2.0, 0.0, 0.0, 10.0],
[0.0, 3.0, 0.0, 20.0],
[0.0, 0.0, 4.0, 30.0],
[0.0, 0.0, 0.0, 1.0],
]
)

voxel = viewer._display_to_voxel(6.0, 5.0)
assert voxel.tolist() == [5.0, 6.0, 4.0]

coord = viewer._voxel_to_ct_coordinate(voxel)
assert coord.tolist() == [20.0, 38.0, 46.0]


def test_record_annotation_persists_in_dataframe():
viewer = CTScanViewer.__new__(CTScanViewer)
viewer.view_plane = "axial"
viewer.ct_scan_raw = np.zeros((10, 10, 10))
viewer.slice_slider = widgets.IntSlider(value=3)
viewer.ct_affine = np.eye(4)
viewer.current_index = 0
viewer.df = pd.DataFrame([{"patient_key": "p1"}])
viewer.annotations_current_scan = []
viewer.annotation_summary = widgets.HTML(value="")
viewer.update_display = lambda *_: None

viewer._record_annotation("bounding_box", [(1.0, 2.0), (4.0, 5.0)])

stored = viewer.df.at[0, "annotations"]
assert isinstance(stored, list)
assert len(stored) == 1
ann = stored[0]
assert ann["label"] == "tumor"
assert ann["shape"] == "bounding_box"
assert ann["slice_idx"] == 3
assert ann["voxel_points"][0] == [2.0, 1.0, 3.0]