|
| 1 | +"""Run a synthetic DSC-MRI Bayesian fitting demonstration.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import math |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from src.bayesian_fitter import DSCBayesianFitter, gamma_variate_curve |
| 12 | + |
| 13 | + |
| 14 | +def main() -> None: |
| 15 | + """Generate synthetic DSC data, fit it, and save a comparison plot.""" |
| 16 | + rng = np.random.default_rng(seed=42) |
| 17 | + t = np.linspace(0.0, 50.0, 50) |
| 18 | + |
| 19 | + true_params = { |
| 20 | + "A": 5.0, |
| 21 | + "alpha": 3.0, |
| 22 | + "beta": 1.5, |
| 23 | + "t0": 10.0, |
| 24 | + } |
| 25 | + |
| 26 | + clean_curve = gamma_variate_curve( |
| 27 | + t, |
| 28 | + amplitude=true_params["A"], |
| 29 | + alpha=true_params["alpha"], |
| 30 | + beta=true_params["beta"], |
| 31 | + t0=true_params["t0"], |
| 32 | + ) |
| 33 | + noise_sigma = 0.35 |
| 34 | + c_obs = clean_curve + rng.normal(loc=0.0, scale=noise_sigma, size=t.shape) |
| 35 | + |
| 36 | + stan_path = Path("src") / "stan_models" / "gamma_variate.stan" |
| 37 | + fitter = DSCBayesianFitter(stan_file=stan_path) |
| 38 | + fit_result = fitter.fit_curve(t, c_obs, num_samples=1000) |
| 39 | + |
| 40 | + fitted_curve = gamma_variate_curve( |
| 41 | + t, |
| 42 | + amplitude=fit_result["A"], |
| 43 | + alpha=fit_result["alpha"], |
| 44 | + beta=fit_result["beta"], |
| 45 | + t0=fit_result["t0"], |
| 46 | + ) |
| 47 | + true_cbv = float( |
| 48 | + true_params["A"] |
| 49 | + * math.gamma(true_params["alpha"] + 1.0) |
| 50 | + * (true_params["beta"] ** (true_params["alpha"] + 1.0)) |
| 51 | + ) |
| 52 | + true_mtt = float(true_params["alpha"] * true_params["beta"]) |
| 53 | + |
| 54 | + print("True parameters:") |
| 55 | + for key, value in true_params.items(): |
| 56 | + print(f" {key}: {value:.4f}") |
| 57 | + print(f" cbv: {true_cbv:.4f}") |
| 58 | + print(f" mtt: {true_mtt:.4f}") |
| 59 | + |
| 60 | + print("\nPosterior means with 95% HDI:") |
| 61 | + for key in ("A", "alpha", "beta", "t0", "sigma", "cbv", "mtt"): |
| 62 | + print( |
| 63 | + f" {key}: {fit_result[key]:.4f} " |
| 64 | + f"[{fit_result[f'{key}_hdi_lower']:.4f}, {fit_result[f'{key}_hdi_upper']:.4f}]" |
| 65 | + ) |
| 66 | + |
| 67 | + plt.figure(figsize=(10, 6)) |
| 68 | + plt.scatter(t, c_obs, color="tab:blue", label="Noisy observations", alpha=0.8) |
| 69 | + plt.plot(t, clean_curve, color="tab:green", linewidth=2, label="Ground truth") |
| 70 | + plt.plot(t, fitted_curve, color="tab:red", linewidth=2, label="Posterior mean fit") |
| 71 | + plt.xlabel("Time (s)") |
| 72 | + plt.ylabel("Concentration") |
| 73 | + plt.title("Bayesian DSC-MRI Gamma-Variate Fit") |
| 74 | + plt.legend() |
| 75 | + plt.tight_layout() |
| 76 | + plt.savefig("fit_result.png", dpi=200) |
| 77 | + plt.close() |
| 78 | + |
| 79 | + |
| 80 | +if __name__ == "__main__": |
| 81 | + main() |
0 commit comments