Skip to content

Commit f3236b7

Browse files
jhlegarretaoesteban
andcommitted
ENH: Add transform analysis utils
Add transform analysis utils. Transfer contents from the `NiFreeze` projects so that hey can be reused across projects requiring transform analysis: https://github.com/nipreps/nifreeze/blob/d27ba7552bbd9095c3c13b46443d87a4b5504c4c/src/nifreeze/analysis/motion.py https://github.com/nipreps/nifreeze/blob/d27ba7552bbd9095c3c13b46443d87a4b5504c4c/src/nifreeze/data/utils.py Add a fixture to be able to reuse a random number generator across tests. Co-authored-by: Oscar Esteban <code@oscaresteban.es>
1 parent 7845e64 commit f3236b7

File tree

4 files changed

+355
-0
lines changed

4 files changed

+355
-0
lines changed

nitransforms/analysis/__init__.py

Whitespace-only changes.

nitransforms/analysis/utils.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""
4+
Utilities to aid in performing and evaluating image registration.
5+
6+
This module provides functions to compute displacements of image coordinates
7+
under a transformation, useful for assessing the accuracy of image registration
8+
processes.
9+
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from itertools import product
15+
from typing import Tuple
16+
17+
import nibabel as nb
18+
import numpy as np
19+
from scipy.stats import zscore
20+
21+
from nitransforms.base import TransformBase
22+
23+
24+
RADIUS = 50.0
25+
"""Typical radius (in mm) of a sphere mimicking the size of a typical human brain."""
26+
27+
28+
def compute_fd_from_motion(motion_parameters: np.ndarray, radius: float = RADIUS) -> np.ndarray:
29+
"""Compute framewise displacement (FD) from motion parameters.
30+
31+
Each row in the motion parameters represents one frame, and columns
32+
represent each coordinate axis ``x``, `y``, and ``z``. Translation
33+
parameters are followed by rotation parameters column-wise.
34+
35+
Parameters
36+
----------
37+
motion_parameters : :obj:`numpy.ndarray`
38+
Motion parameters.
39+
radius : :obj:`float`, optional
40+
Radius (in mm) of a sphere mimicking the size of a typical human brain.
41+
42+
Returns
43+
-------
44+
:obj:`numpy.ndarray`
45+
The framewise displacement (FD) as the sum of absolute differences
46+
between consecutive frames.
47+
"""
48+
49+
translations = motion_parameters[:, :3]
50+
rotations_deg = motion_parameters[:, 3:]
51+
rotations_rad = np.deg2rad(rotations_deg)
52+
53+
# Compute differences between consecutive frames
54+
d_translations = np.vstack([np.zeros((1, 3)), np.diff(translations, axis=0)])
55+
d_rotations = np.vstack([np.zeros((1, 3)), np.diff(rotations_rad, axis=0)])
56+
57+
# Convert rotations from radians to displacement on a sphere
58+
rotation_displacement = d_rotations * radius
59+
60+
# Compute FD as sum of absolute differences
61+
return np.sum(np.abs(d_translations) + np.abs(rotation_displacement), axis=1)
62+
63+
64+
def compute_fd_from_transform(
65+
img: nb.spatialimages.SpatialImage,
66+
test_xfm: TransformBase,
67+
radius: float = RADIUS,
68+
) -> float:
69+
"""
70+
Compute the framewise displacement (FD) for a given transformation.
71+
72+
Parameters
73+
----------
74+
img : :obj:`~nibabel.spatialimages.SpatialImage`
75+
The reference image. Used to extract the center coordinates.
76+
test_xfm : :obj:`~nitransforms.base.TransformBase`
77+
The transformation to test. Applied to coordinates around the image center.
78+
radius : :obj:`float`, optional
79+
The radius (in mm) of the spherical neighborhood around the center of the image.
80+
81+
Returns
82+
-------
83+
:obj:`float`
84+
The average framewise displacement (FD) for the test transformation.
85+
86+
"""
87+
affine = img.affine
88+
# Compute the center of the image in voxel space
89+
center_ijk = 0.5 * (np.array(img.shape[:3]) - 1)
90+
# Convert to world coordinates
91+
center_xyz = nb.affines.apply_affine(affine, center_ijk)
92+
# Generate coordinates of points at radius distance from center
93+
fd_coords = np.array(list(product(*((radius, -radius),) * 3))) + center_xyz
94+
# Compute the average displacement from the test transformation
95+
return np.mean(np.linalg.norm(test_xfm.map(fd_coords) - fd_coords, axis=-1))
96+
97+
98+
def displacements_within_mask(
99+
mask_img: nb.spatialimages.SpatialImage,
100+
test_xfm: TransformBase,
101+
reference_xfm: TransformBase | None = None,
102+
) -> np.ndarray:
103+
"""
104+
Compute the distance between voxel coordinates mapped through two transforms.
105+
106+
Parameters
107+
----------
108+
mask_img : :obj:`~nibabel.spatialimages.SpatialImage`
109+
A mask image that defines the region of interest. Voxel coordinates
110+
within the mask are transformed.
111+
test_xfm : :obj:`~nitransforms.base.TransformBase`
112+
The transformation to test. This transformation is applied to the
113+
voxel coordinates.
114+
reference_xfm : :obj:`~nitransforms.base.TransformBase`, optional
115+
A reference transformation to compare with. If ``None``, the identity
116+
transformation is assumed (no transformation).
117+
118+
Returns
119+
-------
120+
:obj:`~numpy.ndarray`
121+
An array of displacements (in mm) for each voxel within the mask.
122+
123+
"""
124+
# Mask data as boolean (True for voxels inside the mask)
125+
maskdata = np.asanyarray(mask_img.dataobj) > 0
126+
# Convert voxel coordinates to world coordinates using affine transform
127+
xyz = nb.affines.apply_affine(
128+
mask_img.affine,
129+
np.argwhere(maskdata),
130+
)
131+
# Apply the test transformation
132+
targets = test_xfm.map(xyz)
133+
134+
# Compute the difference (displacement) between the test and reference transformations
135+
diffs = targets - xyz if reference_xfm is None else targets - reference_xfm.map(xyz)
136+
return np.linalg.norm(diffs, axis=-1)
137+
138+
139+
def extract_motion_parameters(affine: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
140+
"""Extract translation (mm) and rotation (degrees) parameters from an affine matrix.
141+
142+
Parameters
143+
----------
144+
affine : :obj:`~numpy.ndarray`
145+
The affine transformation matrix.
146+
147+
Returns
148+
-------
149+
:obj:`tuple`
150+
Extracted translation and rotation parameters.
151+
"""
152+
153+
translation = affine[:3, 3]
154+
rotation_rad = np.arctan2(
155+
[affine[2, 1], affine[0, 2], affine[1, 0]], [affine[2, 2], affine[0, 0], affine[1, 1]]
156+
)
157+
rotation_deg = np.rad2deg(rotation_rad)
158+
return *translation, *rotation_deg
159+
160+
161+
def identify_spikes(fd: np.ndarray, threshold: float = 2.0):
162+
"""Identify motion spikes in framewise displacement data.
163+
164+
Identifies high-motion frames as timepoint exceeding a given threshold value
165+
based on z-score normalized framewise displacement (FD) values.
166+
167+
Parameters
168+
----------
169+
fd : :obj:`~numpy.ndarray`
170+
Framewise displacement data.
171+
threshold : :obj:`float`, optional
172+
Threshold value to determine motion spikes.
173+
174+
Returns
175+
-------
176+
indices : :obj:`~numpy.ndarray`
177+
Indices of identified motion spikes.
178+
mask : :obj:`~numpy.ndarray`
179+
Mask of identified motion spikes.
180+
"""
181+
182+
# Normalize (z-score)
183+
fd_norm = zscore(fd)
184+
185+
mask = fd_norm > threshold
186+
indices = np.where(mask)[0]
187+
188+
return indices, mask

nitransforms/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
4+
import numpy as np
5+
import pytest
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def random_number_generator(request):
10+
"""Automatically set a fixed-seed random number generator for all tests."""
11+
request.node.rng = np.random.default_rng(1234)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
4+
import numpy as np
5+
import nibabel as nb
6+
import pytest
7+
8+
import nitransforms as nt
9+
10+
from nitransforms.analysis.utils import (
11+
compute_fd_from_motion,
12+
compute_fd_from_transform,
13+
displacements_within_mask,
14+
extract_motion_parameters,
15+
identify_spikes,
16+
)
17+
18+
19+
@pytest.fixture
20+
def identity_affine():
21+
return np.eye(4)
22+
23+
24+
@pytest.fixture
25+
def simple_mask_img(identity_affine):
26+
# 3x3x3 mask with center voxel as 1, rest 0
27+
data = np.zeros((3, 3, 3), dtype=np.uint8)
28+
data[1, 1, 1] = 1
29+
return nb.Nifti1Image(data, identity_affine)
30+
31+
32+
@pytest.fixture
33+
def translation_transform():
34+
# Simple translation of (1, 2, 3) mm
35+
return nt.linear.Affine(map=np.array([
36+
[1, 0, 0, 1],
37+
[0, 1, 0, 2],
38+
[0, 0, 1, 3],
39+
[0, 0, 0, 1],
40+
]))
41+
42+
43+
@pytest.fixture
44+
def rotation_transform():
45+
# 90 degree rotation around z axis
46+
angle = np.pi / 2
47+
rot = np.array([
48+
[np.cos(angle), -np.sin(angle), 0, 0],
49+
[np.sin(angle), np.cos(angle), 0, 0],
50+
[0, 0, 1, 0],
51+
[0, 0, 0, 1],
52+
])
53+
return nt.linear.Affine(map=rot)
54+
55+
56+
@pytest.mark.parametrize(
57+
"test_xfm, reference_xfm, expected",
58+
[
59+
(nt.linear.Affine(np.eye(4)), None, np.zeros(1)),
60+
(nt.linear.Affine(np.array([
61+
[1, 0, 0, 1],
62+
[0, 1, 0, 2],
63+
[0, 0, 1, 3],
64+
[0, 0, 0, 1],
65+
])), None, [np.linalg.norm([1, 2, 3])]),
66+
(nt.linear.Affine(np.array([
67+
[1, 0, 0, 1],
68+
[0, 1, 0, 2],
69+
[0, 0, 1, 3],
70+
[0, 0, 0, 1],
71+
])), nt.linear.Affine(np.eye(4)), [np.linalg.norm([1, 2, 3])]),
72+
],
73+
)
74+
def test_displacements_within_mask(simple_mask_img, test_xfm, reference_xfm, expected):
75+
disp = displacements_within_mask(simple_mask_img, test_xfm, reference_xfm)
76+
np.testing.assert_allclose(disp, expected)
77+
78+
79+
@pytest.mark.parametrize(
80+
"test_xfm, expected",
81+
[
82+
(nt.linear.Affine(np.eye(4)), 0),
83+
(nt.linear.Affine(np.array([
84+
[1, 0, 0, 1],
85+
[0, 1, 0, 2],
86+
[0, 0, 1, 3],
87+
[0, 0, 0, 1],
88+
])), np.linalg.norm([1, 2, 3])),
89+
],
90+
)
91+
def test_compute_fd_from_transform(simple_mask_img, test_xfm, expected):
92+
fd = compute_fd_from_transform(simple_mask_img, test_xfm)
93+
assert np.isclose(fd, expected)
94+
95+
96+
@pytest.mark.parametrize(
97+
"motion_params, radius, expected",
98+
[
99+
(np.zeros((5, 6)), 50, np.zeros(5)), # 5 frames, 3 trans, 3 rot
100+
(
101+
np.array([
102+
[0,0,0,0,0,0],
103+
[2,0,0,0,0,0], # 2mm translation in x at frame 1
104+
[2,0,0,90,0,0],
105+
]), # 90deg rotation in x at frame 2
106+
50,
107+
[0, 2, abs(np.deg2rad(90)) * 50]
108+
), # First frame: 0, Second: translation 2mm, Third: rotation (pi/2)*50
109+
],
110+
)
111+
def test_compute_fd_from_motion(motion_params, radius, expected):
112+
fd = compute_fd_from_motion(motion_params, radius=radius)
113+
np.testing.assert_allclose(fd, expected, atol=1e-4)
114+
115+
116+
@pytest.mark.parametrize(
117+
"affine, expected_trans, expected_rot",
118+
[
119+
(np.eye(4) + np.array([[0,0,0,10],[0,0,0,15],[0,0,0,20],[0,0,0,0]]), # translation only
120+
[10, 15, 20], [0, 0, 0]),
121+
(np.array([
122+
[1, 0, 0, 0],
123+
[0, np.cos(np.deg2rad(30)), -np.sin(np.deg2rad(30)), 0],
124+
[0, np.sin(np.deg2rad(30)), np.cos(np.deg2rad(30)), 0],
125+
[0, 0, 0, 1], # rotation only
126+
]), [0, 0, 0], [30, 0, 0]), # Only one rot will be close to 30
127+
],
128+
)
129+
def test_extract_motion_parameters(affine, expected_trans, expected_rot):
130+
params = extract_motion_parameters(affine)
131+
assert np.allclose(params[:3], expected_trans)
132+
# For rotation case, at least one value close to 30
133+
if np.any(np.abs(expected_rot)):
134+
assert np.any(np.isclose(np.abs(params[3:]), 30, atol=1e-4))
135+
else:
136+
assert np.allclose(params[3:], expected_rot)
137+
138+
139+
def test_identify_spikes(request):
140+
rng = request.node.rng
141+
142+
n_samples = 450
143+
144+
fd = rng.normal(0, 5, n_samples)
145+
threshold = 2.0
146+
147+
expected_indices = np.asarray(
148+
[5, 57, 85, 100, 127, 180, 191, 202, 335, 393, 409]
149+
)
150+
expected_mask = np.zeros(n_samples, dtype=bool)
151+
expected_mask[expected_indices] = True
152+
153+
obtained_indices, obtained_mask = identify_spikes(fd, threshold=threshold)
154+
155+
assert np.array_equal(obtained_indices, expected_indices)
156+
assert np.array_equal(obtained_mask, expected_mask)

0 commit comments

Comments
 (0)