From 66b93e1b3f45ae48b4e14dd6dfd651ea4909de2d Mon Sep 17 00:00:00 2001 From: "J.C. Macdonald" <72512262+jc-macdonald@users.noreply.github.com> Date: Sun, 12 Apr 2026 19:19:07 -0400 Subject: [PATCH] feat(estimators): expose niter_broadprior on VBPCA constructor Add niter_broadprior as a named keyword argument on VBPCA.__init__() so users can control the broad-prior warmup phase without resorting to **opts. Default remains 100 (backward compatible). Wire the parameter through fit() and get_options(). - Add 3 tests: stored value, default fallback, iteration count effect Closes #51 --- src/vbpca_py/estimators.py | 8 ++++++++ tests/test_estimators.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/vbpca_py/estimators.py b/src/vbpca_py/estimators.py index 821ff84..ca563f3 100644 --- a/src/vbpca_py/estimators.py +++ b/src/vbpca_py/estimators.py @@ -32,6 +32,7 @@ def __init__( # noqa: PLR0913 hp_va: float | None = None, hp_vb: float | None = None, hp_v: float | None = None, + niter_broadprior: int | None = None, **opts: object, ) -> None: """ @@ -46,6 +47,8 @@ def __init__( # noqa: PLR0913 hp_va: Prior hyperparameter for loadings variance (default 0.001). hp_vb: Prior hyperparameter for score variance (default 0.001). hp_v: Prior hyperparameter for noise variance (default 0.001). + niter_broadprior: Number of iterations to run under the broad + prior before convergence checks activate (default 100). **opts: Additional options passed to the underlying PCA_FULL implementation. """ self.n_components = n_components @@ -56,6 +59,7 @@ def __init__( # noqa: PLR0913 self.hp_va = hp_va self.hp_vb = hp_vb self.hp_v = hp_v + self.niter_broadprior = niter_broadprior self.opts = opts self.components_: np.ndarray | None = None self.scores_: np.ndarray | None = None @@ -111,6 +115,8 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915 opts["hp_vb"] = self.hp_vb if self.hp_v is not None: opts["hp_v"] = self.hp_v + if self.niter_broadprior is not None: + opts["niter_broadprior"] = self.niter_broadprior opts.update(self.opts) if xprobe is not None: opts["xprobe"] = xprobe @@ -241,6 +247,8 @@ def get_options(self) -> dict[str, object]: opts["hp_vb"] = self.hp_vb if self.hp_v is not None: opts["hp_v"] = self.hp_v + if self.niter_broadprior is not None: + opts["niter_broadprior"] = self.niter_broadprior opts.update(self.opts) return _build_options(opts) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 5e66003..1552ac7 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -175,3 +175,38 @@ def test_hp_params_affect_fit() -> None: assert m_default.noise_variance_ != pytest.approx( m_strong.noise_variance_, rel=1e-3 ) + + +def test_niter_broadprior_stored_on_estimator() -> None: + """niter_broadprior should be stored and passed through to options.""" + model = VBPCA(n_components=2, niter_broadprior=10) + assert model.niter_broadprior == 10 + + opts = model.get_options() + assert opts["niter_broadprior"] == 10 + + +def test_niter_broadprior_default_is_none() -> None: + """When not set, niter_broadprior is None and options use library default.""" + model = VBPCA(n_components=2) + assert model.niter_broadprior is None + + opts = model.get_options() + assert opts["niter_broadprior"] == 100 + + +def test_niter_broadprior_affects_iteration_count() -> None: + """Lower niter_broadprior should allow earlier convergence.""" + rng = np.random.default_rng(42) + w = rng.standard_normal((10, 2)) + s = rng.standard_normal((2, 30)) + x = w @ s + 0.1 * rng.standard_normal((10, 30)) + + from vbpca_py._pca_full import pca_full + + r_default = pca_full(x, 2, bias=True, maxiters=200, niter_broadprior=100) + r_low = pca_full(x, 2, bias=True, maxiters=200, niter_broadprior=5) + + iters_default = len(r_default["lc"]["rms"]) + iters_low = len(r_low["lc"]["rms"]) + assert iters_low <= iters_default