Skip to content

Commit b99ea5a

Browse files
committed
ENH: Add transform analysis utils
Add transform analysis utils. Transfer contents from the `NiFreeze` projects so that hey can be reused across projects requiring transform analysis: https://github.com/nipreps/nifreeze/blob/d27ba7552bbd9095c3c13b46443d87a4b5504c4c/src/nifreeze/analysis/motion.py https://github.com/nipreps/nifreeze/blob/d27ba7552bbd9095c3c13b46443d87a4b5504c4c/src/nifreeze/data/utils.py Add a fixture to be able to reuse a random number generator across tests.
1 parent 7845e64 commit b99ea5a

File tree

4 files changed

+384
-0
lines changed

4 files changed

+384
-0
lines changed

nitransforms/analysis/__init__.py

Whitespace-only changes.

nitransforms/analysis/utils.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""
4+
Utilities to aid in performing and evaluating image registration.
5+
6+
This module provides functions to compute displacements of image coordinates
7+
under a transformation, useful for assessing the accuracy of image registration
8+
processes.
9+
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from itertools import product
15+
from typing import Tuple
16+
17+
import nibabel as nb
18+
import numpy as np
19+
from scipy.stats import zscore
20+
21+
from nitransforms.base import TransformBase
22+
23+
24+
RADIUS = 50.0
25+
"""Typical radius (in mm) of a sphere mimicking the size of a typical human brain."""
26+
27+
28+
def compute_fd_from_motion(motion_parameters: np.ndarray, radius: float = RADIUS) -> np.ndarray:
29+
"""Compute framewise displacement (FD) from motion parameters.
30+
31+
Each row in the motion parameters represents one frame, and columns
32+
represent each coordinate axis ``x``, `y``, and ``z``. Translation
33+
parameters are followed by rotation parameters column-wise.
34+
35+
Parameters
36+
----------
37+
motion_parameters : :obj:`numpy.ndarray`
38+
Motion parameters.
39+
radius : :obj:`float`, optional
40+
Radius (in mm) of a sphere mimicking the size of a typical human brain.
41+
42+
Returns
43+
-------
44+
:obj:`numpy.ndarray`
45+
The framewise displacement (FD) as the sum of absolute differences
46+
between consecutive frames.
47+
"""
48+
49+
translations = motion_parameters[:, :3]
50+
rotations_deg = motion_parameters[:, 3:]
51+
rotations_rad = np.deg2rad(rotations_deg)
52+
53+
# Compute differences between consecutive frames
54+
d_translations = np.vstack([np.zeros((1, 3)), np.diff(translations, axis=0)])
55+
d_rotations = np.vstack([np.zeros((1, 3)), np.diff(rotations_rad, axis=0)])
56+
57+
# Convert rotations from radians to displacement on a sphere
58+
rotation_displacement = d_rotations * radius
59+
60+
# Compute FD as sum of absolute differences
61+
return np.sum(np.abs(d_translations) + np.abs(rotation_displacement), axis=1)
62+
63+
64+
def compute_fd_from_transform(
65+
img: nb.spatialimages.SpatialImage,
66+
test_xfm: TransformBase,
67+
radius: float = RADIUS,
68+
) -> float:
69+
"""
70+
Compute the framewise displacement (FD) for a given transformation.
71+
72+
Parameters
73+
----------
74+
img : :obj:`~nibabel.spatialimages.SpatialImage`
75+
The reference image. Used to extract the center coordinates.
76+
test_xfm : :obj:`~nitransforms.base.TransformBase`
77+
The transformation to test. Applied to coordinates around the image center.
78+
radius : :obj:`float`, optional
79+
The radius (in mm) of the spherical neighborhood around the center of the image.
80+
81+
Returns
82+
-------
83+
:obj:`float`
84+
The average framewise displacement (FD) for the test transformation.
85+
86+
"""
87+
affine = img.affine
88+
# Compute the center of the image in voxel space
89+
center_ijk = 0.5 * (np.array(img.shape[:3]) - 1)
90+
# Convert to world coordinates
91+
center_xyz = nb.affines.apply_affine(affine, center_ijk)
92+
# Generate coordinates of points at radius distance from center
93+
fd_coords = np.array(list(product(*((radius, -radius),) * 3))) + center_xyz
94+
# Compute the average displacement from the test transformation
95+
return np.mean(np.linalg.norm(test_xfm.map(fd_coords) - fd_coords, axis=-1))
96+
97+
98+
def displacements_within_mask(
99+
mask_img: nb.spatialimages.SpatialImage,
100+
test_xfm: TransformBase,
101+
reference_xfm: TransformBase | None = None,
102+
) -> np.ndarray:
103+
"""
104+
Compute the distance between voxel coordinates mapped through two transforms.
105+
106+
Parameters
107+
----------
108+
mask_img : :obj:`~nibabel.spatialimages.SpatialImage`
109+
A mask image that defines the region of interest. Voxel coordinates
110+
within the mask are transformed.
111+
test_xfm : :obj:`~nitransforms.base.TransformBase`
112+
The transformation to test. This transformation is applied to the
113+
voxel coordinates.
114+
reference_xfm : :obj:`~nitransforms.base.TransformBase`, optional
115+
A reference transformation to compare with. If ``None``, the identity
116+
transformation is assumed (no transformation).
117+
118+
Returns
119+
-------
120+
:obj:`~numpy.ndarray`
121+
An array of displacements (in mm) for each voxel within the mask.
122+
123+
"""
124+
# Mask data as boolean (True for voxels inside the mask)
125+
maskdata = np.asanyarray(mask_img.dataobj) > 0
126+
# Convert voxel coordinates to world coordinates using affine transform
127+
xyz = nb.affines.apply_affine(
128+
mask_img.affine,
129+
np.argwhere(maskdata),
130+
)
131+
# Apply the test transformation
132+
targets = test_xfm.map(xyz)
133+
134+
# Compute the difference (displacement) between the test and reference transformations
135+
diffs = targets - xyz if reference_xfm is None else targets - reference_xfm.map(xyz)
136+
return np.linalg.norm(diffs, axis=-1)
137+
138+
139+
def extract_motion_parameters(affine: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
140+
"""Extract translation (mm) and rotation (degrees) parameters from an affine matrix.
141+
142+
Parameters
143+
----------
144+
affine : :obj:`~numpy.ndarray`
145+
The affine transformation matrix.
146+
147+
Returns
148+
-------
149+
:obj:`tuple`
150+
Extracted translation and rotation parameters.
151+
"""
152+
153+
translation = affine[:3, 3]
154+
rotation_rad = np.arctan2(
155+
[affine[2, 1], affine[0, 2], affine[1, 0]], [affine[2, 2], affine[0, 0], affine[1, 1]]
156+
)
157+
rotation_deg = np.rad2deg(rotation_rad)
158+
return *translation, *rotation_deg
159+
160+
161+
def identify_spikes(fd: np.ndarray, threshold: float = 2.0):
162+
"""Identify motion spikes in framewise displacement data.
163+
164+
Identifies high-motion frames as timepoint exceeding a given threshold value
165+
based on z-score normalized framewise displacement (FD) values.
166+
167+
Parameters
168+
----------
169+
fd : :obj:`~numpy.ndarray`
170+
Framewise displacement data.
171+
threshold : :obj:`float`, optional
172+
Threshold value to determine motion spikes.
173+
174+
Returns
175+
-------
176+
indices : :obj:`~numpy.ndarray`
177+
Indices of identified motion spikes.
178+
mask : :obj:`~numpy.ndarray`
179+
Mask of identified motion spikes.
180+
"""
181+
182+
# Normalize (z-score)
183+
fd_norm = zscore(fd)
184+
185+
mask = fd_norm > threshold
186+
indices = np.where(mask)[0]
187+
188+
return indices, mask

nitransforms/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
4+
import numpy as np
5+
import pytest
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def random_number_generator(request):
10+
"""Automatically set a fixed-seed random number generator for all tests."""
11+
request.node.rng = np.random.default_rng(1234)
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
4+
import numpy as np
5+
import nibabel as nb
6+
import pytest
7+
8+
import nitransforms as nt
9+
10+
from nitransforms.analysis.utils import (
11+
compute_fd_from_motion,
12+
compute_fd_from_transform,
13+
compute_percentage_change,
14+
displacements_within_mask,
15+
extract_motion_parameters,
16+
identify_spikes,
17+
)
18+
19+
20+
@pytest.fixture
21+
def identity_affine():
22+
return np.eye(4)
23+
24+
25+
@pytest.fixture
26+
def simple_mask_img(identity_affine):
27+
# 3x3x3 mask with center voxel as 1, rest 0
28+
data = np.zeros((3, 3, 3), dtype=np.uint8)
29+
data[1, 1, 1] = 1
30+
return nb.Nifti1Image(data, identity_affine)
31+
32+
33+
@pytest.fixture
34+
def translation_transform():
35+
# Simple translation of (1, 2, 3) mm
36+
return nt.linear.Affine(map=np.array([
37+
[1, 0, 0, 1],
38+
[0, 1, 0, 2],
39+
[0, 0, 1, 3],
40+
[0, 0, 0, 1],
41+
]))
42+
43+
44+
@pytest.fixture
45+
def rotation_transform():
46+
# 90 degree rotation around z axis
47+
angle = np.pi / 2
48+
rot = np.array([
49+
[np.cos(angle), -np.sin(angle), 0, 0],
50+
[np.sin(angle), np.cos(angle), 0, 0],
51+
[0, 0, 1, 0],
52+
[0, 0, 0, 1],
53+
])
54+
return nt.linear.Affine(map=rot)
55+
56+
57+
@pytest.mark.parametrize(
58+
"test_xfm, reference_xfm, expected",
59+
[
60+
(nt.linear.Affine(np.eye(4)), None, np.zeros(1)),
61+
(nt.linear.Affine(np.array([
62+
[1, 0, 0, 1],
63+
[0, 1, 0, 2],
64+
[0, 0, 1, 3],
65+
[0, 0, 0, 1],
66+
])), None, [np.linalg.norm([1, 2, 3])]),
67+
(nt.linear.Affine(np.array([
68+
[1, 0, 0, 1],
69+
[0, 1, 0, 2],
70+
[0, 0, 1, 3],
71+
[0, 0, 0, 1],
72+
])), nt.linear.Affine(np.eye(4)), [np.linalg.norm([1, 2, 3])]),
73+
],
74+
)
75+
def test_displacements_within_mask(simple_mask_img, test_xfm, reference_xfm, expected):
76+
disp = displacements_within_mask(simple_mask_img, test_xfm, reference_xfm)
77+
np.testing.assert_allclose(disp, expected)
78+
79+
80+
@pytest.mark.parametrize(
81+
"test_xfm, expected",
82+
[
83+
(nt.linear.Affine(np.eye(4)), 0),
84+
(nt.linear.Affine(np.array([
85+
[1, 0, 0, 1],
86+
[0, 1, 0, 2],
87+
[0, 0, 1, 3],
88+
[0, 0, 0, 1],
89+
])), np.linalg.norm([1, 2, 3])),
90+
],
91+
)
92+
def test_compute_fd_from_transform(simple_mask_img, test_xfm, expected):
93+
fd = compute_fd_from_transform(simple_mask_img, test_xfm)
94+
assert np.isclose(fd, expected)
95+
96+
97+
@pytest.mark.parametrize(
98+
"motion_params, radius, expected",
99+
[
100+
(np.zeros((5, 6)), 50, np.zeros(5)), # 5 frames, 3 trans, 3 rot
101+
(
102+
np.array([
103+
[0,0,0,0,0,0],
104+
[2,0,0,0,0,0], # 2mm translation in x at frame 1
105+
[2,0,0,90,0,0],
106+
]), # 90deg rotation in x at frame 2
107+
50,
108+
[0, 2, abs(np.deg2rad(90)) * 50]
109+
), # First frame: 0, Second: translation 2mm, Third: rotation (pi/2)*50
110+
],
111+
)
112+
def test_compute_fd_from_motion(motion_params, radius, expected):
113+
fd = compute_fd_from_motion(motion_params, radius=radius)
114+
np.testing.assert_allclose(fd, expected, atol=1e-4)
115+
116+
117+
@pytest.mark.parametrize(
118+
"affine, expected_trans, expected_rot",
119+
[
120+
(np.eye(4) + np.array([[0,0,0,10],[0,0,0,15],[0,0,0,20],[0,0,0,0]]), # translation only
121+
[10, 15, 20], [0, 0, 0]),
122+
(np.array([
123+
[1, 0, 0, 0],
124+
[0, np.cos(np.deg2rad(30)), -np.sin(np.deg2rad(30)), 0],
125+
[0, np.sin(np.deg2rad(30)), np.cos(np.deg2rad(30)), 0],
126+
[0, 0, 0, 1], # rotation only
127+
]), [0, 0, 0], [30, 0, 0]), # Only one rot will be close to 30
128+
],
129+
)
130+
def test_extract_motion_parameters(affine, expected_trans, expected_rot):
131+
params = extract_motion_parameters(affine)
132+
assert np.allclose(params[:3], expected_trans)
133+
# For rotation case, at least one value close to 30
134+
if np.any(np.abs(expected_rot)):
135+
assert np.any(np.isclose(np.abs(params[3:]), 30, atol=1e-4))
136+
else:
137+
assert np.allclose(params[3:], expected_rot)
138+
139+
140+
@pytest.mark.parametrize(
141+
"reference, test, mask, expected",
142+
[
143+
(
144+
np.array([[1.0, 2.0], [0.0, 4.0]]),
145+
np.array([[2.0, 1.0], [3.0, 8.0]]),
146+
np.array([[True, True], [True, True]]),
147+
np.array([[(2.0 - 1.0) / 1.0, (1.0 - 2.0) / 2.0], [0, (8.0 - 4.0) / 4.0]]) * 100,
148+
),
149+
(
150+
np.zeros((2,2)),
151+
np.ones((2,2)),
152+
np.ones((2,2), dtype=bool),
153+
np.zeros((2,2)),
154+
),
155+
(
156+
np.array([[5, 10], [15, 20]]),
157+
np.array([[10, 5], [30, 10]]),
158+
np.array([[False, True], [True, False]]),
159+
np.array([[0, (5 - 10) / 10], [(30 - 15) / 15, 0]]) * 100,
160+
),
161+
],
162+
)
163+
def test_compute_percentage_change_param(reference, test, mask, expected):
164+
result = compute_percentage_change(reference, test, mask)
165+
np.testing.assert_array_almost_equal(result, expected)
166+
167+
168+
def test_identify_spikes(request):
169+
rng = request.node.rng
170+
171+
n_samples = 450
172+
173+
fd = rng.normal(0, 5, n_samples)
174+
threshold = 2.0
175+
176+
expected_indices = np.asarray(
177+
[5, 57, 85, 100, 127, 180, 191, 202, 335, 393, 409]
178+
)
179+
expected_mask = np.zeros(n_samples, dtype=bool)
180+
expected_mask[expected_indices] = True
181+
182+
obtained_indices, obtained_mask = identify_spikes(fd, threshold=threshold)
183+
184+
assert np.array_equal(obtained_indices, expected_indices)
185+
assert np.array_equal(obtained_mask, expected_mask)

0 commit comments

Comments
 (0)