From 83d5af09bf9e67f38a10fe226f0bf6b08fca9078 Mon Sep 17 00:00:00 2001 From: Andrewwango Date: Fri, 22 May 2026 09:35:21 +0100 Subject: [PATCH 1/6] allow multicoil dataset coil map estim + save to fastmri_multicoil --- examples/fastmri_inference_plot.py | 107 ++++++++++++++++++----------- mri_recon/distortions/base.py | 4 +- mri_recon/utils/plot.py | 8 ++- 3 files changed, 73 insertions(+), 46 deletions(-) diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 1379db0..8d05706 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -52,49 +52,51 @@ ) FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" +FASTMRI_MULTICOIL_REPORT_DIR = Path("reports") / "fastmri_multicoil_inference_plot" OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) +FASTMRI_MULTICOIL_REPORT_DIR.mkdir(parents=True, exist_ok=True) OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ - "zero-filled", - # "conjugate-gradient", + # "zero-filled", + "conjugate-gradient", # "ram", # "dip", - "tv-pgd", + # "tv-pgd", # "wavelet-fista", - "tv-fista", + # "tv-fista", # "tv-pdhg", - *list(EXPLICIT_UNET_ALGORITHMS), + # *list(EXPLICIT_UNET_ALGORITHMS), ] DISTORTIONS = [ "Cartesian undersampling (variable density)", - "Cartesian undersampling (uniform random)", - "Cartesian undersampling (uniform random, zero ACS)", - "Cartesian undersampling (equispaced)", - "Cartesian undersampling (equispaced, zero ACS)", - "Partial Fourier", - "Phase-encode ghosting", - "Segmented translation motion", - "Segmented rotational motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Gaussian noise", - "Isotropic LP", - "Radial high-pass emphasis", + # "Cartesian undersampling (uniform random)", + # "Cartesian undersampling (uniform random, zero ACS)", + # "Cartesian undersampling (equispaced)", + # "Cartesian undersampling (equispaced, zero ACS)", + # "Partial Fourier", + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Segmented rotational motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + # "Gaussian noise", + # "Isotropic LP", + # "Radial high-pass emphasis", ] METRICS = [ "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", + # "NMSE", + # "SSIM", + # "HaarPSI", + # "SharpnessIndex", + # "BlurStrength", ] @@ -248,16 +250,22 @@ def prepare_measurement_sample( """ if dataset_name == "oasis": - reference_image = sample_batch["x"].to(run_device) - return reference_image, image_to_kspace(reference_image) - - # FastMRI batches are tuples such as (x, y) or (x, y, params). - y_fastmri = sample_batch[1].to(run_device) - if use_oasis_fft_path: - reference_image = fastmri_measurement_to_image(y_fastmri, device=run_device) - return reference_image, fastmri_measurement_to_oasis_kspace(y_fastmri, device=run_device) - - return None, y_fastmri + x = sample_batch["x"].to(run_device) + y = image_to_kspace(x) + coil_maps = None + elif dataset_name in ("fastmri",) and use_oasis_fft_path: + y = sample_batch[1].to(run_device) + x = fastmri_measurement_to_image(y) + y = fastmri_measurement_to_oasis_kspace(y, device=run_device) + coil_maps = None + elif dataset_name == "fastmri_multicoil" and use_oasis_fft_path: + raise NotImplementedError("TODO allow use_oasis_fft_path to take fastmri_multicoil dataset") + elif dataset_name in ("fastmri", "fastmri_multicoil"): + x = None + y = sample_batch[1].to(run_device) + coil_maps = sample_batch[2]["coil_maps"].to(run_device) if isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 3 and "coil_maps" in sample_batch[2] else None + + return x, y, coil_maps def build_physics_pair( @@ -265,6 +273,7 @@ def build_physics_pair( distortion_operator: BaseDistortion, run_device: torch.device | str, use_oasis_fft_path: bool, + coil_maps: torch.Tensor | None = None, ) -> tuple[object, object]: """Build clean and distorted physics operators for the active path.""" @@ -276,11 +285,13 @@ def build_physics_pair( clean_physics = DistortedKspaceMultiCoilMRI( distortion=BaseDistortion(), img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) distorted_physics = DistortedKspaceMultiCoilMRI( distortion=distortion_operator, img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) return clean_physics, distorted_physics @@ -295,7 +306,7 @@ def build_physics_pair( type=Path, help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) - parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri") + parser.add_argument("--dataset", choices=("fastmri", "oasis", "fastmri_multicoil"), default="fastmri") parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) parser.add_argument( @@ -334,7 +345,14 @@ def build_physics_pair( validate_algorithm_dataset_compatibility(args.dataset, algo_name) # set up report dir - REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR + if args.dataset == "fastmri": + REPORT_DIR = FASTMRI_REPORT_DIR + elif args.dataset == "oasis": + REPORT_DIR = OASIS_REPORT_DIR + elif args.dataset == "fastmri_multicoil": + REPORT_DIR = FASTMRI_MULTICOIL_REPORT_DIR + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") # set up device, dataset, metrics device = dinv.utils.get_device() @@ -345,8 +363,12 @@ def build_physics_pair( split_csv=split_csv, sample_rate=0.6, ) - else: + elif args.dataset == "fastmri": dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") + elif args.dataset == "fastmri_multicoil": + dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle", transform=dinv.datasets.MRISliceTransform(estimate_coil_maps=True, acs=15,),) + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") metrics = [choose_metric(m) for m in METRICS] for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): @@ -356,7 +378,7 @@ def build_physics_pair( for algo_name in selected_algorithms: use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) - x_reference, y = prepare_measurement_sample( + x_reference, y, coil_maps = prepare_measurement_sample( sample_batch=batch, dataset_name=args.dataset, use_oasis_fft_path=use_oasis_path, @@ -383,6 +405,7 @@ def build_physics_pair( distortion_operator=distortion, run_device=device, use_oasis_fft_path=use_oasis_path, + coil_maps=coil_maps, ) y_distorted = distortion.A(y) diff --git a/mri_recon/distortions/base.py b/mri_recon/distortions/base.py index 61bc7f3..823e626 100644 --- a/mri_recon/distortions/base.py +++ b/mri_recon/distortions/base.py @@ -179,9 +179,9 @@ def A(self, x: torch.Tensor) -> torch.Tensor: y = y.squeeze(2) # remove coil dim if singlecoil return self.distortion(y) - def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + def A_adjoint(self, y: torch.Tensor, **kwargs) -> torch.Tensor: if len(y.shape) == (5 if self.three_d else 4): y = y.unsqueeze(2) # add coil dim if singlecoil y = self.distortion.A_adjoint(y) - return super().A_adjoint(y) + return super().A_adjoint(y, **kwargs) diff --git a/mri_recon/utils/plot.py b/mri_recon/utils/plot.py index 068c222..6c687f5 100644 --- a/mri_recon/utils/plot.py +++ b/mri_recon/utils/plot.py @@ -9,9 +9,13 @@ def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to a log-magnitude image for visualization.""" + """Convert k-space tensor to a log-magnitude image for visualization. + NOTE: for multicoil, just plot the first coil data. + """ - if kspace.ndim == 4: + if kspace.ndim == 5: # multicoil + kspace = kspace[:, :, 0] + if kspace.ndim == 4: # batched kspace = kspace[0] if kspace.ndim != 3 or kspace.shape[0] != 2: raise ValueError( From 17077c71f373cf83db6b97acc2240b9a175ed4f1 Mon Sep 17 00:00:00 2001 From: Andrewwango Date: Fri, 22 May 2026 09:50:40 +0100 Subject: [PATCH 2/6] also add cmrxrecon --- examples/fastmri_inference_plot.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 8d05706..74ab766 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -54,9 +54,11 @@ FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" FASTMRI_MULTICOIL_REPORT_DIR = Path("reports") / "fastmri_multicoil_inference_plot" OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" +CMRXRECON_REPORT_DIR = Path("reports") / "cmrxrecon_inference_plot" FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) FASTMRI_MULTICOIL_REPORT_DIR.mkdir(parents=True, exist_ok=True) OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) +CMRXRECON_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ # "zero-filled", "conjugate-gradient", @@ -264,7 +266,11 @@ def prepare_measurement_sample( x = None y = sample_batch[1].to(run_device) coil_maps = sample_batch[2]["coil_maps"].to(run_device) if isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 3 and "coil_maps" in sample_batch[2] else None - + elif dataset_name in ("cmrxrecon"): + x = sample_batch[0].to(run_device) + y = sample_batch[1].to(run_device) + coil_maps = sample_batch[2]["coil_maps"].to(run_device) if isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 3 and "coil_maps" in sample_batch[2] else None + return x, y, coil_maps @@ -306,7 +312,7 @@ def build_physics_pair( type=Path, help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) - parser.add_argument("--dataset", choices=("fastmri", "oasis", "fastmri_multicoil"), default="fastmri") + parser.add_argument("--dataset", choices=("fastmri", "oasis", "fastmri_multicoil", "cmrxrecon"), default="fastmri") parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) parser.add_argument( @@ -351,6 +357,8 @@ def build_physics_pair( REPORT_DIR = OASIS_REPORT_DIR elif args.dataset == "fastmri_multicoil": REPORT_DIR = FASTMRI_MULTICOIL_REPORT_DIR + elif args.dataset == "cmrxrecon": + REPORT_DIR = CMRXRECON_REPORT_DIR else: raise NotImplementedError(f"Invalid dataset: {args.dataset}") @@ -367,6 +375,8 @@ def build_physics_pair( dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") elif args.dataset == "fastmri_multicoil": dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle", transform=dinv.datasets.MRISliceTransform(estimate_coil_maps=True, acs=15,),) + elif args.dataset == "cmrxrecon": + dataset = dinv.datasets.CMRxReconSliceDataset(str(args.source), data_dir="SingleCoil/Cine/TrainingSet/FullSample", apply_mask=False) else: raise NotImplementedError(f"Invalid dataset: {args.dataset}") metrics = [choose_metric(m) for m in METRICS] From 15c53f88b426619af0fa2ffeff2bd11e9229f7cc Mon Sep 17 00:00:00 2001 From: Melanie Dohmen Date: Fri, 22 May 2026 09:17:34 +0000 Subject: [PATCH 3/6] run precommit hooks --- examples/fastmri_inference_plot.py | 36 +++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 74ab766..0864888 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -35,7 +35,6 @@ ) from mri_recon.reconstruction import ( ConjugateGradientReconstructor, - EXPLICIT_UNET_ALGORITHMS, OASISSinglecoilUnetReconstructor, choose_reconstructor, uses_oasis_centered_path, @@ -265,11 +264,23 @@ def prepare_measurement_sample( elif dataset_name in ("fastmri", "fastmri_multicoil"): x = None y = sample_batch[1].to(run_device) - coil_maps = sample_batch[2]["coil_maps"].to(run_device) if isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 3 and "coil_maps" in sample_batch[2] else None + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) elif dataset_name in ("cmrxrecon"): x = sample_batch[0].to(run_device) y = sample_batch[1].to(run_device) - coil_maps = sample_batch[2]["coil_maps"].to(run_device) if isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 3 and "coil_maps" in sample_batch[2] else None + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) return x, y, coil_maps @@ -312,7 +323,11 @@ def build_physics_pair( type=Path, help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) - parser.add_argument("--dataset", choices=("fastmri", "oasis", "fastmri_multicoil", "cmrxrecon"), default="fastmri") + parser.add_argument( + "--dataset", + choices=("fastmri", "oasis", "fastmri_multicoil", "cmrxrecon"), + default="fastmri", + ) parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) parser.add_argument( @@ -374,9 +389,18 @@ def build_physics_pair( elif args.dataset == "fastmri": dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") elif args.dataset == "fastmri_multicoil": - dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle", transform=dinv.datasets.MRISliceTransform(estimate_coil_maps=True, acs=15,),) + dataset = dinv.datasets.FastMRISliceDataset( + str(args.source), + slice_index="middle", + transform=dinv.datasets.MRISliceTransform( + estimate_coil_maps=True, + acs=15, + ), + ) elif args.dataset == "cmrxrecon": - dataset = dinv.datasets.CMRxReconSliceDataset(str(args.source), data_dir="SingleCoil/Cine/TrainingSet/FullSample", apply_mask=False) + dataset = dinv.datasets.CMRxReconSliceDataset( + str(args.source), data_dir="SingleCoil/Cine/TrainingSet/FullSample", apply_mask=False + ) else: raise NotImplementedError(f"Invalid dataset: {args.dataset}") metrics = [choose_metric(m) for m in METRICS] From 634d1cd3801c832a7165b36d586f3963b4d252c3 Mon Sep 17 00:00:00 2001 From: Melanie Dohmen Date: Fri, 22 May 2026 09:31:11 +0000 Subject: [PATCH 4/6] pre-commit on plot.py --- mri_recon/utils/plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mri_recon/utils/plot.py b/mri_recon/utils/plot.py index 6c687f5..b167482 100644 --- a/mri_recon/utils/plot.py +++ b/mri_recon/utils/plot.py @@ -13,9 +13,9 @@ def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: NOTE: for multicoil, just plot the first coil data. """ - if kspace.ndim == 5: # multicoil + if kspace.ndim == 5: # multicoil kspace = kspace[:, :, 0] - if kspace.ndim == 4: # batched + if kspace.ndim == 4: # batched kspace = kspace[0] if kspace.ndim != 3 or kspace.shape[0] != 2: raise ValueError( From 30baff3ea941eac4e5f187743eed0538fba1676a Mon Sep 17 00:00:00 2001 From: Melanie Dohmen Date: Fri, 22 May 2026 10:10:08 +0000 Subject: [PATCH 5/6] use oasis fft path on multi-coil data Co-authored-by: Copilot --- examples/fastmri_inference_plot.py | 182 ++++++++++++++++------------- mri_recon/utils/oasis_adapter.py | 11 +- 2 files changed, 110 insertions(+), 83 deletions(-) diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 0864888..bb07501 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -39,6 +39,7 @@ choose_reconstructor, uses_oasis_centered_path, validate_algorithm_dataset_compatibility, + EXPLICIT_UNET_ALGORITHMS, ) from mri_recon.utils import ( OasisCenteredFFTPhysics, @@ -67,15 +68,15 @@ # "wavelet-fista", # "tv-fista", # "tv-pdhg", - # *list(EXPLICIT_UNET_ALGORITHMS), + *list(EXPLICIT_UNET_ALGORITHMS), ] DISTORTIONS = [ - "Cartesian undersampling (variable density)", + # "Cartesian undersampling (variable density)", # "Cartesian undersampling (uniform random)", # "Cartesian undersampling (uniform random, zero ACS)", # "Cartesian undersampling (equispaced)", - # "Cartesian undersampling (equispaced, zero ACS)", + "Cartesian undersampling (equispaced, zero ACS)", # "Partial Fourier", # "Phase-encode ghosting", # "Segmented translation motion", @@ -260,7 +261,17 @@ def prepare_measurement_sample( y = fastmri_measurement_to_oasis_kspace(y, device=run_device) coil_maps = None elif dataset_name == "fastmri_multicoil" and use_oasis_fft_path: - raise NotImplementedError("TODO allow use_oasis_fft_path to take fastmri_multicoil dataset") + y = sample_batch[1].to(run_device) + + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + x = fastmri_measurement_to_image(y, coil_maps=coil_maps) + y = fastmri_measurement_to_oasis_kspace(y, coil_maps=coil_maps, device=run_device) elif dataset_name in ("fastmri", "fastmri_multicoil"): x = None y = sample_batch[1].to(run_device) @@ -411,84 +422,91 @@ def build_physics_pair( break for algo_name in selected_algorithms: - use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) - x_reference, y, coil_maps = prepare_measurement_sample( - sample_batch=batch, - dataset_name=args.dataset, - use_oasis_fft_path=use_oasis_path, - run_device=device, - ) - algo = choose_reconstructor( - algo_name, - img_size=y.shape[-2:], - device=device, - verbose=args.verbose, - dataset=args.dataset, - ).to(device) - - for distortion_name in selected_distortions: - distortion = choose_distortion( - distortion_name, - keep_fraction=args.keep_fraction, - center_fraction=args.center_fraction, - cartesian_axis=-1 if use_oasis_path else -2, - ) - - physics_clean, physics = build_physics_pair( - image_shape=y.shape[-2:], - distortion_operator=distortion, - run_device=device, + try: + use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) + x_reference, y, coil_maps = prepare_measurement_sample( + sample_batch=batch, + dataset_name=args.dataset, use_oasis_fft_path=use_oasis_path, - coil_maps=coil_maps, - ) - y_distorted = distortion.A(y) - - # generate reference reconstructions (CG) for both clean and distorted k-space - # without correction for the distortion, i.e. using physics_clean in both cases - if use_oasis_path: - x_clean = x_reference - x_distorted = kspace_to_image(y_distorted) - else: - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", - distortion_name, + run_device=device, ) - - print(f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}...") - - # actual reconstruction with the algo being evaluated - x_uncorrected = algo(y_distorted, physics_clean) - x_corrected = algo(y_distorted, physics) - - print("done!") - - dinv.utils.plot( - { - "Undistorted ksp, CG recon": x_clean, - "Distorted ksp, CG recon": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join( - f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" - for m in metrics - ), - "\n".join( - f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" - for m in metrics - ), - ], - show=False, - close=True, - suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, + algo = choose_reconstructor( + algo_name, + img_size=y.shape[-2:], + device=device, + verbose=args.verbose, + dataset=args.dataset, + ).to(device) + + for distortion_name in selected_distortions: + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if use_oasis_path else -2, + ) + + physics_clean, physics = build_physics_pair( + image_shape=y.shape[-2:], + distortion_operator=distortion, + run_device=device, + use_oasis_fft_path=use_oasis_path, + coil_maps=coil_maps, + ) + y_distorted = distortion.A(y) + + # generate reference reconstructions (CG) for both clean and distorted k-space + # without correction for the distortion, i.e. using physics_clean in both cases + if use_oasis_path: + x_clean = x_reference + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}..." + ) + + # actual reconstruction with the algo being evaluated + x_uncorrected = algo(y_distorted, physics_clean) + x_corrected = algo(y_distorted, physics) + + print("done!") + + dinv.utils.plot( + { + "Undistorted ksp, CG recon": x_clean, + "Distorted ksp, CG recon": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join( + f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" + for m in metrics + ), + "\n".join( + f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" + for m in metrics + ), + ], + show=False, + close=True, + suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + except Exception as e: + print( + f"Error processing algo {algo_name}, distortion {distortion_name}, sample {i}: {e}" ) diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py index 0f833aa..a55b34f 100644 --- a/mri_recon/utils/oasis_adapter.py +++ b/mri_recon/utils/oasis_adapter.py @@ -169,6 +169,7 @@ def kspace_to_image(y: torch.Tensor) -> torch.Tensor: def fastmri_measurement_to_image( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, device: torch.device | str | None = None, ) -> torch.Tensor: """Convert FastMRI measurements to image space using the repo's native physics. @@ -177,6 +178,9 @@ def fastmri_measurement_to_image( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor, optional + Coil sensitivity maps with shape ``(B, 2, H, W)``. If provided, these will be applied to the image before the centered FFT, matching + the OASIS U-Net training setup. If not provided, the function will still return an image but without coil sensitivity modulation. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -191,6 +195,7 @@ def fastmri_measurement_to_image( physics = DistortedKspaceMultiCoilMRI( distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), + coil_maps=coil_maps, device=device, ) return physics.A_adjoint(y) @@ -198,6 +203,7 @@ def fastmri_measurement_to_image( def fastmri_measurement_to_oasis_kspace( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, device: torch.device | str | None = None, ) -> torch.Tensor: """Adapt FastMRI measurements to the centered OASIS k-space convention. @@ -206,6 +212,9 @@ def fastmri_measurement_to_oasis_kspace( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor, optional + Coil sensitivity maps with shape ``(B, 2, H, W)``. If provided, these will be applied to the image before the centered FFT, matching + the OASIS U-Net training setup. If not provided, the function will still return centered k-space but without coil sensitivity modulation. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -215,7 +224,7 @@ def fastmri_measurement_to_oasis_kspace( Centered OASIS-convention k-space tensor with shape ``(B, 2, H, W)``. """ - return image_to_kspace(fastmri_measurement_to_image(y, device=device)) + return image_to_kspace(fastmri_measurement_to_image(y, coil_maps=coil_maps, device=device)) class OasisCenteredFFTPhysics: From 63324b94e2f782b9cf74ba586c730cd41274af63 Mon Sep 17 00:00:00 2001 From: Melanie Dohmen Date: Fri, 22 May 2026 10:28:46 +0000 Subject: [PATCH 6/6] allow no distortions for different reconstruction methods --- examples/fastmri_inference_plot.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index bb07501..b7a1913 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -72,6 +72,7 @@ ] DISTORTIONS = [ + "no distortion", # "Cartesian undersampling (variable density)", # "Cartesian undersampling (uniform random)", # "Cartesian undersampling (uniform random, zero ACS)", @@ -215,6 +216,8 @@ def choose_distortion( return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) case "Gaussian noise": return GaussianNoiseDistortion(sigma=0.00001) + case "no distortion": + return BaseDistortion() case _: raise ValueError(f"Unknown distortion {name!r}")