Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/vbpca_py/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__( # noqa: PLR0913
hp_va: float | None = None,
hp_vb: float | None = None,
hp_v: float | None = None,
niter_broadprior: int | None = None,
**opts: object,
) -> None:
"""
Expand All @@ -46,6 +47,8 @@ def __init__( # noqa: PLR0913
hp_va: Prior hyperparameter for loadings variance (default 0.001).
hp_vb: Prior hyperparameter for score variance (default 0.001).
hp_v: Prior hyperparameter for noise variance (default 0.001).
niter_broadprior: Number of iterations to run under the broad
prior before convergence checks activate (default 100).
**opts: Additional options passed to the underlying PCA_FULL implementation.
"""
self.n_components = n_components
Expand All @@ -56,6 +59,7 @@ def __init__( # noqa: PLR0913
self.hp_va = hp_va
self.hp_vb = hp_vb
self.hp_v = hp_v
self.niter_broadprior = niter_broadprior
self.opts = opts
self.components_: np.ndarray | None = None
self.scores_: np.ndarray | None = None
Expand Down Expand Up @@ -111,6 +115,8 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915
opts["hp_vb"] = self.hp_vb
if self.hp_v is not None:
opts["hp_v"] = self.hp_v
if self.niter_broadprior is not None:
opts["niter_broadprior"] = self.niter_broadprior
opts.update(self.opts)
if xprobe is not None:
opts["xprobe"] = xprobe
Expand Down Expand Up @@ -241,6 +247,8 @@ def get_options(self) -> dict[str, object]:
opts["hp_vb"] = self.hp_vb
if self.hp_v is not None:
opts["hp_v"] = self.hp_v
if self.niter_broadprior is not None:
opts["niter_broadprior"] = self.niter_broadprior
opts.update(self.opts)
return _build_options(opts)

Expand Down
35 changes: 35 additions & 0 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,38 @@ def test_hp_params_affect_fit() -> None:
assert m_default.noise_variance_ != pytest.approx(
m_strong.noise_variance_, rel=1e-3
)


def test_niter_broadprior_stored_on_estimator() -> None:
"""niter_broadprior should be stored and passed through to options."""
model = VBPCA(n_components=2, niter_broadprior=10)
assert model.niter_broadprior == 10

opts = model.get_options()
assert opts["niter_broadprior"] == 10


def test_niter_broadprior_default_is_none() -> None:
"""When not set, niter_broadprior is None and options use library default."""
model = VBPCA(n_components=2)
assert model.niter_broadprior is None

opts = model.get_options()
assert opts["niter_broadprior"] == 100


def test_niter_broadprior_affects_iteration_count() -> None:
"""Lower niter_broadprior should allow earlier convergence."""
rng = np.random.default_rng(42)
w = rng.standard_normal((10, 2))
s = rng.standard_normal((2, 30))
x = w @ s + 0.1 * rng.standard_normal((10, 30))

from vbpca_py._pca_full import pca_full

r_default = pca_full(x, 2, bias=True, maxiters=200, niter_broadprior=100)
r_low = pca_full(x, 2, bias=True, maxiters=200, niter_broadprior=5)

iters_default = len(r_default["lc"]["rms"])
iters_low = len(r_low["lc"]["rms"])
assert iters_low <= iters_default
Loading