diff --git a/.gitignore b/.gitignore index 15c7e8d..80da5e0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,9 @@ data/ *_cache/ reports/ .python-version +*.ckpt +*.pt +manifest.json +# Ignore all yaml files in the examples folder except config.yaml as template: +examples/*.yaml +!examples/config.yaml diff --git a/examples/config.yaml b/examples/config.yaml new file mode 100644 index 0000000..7fbed4e --- /dev/null +++ b/examples/config.yaml @@ -0,0 +1,46 @@ +data: + "fastmri_knee": "/path/to/fastmri/singlecoil_val" + "oasis": "/path/to/oasis" + "fastmri_brain": "/path/to/fastmri/fastMRI_multicoil_brain_test" + "cmrxrecon": "/path/to/CMRxRecon" + "fastmri_prostate": "/path/to/fastmri/fastMRI_prostate_T2_IDS_001_020" + +distortions: + - "BaseDistortion" + - "CartesianUndersamplingVariableDensity" + - "CartesianUndersamplingUniformRandom" + - "CartesianUndersamplingUniformRandomZeroACS" + - "CartesianUndersamplingEquispaced" + - "CartesianUndersamplingEquispacedZeroACS" + - "PartialFourier" + - "PhaseEncodeGhosting" + - "SegmentedTranslationMotion" + - "SegmentedRotationalMotion" + - "TranslationMotion" + - "RotationalMotion" + - "OffCenterAnisotropicGaussianBiasField" + - "GaussianBiasField" + - "AnisotropicLP" + - "HannTaperLP" + - "KaiserTaperLP" + - "GaussianNoise" + - "IsotropicLP" + - "RadialHighPassEmphasis" +reconstruction_algorithms: + - "zero-filled" + - "conjugate-gradient" + #- "ram", + #- "dip", + #- "tv-pgd", + #- "wavelet-fista", + #- "tv-fista", + #- "tv-pdhg", + - "unet-fastmri" + - "unet-oasis-acceleration4" + #- "unet-oasis-acceleration8" + #- "unet-oasis-acceleration10" +num_samples: 1 +keep_fraction: 0.25 +center_fraction: 0.125 +verbose: true +results_dir: "/home/melanie.dohmen/ArtifactLab/reports/experiments_run1" diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 1379db0..d580324 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -35,11 +35,11 @@ ) from mri_recon.reconstruction import ( ConjugateGradientReconstructor, - EXPLICIT_UNET_ALGORITHMS, OASISSinglecoilUnetReconstructor, choose_reconstructor, uses_oasis_centered_path, - validate_algorithm_dataset_compatibility, + compatible_dataset_with_reconstructor, + EXPLICIT_UNET_ALGORITHMS, ) from mri_recon.utils import ( OasisCenteredFFTPhysics, @@ -52,49 +52,54 @@ ) FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" +FASTMRI_MULTICOIL_REPORT_DIR = Path("reports") / "fastmri_multicoil_inference_plot" OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" +CMRXRECON_REPORT_DIR = Path("reports") / "cmrxrecon_inference_plot" FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) +FASTMRI_MULTICOIL_REPORT_DIR.mkdir(parents=True, exist_ok=True) OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) +CMRXRECON_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ - "zero-filled", - # "conjugate-gradient", + # "zero-filled", + "conjugate-gradient", # "ram", # "dip", - "tv-pgd", + # "tv-pgd", # "wavelet-fista", - "tv-fista", + # "tv-fista", # "tv-pdhg", *list(EXPLICIT_UNET_ALGORITHMS), ] DISTORTIONS = [ - "Cartesian undersampling (variable density)", - "Cartesian undersampling (uniform random)", - "Cartesian undersampling (uniform random, zero ACS)", - "Cartesian undersampling (equispaced)", + "no distortion", + # "Cartesian undersampling (variable density)", + # "Cartesian undersampling (uniform random)", + # "Cartesian undersampling (uniform random, zero ACS)", + # "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", - "Partial Fourier", - "Phase-encode ghosting", - "Segmented translation motion", - "Segmented rotational motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Gaussian noise", - "Isotropic LP", - "Radial high-pass emphasis", + # "Partial Fourier", + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Segmented rotational motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + # "Gaussian noise", + # "Isotropic LP", + # "Radial high-pass emphasis", ] METRICS = [ "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", + # "NMSE", + # "SSIM", + # "HaarPSI", + # "SharpnessIndex", + # "BlurStrength", ] @@ -211,6 +216,8 @@ def choose_distortion( return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) case "Gaussian noise": return GaussianNoiseDistortion(sigma=0.00001) + case "no distortion": + return BaseDistortion() case _: raise ValueError(f"Unknown distortion {name!r}") @@ -248,16 +255,48 @@ def prepare_measurement_sample( """ if dataset_name == "oasis": - reference_image = sample_batch["x"].to(run_device) - return reference_image, image_to_kspace(reference_image) - - # FastMRI batches are tuples such as (x, y) or (x, y, params). - y_fastmri = sample_batch[1].to(run_device) - if use_oasis_fft_path: - reference_image = fastmri_measurement_to_image(y_fastmri, device=run_device) - return reference_image, fastmri_measurement_to_oasis_kspace(y_fastmri, device=run_device) + x = sample_batch["x"].to(run_device) + y = image_to_kspace(x) + coil_maps = None + elif dataset_name in ("fastmri",) and use_oasis_fft_path: + y = sample_batch[1].to(run_device) + x = fastmri_measurement_to_image(y) + y = fastmri_measurement_to_oasis_kspace(y, device=run_device) + coil_maps = None + elif dataset_name == "fastmri_multicoil" and use_oasis_fft_path: + y = sample_batch[1].to(run_device) + + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + x = fastmri_measurement_to_image(y, coil_maps=coil_maps) + y = fastmri_measurement_to_oasis_kspace(y, coil_maps=coil_maps, device=run_device) + elif dataset_name in ("fastmri", "fastmri_multicoil"): + x = None + y = sample_batch[1].to(run_device) + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + elif dataset_name in ("cmrxrecon"): + x = sample_batch[0].to(run_device) + y = sample_batch[1].to(run_device) + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) - return None, y_fastmri + return x, y, coil_maps def build_physics_pair( @@ -265,6 +304,7 @@ def build_physics_pair( distortion_operator: BaseDistortion, run_device: torch.device | str, use_oasis_fft_path: bool, + coil_maps: torch.Tensor | None = None, ) -> tuple[object, object]: """Build clean and distorted physics operators for the active path.""" @@ -276,11 +316,13 @@ def build_physics_pair( clean_physics = DistortedKspaceMultiCoilMRI( distortion=BaseDistortion(), img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) distorted_physics = DistortedKspaceMultiCoilMRI( distortion=distortion_operator, img_size=(1, 2, *image_shape), + coil_maps=coil_maps, device=run_device, ) return clean_physics, distorted_physics @@ -295,7 +337,11 @@ def build_physics_pair( type=Path, help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) - parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri") + parser.add_argument( + "--dataset", + choices=("fastmri", "oasis", "fastmri_multicoil", "cmrxrecon"), + default="fastmri", + ) parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) parser.add_argument( @@ -330,11 +376,25 @@ def build_physics_pair( selected_algorithms = ALGORITHMS if args.algorithm == "" else [args.algorithm] selected_distortions = DISTORTIONS if args.distortion == "" else [args.distortion] - for algo_name in selected_algorithms: - validate_algorithm_dataset_compatibility(args.dataset, algo_name) + + # skip non-compatible algorithm-dataset pairs + selected_algorithms = [ + algo_name + for algo_name in selected_algorithms + if compatible_dataset_with_reconstructor(args.dataset, algo_name) + ] # set up report dir - REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR + if args.dataset == "fastmri": + REPORT_DIR = FASTMRI_REPORT_DIR + elif args.dataset == "oasis": + REPORT_DIR = OASIS_REPORT_DIR + elif args.dataset == "fastmri_multicoil": + REPORT_DIR = FASTMRI_MULTICOIL_REPORT_DIR + elif args.dataset == "cmrxrecon": + REPORT_DIR = CMRXRECON_REPORT_DIR + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") # set up device, dataset, metrics device = dinv.utils.get_device() @@ -345,8 +405,23 @@ def build_physics_pair( split_csv=split_csv, sample_rate=0.6, ) - else: + elif args.dataset == "fastmri": dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") + elif args.dataset == "fastmri_multicoil": + dataset = dinv.datasets.FastMRISliceDataset( + str(args.source), + slice_index="middle", + transform=dinv.datasets.MRISliceTransform( + estimate_coil_maps=True, + acs=15, + ), + ) + elif args.dataset == "cmrxrecon": + dataset = dinv.datasets.CMRxReconSliceDataset( + str(args.source), data_dir="SingleCoil/Cine/TrainingSet/FullSample", apply_mask=False + ) + else: + raise NotImplementedError(f"Invalid dataset: {args.dataset}") metrics = [choose_metric(m) for m in METRICS] for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): @@ -355,83 +430,91 @@ def build_physics_pair( break for algo_name in selected_algorithms: - use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) - x_reference, y = prepare_measurement_sample( - sample_batch=batch, - dataset_name=args.dataset, - use_oasis_fft_path=use_oasis_path, - run_device=device, - ) - algo = choose_reconstructor( - algo_name, - img_size=y.shape[-2:], - device=device, - verbose=args.verbose, - dataset=args.dataset, - ).to(device) - - for distortion_name in selected_distortions: - distortion = choose_distortion( - distortion_name, - keep_fraction=args.keep_fraction, - center_fraction=args.center_fraction, - cartesian_axis=-1 if use_oasis_path else -2, - ) - - physics_clean, physics = build_physics_pair( - image_shape=y.shape[-2:], - distortion_operator=distortion, - run_device=device, + try: + use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) + x_reference, y, coil_maps = prepare_measurement_sample( + sample_batch=batch, + dataset_name=args.dataset, use_oasis_fft_path=use_oasis_path, + run_device=device, ) - y_distorted = distortion.A(y) - - # generate reference reconstructions (CG) for both clean and distorted k-space - # without correction for the distortion, i.e. using physics_clean in both cases - if use_oasis_path: - x_clean = x_reference - x_distorted = kspace_to_image(y_distorted) - else: - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - print(f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}...") - - # actual reconstruction with the algo being evaluated - x_uncorrected = algo(y_distorted, physics_clean) - x_corrected = algo(y_distorted, physics) - - print("done!") - - dinv.utils.plot( - { - "Undistorted ksp, CG recon": x_clean, - "Distorted ksp, CG recon": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join( - f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" - for m in metrics - ), - "\n".join( - f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" - for m in metrics - ), - ], - show=False, - close=True, - suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, + algo = choose_reconstructor( + algo_name, + img_size=y.shape[-2:], + device=device, + verbose=args.verbose, + dataset=args.dataset, + ).to(device) + + for distortion_name in selected_distortions: + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if use_oasis_path else -2, + ) + + physics_clean, physics = build_physics_pair( + image_shape=y.shape[-2:], + distortion_operator=distortion, + run_device=device, + use_oasis_fft_path=use_oasis_path, + coil_maps=coil_maps, + ) + y_distorted = distortion.A(y) + + # generate reference reconstructions (CG) for both clean and distorted k-space + # without correction for the distortion, i.e. using physics_clean in both cases + if use_oasis_path: + x_clean = x_reference + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}..." + ) + + # actual reconstruction with the algo being evaluated + x_uncorrected = algo(y_distorted, physics_clean) + x_corrected = algo(y_distorted, physics) + + print("done!") + + dinv.utils.plot( + { + "Undistorted ksp, CG recon": x_clean, + "Distorted ksp, CG recon": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join( + f"{m.__class__.__name__} {m(x_uncorrected, x_clean).item():.2f}" + for m in metrics + ), + "\n".join( + f"{m.__class__.__name__} {m(x_corrected, x_clean).item():.2f}" + for m in metrics + ), + ], + show=False, + close=True, + suptitle=f"Algo {algo_name}, distortion {distortion_name}, Sample {i}", + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + except Exception as e: + print( + f"Error processing algo {algo_name}, distortion {distortion_name}, sample {i}: {e}" ) diff --git a/examples/run_all.py b/examples/run_all.py new file mode 100644 index 0000000..e051129 --- /dev/null +++ b/examples/run_all.py @@ -0,0 +1,363 @@ +"""Inference various reconstructors for various distortion operators. + +Usage: + python examples/run_all.py config.yaml +""" + +import os +import sys + + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from datetime import datetime +import deepinv as dinv +import torch +import yaml +from tifffile import imwrite + +from mri_recon.distortions import ( + BaseDistortion, + DistortedKspaceMultiCoilMRI, + choose_distortion, +) +from mri_recon.reconstruction import ( + choose_reconstructor, + uses_oasis_centered_path, + compatible_dataset_with_reconstructor, +) +from mri_recon.utils import ( + OasisCenteredFFTPhysics, + OasisCenterSliceFolderDataset, + FastMRIProstateDataset, + fastmri_measurement_to_oasis_kspace, + oasis_kspace_to_fastmri_measurement, + image_to_kspace, + _kspace_to_log_magnitude, + convert_image_for_save, +) + + +def get_measurement_sample( + sample_batch: object, + dataset_name: str, + run_device: torch.device | str, +) -> tuple[torch.Tensor | None, torch.Tensor]: + """Prepare one input measurement and its clean image reference. + + Always prepare a (fast-MRI-like) non-centered k-space measurement + as well as a (oasis-like) centered k-space version of the measurement + and a reference reconstruction in the image domain. + """ + coil_maps = None + if dataset_name == "oasis": + # reference image, shape: (B, 2, H, W) dtype: float32 + x = sample_batch["x"].to(run_device) + # centered k-space data, shape: (B, 2, H, W) dtype: float32 + y_centered = image_to_kspace(x) + # k-space data, shape: (B, 2, H, W) dtype: float32 + y = oasis_kspace_to_fastmri_measurement(y_centered) + elif dataset_name == "fastmri_knee": + # reference image, shape: (B, 1, H/2, H/2) dtype: float32 + x = sample_batch[0].to(run_device) + # kspace data, shape: (B, 2, H, W) dtype: float32 + y = sample_batch[1].to(run_device) + # centered k-space data, shape: (B, 2, H, W) dtype: float32 + y_centered = fastmri_measurement_to_oasis_kspace(y, device=run_device) + # reconstructed reference image: + # shape: (B, 1, H, W) dtype: float32 + elif dataset_name == "fastmri_brain": + # reference image, shape: (B, 1, H/2, H/2) dtype: float32 + x = sample_batch[0].to(run_device) + # kspace data, shape: (B, 2, num_coils, H, W) dtype: float32 + y = sample_batch[1].to(run_device) + # coil maps, shape: (B, num_coils, H, W) dtype: complex64 + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + # centered k-space data, shape: (B, 2, H, W) dtype: float32 + y_centered = fastmri_measurement_to_oasis_kspace(y, coil_maps=coil_maps, device=run_device) + + elif dataset_name == "cmrxrecon": + # reference image, shape: (B, 2, n_timepoints, (n_coils), H, W) + x = sample_batch[0].to(run_device) + # k-space data, shape: (B, 2, n_timepoints, (n_coils), H, W) dtype: float32 + y = sample_batch[1].to(run_device) + # not available for all samples, either None or + # shape (1, num_coils, H, W) + coil_maps = ( + sample_batch[2]["coil_maps"].to(run_device) + if isinstance(sample_batch, (tuple, list)) + and len(sample_batch) == 3 + and "coil_maps" in sample_batch[2] + else None + ) + + # centered k-space data, shape: (B, 2, num_clois, H, W) dtype: float32 + y_centered = fastmri_measurement_to_oasis_kspace(y, coil_maps=coil_maps, device=run_device) + + elif dataset_name == "fastmri_prostate": + # reference image, shape: (B, W, H): dtype float32 + x = sample_batch[0].to(run_device) + + # add zero imaginary channel: + # (B, H, W) -> (B, 2, H, W) + x = torch.stack([x, torch.zeros_like(x)], dim=1) + + # (B, 2, H, W) + y_centered = image_to_kspace(x) + y = oasis_kspace_to_fastmri_measurement(y_centered) + + return x, y, y_centered, coil_maps + + +if __name__ == "__main__": + # read config file in yaml format as first argument from commmand line + if len(sys.argv) < 2: + print("Usage: python examples/run_all.py ") + sys.exit(1) + + with open(sys.argv[1], "r") as f: + config = yaml.safe_load(f) + + os.makedirs(config["results_dir"], exist_ok=True) + + # set up device + device = dinv.utils.get_device() + + for dataset_name, dataset_rootdir in config["data"].items(): + print(f"=== {dataset_name} ===") + + # initialize dataset + if dataset_name == "oasis": + dataset = OasisCenterSliceFolderDataset( + data_path=dataset_rootdir, + ) + elif dataset_name == "fastmri_knee": + dataset = dinv.datasets.FastMRISliceDataset(str(dataset_rootdir), slice_index="middle") + elif dataset_name == "fastmri_brain": + dataset = dinv.datasets.FastMRISliceDataset( + str(dataset_rootdir), + slice_index="middle", + transform=dinv.datasets.MRISliceTransform( + estimate_coil_maps=True, + acs=15, + ), + ) + elif dataset_name == "cmrxrecon": + dataset = dinv.datasets.CMRxReconSliceDataset( + str(dataset_rootdir), + data_dir="SingleCoil/Cine/TrainingSet/FullSample", + apply_mask=False, + ) + elif dataset_name == "fastmri_prostate": + dataset = FastMRIProstateDataset( + data_path=dataset_rootdir, num_samples=config["num_samples"], slice_index="middle" + ) + else: + raise NotImplementedError(f"Invalid dataset: {dataset_name}") + + # loop through samples of dataset + for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): + # exit loop if we have processed the specified number of samples + if i >= config["num_samples"]: + break + + print(f"{dataset_name} sample {i}...") + x_reference, y, y_centered, coil_maps = get_measurement_sample( + sample_batch=batch, + dataset_name=dataset_name, + run_device=device, + ) + + # save reference image + imwrite( + os.path.join( + config["results_dir"], f"image_{dataset_name}_sample_{i}_reference.tiff" + ), + convert_image_for_save(x_reference), + ) + + # use fast-mri type samples first, later proceed with oasis-centered fft path + physics_clean = DistortedKspaceMultiCoilMRI( + BaseDistortion(), img_size=y.shape[-2:], coil_maps=coil_maps, device=device + ) + + # reference from dataset: + for distortion_name in config["distortions"]: + print(f"\t{distortion_name} ...") + + distortion = choose_distortion( + distortion_name, + keep_fraction=config["keep_fraction"], + center_fraction=config["center_fraction"], + cartesian_axis=-2, + ) + + y_distorted = distortion.A(y) + + physics_distorted = DistortedKspaceMultiCoilMRI( + distortion, + img_size=y.shape[-2:], + coil_maps=coil_maps, + device=device, + ) + + for reconstructor_name in config["reconstruction_algorithms"]: + # only run on reconstructors, that use the fastmri-like k-space + if not uses_oasis_centered_path(reconstructor_name): + print(f"\t\t{reconstructor_name} ...") + start = datetime.now() + if compatible_dataset_with_reconstructor(dataset_name, reconstructor_name): + reconstructor = choose_reconstructor( + reconstructor_name, + img_size=y_distorted.shape[-2:], + device=device, + verbose=config["verbose"], + ).to(device) + + # save reference and distorted k-space for debugging purposes + imwrite( + os.path.join( + config["results_dir"], + f"kspace_{dataset_name}_sample_{i}_reference.tiff", + ), + _kspace_to_log_magnitude(y).numpy(), + ) + imwrite( + os.path.join( + config["results_dir"], + f"kspace_{dataset_name}_sample_{i}_{distortion_name}.tiff", + ), + _kspace_to_log_magnitude(y_distorted).numpy(), + ) + + # actual reconstruction with the selected reconstructor + try: + x_uncorrected = reconstructor(y_distorted, physics_clean) + x_corrected = reconstructor(y_distorted, physics_distorted) + + # crop recostructed image to reference image size: + if x_uncorrected.shape[-2:] != x_reference.shape[-2:]: + x_uncorrected = physics_clean.crop( + x_uncorrected, shape=x_reference.shape[-2:] + ) + + if x_corrected.shape[-2:] != x_reference.shape[-2:]: + x_corrected = physics_distorted.crop( + x_corrected, shape=x_reference.shape[-2:] + ) + + # save reconstructed images + imwrite( + os.path.join( + config["results_dir"], + f"image_{dataset_name}_sample_{i}_{distortion_name}_{reconstructor_name}_uncorrected.tiff", + ), + convert_image_for_save(x_uncorrected), + ) + imwrite( + os.path.join( + config["results_dir"], + f"image_{dataset_name}_sample_{i}_{distortion_name}_{reconstructor_name}_corrected.tiff", + ), + convert_image_for_save(x_corrected), + ) + print(f"\t\t... done in {datetime.now() - start}") + + except Exception as e: + print(f"Error using {reconstructor_name}: {e}") + + else: + print(f"\t\t ... not compatible with {dataset_name}") + + # now proceed with oasis-centered fft path + physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) + + for distortion_name in config["distortions"]: + print(f"\t{distortion_name} ...") + distortion = choose_distortion( + distortion_name, + keep_fraction=config["keep_fraction"], + center_fraction=config["center_fraction"], + cartesian_axis=-1, + ) + + y_distorted = torch.fft.fftshift( + distortion.A(torch.fft.fftshift(y_centered, dim=(-1, -2))), dim=(-2, -1) + ) + + physics_distorted = OasisCenteredFFTPhysics(distortion) + + for reconstructor_name in config["reconstruction_algorithms"]: + # skip all reconstructors, that don't use the oasis-centered path + if uses_oasis_centered_path(reconstructor_name): + print(f"\t\t{reconstructor_name} ...") + start = datetime.now() + if compatible_dataset_with_reconstructor(dataset_name, reconstructor_name): + reconstructor = choose_reconstructor( + reconstructor_name, + img_size=y_distorted.shape[-2:], + device=device, + verbose=config["verbose"], + ).to(device) + + # save reference and distorted k-space for debugging purposes + imwrite( + os.path.join( + config["results_dir"], + f"kspace_centered_{dataset_name}_sample_{i}_reference.tiff", + ), + _kspace_to_log_magnitude(y_centered).numpy(), + ) + imwrite( + os.path.join( + config["results_dir"], + f"kspace_centered_{dataset_name}_sample_{i}_{distortion_name}.tiff", + ), + _kspace_to_log_magnitude(y_distorted).numpy(), + ) + + # actual reconstruction with the algo being evaluated + try: + x_uncorrected = reconstructor(y_distorted, physics_clean) + x_corrected = reconstructor(y_distorted, physics_distorted) + + if x_uncorrected.shape[-2:] != x_reference.shape[-2:]: + x_uncorrected = physics_clean.crop( + x_uncorrected, shape=x_reference.shape[-2:] + ) + + if x_corrected.shape[-2:] != x_reference.shape[-2:]: + x_corrected = physics_distorted.crop( + x_corrected, shape=x_reference.shape[-2:] + ) + + # save reconstructed images + imwrite( + os.path.join( + config["results_dir"], + f"image_{dataset_name}_sample_{i}_{distortion_name}_{reconstructor_name}_uncorrected.tiff", + ), + convert_image_for_save(x_uncorrected), + ) + imwrite( + os.path.join( + config["results_dir"], + f"image_{dataset_name}_sample_{i}_{distortion_name}_{reconstructor_name}_corrected.tiff", + ), + convert_image_for_save(x_corrected), + ) + print(f"\t\t... done in {datetime.now() - start}") + + except Exception as e: + print( + f"\t\tError using {reconstructor_name} with distortion {distortion_name} on sample {i}: {e}" + ) + + else: + print(f"\t\t ... not compatible with {dataset_name}") diff --git a/mri_recon/distortions/__init__.py b/mri_recon/distortions/__init__.py index af22fd1..ea85943 100644 --- a/mri_recon/distortions/__init__.py +++ b/mri_recon/distortions/__init__.py @@ -20,3 +20,4 @@ RadialHighPassEmphasisDistortion, ) from .undersampling import CartesianUndersampling, PartialFourierDistortion +from .utils import choose_distortion diff --git a/mri_recon/distortions/base.py b/mri_recon/distortions/base.py index 61bc7f3..d89efd0 100644 --- a/mri_recon/distortions/base.py +++ b/mri_recon/distortions/base.py @@ -176,12 +176,18 @@ def __init__(self, distortion: BaseDistortion = None, *args, **kwargs): def A(self, x: torch.Tensor) -> torch.Tensor: y = super().A(x) - y = y.squeeze(2) # remove coil dim if singlecoil + + y = y.squeeze(2) # remove coil dimension if single coil + return self.distortion(y) def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: if len(y.shape) == (5 if self.three_d else 4): - y = y.unsqueeze(2) # add coil dim if singlecoil + y = y.unsqueeze(2) # add coil dimension for single coil y = self.distortion.A_adjoint(y) + return super().A_adjoint(y) + + # def A_dagger(self, y: torch.Tensor, **kwargs) -> torch.Tensor: + # return super().A_dagger(y, coil_maps=self.coil_maps, **kwargs) diff --git a/mri_recon/distortions/utils.py b/mri_recon/distortions/utils.py new file mode 100644 index 0000000..96d00b7 --- /dev/null +++ b/mri_recon/distortions/utils.py @@ -0,0 +1,140 @@ +import torch + +from .base import BaseDistortion +from .resolution import ( + HannTaperResolutionReduction, + IsotropicResolutionReduction, + AnisotropicResolutionReduction, + KaiserTaperResolutionReduction, + RadialHighPassEmphasisDistortion, +) +from .undersampling import CartesianUndersampling, PartialFourierDistortion +from .biasfield import OffCenterAnisotropicGaussianKspaceBiasField, GaussianKspaceBiasField +from .noise import GaussianNoiseDistortion +from .motion import ( + RotationalMotionDistortion, + SegmentedRotationalMotionDistortion, + TranslationMotionDistortion, + SegmentedTranslationMotionDistortion, +) +from .ghosting import PhaseEncodeGhostingDistortion + + +def choose_distortion( + name: str, + keep_fraction: float = 0.25, + center_fraction: float = 0.125, + cartesian_axis: int = -2, +) -> BaseDistortion: + """Build one distortion operator for the inference comparison script. + + The ``cartesian_axis`` is supplied by the active measurement convention: + FastMRI-native runs use the repository's existing axis, while OASIS-native + and FastMRI-to-OASIS runs use the centered OASIS axis. + """ + + match name: + case "PhaseEncodeGhosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "CartesianUndersamplingVariableDensity": + return CartesianUndersampling( + keep_fraction=keep_fraction, + center_fraction=center_fraction, + pattern="variable_density_random", + axis=cartesian_axis, + seed=42, + ) + case "CartesianUndersamplingUniformRandom": + return CartesianUndersampling( + keep_fraction=keep_fraction, + center_fraction=center_fraction, + pattern="uniform_random", + axis=cartesian_axis, + seed=42, + ) + case "CartesianUndersamplingUniformRandomZeroACS": + return CartesianUndersampling( + keep_fraction=keep_fraction, + center_fraction=0.0, + pattern="uniform_random", + axis=cartesian_axis, + seed=42, + ) + case "CartesianUndersamplingEquispaced": + return CartesianUndersampling( + keep_fraction=keep_fraction, + center_fraction=center_fraction, + pattern="equispaced", + axis=cartesian_axis, + seed=42, + ) + case "CartesianUndersamplingEquispacedZeroACS": + return CartesianUndersampling( + keep_fraction=keep_fraction, + center_fraction=0.0, + pattern="equispaced", + axis=cartesian_axis, + seed=42, + ) + case "PartialFourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=center_fraction, + axis=cartesian_axis, + side="high", + ) + case "AnisotropicLP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.25, + ) + case "HannTaperLP": + return HannTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + ) + case "KaiserTaperLP": + return KaiserTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + beta=8.6, + ) + case "RadialHighPassEmphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "IsotropicLP": + return IsotropicResolutionReduction(radius_fraction=0.1) + case "OffCenterAnisotropicGaussianKspaceBiasField": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "TranslationMotion": + return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) + case "RotationalMotion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "SegmentedRotationalMotion": + return SegmentedRotationalMotionDistortion( + angle_radians=(0.0, torch.pi / 20, -torch.pi / 24, torch.pi / 16), + ) + case "SegmentedTranslationMotion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 20.0, 50.0, -50.0), + shift_y_pixels=(0.0, 10.0, -20.0, 20.0), + ) + case "GaussianKspaceBiasField": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "GaussianNoise": + return GaussianNoiseDistortion(sigma=0.00001) + case "BaseDistortion": + return BaseDistortion() + + case _: + raise ValueError(f"Unknown distortion {name!r}") diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index 5eb75ea..fef846c 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -10,7 +10,7 @@ OASIS_UNET_ALGORITHMS, choose_reconstructor, uses_oasis_centered_path, - validate_algorithm_dataset_compatibility, + compatible_dataset_with_reconstructor, ) from .classic import ( ZeroFilledReconstructor, diff --git a/mri_recon/reconstruction/inference.py b/mri_recon/reconstruction/inference.py index b0fe1d3..8155ff6 100644 --- a/mri_recon/reconstruction/inference.py +++ b/mri_recon/reconstruction/inference.py @@ -28,29 +28,36 @@ def uses_oasis_centered_path( - dataset: str, algorithm: str, ) -> bool: """Return whether inference should use the centered OASIS k-space path. - OASIS samples always use the centered FFT convention. FastMRI only switches - to that path when the selected algorithm is one of the explicit OASIS U-Net + The centered FFT path is only used with the explicit OASIS U-Net variants. """ - - if dataset == "oasis": - return True return algorithm in OASIS_UNET_ALGORITHMS -def validate_algorithm_dataset_compatibility(dataset: str, algorithm: str) -> None: - """Raise a clear error when an explicit algorithm is incompatible with a dataset.""" +def compatible_dataset_with_reconstructor(dataset: str, reconstructor_name: str) -> bool: + """Check if dataset and trained reconstructor are compatible""" - if dataset == "oasis" and algorithm == FASTMRI_UNET_ALGORITHM: - raise ValueError( - "The algorithm 'unet-fastmri' is not supported on the OASIS dataset. " - "Use one of the explicit OASIS U-Net algorithms instead." - ) + # fast mri u-net is only trained with knee data + if reconstructor_name == FASTMRI_UNET_ALGORITHM: + if dataset == "fastmri_knee": + return True + else: + return False + + # oasis is only trained with brain data + elif reconstructor_name in OASIS_UNET_ALGORITHMS: + if dataset in ["fastmri_brain", "oasis"]: + return True + else: + return False + + # all other (classic) reconstructors work with any dataset: + else: + return True def choose_reconstructor( @@ -58,7 +65,7 @@ def choose_reconstructor( img_size: tuple = (640, 368), device: torch.device | str = "cpu", verbose: bool = False, - dataset: str = "fastmri", + dataset: str | None = None, ) -> dinv.models.Reconstructor: """Create a reconstructor while enforcing the supported dataset/model matrix. @@ -78,7 +85,10 @@ def choose_reconstructor( explicit algorithm names that are dataset-specific. """ - validate_algorithm_dataset_compatibility(dataset, name) + if dataset is not None and not compatible_dataset_with_reconstructor(dataset, name): + raise ValueError( + f"Reconstructor {name} is not compatible with dataset {dataset}, because it was trained with a different image domain." + ) match name: case "zero-filled": diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 744c860..a0648ad 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -4,13 +4,21 @@ from .io import matches_sha256 as matches_sha256 from .oasis_adapter import OasisCenteredFFTPhysics as OasisCenteredFFTPhysics from .oasis_adapter import OasisSliceDataset as OasisSliceDataset +from .oasis_adapter import OasisCenterSliceFolderDataset as OasisCenterSliceFolderDataset from .oasis_adapter import fastmri_measurement_to_image as fastmri_measurement_to_image from .oasis_adapter import ( fastmri_measurement_to_oasis_kspace as fastmri_measurement_to_oasis_kspace, ) +from .oasis_adapter import ( + oasis_kspace_to_fastmri_measurement as oasis_kspace_to_fastmri_measurement, +) from .oasis_adapter import image_to_kspace as image_to_kspace from .oasis_adapter import kspace_to_image as kspace_to_image +from .oasis_adapter import image_to_fastmri_measurement as image_to_fastmri_measurement +from .prostate_adaptor import FastMRIProstateDataset as FastMRIProstateDataset from .plot import save_kspace_plot as save_kspace_plot +from .plot import _kspace_to_log_magnitude as _kspace_to_log_magnitude +from .plot import convert_image_for_save as convert_image_for_save __all__ = [ "download_file_with_sha256", @@ -23,5 +31,6 @@ "matches_sha256", "OasisCenteredFFTPhysics", "OasisSliceDataset", + "FastMRIProstateDataset", "save_kspace_plot", ] diff --git a/mri_recon/utils/grappa.py b/mri_recon/utils/grappa.py new file mode 100644 index 0000000..a2cbe9c --- /dev/null +++ b/mri_recon/utils/grappa.py @@ -0,0 +1,215 @@ +from typing import Dict, Tuple +import numpy as np +from skimage.util import view_as_windows +from tempfile import NamedTemporaryFile as NTF + + +class Grappa: + def __init__( + self, kspace: np.ndarray, kernel_size: Tuple[int, int] = (5, 5), coil_axis: int = -1 + ) -> None: + self.kspace = kspace + self.kernel_size = kernel_size + self.coil_axis = coil_axis + self.lamda = 0.01 + + self.kernel_var_dict = self.get_kernel_geometries() + + def get_kernel_geometries(self): + """ + Extract unique kernel geometries based on a slice of kspace data + + Returns + ------- + geometries : dict + A dictionary containing the following keys: + - 'patches': an array of overlapping patches from the k-space data. + - 'patch_indices': an array of unique patch indices. + - 'holes_x': a dictionary of x-coordinates for holes in each patch. + - 'holes_y': a dictionary of y-coordinates for holes in each patch. + + Notes + ----- + This function extracts unique kernel geometries from a slice of k-space data. + The geometries correspond to overlapping patches that contain at least one hole. + A hole is defined as a region of k-space data where the absolute value of the + complex signal is equal to zero. The function returns a dictionary containing + information about the patches and holes, which can be used to compute weights + for each geometry using the GRAPPA algorithm. + + """ + self.kspace = np.moveaxis(self.kspace, self.coil_axis, -1) + + # Quit early if there are no holes + if np.sum((np.abs(self.kspace[..., 0]) == 0).flatten()) == 0: + return np.moveaxis(self.kspace, -1, self.coil_axis) + + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + nc = self.kspace.shape[-1] + + self.kspace = np.pad(self.kspace, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + mask = np.ascontiguousarray(np.abs(self.kspace[..., 0]) > 0) + + with NTF() as fP: + # Get all overlapping patches from the mask + P = np.memmap( + fP, + dtype=mask.dtype, + mode="w+", + shape=(mask.shape[0] - 2 * kx2, mask.shape[1] - 2 * ky2, 1, kx, ky), + ) + P = view_as_windows(mask, (kx, ky)) + Psh = P.shape[:] # save shape for unflattening indices later + P = P.reshape((-1, kx, ky)) + + # Find the unique patches and associate them with indices + P, iidx = np.unique(P, return_inverse=True, axis=0) + + # Filter out geometries that don't have a hole at the center. + # These are all the kernel geometries we actually need to + # compute weights for. + validP = np.argwhere(~P[:, kx2, ky2]).squeeze() + + # ignore empty patches + invalidP = np.argwhere(np.all(P == 0, axis=(1, 2))) + validP = np.setdiff1d(validP, invalidP, assume_unique=True) + + validP = np.atleast_1d(validP) + + # Give P back its coil dimension + P = np.tile(P[..., None], (1, 1, 1, nc)) + + holes_x = {} + holes_y = {} + for ii in validP: + # x, y define where top left corner is, so move to ctr, + # also make sure they are iterable by enforcing atleast_1d + idx = np.unravel_index(np.argwhere(iidx == ii), Psh[:2]) + x, y = idx[0] + kx2, idx[1] + ky2 + x = np.atleast_1d(x.squeeze()) + y = np.atleast_1d(y.squeeze()) + + holes_x[ii] = x + holes_y[ii] = y + + return {"patches": P, "patch_indices": validP, "holes_x": holes_x, "holes_y": holes_y} + + def compute_weights(self, calib: np.ndarray) -> Dict[int, np.ndarray]: + """ + Compute the GRAPPA weights for each slice in the input calibration data. + + Parameters: + ---------- + calib : numpy.ndarray + Calibration data with shape (Nx, Nc, Ny) where Nx, Ny are the size of the image in the x and y dimensions, + respectively, and Nc is the number of coils. + + Returns: + ------- + weights : dict + A dictionary of GRAPPA weights for each patch index. + + Notes: + ----- + The GRAPPA algorithm is used to estimate the missing k-space data in undersampled MRI acquisitions. + The algorithm used to compute the GRAPPA weights involves first extracting patches from the calibration data, + and then solving a linear system to estimate the weights. The resulting weights are stored in a dictionary + where the key is the patch index. The equation to solve for the weights involves taking the product of the + sources and the targets in the patch domain, and then regularizing the matrix using Tikhonov regularization. + The function uses numpy's `memmap` to store temporary files to avoid overwhelming memory usage. + """ + + calib = np.moveaxis(calib, self.coil_axis, -1) + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + nc = calib.shape[-1] + + calib = np.pad(calib, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + # Store windows in temporary files so we don't overwhelm memory + with NTF() as fA: + # Get all overlapping patches of ACS + try: + A = np.memmap( + fA, + dtype=calib.dtype, + mode="w+", + shape=(calib.shape[0] - 2 * kx, calib.shape[1] - 2 * ky, 1, kx, ky, nc), + ) + A[:] = view_as_windows(calib, (kx, ky, nc)).reshape((-1, kx, ky, nc)) + except ValueError: + A = view_as_windows(calib, (kx, ky, nc)).reshape((-1, kx, ky, nc)) + + weights = {} + + for ii in self.kernel_var_dict["patch_indices"]: + # Get the sources by masking all patches of the ACS and + # get targets by taking the center of each patch. Source + # and targets will have the following sizes: + # S : (# samples, N possible patches in ACS) + # T : (# coils, N possible patches in ACS) + # Solve the equation for the weights: using numpy.linalg.solve, + # and Tikhonov regularization for better conditioning: + # SW = T + # S^HSW = S^HT + # W = (S^HS)^-1 S^HT + # -> W = (S^HS + lamda I)^-1 S^HT + + S = A[:, self.kernel_var_dict["patches"][ii, ...]] + T = A[:, kx2, ky2, :] + ShS = S.conj().T @ S + ShT = S.conj().T @ T + lamda0 = self.lamda * np.linalg.norm(ShS) / ShS.shape[0] + weights[ii] = np.linalg.solve(ShS + lamda0 * np.eye(ShS.shape[0]), ShT).T + + return weights + + def apply_weights(self, kspace: np.ndarray, weights: Dict[int, np.ndarray]) -> np.ndarray: + """ + Applies the computed GRAPPA weights to the k-space data. + + Parameters: + ---------- + kspace : numpy.ndarray + The k-space data to apply the weights to. + + weights : dict + A dictionary containing the GRAPPA weights to apply. + + Returns: + ------- + numpy.ndarray: The reconstructed data after applying the weights. + """ + + # fin_shape = kspace.shape[:] + + # Put the coil dimension at the end + kspace = np.moveaxis(kspace, self.coil_axis, -1) + + # Get shape of kernel + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + + # adjustment factor for odd kernel size + adjx = np.mod(kx, 2) + adjy = np.mod(ky, 2) + + # Pad kspace data + kspace = np.pad(kspace, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + with NTF() as frecon: + # Initialize recon array + recon = np.memmap(frecon, dtype=kspace.dtype, mode="w+", shape=kspace.shape) + + for ii in self.kernel_var_dict["patch_indices"]: + for xx, yy in zip( + self.kernel_var_dict["holes_x"][ii], self.kernel_var_dict["holes_y"][ii] + ): + # Collect sources for this hole and apply weights + S = kspace[xx - kx2 : xx + kx2 + adjx, yy - ky2 : yy + ky2 + adjy, :] + S = S[self.kernel_var_dict["patches"][ii, ...]] + recon[xx, yy, :] = (weights[ii] @ S[:, None]).squeeze() + + return np.moveaxis((recon[:] + kspace)[kx2:-kx2, ky2:-ky2, :], -1, self.coil_axis) diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py index 0f833aa..88ba5de 100644 --- a/mri_recon/utils/oasis_adapter.py +++ b/mri_recon/utils/oasis_adapter.py @@ -7,8 +7,9 @@ import numpy as np import torch from torch.utils.data import Dataset +import deepinv as dinv -from mri_recon.distortions import BaseDistortion, DistortedKspaceMultiCoilMRI +from mri_recon.distortions import BaseDistortion class OasisSliceDataset(Dataset): @@ -122,6 +123,76 @@ def _get_volume(self, subject_id: str) -> np.ndarray: return volume +class OasisCenterSliceFolderDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes. + Select a center slice from all subjects in the folder. + + Parameters + ---------- + data_path : Path + Root directory containing OASIS subject folders. + + """ + + def __init__( + self, + data_path: Path, + ) -> None: + try: + import nibabel as nib + except ImportError as exc: + raise ImportError( + "OASIS loading requires nibabel. Install the project dependencies " + "or add nibabel to your environment before using OasisSliceDataset." + ) from exc + + self._nib = nib + self.data_path = Path(data_path) + self.subject_paths = self._discover_subject_paths() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.subject_paths) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + volume = self._get_volume(list(self.subject_paths.values())[0]) + n_slices, _, _ = volume.shape + slice_num = n_slices // 2 + subject_id = list(self.subject_paths.keys())[0] + target_np = np.ascontiguousarray(volume[slice_num], dtype=np.float32) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _discover_subject_paths(self) -> dict[str, Path]: + subject_paths = {} + for subject_dir in sorted(self.data_path.iterdir()): + if not subject_dir.is_dir(): + continue + image_glob = subject_dir / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if matches: + subject_paths[subject_dir.name] = matches[0] + + if not subject_paths: + raise FileNotFoundError( + "Could not find OASIS subject folders under " + f"{self.data_path} matching PROCESSED/MPRAGE/T88_111/*t88_gfc.img." + ) + return subject_paths + + def _get_volume(self, subject_path: str) -> np.ndarray: + image_data = self._nib.load(subject_path).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + return volume + + def image_to_kspace(x: torch.Tensor) -> torch.Tensor: """Convert channel-first complex images to centered k-space. @@ -169,6 +240,8 @@ def kspace_to_image(y: torch.Tensor) -> torch.Tensor: def fastmri_measurement_to_image( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, + rss: bool = False, device: torch.device | str | None = None, ) -> torch.Tensor: """Convert FastMRI measurements to image space using the repo's native physics. @@ -177,6 +250,11 @@ def fastmri_measurement_to_image( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor | None, optional + Coil sensitivity maps with shape ``(B, C, H, W)``, where ``C`` is the number of coils. + rss : bool, optional + If ``True``, return root-sum-of-squares image across coils. Otherwise, + return coil-combined image using the provided coil sensitivity maps. Defaults to ``False``. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -188,16 +266,17 @@ def fastmri_measurement_to_image( if device is None: device = y.device - physics = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), + physics = dinv.physics.MultiCoilMRI( img_size=(1, 2, *y.shape[-2:]), + coil_maps=coil_maps, device=device, ) - return physics.A_adjoint(y) + return physics.A_adjoint(y, rss=rss) def fastmri_measurement_to_oasis_kspace( y: torch.Tensor, + coil_maps: torch.Tensor | None = None, device: torch.device | str | None = None, ) -> torch.Tensor: """Adapt FastMRI measurements to the centered OASIS k-space convention. @@ -206,6 +285,8 @@ def fastmri_measurement_to_oasis_kspace( ---------- y : torch.Tensor FastMRI measurement tensor with shape ``(B, 2, H, W)``. + coil_maps : torch.Tensor | None, optional + Coil sensitivity maps with shape ``(B, C, H, W)``, where ``C`` is the number of coils. device : torch.device | str, optional Device on which to instantiate the temporary native physics operator. @@ -215,10 +296,59 @@ def fastmri_measurement_to_oasis_kspace( Centered OASIS-convention k-space tensor with shape ``(B, 2, H, W)``. """ - return image_to_kspace(fastmri_measurement_to_image(y, device=device)) + return image_to_kspace(fastmri_measurement_to_image(y, coil_maps=coil_maps, device=device)) + +def image_to_fastmri_measurement( + x: torch.Tensor, + device: torch.device | str | None = None, +) -> torch.Tensor: + """Perform FFT from image space to k-space (fast-MRI convention). + + Parameters + ---------- + x : torch.Tensor + image tensor with shape ``(B, 2, H, W)``. + device : torch.device | str, optional + Device on which to instantiate the temporary native physics operator. + + Returns + ------- + torch.Tensor + FastMRI-convention k-space tensor with shape ``(B, 2, H, W)``. + """ -class OasisCenteredFFTPhysics: + if device is None: + device = x.device + physics = dinv.physics.MultiCoilMRI( + img_size=(1, 2, *x.shape[-2:]), + coil_maps=None, + device=device, + ) + return physics.A(x) + + +def oasis_kspace_to_fastmri_measurement( + y: torch.Tensor, +) -> torch.Tensor: + """Adapt OASIS-convention k-space to FastMRI measurement convention. + + Parameters + ---------- + y : torch.Tensor + Centered OASIS-convention k-space tensor with shape ``(B, 2, H, W)``. + + + Returns + ------- + torch.Tensor + FastMRI-convention k-space tensor with shape ``(B, 2, H, W)``. + """ + + return image_to_fastmri_measurement(kspace_to_image(y)) + + +class OasisCenteredFFTPhysics(dinv.utils.mixins.MRIMixin, dinv.physics.LinearPhysics): """Physics adapter matching the OASIS U-Net FFT convention. Parameters @@ -227,7 +357,8 @@ class OasisCenteredFFTPhysics: K-space distortion applied after the centered FFT. """ - def __init__(self, distortion: BaseDistortion) -> None: + def __init__(self, distortion: BaseDistortion, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.distortion = distortion def A(self, x: torch.Tensor) -> torch.Tensor: @@ -261,3 +392,14 @@ def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: """ return kspace_to_image(self.distortion.A_adjoint(y)) + + # def A_dagger(self, y: torch.Tensor, **kwargs) -> torch.Tensor: + # r""" + # Computes least squares solution to the MRI inverse problem, as proposed in `SENSE: Sensitivity encoding for fast MRI `_. + + # By default uses conjugate gradient solver. Overwrite default solver arguments by passing `kwargs`. See :func:`deepinv.optim.linear.least_squares` for details. + + # :param dict kwargs: kwargs to pass to base :meth:`deepinv.physics.LinearPhysics.A_dagger`. + # :returns: (:class:`torch.Tensor`) image with shape `(B,2,...,H,W)` + # """ + # return super().A_dagger(y, **kwargs) diff --git a/mri_recon/utils/plot.py b/mri_recon/utils/plot.py index 068c222..35921b7 100644 --- a/mri_recon/utils/plot.py +++ b/mri_recon/utils/plot.py @@ -6,16 +6,26 @@ import matplotlib.pyplot as plt import torch +import numpy as np +import deepinv as dinv def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: """Convert k-space tensor to a log-magnitude image for visualization.""" - if kspace.ndim == 4: + if kspace.ndim == 5: + # show only middle coil for visualization + # (1, 2, C, H, W) -> (2, H, W) + kspace = kspace[0, :, kspace.shape[2] // 2] + elif kspace.ndim == 4: + # (1, 2, H, W) -> (2, H, W) kspace = kspace[0] - if kspace.ndim != 3 or kspace.shape[0] != 2: + elif kspace.ndim == 3: + pass + # (2, H, W) -> (2, H, W) + else: raise ValueError( - f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" + f"Expected k-space with shape (2, H, W) or (1, 2, H, W) or (1, 2, C, H, W),got {tuple(kspace.shape)}" ) kspace = kspace.detach().cpu() @@ -43,11 +53,18 @@ def save_kspace_plot( ) -> None: """Save side-by-side log-magnitude visualizations of clean and distorted k-space.""" + print("transforming k-space to log-magnitude images for visualization...") + print(f"\tclean k-space shape: {clean_kspace.shape}") + print(f"\tdistorted k-space shape: {distorted_kspace.shape}") + images = [ ("Original k-space", _kspace_to_log_magnitude(clean_kspace)), ("Distorted k-space", _kspace_to_log_magnitude(distorted_kspace)), ] + print(f"clean k-space magnitude shape: {images[0][1].shape}") + print(f"distorted k-space magnitude shape: {images[1][1].shape}") + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) fig.suptitle(f"Distortion: {distortion_label}") for ax, (title, image) in zip(axes, images, strict=True): @@ -56,3 +73,20 @@ def save_kspace_plot( ax.axis("off") fig.savefig(save_fn, dpi=200, bbox_inches="tight") plt.close(fig) + + +def convert_image_for_save(im: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor image complex tensor to a real-valued NumPy array + by calculating the magnitude. + (B, 2, H, W) or (B, H, W) with complex type -> (B, H, W) + + Args: + im (torch.Tensor): The input image tensor. + + Returns: + np.ndarray: The converted image array. + """ + if torch.is_complex(im) or im.shape[1] == 2: + im = dinv.utils.signals.complex_abs(im, dim=1, keepdim=False) + return im.detach().cpu().numpy() diff --git a/mri_recon/utils/prostate_adaptor.py b/mri_recon/utils/prostate_adaptor.py new file mode 100644 index 0000000..c7b1d6a --- /dev/null +++ b/mri_recon/utils/prostate_adaptor.py @@ -0,0 +1,49 @@ +import os +import glob +import h5py +import numpy as np +import torch + + +class FastMRIProstateDataset(torch.utils.data.Dataset): + def __init__( + self, data_path: str, num_samples: None | int = None, slice_index: str = "middle" + ) -> None: + self.data_path = data_path + self.num_samples = num_samples + self.slice_index = slice_index + self.image_data = self.get_image_data() + + if num_samples is not None: + self.image_data = self.image_data[:num_samples] + else: + self.num_samples = len(self.image_data) + + def get_image_data(self) -> np.ndarray: + image_result_list = [] + for sample_idx, filename in enumerate(glob.glob(os.path.join(self.data_path, "*.h5"))): + if (self.num_samples is not None) and (sample_idx >= self.num_samples): + break + try: + with h5py.File(filename, "r") as hf: + image_recon = hf["reconstruction_rss"][:] + + except Exception as e: + print(f"Error processing file {filename}: {e}") + continue + + if self.slice_index == "middle": + image_result_list.append(image_recon[image_recon.shape[0] // 2]) + else: + image_result_list.extend( + [image_recon[i, :, :] for i in range(image_recon.shape[0])] + ) + + return image_result_list + + def __len__(self) -> int: + return len(self.image_data) + + def __getitem__(self, idx: int) -> torch.Tensor: + # add batch dimension and convert to torch.Tensor + return torch.from_numpy(self.image_data[idx]).unsqueeze(0) diff --git a/mri_recon/utils/recon_prostate_T2.py b/mri_recon/utils/recon_prostate_T2.py new file mode 100644 index 0000000..f90ace7 --- /dev/null +++ b/mri_recon/utils/recon_prostate_T2.py @@ -0,0 +1,700 @@ +import os +from tempfile import NamedTemporaryFile as NTF +from typing import Dict, Tuple, Optional, Sequence +import xml.etree.ElementTree as etree + +import h5py +import numpy as np +from numpy.fft import fftshift, ifftshift, ifftn +from tifffile import imwrite +from skimage.util import view_as_windows + + +def center_crop_im(im_3d: np.ndarray, crop_to_size: Tuple[int, int]) -> np.ndarray: + """ + Center crop an image to a given size. + + Parameters: + ----------- + im_3d : numpy.ndarray + Input image of shape (slices, x, y). + crop_to_size : tuple + Tuple containing the target size for x and y dimensions. + + Returns: + -------- + numpy.ndarray + Center cropped image of size {slices, x_cropped, y_cropped}. + """ + x_crop = im_3d.shape[-1] / 2 - crop_to_size[0] / 2 + y_crop = im_3d.shape[-2] / 2 - crop_to_size[1] / 2 + + return im_3d[ + :, int(y_crop) : int(crop_to_size[1] + y_crop), int(x_crop) : int(crop_to_size[0] + x_crop) + ] + + +def ifftnd(kspace: np.ndarray, axes: Optional[Sequence[int]] = [-1]) -> np.ndarray: + """ + Compute the n-dimensional inverse Fourier transform of the k-space data along the specified axes. + + Parameters: + ----------- + kspace: np.ndarray + The input k-space data. + axes: list or tuple, optional + The list of axes along which to compute the inverse Fourier transform. Default is [-1]. + + Returns: + -------- + img: ndarray + The output image after inverse Fourier transform. + """ + + if axes is None: + axes = range(kspace.ndim) + img = fftshift(ifftn(ifftshift(kspace, axes=axes), axes=axes), axes=axes) + img *= np.sqrt(np.prod(np.take(img.shape, axes))) + + return img + + +def create_coil_combined_im(multicoil_multislice_kspace: np.ndarray) -> np.ndarray: + """ + Create a coil combined image from a multicoil-multislice k-space array. + + Parameters: + ----------- + multicoil_multislice_kspace : array-like + Input k-space data with shape (slices, coils, readout, phase encode). + + Returns: + -------- + image_mat : array-like + Coil combined image data with shape (slices, x, y). + """ + + k = multicoil_multislice_kspace + image_mat = np.zeros((k.shape[0], k.shape[2], k.shape[3])) + for i in range(image_mat.shape[0]): + data_sl = k[i, :, :, :] + image = ifftnd(data_sl, [1, 2]) + image_rss = rss(image, axis=0) + image_mat[i, :, :] = np.flipud(image_rss) + if i == 15: + print("data_sl.shape:", data_sl.shape) + print("image.shape:", image.shape) + print("image flipped: ", np.flipud(image).shape) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/image_slice_15_kspace.tiff", + np.abs(data_sl[0, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/image_slice_15_ifft.tiff", + np.abs(image), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/image_slice_15_rss.tiff", + image_rss, + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/image_slice_15.tiff", + image_mat[i, :, :], + ) + return image_mat + + +def rss(sig: np.ndarray, axis: int = -1) -> np.ndarray: + """ + Compute the Root Sum-of-Squares (RSS) value of a complex signal along a specified axis. + + Parameters + ---------- + sig : np.ndarray + The complex signal to compute the RMS value of. + axis : int, optional + The axis along which to compute the RMS value. Default is -1. + + Returns + ------- + rss : np.ndarray + The RSS value of the complex signal along the specified axis. + """ + return np.sqrt(np.sum(abs(sig) ** 2, axis)) + + +def et_query( + root: etree.Element, qlist: Sequence[str], namespace: str = "http://www.ismrm.org/ISMRMRD" +) -> str: + """ + ElementTree query function. + + This function queries an XML document using ElementTree. + + Parameters: + ----------- + root : Element + Root of the XML document to search through. + qlist : Sequence of str + A sequence of strings for nested searches, e.g., ["Encoding", "matrixSize"]. + namespace : str, optional + XML namespace to prepend query. + + Returns: + -------- + str + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + + +def get_padding(hdr: str) -> float: + """ + Extract the padding value from an XML header string. + + Parameters: + ----------- + hdr : str + The XML header string. + + Returns: + -------- + float + The padding value calculated as (x - max_enc)/2, where x is the readout dimension and + max_enc is the maximum phase-encoding dimension. + """ + et_root = etree.fromstring(hdr) + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_x = int(et_query(et_root, enc + ["x"])) + padding = (enc_x - enc_limits_max) / 2 + + return padding + + +def zero_pad_kspace_hdr(hdr: str, unpadded_kspace: np.ndarray) -> np.ndarray: + """ + Perform zero-padding on k-space data to have the same number of + points in the x- and y-directions. + + Parameters + ---------- + hdr : str + The XML header string. + unpadded_kspace : array-like of shape (sl, ro , coils, pe) + The k-space data to be padded. + + Returns + ------- + padded_kspace : ndarray of shape (sl, ro_padded, coils, pe_padded) + The zero-padded k-space data, where ro_padded and pe_padded are + the dimensions of the readout and phase-encoding directions after + padding. + + Notes + ----- + The padding value is calculated using the `get_padding` function, which + extracts the padding value from the XML header string. If the difference + between the readout dimension and the maximum phase-encoding dimension + is not divisible by 2, the padding is applied asymmetrically, with one + side having an additional zero-padding. + + """ + padding = get_padding(hdr) + if padding % 2 != 0: + padding_left = int(np.floor(padding)) + padding_right = int(np.ceil(padding)) + else: + padding_left = int(padding) + padding_right = int(padding) + padded_kspace = np.pad(unpadded_kspace, ((0, 0), (0, 0), (0, 0), (padding_left, padding_right))) + + return padded_kspace + + +def zero_pad_kspace_slice_hdr(hdr: str, unpadded_kspace: np.ndarray) -> np.ndarray: + """ + Perform zero-padding on k-space data to have the same number of + points in the x- and y-directions. + + Parameters + ---------- + hdr : str + The XML header string. + unpadded_kspace : array-like of shape (ro , coils, pe) + The k-space data to be padded. + + Returns + ------- + padded_kspace : ndarray of shape (ro_padded, coils, pe_padded) + The zero-padded k-space data, where ro_padded and pe_padded are + the dimensions of the readout and phase-encoding directions after + padding. + + Notes + ----- + The padding value is calculated using the `get_padding` function, which + extracts the padding value from the XML header string. If the difference + between the readout dimension and the maximum phase-encoding dimension + is not divisible by 2, the padding is applied asymmetrically, with one + side having an additional zero-padding. + + """ + padding = get_padding(hdr) + if padding % 2 != 0: + padding_left = int(np.floor(padding)) + padding_right = int(np.ceil(padding)) + else: + padding_left = int(padding) + padding_right = int(padding) + padded_kspace = np.pad(unpadded_kspace, ((0, 0), (0, 0), (padding_left, padding_right))) + + return padded_kspace + + +class Grappa: + def __init__( + self, kspace: np.ndarray, kernel_size: Tuple[int, int] = (5, 5), coil_axis: int = -1 + ) -> None: + self.kspace = kspace + self.kernel_size = kernel_size + self.coil_axis = coil_axis + self.lamda = 0.01 + + self.kernel_var_dict = self.get_kernel_geometries() + + def get_kernel_geometries(self): + """ + Extract unique kernel geometries based on a slice of kspace data + + Returns + ------- + geometries : dict + A dictionary containing the following keys: + - 'patches': an array of overlapping patches from the k-space data. + - 'patch_indices': an array of unique patch indices. + - 'holes_x': a dictionary of x-coordinates for holes in each patch. + - 'holes_y': a dictionary of y-coordinates for holes in each patch. + + Notes + ----- + This function extracts unique kernel geometries from a slice of k-space data. + The geometries correspond to overlapping patches that contain at least one hole. + A hole is defined as a region of k-space data where the absolute value of the + complex signal is equal to zero. The function returns a dictionary containing + information about the patches and holes, which can be used to compute weights + for each geometry using the GRAPPA algorithm. + + """ + self.kspace = np.moveaxis(self.kspace, self.coil_axis, -1) + + # Quit early if there are no holes + if np.sum((np.abs(self.kspace[..., 0]) == 0).flatten()) == 0: + return np.moveaxis(self.kspace, -1, self.coil_axis) + + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + nc = self.kspace.shape[-1] + + self.kspace = np.pad(self.kspace, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + mask = np.ascontiguousarray(np.abs(self.kspace[..., 0]) > 0) + + with NTF() as fP: + # Get all overlapping patches from the mask + P = np.memmap( + fP, + dtype=mask.dtype, + mode="w+", + shape=(mask.shape[0] - 2 * kx2, mask.shape[1] - 2 * ky2, 1, kx, ky), + ) + P = view_as_windows(mask, (kx, ky)) + Psh = P.shape[:] # save shape for unflattening indices later + P = P.reshape((-1, kx, ky)) + + # Find the unique patches and associate them with indices + P, iidx = np.unique(P, return_inverse=True, axis=0) + + # Filter out geometries that don't have a hole at the center. + # These are all the kernel geometries we actually need to + # compute weights for. + validP = np.argwhere(~P[:, kx2, ky2]).squeeze() + + # ignore empty patches + invalidP = np.argwhere(np.all(P == 0, axis=(1, 2))) + validP = np.setdiff1d(validP, invalidP, assume_unique=True) + + validP = np.atleast_1d(validP) + + # Give P back its coil dimension + P = np.tile(P[..., None], (1, 1, 1, nc)) + + holes_x = {} + holes_y = {} + for ii in validP: + # x, y define where top left corner is, so move to ctr, + # also make sure they are iterable by enforcing atleast_1d + idx = np.unravel_index(np.argwhere(iidx == ii), Psh[:2]) + x, y = idx[0] + kx2, idx[1] + ky2 + x = np.atleast_1d(x.squeeze()) + y = np.atleast_1d(y.squeeze()) + + holes_x[ii] = x + holes_y[ii] = y + + return {"patches": P, "patch_indices": validP, "holes_x": holes_x, "holes_y": holes_y} + + def compute_weights(self, calib: np.ndarray) -> Dict[int, np.ndarray]: + """ + Compute the GRAPPA weights for each slice in the input calibration data. + + Parameters: + ---------- + calib : numpy.ndarray + Calibration data with shape (Nx, Nc, Ny) where Nx, Ny are the size of the image in the x and y dimensions, + respectively, and Nc is the number of coils. + + Returns: + ------- + weights : dict + A dictionary of GRAPPA weights for each patch index. + + Notes: + ----- + The GRAPPA algorithm is used to estimate the missing k-space data in undersampled MRI acquisitions. + The algorithm used to compute the GRAPPA weights involves first extracting patches from the calibration data, + and then solving a linear system to estimate the weights. The resulting weights are stored in a dictionary + where the key is the patch index. The equation to solve for the weights involves taking the product of the + sources and the targets in the patch domain, and then regularizing the matrix using Tikhonov regularization. + The function uses numpy's `memmap` to store temporary files to avoid overwhelming memory usage. + """ + + calib = np.moveaxis(calib, self.coil_axis, -1) + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + nc = calib.shape[-1] + + calib = np.pad(calib, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + # Store windows in temporary files so we don't overwhelm memory + with NTF() as fA: + # Get all overlapping patches of ACS + try: + A = np.memmap( + fA, + dtype=calib.dtype, + mode="w+", + shape=(calib.shape[0] - 2 * kx, calib.shape[1] - 2 * ky, 1, kx, ky, nc), + ) + A[:] = view_as_windows(calib, (kx, ky, nc)).reshape((-1, kx, ky, nc)) + except ValueError: + A = view_as_windows(calib, (kx, ky, nc)).reshape((-1, kx, ky, nc)) + + weights = {} + + for ii in self.kernel_var_dict["patch_indices"]: + # Get the sources by masking all patches of the ACS and + # get targets by taking the center of each patch. Source + # and targets will have the following sizes: + # S : (# samples, N possible patches in ACS) + # T : (# coils, N possible patches in ACS) + # Solve the equation for the weights: using numpy.linalg.solve, + # and Tikhonov regularization for better conditioning: + # SW = T + # S^HSW = S^HT + # W = (S^HS)^-1 S^HT + # -> W = (S^HS + lamda I)^-1 S^HT + + S = A[:, self.kernel_var_dict["patches"][ii, ...]] + T = A[:, kx2, ky2, :] + ShS = S.conj().T @ S + ShT = S.conj().T @ T + lamda0 = self.lamda * np.linalg.norm(ShS) / ShS.shape[0] + weights[ii] = np.linalg.solve(ShS + lamda0 * np.eye(ShS.shape[0]), ShT).T + + return weights + + def apply_weights(self, kspace: np.ndarray, weights: Dict[int, np.ndarray]) -> np.ndarray: + """ + Applies the computed GRAPPA weights to the k-space data. + + Parameters: + ---------- + kspace : numpy.ndarray + The k-space data to apply the weights to. + + weights : dict + A dictionary containing the GRAPPA weights to apply. + + Returns: + ------- + numpy.ndarray: The reconstructed data after applying the weights. + """ + + # fin_shape = kspace.shape[:] + + # Put the coil dimension at the end + kspace = np.moveaxis(kspace, self.coil_axis, -1) + + # Get shape of kernel + kx, ky = self.kernel_size[:] + kx2, ky2 = int(kx / 2), int(ky / 2) + + # adjustment factor for odd kernel size + adjx = np.mod(kx, 2) + adjy = np.mod(ky, 2) + + # Pad kspace data + kspace = np.pad(kspace, ((kx2, kx2), (ky2, ky2), (0, 0)), mode="constant") + + with NTF() as frecon: + # Initialize recon array + recon = np.memmap(frecon, dtype=kspace.dtype, mode="w+", shape=kspace.shape) + map_of_holes = np.zeros(shape=kspace.shape[:2], dtype=bool) + for patch_index, ii in enumerate(self.kernel_var_dict["patch_indices"]): + imwrite( + f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/patch_{ii}.tiff", + self.kernel_var_dict["patches"][ii, ...], + ) + map_of_holes = np.zeros(shape=kspace.shape, dtype=bool) + for hole_idx, (xx, yy) in enumerate( + zip(self.kernel_var_dict["holes_x"][ii], self.kernel_var_dict["holes_y"][ii]) + ): + # Collect sources for this hole and apply weights + + map_of_holes[xx - kx2 : xx + kx2 + adjx, yy - ky2 : yy + ky2 + adjy, :] = ( + hole_idx + ) + S = kspace[xx - kx2 : xx + kx2 + adjx, yy - ky2 : yy + ky2 + adjy, :] + # if patch_index < 10: + # print(f"Kernel-patch {ii} from k-space: {S.shape}") + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/kernel_kspace_{ii}_hole{xx}_hole{yy}.tiff", S) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/abs_kernel_kspace_{ii}_hole{xx}_hole{yy}.tiff", np.abs(S)) + S = S[self.kernel_var_dict["patches"][ii, ...]] + # if patch_index >10: + # print(f"Sources for hole x/y in kspace for patch {ii}: {S.shape}") + # print(f"Weights for hole x/y in kspace for patch {ii}: {weights[ii].shape}") + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/patch_mask_{ii}.tiff", self.kernel_var_dict['patches'][ii, ...]) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/kernel_patch_kspace_{ii}_hole{xx}_hole{yy}.tiff", S) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/abs_kernel_patch_kspace_{ii}_hole{xx}_hole{yy}.tiff", np.abs(S)) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/weights_{ii}_hole{xx}_hole{yy}.tiff", weights[ii]) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/abs_weights_{ii}_hole{xx}_hole{yy}.tiff", np.abs(weights[ii])) + + recon[xx, yy, :] = (weights[ii] @ S[:, None]).squeeze() + # if patch_index > 10: + # print(f"Reconstructed hole x/y in kspace{ii}: {recon[xx, yy, :].shape}") + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/recon_{ii}_hole{xx}_hole{yy}.tiff", np.moveaxis(recon[xx, yy, :], -1, 0)) + # imwrite(f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/abs_recon_{ii}_hole{xx}_hole{yy}.tiff", np.moveaxis(np.abs(recon[xx, yy, :]), -1,0)) + + imwrite( + f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/map_of_holes_patch_{ii}.tiff", + map_of_holes, + ) + print(f"recon shape before adding to kspace: {recon.shape}") + print(f"(padded) kspace shape before adding recon: {kspace.shape}") + return np.moveaxis((recon[:] + kspace)[kx2:-kx2, ky2:-ky2, :], -1, self.coil_axis) + + +def _kspace_to_log_magnitude(kspace: np.ndarray) -> np.ndarray: + """Convert k-space tensor to a log-magnitude image for visualization.""" + + magnitude = np.log1p(np.abs(kspace)) + + lower = np.quantile(magnitude, 0.05) + upper = np.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = np.clip(magnitude, lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return np.sqrt(magnitude) + + +if __name__ == "__main__": + filename = "/home/melanie.dohmen/mri_recon/data/fastmri/fastMRI_prostate_T2_IDS_001_020/file_prostate_AXT2_001.h5" + + os.makedirs("/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/", exist_ok=True) + os.makedirs( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/grappa/", exist_ok=True + ) + with h5py.File(filename, "r") as hf: + kspace_data = hf["kspace"][:] + calibration_data = hf["calibration_data"][:] + hdr = hf["ismrmrd_header"][()] + im_recon = hf["reconstruction_rss"][:] + atts = dict() + atts["max"] = hf.attrs["max"] + atts["norm"] = hf.attrs["norm"] + atts["patient_id"] = hf.attrs["patient_id"] + atts["acquisition"] = hf.attrs["acquisition"] + + # (A, S, C, RO, PE) + num_avg, num_slices, num_coils, num_ro, num_pe = kspace_data.shape + + # Calib_data shape: num_slices, num_coils, num_pe_cal + grappa_weight_dict = {} + grappa_weight_dict_2 = {} + + # (A, S, C, RO, PE) -> take first average and slice to get (C, RO, PE) for GRAPPA weight calculation + kspace_slice_regridded = kspace_data[0, 0, ...] + + print("kspace_slice_regridded shape: (A, S, C, RO, PE)") + print(kspace_slice_regridded.shape) + + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_av0_slice0.tiff", + _kspace_to_log_magnitude(kspace_data[0, 0, :, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_av1_slice0.tiff", + _kspace_to_log_magnitude(kspace_data[1, 0, :, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_av2_slice0.tiff", + _kspace_to_log_magnitude(kspace_data[2, 0, :, :, :]), + ) + + print("kspace_slice_regridded transposed shape for GRAPPA: (PE, C, RO)") + print(np.transpose(kspace_slice_regridded, (2, 0, 1)).shape) + + grappa_obj = Grappa( + np.transpose(kspace_slice_regridded, (2, 0, 1)), kernel_size=(5, 5), coil_axis=1 + ) + + kspace_slice_regridded_2 = kspace_data[1, 0, ...] + grappa_obj_2 = Grappa( + np.transpose(kspace_slice_regridded_2, (2, 0, 1)), kernel_size=(5, 5), coil_axis=1 + ) + + # calculate GRAPPA weights + for slice_num in range(num_slices): + # (S, C, PE, cal) -> (C, PE, cal) + calibration_regridded = calibration_data[slice_num, ...] + # (C, PE, cal) -> (cal, C, PE) for GRAPPA weight calculation# + if slice_num == 0: + print(f"calibration_data shape (S, C, PE, cal): {calibration_data.shape}") + print(f"calibration_regridded shape (C, PE, cal): {calibration_regridded.shape}") + print( + f"calibration_regridded transposed shape for GRAPPA: (cal, C, PE)?: {np.transpose(calibration_regridded, (2, 0, 1)).shape}" + ) + grappa_weight_dict[slice_num] = grappa_obj.compute_weights( + np.transpose(calibration_regridded, (2, 0, 1)) + ) + grappa_weight_dict_2[slice_num] = grappa_obj_2.compute_weights( + np.transpose(calibration_regridded, (2, 0, 1)) + ) + + # apply GRAPPA weights + kspace_post_grappa_all = np.zeros(shape=kspace_data.shape, dtype=complex) + + for average, grappa_obj, grappa_weight_dict in zip( + [0, 1, 2], + [grappa_obj, grappa_obj_2, grappa_obj], + [grappa_weight_dict, grappa_weight_dict_2, grappa_weight_dict], + ): + for slice_num in range(num_slices): + # (A, S, C, RO, PE) -> (C, RO, PE) for GRAPPA application + kspace_slice_regridded = kspace_data[average, slice_num, ...] + + # apply weights to transposed k-space slice (PE, C, RO) + kspace_post_grappa = grappa_obj.apply_weights( + np.transpose(kspace_slice_regridded, (2, 0, 1)), grappa_weight_dict[slice_num] + ) + # and move axes back to (C, RO, PE) after GRAPPA application + kspace_post_grappa_all[average, slice_num, ...] = np.moveaxis( + np.moveaxis(kspace_post_grappa, 0, 1), 1, 2 + ) + if average == 0 and slice_num == 0: + print( + f"k-space transposed shape: {np.transpose(kspace_slice_regridded, (2, 0, 1)).shape}" + ) + print(f"k-space post GRAPPA shape: {kspace_post_grappa.shape}") + print( + f"k-space post GRAPPA moved axes shape: {np.moveaxis(np.moveaxis(kspace_post_grappa, 0, 1), 1, 2).shape}" + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_pre_grappa_av0_slice0.tiff", + _kspace_to_log_magnitude(kspace_slice_regridded), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_post_grappa_av0_slice0.tiff", + _kspace_to_log_magnitude(kspace_post_grappa_all[0, 0, :, :, :]), + ) + + # recon image for each average + im = np.zeros((num_avg, num_slices, num_ro, num_ro)) + for average in range(num_avg): + kspace_grappa = kspace_post_grappa_all[average, ...] + kspace_grappa_padded = zero_pad_kspace_hdr(hdr, kspace_grappa) + im[average] = create_coil_combined_im(kspace_grappa_padded) + imwrite( + f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/im_average_{average}.tiff", + im[average], + ) + + im_3d = np.mean(im, axis=0) + # center crop image to 320 x 320 + img_dict = {} + img_dict["reconstruction_rss"] = center_crop_im(im_3d, [320, 320]) + + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/reconstruction_rss.tiff", + img_dict["reconstruction_rss"], + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/given_reconstruction_rss.tiff", + im_recon, + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/calibratin_data.tiff", + calibration_data[15, 10, :, :], + ) + print("num_avg, num_slices, num_coils, num_ro, num_pe") + print("kspace_data.shape:", kspace_data.shape) + print("kspace_grappa.shape:", kspace_grappa.shape) + print("kspace_grappa_padded.shape:", kspace_grappa_padded.shape) + print("kspace_post_grappa_all.shape:", kspace_post_grappa_all.shape) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_slice_15_coil10.tiff", + _kspace_to_log_magnitude(kspace_data[0, 15, 10, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_grappa_slice_15_coil10.tiff", + _kspace_to_log_magnitude(kspace_grappa[15, 10, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_grappa_padded_slice_15_coil10.tiff", + _kspace_to_log_magnitude(kspace_grappa_padded[15, 10, :, :]), + ) + imwrite( + "/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/kspace_post_grappa_slice_15_coil10.tiff", + _kspace_to_log_magnitude(kspace_post_grappa_all[0, 15, 10, :, :]), + ) + + for average in range(num_avg): + print("average =", average, ": coil_combined_im(kspace_grappe_padded): ", im[average].shape) + imwrite( + f"/home/melanie.dohmen/mri_recon/reports/test_prostate_T2_recon/im_slice_15_average_{average}.tiff", + im[average][15, :, :], + ) + + print("im_3d.shape:", im_3d.shape) + print("num_slices, num_coils, num_pe_cal") + print("calibration_data.shape:", calibration_data.shape) + print("done") diff --git a/pyproject.toml b/pyproject.toml index 611d63c..c619e37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "deepinv>=0.3.8", "fastmri>=0.3.0", "h5py>=3.16.0", + "mat73>= 0.65", "matplotlib>=3.9.0", "nibabel>=5.3.2", "numpy>=2.4.3", diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index 3c95455..966ad49 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -17,7 +17,7 @@ OASIS_UNET_ALGORITHMS, choose_reconstructor, uses_oasis_centered_path, - validate_algorithm_dataset_compatibility, + compatible_dataset_with_reconstructor, ) from mri_recon.distortions import DistortedKspaceMultiCoilMRI @@ -245,21 +245,22 @@ def fake_resolve(_cls, acceleration, manifest_path=None): def test_validate_algorithm_dataset_compatibility_accepts_supported_explicit_unets(): - validate_algorithm_dataset_compatibility("fastmri", FASTMRI_UNET_ALGORITHM) - validate_algorithm_dataset_compatibility("fastmri", "unet-oasis-acceleration8") - validate_algorithm_dataset_compatibility("oasis", "unet-oasis-acceleration4") + assert compatible_dataset_with_reconstructor("fastmri_knee", FASTMRI_UNET_ALGORITHM) + assert compatible_dataset_with_reconstructor("fastmri_brain", "unet-oasis-acceleration8") + assert compatible_dataset_with_reconstructor("oasis", "unet-oasis-acceleration4") def test_validate_algorithm_dataset_compatibility_rejects_unsupported_oasis_fastmri_combo(): - with pytest.raises(ValueError, match="unet-fastmri"): - validate_algorithm_dataset_compatibility("oasis", FASTMRI_UNET_ALGORITHM) + assert not compatible_dataset_with_reconstructor("oasis", FASTMRI_UNET_ALGORITHM) + assert not compatible_dataset_with_reconstructor("fastmri_brain", FASTMRI_UNET_ALGORITHM) + assert not compatible_dataset_with_reconstructor("fastmri_prostate", FASTMRI_UNET_ALGORITHM) + assert not compatible_dataset_with_reconstructor("fastmri_knee", "unet-oasis-acceleration10") -def test_uses_oasis_centered_path_tracks_dataset_and_explicit_algorithm(): - assert uses_oasis_centered_path("oasis", FASTMRI_UNET_ALGORITHM) is True - assert uses_oasis_centered_path("fastmri", "unet-oasis-acceleration8") is True - assert uses_oasis_centered_path("fastmri", FASTMRI_UNET_ALGORITHM) is False - assert uses_oasis_centered_path("fastmri", "tv-pgd") is False +def test_uses_oasis_centered_algorithm(): + assert uses_oasis_centered_path(FASTMRI_UNET_ALGORITHM) is False + assert uses_oasis_centered_path("unet-oasis-acceleration4") is True + assert uses_oasis_centered_path("tv-pgd") is False def test_choose_reconstructor_selects_oasis_unet_for_fastmri_when_requested(monkeypatch): @@ -280,7 +281,7 @@ def fake_oasis(*, acceleration, device): reconstructor = choose_reconstructor( "unet-oasis-acceleration8", - dataset="fastmri", + dataset="oasis", device="cpu", ) @@ -302,7 +303,7 @@ def fake_fastmri(*, device): reconstructor = choose_reconstructor( FASTMRI_UNET_ALGORITHM, - dataset="fastmri", + dataset="fastmri_knee", device="cpu", ) @@ -327,7 +328,7 @@ def fake_oasis(*, acceleration, device): for algorithm_name in OASIS_UNET_ALGORITHMS: reconstructor = choose_reconstructor( algorithm_name, - dataset="fastmri", + dataset="fastmri_brain", device="cpu", ) assert isinstance(reconstructor, Marker) diff --git a/tests/test_utils_io.py b/tests/test_utils_io.py index 304446d..9192da1 100644 --- a/tests/test_utils_io.py +++ b/tests/test_utils_io.py @@ -2,12 +2,14 @@ from io import BytesIO import torch +import deepinv as dinv -from mri_recon.distortions import BaseDistortion, DistortedKspaceMultiCoilMRI from mri_recon.utils.oasis_adapter import ( fastmri_measurement_to_image, fastmri_measurement_to_oasis_kspace, kspace_to_image, + oasis_kspace_to_fastmri_measurement, + image_to_fastmri_measurement, ) from mri_recon.utils.io import download_file_with_sha256, download_google_drive_file_with_sha256 @@ -78,8 +80,7 @@ def fake_urlopen(url, timeout=30): def test_fastmri_measurement_helpers_match_centered_oasis_path(): x = torch.randn(1, 2, 16, 12) - physics = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), + physics = dinv.physics.MultiCoilMRI( img_size=(1, 2, *x.shape[-2:]), device="cpu", ) @@ -99,3 +100,19 @@ def test_fastmri_measurement_helpers_match_centered_oasis_path(): atol=1e-6, rtol=1e-6, ) + + y_fastmri_from_oasis = oasis_kspace_to_fastmri_measurement(y_oasis) + assert torch.allclose( + y_fastmri_from_oasis, + y_fastmri, + atol=1e-6, + rtol=1e-6, + ) + + y_fastmri_from_image = image_to_fastmri_measurement(x_native) + assert torch.allclose( + y_fastmri_from_image, + y_fastmri, + atol=1e-6, + rtol=1e-6, + ) diff --git a/uv.lock b/uv.lock index d25083a..30e3765 100644 --- a/uv.lock +++ b/uv.lock @@ -121,6 +121,73 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "artifactlab" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "certifi" }, + { name = "deepinv" }, + { name = "fastmri" }, + { name = "h5py" }, + { name = "mat73" }, + { name = "matplotlib" }, + { name = "nibabel" }, + { name = "numpy" }, + { name = "ptwt" }, + { name = "pydicom" }, + { name = "pytest" }, + { name = "python-certifi-win32" }, + { name = "sigpy" }, + { name = "torch", version = "2.11.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torchmetrics" }, + { name = "torchvision", version = "0.26.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torchvision", version = "0.26.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torchvision", version = "0.26.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "tqdm" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pre-commit" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "certifi", specifier = ">=2026.2.25" }, + { name = "deepinv", specifier = ">=0.3.8" }, + { name = "fastmri", specifier = ">=0.3.0" }, + { name = "h5py", specifier = ">=3.16.0" }, + { name = "mat73", specifier = ">=0.65" }, + { name = "matplotlib", specifier = ">=3.9.0" }, + { name = "nibabel", specifier = ">=5.3.2" }, + { name = "numpy", specifier = ">=2.4.3" }, + { name = "ptwt", specifier = ">=1.0.1" }, + { name = "pydicom", specifier = ">=3.0.1" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "python-certifi-win32", specifier = ">=1.6.1" }, + { name = "sigpy", specifier = ">=0.1.27" }, + { name = "torch", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.11.0" }, + { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cpu" }, + { name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "sys_platform == 'win32'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torchmetrics", specifier = ">=1.9.0" }, + { name = "torchvision", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.26.0" }, + { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cpu" }, + { name = "torchvision", marker = "sys_platform == 'linux'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torchvision", marker = "sys_platform == 'win32'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "tqdm", specifier = ">=4.67.3" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pre-commit", specifier = ">=4.2.0" }, + { name = "ruff", specifier = ">=0.11.0" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -874,6 +941,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "mat73" +version = "0.65" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h5py" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/61/0e6375513085b13ad23ab150fc83d4d8faa91cba56eb2f30b646259f8214/mat73-0.65.tar.gz", hash = "sha256:ad38a06af3d483632bd939ee572b3724ea8c03d37916765d7278f9de95541ade", size = 19355, upload-time = "2024-07-24T09:12:33.007Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/24/e867b1b89b2a2102a5a3bb64ddcd49c5cb815244b2dadff7740c6a422e4f/mat73-0.65-py3-none-any.whl", hash = "sha256:aadfcd00f328eb8f75dd1d4a060a956dc0abefcf5af20f5bc69a5aae64d62cbf", size = 19665, upload-time = "2024-07-24T09:12:31.976Z" }, +] + [[package]] name = "matplotlib" version = "3.10.8" @@ -937,71 +1017,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] -[[package]] -name = "mri-recon" -version = "0.1.0" -source = { virtual = "." } -dependencies = [ - { name = "certifi" }, - { name = "deepinv" }, - { name = "fastmri" }, - { name = "h5py" }, - { name = "matplotlib" }, - { name = "nibabel" }, - { name = "numpy" }, - { name = "ptwt" }, - { name = "pydicom" }, - { name = "pytest" }, - { name = "python-certifi-win32" }, - { name = "sigpy" }, - { name = "torch", version = "2.11.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "torchmetrics" }, - { name = "torchvision", version = "0.26.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torchvision", version = "0.26.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torchvision", version = "0.26.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "tqdm" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pre-commit" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "certifi", specifier = ">=2026.2.25" }, - { name = "deepinv", specifier = ">=0.3.8" }, - { name = "fastmri", specifier = ">=0.3.0" }, - { name = "h5py", specifier = ">=3.16.0" }, - { name = "matplotlib", specifier = ">=3.9.0" }, - { name = "nibabel", specifier = ">=5.3.2" }, - { name = "numpy", specifier = ">=2.4.3" }, - { name = "ptwt", specifier = ">=1.0.1" }, - { name = "pydicom", specifier = ">=3.0.1" }, - { name = "pytest", specifier = ">=9.0.2" }, - { name = "python-certifi-win32", specifier = ">=1.6.1" }, - { name = "sigpy", specifier = ">=0.1.27" }, - { name = "torch", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.11.0" }, - { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cpu" }, - { name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torch", marker = "sys_platform == 'win32'", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torchmetrics", specifier = ">=1.9.0" }, - { name = "torchvision", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.26.0" }, - { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cpu" }, - { name = "torchvision", marker = "sys_platform == 'linux'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torchvision", marker = "sys_platform == 'win32'", specifier = ">=0.26.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "tqdm", specifier = ">=4.67.3" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pre-commit", specifier = ">=4.2.0" }, - { name = "ruff", specifier = ">=0.11.0" }, -] - [[package]] name = "multidict" version = "6.7.1"