Skip to content

Commit d9393f3

Browse files
authored
Merge branch 'main' into add_robust_coreg
2 parents 9e7561f + 153cd71 commit d9393f3

File tree

6 files changed

+232
-15
lines changed

6 files changed

+232
-15
lines changed

docs/usage.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,15 @@ manual) fixed during robust template estimation to improve reproducibility.
201201
Iterations are automatically disabled to reduce runtime when :option:`--hmc-init-frame-fix` is
202202
used.
203203

204+
When motion correction is undesirable, use :option:`--hmc-off` to disable head motion
205+
correction entirely and keep the data unmodified apart from downstream
206+
processing steps.
207+
204208
Examples: ::
205209

206210
$ petprep /data/bids_root /out participant --hmc-fwhm 8 --hmc-start-time 60
207211
$ petprep /data/bids_root /out participant --hmc-init-frame 10 --hmc-init-frame-fix
212+
$ petprep /data/bids_root /out participant --hmc-off
208213

209214
Anatomical co-registration
210215
--------------------------

petprep/cli/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,12 @@ def _bids_filter(value, parser):
569569
action='store_true',
570570
help=('Keep the chosen initial reference frame fixed during head-motion estimation.'),
571571
)
572+
g_hmc.add_argument(
573+
'--hmc-off',
574+
dest='hmc_off',
575+
action='store_true',
576+
help='Disable head-motion correction and use the uncorrected data.',
577+
)
572578

573579
g_seg = parser.add_argument_group('Segmentation options')
574580
g_seg.add_argument(

petprep/cli/tests/test_parser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,19 @@ def test_hmc_init_frame_parsing(tmp_path):
366366
opts = parser.parse_args(base_args + ['--hmc-init-frame', '3', '--hmc-init-frame-fix'])
367367
assert opts.hmc_init_frame == 3
368368
assert opts.hmc_fix_frame is True
369+
370+
371+
def test_hmc_off_flag(tmp_path):
372+
"""Ensure disabling motion correction is parsed correctly."""
373+
datapath = tmp_path / 'data'
374+
outpath = tmp_path / 'out'
375+
datapath.mkdir()
376+
377+
parser = _build_parser()
378+
base_args = [str(datapath), str(outpath), 'participant']
379+
380+
opts = parser.parse_args(base_args)
381+
assert opts.hmc_off is False
382+
383+
opts = parser.parse_args(base_args + ['--hmc-off'])
384+
assert opts.hmc_off is True

petprep/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,8 @@ class workflow(_Config):
604604
"""Index of initial frame for head-motion estimation ('auto' selects highest uptake)."""
605605
hmc_fix_frame: bool = False
606606
"""Whether to fix the reference frame during head-motion estimation."""
607+
hmc_off: bool = False
608+
"""Disable head-motion correction and keep data uncorrected."""
607609
seg = 'gtm'
608610
"""Segmentation approach ('gtm', 'brainstem', 'thalamicNuclei',
609611
'hippocampusAmygdala', 'wm', 'raphe', 'limbic')."""

petprep/workflows/pet/fit.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
#
2121
# https://www.nipreps.org/community/licensing/
2222
#
23+
from collections.abc import Sequence
2324
from pathlib import Path
2425

2526
import nibabel as nb
27+
import numpy as np
2628
from nipype.interfaces import utility as niu
2729
from nipype.pipeline import engine as pe
28-
from nitransforms.linear import Affine
30+
from nitransforms.linear import Affine, LinearTransformsMapping
2931
from niworkflows.interfaces.header import ValidateImage
3032
from niworkflows.utils.connections import listify
3133

@@ -53,6 +55,70 @@
5355
from .registration import init_pet_reg_wf
5456

5557

58+
def _extract_twa_image(
59+
pet_file: str,
60+
output_dir: Path,
61+
frame_start_times: Sequence[float] | None,
62+
frame_durations: Sequence[float] | None,
63+
) -> str:
64+
"""Return a time-weighted average (twa) reference image from a 4D PET series."""
65+
66+
output_dir.mkdir(parents=True, exist_ok=True)
67+
img = nb.load(pet_file)
68+
if img.ndim < 4 or img.shape[-1] == 1:
69+
return pet_file
70+
71+
if frame_start_times is None or frame_durations is None:
72+
raise ValueError(
73+
'Frame timing metadata are required to compute a time-weighted reference image.'
74+
)
75+
76+
frame_start_times = np.asarray(frame_start_times, dtype=float)
77+
frame_durations = np.asarray(frame_durations, dtype=float)
78+
79+
if frame_start_times.ndim != 1 or frame_durations.ndim != 1:
80+
raise ValueError('Frame timing metadata must be one-dimensional sequences.')
81+
82+
if len(frame_start_times) != len(frame_durations):
83+
raise ValueError('FrameTimesStart and FrameDuration must have the same length.')
84+
85+
if len(frame_durations) != img.shape[-1]:
86+
raise ValueError(
87+
'Frame timing metadata must match the number of frames in the PET series.'
88+
)
89+
90+
if np.any(frame_durations <= 0):
91+
raise ValueError('FrameDuration values must all be positive.')
92+
93+
if np.any(np.diff(frame_start_times) < 0):
94+
raise ValueError('FrameTimesStart values must be non-decreasing.')
95+
96+
hdr = img.header.copy()
97+
data = np.asanyarray(img.dataobj)
98+
weighted_average = np.average(data, axis=-1, weights=frame_durations).astype(np.float32)
99+
hdr.set_data_shape(weighted_average.shape)
100+
101+
pet_path = Path(pet_file)
102+
# Drop all suffixes (e.g., `.nii.gz`) before appending the reference label
103+
pet_stem = pet_path
104+
while pet_stem.suffix:
105+
pet_stem = pet_stem.with_suffix('')
106+
107+
out_file = output_dir / f'{pet_stem.name}_timeavgref.nii.gz'
108+
img.__class__(weighted_average, img.affine, hdr).to_filename(out_file)
109+
return str(out_file)
110+
111+
112+
def _write_identity_xforms(num_frames: int, filename: Path) -> Path:
113+
"""Write ``num_frames`` identity transforms to ``filename``."""
114+
115+
filename = Path(filename)
116+
filename.parent.mkdir(parents=True, exist_ok=True)
117+
n_xforms = max(int(num_frames or 0), 1)
118+
LinearTransformsMapping([Affine() for _ in range(n_xforms)]).to_filename(filename, fmt='itk')
119+
return filename
120+
121+
56122
def init_pet_fit_wf(
57123
*,
58124
pet_series: list[str],
@@ -158,6 +224,13 @@ def init_pet_fit_wf(
158224
if (petref is None) ^ (hmc_xforms is None):
159225
raise ValueError("Both 'petref' and 'hmc' transforms must be provided together.")
160226

227+
if config.workflow.hmc_off and (petref or hmc_xforms):
228+
config.loggers.workflow.warning(
229+
'Ignoring precomputed motion correction derivatives because --hmc-off was set.'
230+
)
231+
petref = None
232+
hmc_xforms = None
233+
161234
workflow = Workflow(name=name)
162235

163236
inputnode = pe.Node(
@@ -202,19 +275,6 @@ def init_pet_fit_wf(
202275
)
203276
hmc_buffer = pe.Node(niu.IdentityInterface(fields=['hmc_xforms']), name='hmc_buffer')
204277

205-
if pet_tlen <= 1: # 3D PET
206-
petref = pet_file
207-
idmat_fname = config.execution.work_dir / 'idmat.tfm'
208-
Affine().to_filename(idmat_fname, fmt='itk')
209-
hmc_xforms = idmat_fname
210-
config.loggers.workflow.debug('3D PET file - motion correction not needed')
211-
if petref:
212-
petref_buffer.inputs.petref = petref
213-
config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}')
214-
if hmc_xforms:
215-
hmc_buffer.inputs.hmc_xforms = hmc_xforms
216-
config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}')
217-
218278
timing_parameters = prepare_timing_parameters(metadata)
219279
frame_durations = timing_parameters.get('FrameDuration')
220280
frame_start_times = timing_parameters.get('FrameTimesStart')
@@ -231,6 +291,31 @@ def init_pet_fit_wf(
231291
registration_method = (
232292
'mri_robust_register' if config.workflow.pet2anat_robust else 'mri_coreg'
233293
)
294+
hmc_disabled = bool(config.workflow.hmc_off)
295+
if hmc_disabled:
296+
config.execution.work_dir.mkdir(parents=True, exist_ok=True)
297+
petref = petref or _extract_twa_image(
298+
pet_file,
299+
config.execution.work_dir,
300+
frame_start_times,
301+
frame_durations,
302+
)
303+
idmat_fname = config.execution.work_dir / 'idmat.tfm'
304+
n_frames = len(frame_durations)
305+
hmc_xforms = _write_identity_xforms(n_frames, idmat_fname)
306+
config.loggers.workflow.info('Head motion correction disabled; using identity transforms.')
307+
308+
if pet_tlen <= 1: # 3D PET
309+
petref = pet_file
310+
idmat_fname = config.execution.work_dir / 'idmat.tfm'
311+
hmc_xforms = _write_identity_xforms(pet_tlen, idmat_fname)
312+
config.loggers.workflow.debug('3D PET file - motion correction not needed')
313+
if petref:
314+
petref_buffer.inputs.petref = petref
315+
config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}')
316+
if hmc_xforms:
317+
hmc_buffer.inputs.hmc_xforms = hmc_xforms
318+
config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}')
234319

235320
summary = pe.Node(
236321
FunctionalSummary(

petprep/workflows/pet/tests/test_fit.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import nibabel as nb
4+
import nitransforms as nt
45
import numpy as np
56
import pytest
67
import yaml
@@ -11,7 +12,7 @@
1112
from ....utils import bids
1213
from ...tests import mock_config
1314
from ...tests.test_base import BASE_LAYOUT
14-
from ..fit import init_pet_fit_wf, init_pet_native_wf
15+
from ..fit import _extract_twa_image, init_pet_fit_wf, init_pet_native_wf
1516
from ..outputs import init_refmask_report_wf
1617

1718

@@ -370,6 +371,108 @@ def test_pet_fit_stage1_with_cached_baseline(bids_root: Path, tmp_path: Path):
370371
assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names())
371372

372373

374+
def test_pet_fit_hmc_off_disables_stage1(bids_root: Path, tmp_path: Path):
375+
"""Disabling HMC should skip Stage 1 and use identity transforms."""
376+
pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
377+
data = np.stack(
378+
(
379+
np.ones((2, 2, 2), dtype=np.float32),
380+
np.full((2, 2, 2), 3.0, dtype=np.float32),
381+
),
382+
axis=-1,
383+
)
384+
img = nb.Nifti1Image(data, np.eye(4))
385+
for path in pet_series:
386+
img.to_filename(path)
387+
388+
sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json')
389+
sidecar.write_text('{"FrameTimesStart": [0, 2], "FrameDuration": [2, 4]}')
390+
391+
with mock_config(bids_dir=bids_root):
392+
config.workflow.hmc_off = True
393+
wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)
394+
395+
assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names())
396+
hmc_buffer = wf.get_node('hmc_buffer')
397+
assert str(hmc_buffer.inputs.hmc_xforms).endswith('idmat.tfm')
398+
hmc = nt.linear.load(hmc_buffer.inputs.hmc_xforms)
399+
assert hmc.matrix.shape[0] == data.shape[-1]
400+
assert np.allclose(hmc.matrix, np.tile(np.eye(4), (data.shape[-1], 1, 1)))
401+
petref_buffer = wf.get_node('petref_buffer')
402+
petref_name = Path(petref_buffer.inputs.petref).name
403+
assert petref_name.endswith('_timeavgref.nii.gz')
404+
assert '.nii_timeavgref' not in petref_name
405+
petref_img = nb.load(petref_buffer.inputs.petref)
406+
assert np.allclose(petref_img.get_fdata(), 14.0 / 6.0)
407+
408+
409+
@pytest.mark.parametrize(
410+
('frame_start_times', 'frame_durations', 'message'),
411+
[
412+
(None, [1, 1], 'Frame timing metadata are required'),
413+
([0, 1], None, 'Frame timing metadata are required'),
414+
([[0, 1]], [1, 1], 'must be one-dimensional'),
415+
([0, 1], [1], 'the same length'),
416+
([0, 1, 2], [1, 1, 1], 'match the number of frames'),
417+
([0, 1], [1, -1], 'must all be positive'),
418+
([1, 0], [1, 1], 'must be non-decreasing'),
419+
],
420+
)
421+
def test_extract_twa_image_validation(
422+
tmp_path: Path, frame_start_times, frame_durations, message: str
423+
):
424+
"""Validate error handling for malformed frame timing metadata."""
425+
426+
pet_img = nb.Nifti1Image(np.zeros((2, 2, 2, 2), dtype=np.float32), np.eye(4))
427+
pet_file = tmp_path / 'pet.nii.gz'
428+
pet_img.to_filename(pet_file)
429+
430+
with pytest.raises(ValueError, match=message): # noqa: PT011
431+
_extract_twa_image(
432+
str(pet_file),
433+
tmp_path / 'out',
434+
frame_start_times,
435+
frame_durations,
436+
)
437+
438+
439+
def test_pet_fit_hmc_off_ignores_precomputed(bids_root: Path, tmp_path: Path):
440+
"""Precomputed derivatives are ignored when ``--hmc-off`` is set."""
441+
442+
pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
443+
data = np.stack((np.ones((2, 2, 2)), np.full((2, 2, 2), 2.0)), axis=-1)
444+
img = nb.Nifti1Image(data, np.eye(4))
445+
for path in pet_series:
446+
img.to_filename(path)
447+
448+
sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json')
449+
sidecar.write_text('{"FrameTimesStart": [0, 1], "FrameDuration": [1, 1]}')
450+
451+
precomputed_petref = tmp_path / 'precomputed_petref.nii.gz'
452+
precomputed_hmc = tmp_path / 'precomputed_hmc.txt'
453+
img.to_filename(precomputed_petref)
454+
np.savetxt(precomputed_hmc, np.eye(4))
455+
456+
with mock_config(bids_dir=bids_root):
457+
config.workflow.hmc_off = True
458+
wf = init_pet_fit_wf(
459+
pet_series=pet_series,
460+
precomputed={
461+
'petref': str(precomputed_petref),
462+
'transforms': {'hmc': str(precomputed_hmc)},
463+
},
464+
omp_nthreads=1,
465+
)
466+
467+
petref_buffer = wf.get_node('petref_buffer')
468+
hmc_buffer = wf.get_node('hmc_buffer')
469+
470+
assert petref_buffer.inputs.petref != str(precomputed_petref)
471+
assert Path(petref_buffer.inputs.petref).name.endswith('_timeavgref.nii.gz')
472+
assert hmc_buffer.inputs.hmc_xforms != str(precomputed_hmc)
473+
assert Path(hmc_buffer.inputs.hmc_xforms).name == 'idmat.tfm'
474+
475+
373476
def test_init_refmask_report_wf(tmp_path: Path):
374477
"""Ensure the refmask report workflow initializes without errors."""
375478
wf = init_refmask_report_wf(output_dir=str(tmp_path), ref_name='test')

0 commit comments

Comments
 (0)