diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 1379db0..b7a1913 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -35,11 +35,11 @@ ) from mri_recon.reconstruction import ( ConjugateGradientReconstructor, - EXPLICIT_UNET_ALGORITHMS, OASISSinglecoilUnetReconstructor, choose_reconstructor, uses_oasis_centered_path, validate_algorithm_dataset_compatibility, + EXPLICIT_UNET_ALGORITHMS, ) from mri_recon.utils import ( OasisCenteredFFTPhysics, @@ -52,49 +52,54 @@ ) FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" +FASTMRI_MULTICOIL_REPORT_DIR = Path("reports") / "fastmri_multicoil_inference_plot" OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" +CMRXRECON_REPORT_DIR = Path("reports") / "cmrxrecon_inference_plot" FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) +FASTMRI_MULTICOIL_REPORT_DIR.mkdir(parents=True, exist_ok=True) OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) +CMRXRECON_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ - "zero-filled", - # "conjugate-gradient", + # "zero-filled", + "conjugate-gradient", # "ram", # "dip", - "tv-pgd", + # "tv-pgd", # "wavelet-fista", - "tv-fista", + # "tv-fista", # "tv-pdhg", *list(EXPLICIT_UNET_ALGORITHMS), ] DISTORTIONS = [ - "Cartesian undersampling (variable density)", - "Cartesian undersampling (uniform random)", - "Cartesian undersampling (uniform random, zero ACS)", - "Cartesian undersampling (equispaced)", + "no distortion", + # "Cartesian undersampling (variable density)", + # "Cartesian undersampling (uniform random)", + # "Cartesian undersampling (uniform random, zero ACS)", + # "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", - "Partial Fourier", - "Phase-encode ghosting", - "Segmented translation motion", - "Segmented rotational motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Gaussian noise", - "Isotropic LP", - "Radial high-pass emphasis", + # "Partial Fourier", + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Segmented rotational motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + # "Gaussian noise", + # "Isotropic LP", + # "Radial high-pass emphasis", ] METRICS = [ "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", + # "NMSE", + # "SSIM", + # "HaarPSI", + # "SharpnessIndex", + # "BlurStrength", ] @@ -211,6 +216,8 @@ def choose_distortion( return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) case "Gaussian noise": return GaussianNoiseDistortion(sigma=0.00001) + case "no distortion": + return BaseDistortion() case _: raise ValueError(f"Unknown distortion {name!r}") @@ -248,16 +255,48 @@ def prepare_measurement_sample( """ if dataset_name == "oasis": - reference_image = sample_batch["x"].to(run_device) - return reference_image, image_to_kspace(reference_image) - - # FastMRI batches are tuples such as (x, y) or (x, y, params). - y_fastmri = sample_batch[1].to(run_device) - if use_oasis_fft_path: - reference_image = fastmri_measurement_to_image(y_fastmri, device=run_device) - return reference_image, fastmri_measurement_to_oasis_kspace(y_fastmri, device=run_device) + x = sample_batch["x"].to(run_device) + y = image_to_kspace(x) + coil_maps = None + elif dataset_name in ("fastmri",) and use_oasis_fft_path: + y = sample_batch[1].to(run_device) + x = fastmri_measurement_to_image(y) + y = fastmri_measurement_to_oasis_kspace(y, device=run_device) + coil_maps = None + elif dataset_name == "fastmri_multicoil" and use_oasis_fft_path: + y = sample_batch[1].to(run_device) + + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + x = fastmri_measurement_to_image(y, coil_maps=coil_maps) + y = fastmri_measurement_to_oasis_kspace(y, coil_maps=coil_maps, device=run_device) + elif dataset_name in ("fastmri", "fastmri_multicoil"): + x = None + y = sample_batch[1].to(run_device) + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + elif dataset_name in ("cmrxrecon"): + x = sample_batch[0].to(run_device) + y = sample_batch[1].to(run_device) + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) - return None, y_fastmri + return x, y, coil_maps def build_physics_pair( @@ -265,6 +304,7 @@ def build_physics_pair( distortion_operator: BaseDistortion, run_device: torch.device | str, use_oasis_fft_path: bool, + coil_maps: torch.Tensor | None = None, ) -> tuple[object, object]: """Build clean and distorted physics operators for the active path.""" @@ -276,11 +316,13 @@ def build_physics_pair( clean_physics = DistortedKspaceMultiCoilMRI( distortion=BaseDistortion(), img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) distorted_physics = DistortedKspaceMultiCoilMRI( distortion=distortion_operator, img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) return clean_physics, distorted_physics @@ -295,7 +337,11 @@ def build_physics_pair( type=Path, help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) - parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri") + parser.add_argument( + "--dataset", + choices=("fastmri", "oasis", "fastmri_multicoil", "cmrxrecon"), + default="fastmri", + ) parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) parser.add_argument( @@ -334,7 +380,16 @@ def build_physics_pair( validate_algorithm_dataset_compatibility(args.dataset, algo_name) # set up report dir - REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR + if args.dataset == "fastmri": + REPORT_DIR = FASTMRI_REPORT_DIR + elif args.dataset == "oasis": + REPORT_DIR = OASIS_REPORT_DIR + elif args.dataset == "fastmri_multicoil": + REPORT_DIR = FASTMRI_MULTICOIL_REPORT_DIR + elif args.dataset == "cmrxrecon": + REPORT_DIR = CMRXRECON_REPORT_DIR + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") # set up device, dataset, metrics device = dinv.utils.get_device() @@ -345,8 +400,23 @@ def build_physics_pair( split_csv=split_csv, sample_rate=0.6, ) - else: + elif args.dataset == "fastmri": dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") + elif args.dataset == "fastmri_multicoil": + dataset = dinv.datasets.FastMRISliceDataset( + str(args.source), + slice_index="middle", + transform=dinv.datasets.MRISliceTransform( + estimate_coil_maps=True, + acs=15, + ), + ) + elif args.dataset == "cmrxrecon": + dataset = dinv.datasets.CMRxReconSliceDataset( + str(args.source), data_dir="SingleCoil/Cine/TrainingSet/FullSample", apply_mask=False + ) + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") metrics = [choose_metric(m) for m in METRICS] for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): @@ -355,83 +425,91 @@ def build_physics_pair( break for algo_name in selected_algorithms: - use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) - x_reference, y = prepare_measurement_sample( - sample_batch=batch, - dataset_name=args.dataset, - use_oasis_fft_path=use_oasis_path, - run_device=device, - ) - algo = choose_reconstructor( - algo_name, - img_size=y.shape[-2:], - device=device, - verbose=args.verbose, - dataset=args.dataset, - ).to(device) - - for distortion_name in selected_distortions: - distortion = choose_distortion( - distortion_name, - keep_fraction=args.keep_fraction, - center_fraction=args.center_fraction, - cartesian_axis=-1 if use_oasis_path else -2, - ) - - physics_clean, physics = build_physics_pair( - image_shape=y.shape[-2:], - distortion_operator=distortion, - run_device=device, + try: + use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) + x_reference, y, coil_maps = prepare_measurement_sample( + sample_batch=batch, + dataset_name=args.dataset, use_oasis_fft_path=use_oasis_path, + run_device=device, ) - y_distorted = distortion.A(y) - - # generate reference reconstructions (CG) for both clean and distorted k-space - # without correction for the distortion, i.e. using physics_clean in both cases - if use_oasis_path: - x_clean = x_reference - x_distorted = kspace_to_image(y_distorted) - else: - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - print(f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}...") - - # actual reconstruction with the algo being evaluated - x_uncorrected = algo(y_distorted, physics_clean) - x_corrected = algo(y_distorted, physics) - - print("done!") - - dinv.utils.plot( - { - "Undistorted ksp, CG recon": x_clean, - "Distorted ksp, CG recon": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join( - f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" - for m in metrics - ), - "\n".join( - f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" - for m in metrics - ), - ], - show=False, - close=True, - suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, + algo = choose_reconstructor( + algo_name, + img_size=y.shape[-2:], + device=device, + verbose=args.verbose, + dataset=args.dataset, + ).to(device) + + for distortion_name in selected_distortions: + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if use_oasis_path else -2, + ) + + physics_clean, physics = build_physics_pair( + image_shape=y.shape[-2:], + distortion_operator=distortion, + run_device=device, + use_oasis_fft_path=use_oasis_path, + coil_maps=coil_maps, + ) + y_distorted = distortion.A(y) + + # generate reference reconstructions (CG) for both clean and distorted k-space + # without correction for the distortion, i.e. using physics_clean in both cases + if use_oasis_path: + x_clean = x_reference + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}..." + ) + + # actual reconstruction with the algo being evaluated + x_uncorrected = algo(y_distorted, physics_clean) + x_corrected = algo(y_distorted, physics) + + print("done!") + + dinv.utils.plot( + { + "Undistorted ksp, CG recon": x_clean, + "Distorted ksp, CG recon": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join( + f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" + for m in metrics + ), + "\n".join( + f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" + for m in metrics + ), + ], + show=False, + close=True, + suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + except Exception as e: + print( + f"Error processing algo {algo_name}, distortion {distortion_name}, sample {i}: {e}" ) diff --git a/mri_recon/distortions/base.py b/mri_recon/distortions/base.py index 61bc7f3..823e626 100644 --- a/mri_recon/distortions/base.py +++ b/mri_recon/distortions/base.py @@ -179,9 +179,9 @@ def A(self, x: torch.Tensor) -> torch.Tensor: y = y.squeeze(2) # remove coil dim if singlecoil return self.distortion(y) - def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + def A_adjoint(self, y: torch.Tensor, **kwargs) -> torch.Tensor: if len(y.shape) == (5 if self.three_d else 4): y = y.unsqueeze(2) # add coil dim if singlecoil y = self.distortion.A_adjoint(y) - return super().A_adjoint(y) + return super().A_adjoint(y, **kwargs) diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py index 0f833aa..a55b34f 100644 --- a/mri_recon/utils/oasis_adapter.py +++ b/mri_recon/utils/oasis_adapter.py @@ -169,6 +169,7 @@ def kspace_to_image(y: torch.Tensor) -> torch.Tensor: def fastmri_measurement_to_image( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, device: torch.device | str | None = None, ) -> torch.Tensor: """Convert FastMRI measurements to image space using the repo's native physics. @@ -177,6 +178,9 @@ def fastmri_measurement_to_image( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor, optional + Coil sensitivity maps with shape ``(B, 2, H, W)``. If provided, these will be applied to the image before the centered FFT, matching + the OASIS U-Net training setup. If not provided, the function will still return an image but without coil sensitivity modulation. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -191,6 +195,7 @@ def fastmri_measurement_to_image( physics = DistortedKspaceMultiCoilMRI( distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), + coil_maps=coil_maps, device=device, ) return physics.A_adjoint(y) @@ -198,6 +203,7 @@ def fastmri_measurement_to_image( def fastmri_measurement_to_oasis_kspace( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, device: torch.device | str | None = None, ) -> torch.Tensor: """Adapt FastMRI measurements to the centered OASIS k-space convention. @@ -206,6 +212,9 @@ def fastmri_measurement_to_oasis_kspace( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor, optional + Coil sensitivity maps with shape ``(B, 2, H, W)``. If provided, these will be applied to the image before the centered FFT, matching + the OASIS U-Net training setup. If not provided, the function will still return centered k-space but without coil sensitivity modulation. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -215,7 +224,7 @@ def fastmri_measurement_to_oasis_kspace( Centered OASIS-convention k-space tensor with shape ``(B, 2, H, W)``. """ - return image_to_kspace(fastmri_measurement_to_image(y, device=device)) + return image_to_kspace(fastmri_measurement_to_image(y, coil_maps=coil_maps, device=device)) class OasisCenteredFFTPhysics: diff --git a/mri_recon/utils/plot.py b/mri_recon/utils/plot.py index 068c222..b167482 100644 --- a/mri_recon/utils/plot.py +++ b/mri_recon/utils/plot.py @@ -9,9 +9,13 @@ def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to a log-magnitude image for visualization.""" + """Convert k-space tensor to a log-magnitude image for visualization. + NOTE: for multicoil, just plot the first coil data. + """ - if kspace.ndim == 4: + if kspace.ndim == 5: # multicoil + kspace = kspace[:, :, 0] + if kspace.ndim == 4: # batched kspace = kspace[0] if kspace.ndim != 3 or kspace.shape[0] != 2: raise ValueError(