From f701a43267be76fe9c4e6eeb1b3161ceacd9f9ad Mon Sep 17 00:00:00 2001 From: Joshua Date: Mon, 13 Apr 2026 15:39:17 -0400 Subject: [PATCH 1/4] Expose initial broad prior value (va_init) as constructor parameter - Add va_init option to _build_options defaults (1000.0) - Thread va_init through _initialize_parameters in _full_update.py - Add va_init parameter to VBPCA.__init__, fit(), and get_options() - Add tests for default, propagation, and behavioral effect Closes #94 --- src/vbpca_py/_full_update.py | 5 +++-- src/vbpca_py/_pca_full.py | 1 + src/vbpca_py/estimators.py | 8 +++++++ tests/test_estimators.py | 41 ++++++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/vbpca_py/_full_update.py b/src/vbpca_py/_full_update.py index 0257a3b..c8a4ca6 100644 --- a/src/vbpca_py/_full_update.py +++ b/src/vbpca_py/_full_update.py @@ -688,9 +688,10 @@ def _initialize_parameters( mu_variances = init_result.muv.reshape(-1, 1) # Priors on loadings and mu + va_init = float(ctx.opts.get("va_init", 1000.0)) if ctx.use_prior: - va = np.full(ctx.shapes.n_components, 1000.0, dtype=float) - vmu = 1000.0 + va = np.full(ctx.shapes.n_components, va_init, dtype=float) + vmu = va_init else: va = np.full(ctx.shapes.n_components, np.inf, dtype=float) vmu = float("inf") diff --git a/src/vbpca_py/_pca_full.py b/src/vbpca_py/_pca_full.py index 0911758..8e99882 100644 --- a/src/vbpca_py/_pca_full.py +++ b/src/vbpca_py/_pca_full.py @@ -1935,6 +1935,7 @@ def _build_options(kwargs: Mapping[str, object]) -> dict[str, object]: "hp_va": 0.001, "hp_vb": 0.001, "hp_v": 0.001, + "va_init": 1000.0, "earlystop": False, "rmsstop": np.array([100, 1e-4, 1e-3]), "cfstop": np.array([]), diff --git a/src/vbpca_py/estimators.py b/src/vbpca_py/estimators.py index ca563f3..9ad4138 100644 --- a/src/vbpca_py/estimators.py +++ b/src/vbpca_py/estimators.py @@ -33,6 +33,7 @@ def __init__( # noqa: PLR0913 hp_vb: float | None = None, hp_v: float | None = None, niter_broadprior: int | None = None, + va_init: float | None = None, **opts: object, ) -> None: """ @@ -49,6 +50,8 @@ def __init__( # noqa: PLR0913 hp_v: Prior hyperparameter for noise variance (default 0.001). niter_broadprior: Number of iterations to run under the broad prior before convergence checks activate (default 100). + va_init: Initial broad prior value for loadings and bias + variances (default 1000). **opts: Additional options passed to the underlying PCA_FULL implementation. """ self.n_components = n_components @@ -60,6 +63,7 @@ def __init__( # noqa: PLR0913 self.hp_vb = hp_vb self.hp_v = hp_v self.niter_broadprior = niter_broadprior + self.va_init = va_init self.opts = opts self.components_: np.ndarray | None = None self.scores_: np.ndarray | None = None @@ -117,6 +121,8 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915 opts["hp_v"] = self.hp_v if self.niter_broadprior is not None: opts["niter_broadprior"] = self.niter_broadprior + if self.va_init is not None: + opts["va_init"] = self.va_init opts.update(self.opts) if xprobe is not None: opts["xprobe"] = xprobe @@ -249,6 +255,8 @@ def get_options(self) -> dict[str, object]: opts["hp_v"] = self.hp_v if self.niter_broadprior is not None: opts["niter_broadprior"] = self.niter_broadprior + if self.va_init is not None: + opts["va_init"] = self.va_init opts.update(self.opts) return _build_options(opts) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 1552ac7..68d3ea0 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -210,3 +210,44 @@ def test_niter_broadprior_affects_iteration_count() -> None: iters_default = len(r_default["lc"]["rms"]) iters_low = len(r_low["lc"]["rms"]) assert iters_low <= iters_default + + +# --------------------------------------------------------------------------- +# va_init parameter +# --------------------------------------------------------------------------- + + +def test_va_init_default_is_none() -> None: + """When not set, va_init is None and options use the library default.""" + model = VBPCA(n_components=2) + assert model.va_init is None + + opts = model.get_options() + assert opts["va_init"] == 1000.0 + + +def test_va_init_custom_value_propagates() -> None: + """A custom va_init should appear in resolved options.""" + model = VBPCA(n_components=2, va_init=500.0) + assert model.va_init == 500.0 + + opts = model.get_options() + assert opts["va_init"] == 500.0 + + +def test_va_init_affects_initial_prior() -> None: + """Different va_init values should produce different model fits.""" + rng = np.random.default_rng(42) + w = rng.standard_normal((10, 2)) + s = rng.standard_normal((2, 30)) + x = w @ s + 0.1 * rng.standard_normal((10, 30)) + + from vbpca_py._pca_full import pca_full + + r_default = pca_full(x, 2, bias=True, maxiters=10, va_init=1000.0) + r_custom = pca_full(x, 2, bias=True, maxiters=10, va_init=100.0) + + # The RMS traces should differ when starting from different priors + rms_default = r_default["lc"]["rms"] + rms_custom = r_custom["lc"]["rms"] + assert rms_default != rms_custom From c594ae7fbfd2ecb14db5fa9949c7e5139b5928e1 Mon Sep 17 00:00:00 2001 From: Joshua Date: Mon, 13 Apr 2026 15:39:48 -0400 Subject: [PATCH 2/4] Fix lint: inline va_init to avoid PLR0914, use pytest.approx for RUF069 --- src/vbpca_py/_full_update.py | 9 ++++++--- tests/test_estimators.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/vbpca_py/_full_update.py b/src/vbpca_py/_full_update.py index c8a4ca6..48b8a76 100644 --- a/src/vbpca_py/_full_update.py +++ b/src/vbpca_py/_full_update.py @@ -688,10 +688,13 @@ def _initialize_parameters( mu_variances = init_result.muv.reshape(-1, 1) # Priors on loadings and mu - va_init = float(ctx.opts.get("va_init", 1000.0)) if ctx.use_prior: - va = np.full(ctx.shapes.n_components, va_init, dtype=float) - vmu = va_init + va = np.full( + ctx.shapes.n_components, + float(ctx.opts.get("va_init", 1000.0)), + dtype=float, + ) + vmu = float(ctx.opts.get("va_init", 1000.0)) else: va = np.full(ctx.shapes.n_components, np.inf, dtype=float) vmu = float("inf") diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 68d3ea0..436cda1 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -223,16 +223,16 @@ def test_va_init_default_is_none() -> None: assert model.va_init is None opts = model.get_options() - assert opts["va_init"] == 1000.0 + assert opts["va_init"] == pytest.approx(1000.0) def test_va_init_custom_value_propagates() -> None: """A custom va_init should appear in resolved options.""" model = VBPCA(n_components=2, va_init=500.0) - assert model.va_init == 500.0 + assert model.va_init == pytest.approx(500.0) opts = model.get_options() - assert opts["va_init"] == 500.0 + assert opts["va_init"] == pytest.approx(500.0) def test_va_init_affects_initial_prior() -> None: From 9858392c072725ba046e5be3b53d7991667b34b9 Mon Sep 17 00:00:00 2001 From: Joshua Date: Mon, 13 Apr 2026 15:40:44 -0400 Subject: [PATCH 3/4] Fix mypy: cast opts.get for va_init, suppress PLR0914 on function def --- src/vbpca_py/_full_update.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/vbpca_py/_full_update.py b/src/vbpca_py/_full_update.py index 48b8a76..dbcbd1e 100644 --- a/src/vbpca_py/_full_update.py +++ b/src/vbpca_py/_full_update.py @@ -649,7 +649,7 @@ def _missing_patterns_info( # --------------------------------------------------------------------------- -def _initialize_parameters( +def _initialize_parameters( # noqa: PLR0914 ctx: InitContext, ) -> tuple[ np.ndarray, @@ -688,13 +688,10 @@ def _initialize_parameters( mu_variances = init_result.muv.reshape(-1, 1) # Priors on loadings and mu + _va_init = float(cast("float", ctx.opts.get("va_init", 1000.0))) if ctx.use_prior: - va = np.full( - ctx.shapes.n_components, - float(ctx.opts.get("va_init", 1000.0)), - dtype=float, - ) - vmu = float(ctx.opts.get("va_init", 1000.0)) + va = np.full(ctx.shapes.n_components, _va_init, dtype=float) + vmu = _va_init else: va = np.full(ctx.shapes.n_components, np.inf, dtype=float) vmu = float("inf") From e34a26e2db751a239722ce3b97ec88b8b02a1ee0 Mon Sep 17 00:00:00 2001 From: Joshua Date: Mon, 13 Apr 2026 15:41:14 -0400 Subject: [PATCH 4/4] Fix lint: drop leading underscore from va_init local (RUF052) --- src/vbpca_py/_full_update.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vbpca_py/_full_update.py b/src/vbpca_py/_full_update.py index dbcbd1e..c5c7350 100644 --- a/src/vbpca_py/_full_update.py +++ b/src/vbpca_py/_full_update.py @@ -688,10 +688,10 @@ def _initialize_parameters( # noqa: PLR0914 mu_variances = init_result.muv.reshape(-1, 1) # Priors on loadings and mu - _va_init = float(cast("float", ctx.opts.get("va_init", 1000.0))) + va_init = float(cast("float", ctx.opts.get("va_init", 1000.0))) if ctx.use_prior: - va = np.full(ctx.shapes.n_components, _va_init, dtype=float) - vmu = _va_init + va = np.full(ctx.shapes.n_components, va_init, dtype=float) + vmu = va_init else: va = np.full(ctx.shapes.n_components, np.inf, dtype=float) vmu = float("inf")