FID-3D provides a clean, modular implementation of the Frechet Inception Distance (FID) adapted to volumetric medical images (e.g., CT/MRI). It loads 3D volumes, extracts features using a configurable backbone, and computes the Frechet distance between real and generated samples. Lower scores indicate closer alignment between generated and real data distributions.
FID compares feature statistics (mean and covariance) between two sets of images. Extending this to 3D volumes enables quantitative evaluation of generative models for medical imaging, where per-slice fidelity and volumetric consistency both matter.
src/fid3d/data.py– volume loading utilities and PyTorch dataset for 3D stacksfeatures.py– feature extraction helpers and default Inception backbonemetrics.py– mean/covariance computation and Frechet distanceutils.py– common helpers (device selection, file discovery, seeds)__init__.py– public API exports
scripts/compute_fid3d.py– CLI entry point to compute FID-3D from folders or precomputed featuresexamples/compute_fid_from_dirs.py– minimal Python exampletests/– small metric testsnotebooks/– place demos or exploratory notebooksrequirements.txt– runtime dependencies
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txtpython scripts/compute_fid3d.py \
--real /path/to/real_volumes \
--fake /path/to/generated_volumes \
--device cudaUse precomputed features (NumPy .npy saved arrays of shape [N, D]):
python scripts/compute_fid3d.py \
--real-features real_features.npy \
--fake-features fake_features.npyfrom fid3d.data import VolumeDataset
from fid3d.features import build_inception_v3, extract_features, default_slice_transform
from fid3d.metrics import calculate_fid
from fid3d.utils import list_volume_files
from torch.utils.data import DataLoader
device = "cuda"
model = build_inception_v3(device=device)
transform = default_slice_transform()
real_ds = VolumeDataset(list_volume_files("data/real"), slice_transform=transform)
fake_ds = VolumeDataset(list_volume_files("data/fake"), slice_transform=transform)
real_loader = DataLoader(real_ds, batch_size=1, shuffle=False)
fake_loader = DataLoader(fake_ds, batch_size=1, shuffle=False)
real_feats = extract_features(real_loader, model, device=device)
fake_feats = extract_features(fake_loader, model, device=device)
fid = calculate_fid(real_feats, fake_feats)
print(f"FID-3D: {fid:.4f}")You can plug in any 2D or 3D PyTorch model. The only requirement is that it returns a feature tensor shaped [batch, feature_dim] (or any shape that flattens to that). Replace build_inception_v3 with your model and keep extract_features the same, or author your own aggregation before calling calculate_fid.
- 3D medical volumes stored as
.npy,.npz,.nii, or.nii.gz - Volumes are treated as
(H, W, D)arrays; slices are taken along the last dimension and broadcast to 3 channels before running through a 2D backbone.