diff --git a/src/vbpca_py/_full_update.py b/src/vbpca_py/_full_update.py index 0257a3b..c5c7350 100644 --- a/src/vbpca_py/_full_update.py +++ b/src/vbpca_py/_full_update.py @@ -649,7 +649,7 @@ def _missing_patterns_info( # --------------------------------------------------------------------------- -def _initialize_parameters( +def _initialize_parameters( # noqa: PLR0914 ctx: InitContext, ) -> tuple[ np.ndarray, @@ -688,9 +688,10 @@ def _initialize_parameters( mu_variances = init_result.muv.reshape(-1, 1) # Priors on loadings and mu + va_init = float(cast("float", ctx.opts.get("va_init", 1000.0))) if ctx.use_prior: - va = np.full(ctx.shapes.n_components, 1000.0, dtype=float) - vmu = 1000.0 + va = np.full(ctx.shapes.n_components, va_init, dtype=float) + vmu = va_init else: va = np.full(ctx.shapes.n_components, np.inf, dtype=float) vmu = float("inf") diff --git a/src/vbpca_py/_pca_full.py b/src/vbpca_py/_pca_full.py index 0911758..8e99882 100644 --- a/src/vbpca_py/_pca_full.py +++ b/src/vbpca_py/_pca_full.py @@ -1935,6 +1935,7 @@ def _build_options(kwargs: Mapping[str, object]) -> dict[str, object]: "hp_va": 0.001, "hp_vb": 0.001, "hp_v": 0.001, + "va_init": 1000.0, "earlystop": False, "rmsstop": np.array([100, 1e-4, 1e-3]), "cfstop": np.array([]), diff --git a/src/vbpca_py/estimators.py b/src/vbpca_py/estimators.py index ca563f3..9ad4138 100644 --- a/src/vbpca_py/estimators.py +++ b/src/vbpca_py/estimators.py @@ -33,6 +33,7 @@ def __init__( # noqa: PLR0913 hp_vb: float | None = None, hp_v: float | None = None, niter_broadprior: int | None = None, + va_init: float | None = None, **opts: object, ) -> None: """ @@ -49,6 +50,8 @@ def __init__( # noqa: PLR0913 hp_v: Prior hyperparameter for noise variance (default 0.001). niter_broadprior: Number of iterations to run under the broad prior before convergence checks activate (default 100). + va_init: Initial broad prior value for loadings and bias + variances (default 1000). **opts: Additional options passed to the underlying PCA_FULL implementation. """ self.n_components = n_components @@ -60,6 +63,7 @@ def __init__( # noqa: PLR0913 self.hp_vb = hp_vb self.hp_v = hp_v self.niter_broadprior = niter_broadprior + self.va_init = va_init self.opts = opts self.components_: np.ndarray | None = None self.scores_: np.ndarray | None = None @@ -117,6 +121,8 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915 opts["hp_v"] = self.hp_v if self.niter_broadprior is not None: opts["niter_broadprior"] = self.niter_broadprior + if self.va_init is not None: + opts["va_init"] = self.va_init opts.update(self.opts) if xprobe is not None: opts["xprobe"] = xprobe @@ -249,6 +255,8 @@ def get_options(self) -> dict[str, object]: opts["hp_v"] = self.hp_v if self.niter_broadprior is not None: opts["niter_broadprior"] = self.niter_broadprior + if self.va_init is not None: + opts["va_init"] = self.va_init opts.update(self.opts) return _build_options(opts) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 1552ac7..436cda1 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -210,3 +210,44 @@ def test_niter_broadprior_affects_iteration_count() -> None: iters_default = len(r_default["lc"]["rms"]) iters_low = len(r_low["lc"]["rms"]) assert iters_low <= iters_default + + +# --------------------------------------------------------------------------- +# va_init parameter +# --------------------------------------------------------------------------- + + +def test_va_init_default_is_none() -> None: + """When not set, va_init is None and options use the library default.""" + model = VBPCA(n_components=2) + assert model.va_init is None + + opts = model.get_options() + assert opts["va_init"] == pytest.approx(1000.0) + + +def test_va_init_custom_value_propagates() -> None: + """A custom va_init should appear in resolved options.""" + model = VBPCA(n_components=2, va_init=500.0) + assert model.va_init == pytest.approx(500.0) + + opts = model.get_options() + assert opts["va_init"] == pytest.approx(500.0) + + +def test_va_init_affects_initial_prior() -> None: + """Different va_init values should produce different model fits.""" + rng = np.random.default_rng(42) + w = rng.standard_normal((10, 2)) + s = rng.standard_normal((2, 30)) + x = w @ s + 0.1 * rng.standard_normal((10, 30)) + + from vbpca_py._pca_full import pca_full + + r_default = pca_full(x, 2, bias=True, maxiters=10, va_init=1000.0) + r_custom = pca_full(x, 2, bias=True, maxiters=10, va_init=100.0) + + # The RMS traces should differ when starting from different priors + rms_default = r_default["lc"]["rms"] + rms_custom = r_custom["lc"]["rms"] + assert rms_default != rms_custom