diff --git a/petprep/cli/parser.py b/petprep/cli/parser.py index 38a28f61..2334c0b3 100644 --- a/petprep/cli/parser.py +++ b/petprep/cli/parser.py @@ -583,13 +583,14 @@ def _bids_filter(value, parser): g_hmc.add_argument( '--petref', default='template', - choices=['template', 'twa', 'sum', 'first5min'], + choices=['template', 'twa', 'sum', 'first5min', 'auto'], help=( "Strategy for generating the PET reference. 'template' uses the " "motion correction template, while 'twa' computes a time-weighted " "average, 'sum' produces a summed image of the motion-corrected " "series, and 'first5min' averages the early (0-5 minute) portion " - 'of the acquisition.' + "of the acquisition. 'auto' evaluates multiple strategies to " + 'select the best reference.' ), ) @@ -788,8 +789,8 @@ def parse_args(args=None, namespace=None): config.execution.log_level = int(max(25 - 5 * opts.verbose_count, logging.DEBUG)) config.from_dict(vars(opts), init=['nipype']) - config.workflow._petref_cli_set = '--petref' in argv - config.workflow._pet2anat_method_cli_set = '--pet2anat-method' in argv + config.workflow.petref_specified = '--petref' in argv + config.workflow.pet2anat_method_specified = '--pet2anat-method' in argv if config.execution.session_label: config.execution.bids_filters = config.execution.bids_filters or {} diff --git a/petprep/config.py b/petprep/config.py index bedea122..3b5c9673 100644 --- a/petprep/config.py +++ b/petprep/config.py @@ -561,7 +561,7 @@ class workflow(_Config): pet2anat_method_specified: bool = False """Flag indicating whether ``--pet2anat-method`` was explicitly provided.""" petref: str = 'template' - """Strategy for building the PET reference (``'template'``, ``'twa'``, ``'sum'`` or ``'first5min'``).""" + """Strategy for building the PET reference (``'template'``, ``'twa'``, ``'sum'``, ``'first5min'`` or ``'auto'``).""" petref_specified: bool = False """Flag indicating whether ``--petref`` was explicitly provided.""" cifti_output = None diff --git a/petprep/interfaces/reports.py b/petprep/interfaces/reports.py index 93b75121..1c8a89bc 100644 --- a/petprep/interfaces/reports.py +++ b/petprep/interfaces/reports.py @@ -252,11 +252,17 @@ class FunctionalSummaryInputSpec(TraitedSpec): 'twa', 'sum', 'first5min', + 'auto', mandatory=True, desc='PET reference generation strategy', ) requested_petref_strategy = traits.Enum( - 'template', 'twa', 'sum', 'first5min', desc='User-requested PET reference strategy' + 'template', + 'twa', + 'sum', + 'first5min', + 'auto', + desc='User-requested PET reference strategy', ) hmc_disabled = traits.Bool(False, desc='Head motion correction disabled') @@ -295,6 +301,7 @@ def _generate_segment(self): 'twa': 'Time-weighted average of motion-corrected series', 'sum': 'Summed motion-corrected series', 'first5min': 'Early (0-5 minute) average of motion-corrected series', + 'auto': 'Automatically selected reference', } petref_strategy = reference_map.get(self.inputs.petref_strategy, 'Unknown') requested = getattr(self.inputs, 'requested_petref_strategy', None) diff --git a/petprep/workflows/pet/fit.py b/petprep/workflows/pet/fit.py index 71a7a1ca..2203cbc7 100644 --- a/petprep/workflows/pet/fit.py +++ b/petprep/workflows/pet/fit.py @@ -208,6 +208,34 @@ def _extract_sum_image(pet_file: str, output_dir: 'Path') -> str: return str(out_file) +def _select_best_petref(labels, scores, transforms, inv_transforms, winners, petrefs): + """Select the PET reference with the lowest registration cost.""" + + if not labels or not scores: + raise ValueError('No PET reference candidates were provided for selection.') + + best_idx = None + best_score = float('inf') + for idx, score in enumerate(scores): + if score is None: + continue + if score < best_score: + best_idx = idx + best_score = score + + if best_idx is None: + raise ValueError('No registration scores were available for selection.') + + return ( + labels[best_idx], + best_score, + transforms[best_idx], + inv_transforms[best_idx], + winners[best_idx], + petrefs[best_idx], + ) + + def _write_identity_xforms(num_frames: int, filename: Path) -> Path: """Write ``num_frames`` identity transforms to ``filename``.""" @@ -388,6 +416,15 @@ def init_pet_fit_wf( requested_petref_strategy = getattr(config.workflow, 'petref', 'template') hmc_disabled = bool(config.workflow.hmc_off) petref_strategy = requested_petref_strategy + petref_candidates = None + petref_candidate_labels: list[str] = [] + if requested_petref_strategy == 'auto': + petref_strategy = 'auto' + petref_candidate_labels = ['template', 'twa', 'sum', 'first5min'] + petref_candidates = pe.Node( + niu.IdentityInterface(fields=petref_candidate_labels), name='petref_candidates' + ) + if hmc_disabled and petref_strategy == 'template': config.loggers.workflow.warning( 'Head motion correction disabled (--hmc-off); using a time-weighted average ' @@ -395,7 +432,7 @@ def init_pet_fit_wf( ) petref_strategy = 'twa' - use_corrected_reference = petref_strategy in {'twa', 'sum', 'first5min'} + use_corrected_reference = petref_strategy in {'twa', 'sum', 'first5min', 'auto'} reference_function = _extract_twa_image reference_kwargs: dict[str, object] = { 'output_dir': config.execution.work_dir, @@ -462,6 +499,42 @@ def init_pet_fit_wf( corrected_reference.inputs.frame_start_times = frame_start_times corrected_reference.inputs.frame_durations = frame_durations + reference_nodes: dict[str, pe.Node] = {} + if petref_strategy == 'auto': + reference_nodes['twa'] = pe.Node( + niu.Function( + function=_extract_twa_image, + input_names=reference_input_names, + output_names=['out_file'], + ), + name='auto_twa_reference', + ) + reference_nodes['twa'].inputs.output_dir = config.execution.work_dir + reference_nodes['twa'].inputs.frame_start_times = frame_start_times + reference_nodes['twa'].inputs.frame_durations = frame_durations + + reference_nodes['sum'] = pe.Node( + niu.Function( + function=_extract_sum_image, + input_names=['pet_file', 'output_dir'], + output_names=['out_file'], + ), + name='auto_sum_reference', + ) + reference_nodes['sum'].inputs.output_dir = config.execution.work_dir + + reference_nodes['first5min'] = pe.Node( + niu.Function( + function=_extract_first5min_image, + input_names=reference_input_names, + output_names=['out_file'], + ), + name='auto_first5min_reference', + ) + reference_nodes['first5min'].inputs.output_dir = config.execution.work_dir + reference_nodes['first5min'].inputs.frame_start_times = frame_start_times + reference_nodes['first5min'].inputs.frame_durations = frame_durations + registration_method = 'Precomputed' if not petref2anat_xform: registration_method = { @@ -477,6 +550,8 @@ def init_pet_fit_wf( n_frames = len(frame_durations) hmc_xforms = _write_identity_xforms(n_frames, idmat_fname) config.loggers.workflow.info('Head motion correction disabled; using identity transforms.') + if petref_strategy == 'auto' and petref_candidates is not None: + petref_candidates.inputs.template = petref if pet_tlen <= 1: # 3D PET petref = pet_file @@ -581,7 +656,9 @@ def init_pet_fit_wf( (pet_hmc_wf, ds_petref_wf, [('outputnode.petref', 'inputnode.petref')]), ]) # fmt:skip - if use_corrected_reference: + if petref_strategy == 'auto': + workflow.connect([(pet_hmc_wf, petref_candidates, [('outputnode.petref', 'template')])]) + elif use_corrected_reference: workflow.connect([ (pet_hmc_wf, corrected_pet_for_report, [('outputnode.petref', 'ref_file')]), (val_pet, corrected_pet_for_report, [('out_file', 'in_file')]), @@ -592,6 +669,19 @@ def init_pet_fit_wf( else: workflow.connect([(pet_hmc_wf, petref_buffer, [('outputnode.petref', 'petref')])]) + if petref_strategy == 'auto' and use_corrected_reference: + workflow.connect([ + (pet_hmc_wf, corrected_pet_for_report, [('outputnode.petref', 'ref_file')]), + (val_pet, corrected_pet_for_report, [('out_file', 'in_file')]), + (hmc_buffer, corrected_pet_for_report, [('hmc_xforms', 'transforms')]), + (corrected_pet_for_report, reference_nodes['twa'], [('out_file', 'pet_file')]), + (corrected_pet_for_report, reference_nodes['sum'], [('out_file', 'pet_file')]), + (corrected_pet_for_report, reference_nodes['first5min'], [('out_file', 'pet_file')]), + (reference_nodes['twa'], petref_candidates, [('out_file', 'twa')]), + (reference_nodes['sum'], petref_candidates, [('out_file', 'sum')]), + (reference_nodes['first5min'], petref_candidates, [('out_file', 'first5min')]), + ]) # fmt:skip + if report_pet_reference: workflow.connect([ (pet_hmc_wf, report_pet_for_coreg, [('outputnode.petref', 'ref_file')]), @@ -612,7 +702,20 @@ def init_pet_fit_wf( ]) # fmt:skip val_pet.inputs.in_file = pet_file - if use_corrected_reference: + if petref_strategy == 'auto': + corrected_pet_for_report.inputs.ref_file = petref + petref_candidates.inputs.template = petref + workflow.connect([ + (val_pet, corrected_pet_for_report, [('out_file', 'in_file')]), + (hmc_buffer, corrected_pet_for_report, [('hmc_xforms', 'transforms')]), + (corrected_pet_for_report, reference_nodes['twa'], [('out_file', 'pet_file')]), + (corrected_pet_for_report, reference_nodes['sum'], [('out_file', 'pet_file')]), + (corrected_pet_for_report, reference_nodes['first5min'], [('out_file', 'pet_file')]), + (reference_nodes['twa'], petref_candidates, [('out_file', 'twa')]), + (reference_nodes['sum'], petref_candidates, [('out_file', 'sum')]), + (reference_nodes['first5min'], petref_candidates, [('out_file', 'first5min')]), + ]) # fmt:skip + elif use_corrected_reference: corrected_pet_for_report.inputs.ref_file = petref workflow.connect( [ @@ -651,36 +754,131 @@ def init_pet_fit_wf( ) petref2anat_xform = None + pet_to_t1_source = None + pet_to_t1_field = None + if not petref2anat_xform: config.loggers.workflow.info('PET Stage 2: Adding co-registration workflow of PET to T1w') - # calculate PET registration to T1w - pet_reg_wf = init_pet_reg_wf( - pet2anat_dof=config.workflow.pet2anat_dof, - omp_nthreads=omp_nthreads, - mem_gb=mem_gb['resampled'], - pet2anat_method=config.workflow.pet2anat_method, - sloppy=config.execution.sloppy, - ) - ds_petreg_wf = init_ds_registration_wf( - bids_root=layout.root, - output_dir=config.execution.petprep_dir, - source='petref', - dest='T1w', - name='ds_petreg_wf', - ) + if petref_strategy == 'auto': + ds_petreg_wf = init_ds_registration_wf( + bids_root=layout.root, + output_dir=config.execution.petprep_dir, + source='petref', + dest='T1w', + name='ds_petreg_wf', + ) - workflow.connect([ - (inputnode, pet_reg_wf, [ - ('t1w_preproc', 'inputnode.anat_preproc'), - ('t1w_mask', 'inputnode.anat_mask'), - ]), - (petref_buffer, pet_reg_wf, [('petref', 'inputnode.ref_pet_brain')]), - (val_pet, ds_petreg_wf, [('out_file', 'inputnode.source_files')]), - (pet_reg_wf, ds_petreg_wf, [('outputnode.itk_pet_to_t1', 'inputnode.xform')]), - (ds_petreg_wf, outputnode, [('outputnode.xform', 'petref2anat_xfm')]), - (pet_reg_wf, summary, [('outputnode.registration_winner', 'registration_winner')]), - ]) # fmt:skip + score_merge = pe.Node(niu.Merge(len(petref_candidate_labels)), name='merge_scores') + xfm_merge = pe.Node(niu.Merge(len(petref_candidate_labels)), name='merge_xfms') + inv_merge = pe.Node(niu.Merge(len(petref_candidate_labels)), name='merge_inv_xfms') + winner_merge = pe.Node( + niu.Merge(len(petref_candidate_labels)), name='merge_reg_winners' + ) + label_merge = pe.Node(niu.Merge(len(petref_candidate_labels)), name='merge_labels') + petref_merge = pe.Node(niu.Merge(len(petref_candidate_labels)), name='merge_petrefs') + + for idx, label in enumerate(petref_candidate_labels): + reg_wf = init_pet_reg_wf( + pet2anat_dof=config.workflow.pet2anat_dof, + omp_nthreads=omp_nthreads, + mem_gb=mem_gb['resampled'], + pet2anat_method=config.workflow.pet2anat_method, + sloppy=config.execution.sloppy, + name=f'pet_reg_wf_{label}', + ) + + label_src = pe.Node(niu.IdentityInterface(fields=['label']), name=f'label_{label}') + label_src.inputs.label = label + + workflow.connect([ + (inputnode, reg_wf, [ + ('t1w_preproc', 'inputnode.anat_preproc'), + ('t1w_mask', 'inputnode.anat_mask'), + ]), + (petref_candidates, reg_wf, [(label, 'inputnode.ref_pet_brain')]), + (reg_wf, score_merge, [( + 'outputnode.registration_score', f'in{idx + 1}' + )]), + (reg_wf, xfm_merge, [('outputnode.itk_pet_to_t1', f'in{idx + 1}')]), + (reg_wf, inv_merge, [('outputnode.itk_t1_to_pet', f'in{idx + 1}')]), + (reg_wf, winner_merge, [('outputnode.registration_winner', f'in{idx + 1}')]), + (petref_candidates, petref_merge, [(label, f'in{idx + 1}')]), + (label_src, label_merge, [('label', f'in{idx + 1}')]), + ]) # fmt:skip + + select_best_ref = pe.Node( + niu.Function( + function=_select_best_petref, + input_names=[ + 'labels', + 'scores', + 'transforms', + 'inv_transforms', + 'winners', + 'petrefs', + ], + output_names=[ + 'best_label', + 'best_score', + 'best_transform', + 'best_inv_transform', + 'best_winner', + 'best_petref', + ], + ), + name='select_best_petref', + ) + + workflow.connect([ + (score_merge, select_best_ref, [('out', 'scores')]), + (xfm_merge, select_best_ref, [('out', 'transforms')]), + (inv_merge, select_best_ref, [('out', 'inv_transforms')]), + (winner_merge, select_best_ref, [('out', 'winners')]), + (label_merge, select_best_ref, [('out', 'labels')]), + (petref_merge, select_best_ref, [('out', 'petrefs')]), + (val_pet, ds_petreg_wf, [('out_file', 'inputnode.source_files')]), + (select_best_ref, ds_petreg_wf, [('best_transform', 'inputnode.xform')]), + (ds_petreg_wf, outputnode, [('outputnode.xform', 'petref2anat_xfm')]), + (select_best_ref, petref_buffer, [('best_petref', 'petref')]), + (select_best_ref, summary, [('best_winner', 'registration_winner')]), + (select_best_ref, summary, [('best_label', 'petref_strategy')]), + ]) # fmt:skip + + pet_to_t1_source = select_best_ref + pet_to_t1_field = 'best_transform' + else: + # calculate PET registration to T1w + pet_reg_wf = init_pet_reg_wf( + pet2anat_dof=config.workflow.pet2anat_dof, + omp_nthreads=omp_nthreads, + mem_gb=mem_gb['resampled'], + pet2anat_method=config.workflow.pet2anat_method, + sloppy=config.execution.sloppy, + ) + + ds_petreg_wf = init_ds_registration_wf( + bids_root=layout.root, + output_dir=config.execution.petprep_dir, + source='petref', + dest='T1w', + name='ds_petreg_wf', + ) + + workflow.connect([ + (inputnode, pet_reg_wf, [ + ('t1w_preproc', 'inputnode.anat_preproc'), + ('t1w_mask', 'inputnode.anat_mask'), + ]), + (petref_buffer, pet_reg_wf, [('petref', 'inputnode.ref_pet_brain')]), + (val_pet, ds_petreg_wf, [('out_file', 'inputnode.source_files')]), + (pet_reg_wf, ds_petreg_wf, [('outputnode.itk_pet_to_t1', 'inputnode.xform')]), + (ds_petreg_wf, outputnode, [('outputnode.xform', 'petref2anat_xfm')]), + (pet_reg_wf, summary, [('outputnode.registration_winner', 'registration_winner')]), + ]) # fmt:skip + + pet_to_t1_source = pet_reg_wf + pet_to_t1_field = 'outputnode.itk_pet_to_t1' else: outputnode.inputs.petref2anat_xfm = petref2anat_xform @@ -699,12 +897,12 @@ def init_pet_fit_wf( petref_mask.inputs.thresh = 0.2 merge_mask = pe.Node(niu.Function(function=_binary_union), name='merge_mask') - if not petref2anat_xform: + if petref2anat_xform: + t1w_mask_tfm.inputs.transforms = petref2anat_xform + elif pet_to_t1_source and pet_to_t1_field: workflow.connect( - [(pet_reg_wf, t1w_mask_tfm, [('outputnode.itk_pet_to_t1', 'transforms')])] + [(pet_to_t1_source, t1w_mask_tfm, [(pet_to_t1_field, 'transforms')])] ) - else: - t1w_mask_tfm.inputs.transforms = petref2anat_xform workflow.connect( [ diff --git a/petprep/workflows/pet/registration.py b/petprep/workflows/pet/registration.py index b9bf3b35..876d7c31 100644 --- a/petprep/workflows/pet/registration.py +++ b/petprep/workflows/pet/registration.py @@ -48,8 +48,8 @@ def _select_best_transform(xfm_ants, xfm_fs, inv_ants, inv_fs, score_ants, score # Default to FreeSurfer branch if scores tie if score_ants > score_fs: - return xfm_ants, inv_ants, 'ants' - return xfm_fs, inv_fs, 'freesurfer' + return xfm_ants, inv_ants, 'ants', score_ants + return xfm_fs, inv_fs, 'freesurfer', score_fs def init_pet_reg_wf( @@ -130,10 +130,13 @@ def init_pet_reg_wf( ) outputnode = pe.Node( - niu.IdentityInterface(fields=['itk_pet_to_t1', 'itk_t1_to_pet', 'registration_winner']), + niu.IdentityInterface( + fields=['itk_pet_to_t1', 'itk_t1_to_pet', 'registration_winner', 'registration_score'] + ), name='outputnode', ) outputnode.inputs.registration_winner = None + outputnode.inputs.registration_score = None mask_brain = pe.Node(ApplyMask(), name='mask_brain') crop_anat_mask = pe.Node(MRIConvert(out_type='niigz'), name='crop_anat_mask') @@ -182,12 +185,24 @@ def init_pet_reg_wf( fs_warp = pe.Node(ApplyTransforms(float=True), name='warp_pet_fs') ants_score = pe.Node( - MeasureImageSimilarity(metric='Mattes', dimension=3), + MeasureImageSimilarity( + metric='Mattes', + dimension=3, + radius_or_number_of_bins=32, + sampling_strategy='Regular', + sampling_percentage=0.25, + ), name='score_ants', mem_gb=config.DEFAULT_MEMORY_MIN_GB, ) fs_score = pe.Node( - MeasureImageSimilarity(metric='Mattes', dimension=3), + MeasureImageSimilarity( + metric='Mattes', + dimension=3, + radius_or_number_of_bins=32, + sampling_strategy='Regular', + sampling_percentage=0.25, + ), name='score_fs', mem_gb=config.DEFAULT_MEMORY_MIN_GB, ) @@ -196,7 +211,7 @@ def init_pet_reg_wf( niu.Function( function=_select_best_transform, input_names=['xfm_ants', 'xfm_fs', 'inv_ants', 'inv_fs', 'score_ants', 'score_fs'], - output_names=['best_xfm', 'best_inv_xfm', 'winner'], + output_names=['best_xfm', 'best_inv_xfm', 'winner', 'best_score'], ), name='select_best', ) @@ -240,6 +255,7 @@ def init_pet_reg_wf( ('best_xfm', 'itk_pet_to_t1'), ('best_inv_xfm', 'itk_t1_to_pet'), ('winner', 'registration_winner'), + ('best_score', 'registration_score'), ]), ] ) # fmt:skip @@ -314,6 +330,18 @@ def init_pet_reg_wf( coreg_output_is_list = False convert_xfm = pe.Node(ConcatenateXFMs(inverse=True), name='convert_xfm') + warp_for_score = pe.Node(ApplyTransforms(float=True), name='warp_for_score') + similarity = pe.Node( + MeasureImageSimilarity( + metric='Mattes', + dimension=3, + radius_or_number_of_bins=32, + sampling_strategy='Regular', + sampling_percentage=0.25, + ), + name='score_registration', + mem_gb=config.DEFAULT_MEMORY_MIN_GB, + ) # Build connections dynamically based on output type if coreg_output_is_list: @@ -340,6 +368,14 @@ def init_pet_reg_wf( ('out_inv', 'itk_t1_to_pet'), ], ), + (inputnode, warp_for_score, [('ref_pet_brain', 'input_image')]), + (robust_fov, warp_for_score, [('out_roi', 'reference_image')]), + (convert_xfm, warp_for_score, [('out_xfm', 'transforms')]), + (warp_for_score, similarity, [('output_image', 'moving_image')]), + (mask_brain, similarity, [('out_file', 'fixed_image')]), + (crop_anat_mask, similarity, [('out_file', 'fixed_image_mask')]), + (crop_anat_mask, similarity, [('out_file', 'moving_image_mask')]), + (similarity, outputnode, [('similarity', 'registration_score')]), ] else: # mri_coreg and mri_robust_register output single transform file @@ -357,6 +393,14 @@ def init_pet_reg_wf( ('out_inv', 'itk_t1_to_pet'), ], ), + (inputnode, warp_for_score, [('ref_pet_brain', 'input_image')]), + (robust_fov, warp_for_score, [('out_roi', 'reference_image')]), + (convert_xfm, warp_for_score, [('out_xfm', 'transforms')]), + (warp_for_score, similarity, [('output_image', 'moving_image')]), + (mask_brain, similarity, [('out_file', 'fixed_image')]), + (crop_anat_mask, similarity, [('out_file', 'fixed_image_mask')]), + (crop_anat_mask, similarity, [('out_file', 'moving_image_mask')]), + (similarity, outputnode, [('similarity', 'registration_score')]), ] workflow.connect(