Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/vbpca_py/_full_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def _missing_patterns_info(
# ---------------------------------------------------------------------------


def _initialize_parameters(
def _initialize_parameters( # noqa: PLR0914
ctx: InitContext,
) -> tuple[
np.ndarray,
Expand Down Expand Up @@ -688,9 +688,10 @@ def _initialize_parameters(
mu_variances = init_result.muv.reshape(-1, 1)

# Priors on loadings and mu
va_init = float(cast("float", ctx.opts.get("va_init", 1000.0)))
if ctx.use_prior:
va = np.full(ctx.shapes.n_components, 1000.0, dtype=float)
vmu = 1000.0
va = np.full(ctx.shapes.n_components, va_init, dtype=float)
vmu = va_init
else:
va = np.full(ctx.shapes.n_components, np.inf, dtype=float)
vmu = float("inf")
Expand Down
1 change: 1 addition & 0 deletions src/vbpca_py/_pca_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,7 @@ def _build_options(kwargs: Mapping[str, object]) -> dict[str, object]:
"hp_va": 0.001,
"hp_vb": 0.001,
"hp_v": 0.001,
"va_init": 1000.0,
"earlystop": False,
"rmsstop": np.array([100, 1e-4, 1e-3]),
"cfstop": np.array([]),
Expand Down
8 changes: 8 additions & 0 deletions src/vbpca_py/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__( # noqa: PLR0913
hp_vb: float | None = None,
hp_v: float | None = None,
niter_broadprior: int | None = None,
va_init: float | None = None,
**opts: object,
) -> None:
"""
Expand All @@ -49,6 +50,8 @@ def __init__( # noqa: PLR0913
hp_v: Prior hyperparameter for noise variance (default 0.001).
niter_broadprior: Number of iterations to run under the broad
prior before convergence checks activate (default 100).
va_init: Initial broad prior value for loadings and bias
variances (default 1000).
**opts: Additional options passed to the underlying PCA_FULL implementation.
"""
self.n_components = n_components
Expand All @@ -60,6 +63,7 @@ def __init__( # noqa: PLR0913
self.hp_vb = hp_vb
self.hp_v = hp_v
self.niter_broadprior = niter_broadprior
self.va_init = va_init
self.opts = opts
self.components_: np.ndarray | None = None
self.scores_: np.ndarray | None = None
Expand Down Expand Up @@ -117,6 +121,8 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915
opts["hp_v"] = self.hp_v
if self.niter_broadprior is not None:
opts["niter_broadprior"] = self.niter_broadprior
if self.va_init is not None:
opts["va_init"] = self.va_init
opts.update(self.opts)
if xprobe is not None:
opts["xprobe"] = xprobe
Expand Down Expand Up @@ -249,6 +255,8 @@ def get_options(self) -> dict[str, object]:
opts["hp_v"] = self.hp_v
if self.niter_broadprior is not None:
opts["niter_broadprior"] = self.niter_broadprior
if self.va_init is not None:
opts["va_init"] = self.va_init
opts.update(self.opts)
return _build_options(opts)

Expand Down
41 changes: 41 additions & 0 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,44 @@ def test_niter_broadprior_affects_iteration_count() -> None:
iters_default = len(r_default["lc"]["rms"])
iters_low = len(r_low["lc"]["rms"])
assert iters_low <= iters_default


# ---------------------------------------------------------------------------
# va_init parameter
# ---------------------------------------------------------------------------


def test_va_init_default_is_none() -> None:
"""When not set, va_init is None and options use the library default."""
model = VBPCA(n_components=2)
assert model.va_init is None

opts = model.get_options()
assert opts["va_init"] == pytest.approx(1000.0)


def test_va_init_custom_value_propagates() -> None:
"""A custom va_init should appear in resolved options."""
model = VBPCA(n_components=2, va_init=500.0)
assert model.va_init == pytest.approx(500.0)

opts = model.get_options()
assert opts["va_init"] == pytest.approx(500.0)


def test_va_init_affects_initial_prior() -> None:
"""Different va_init values should produce different model fits."""
rng = np.random.default_rng(42)
w = rng.standard_normal((10, 2))
s = rng.standard_normal((2, 30))
x = w @ s + 0.1 * rng.standard_normal((10, 30))

from vbpca_py._pca_full import pca_full

r_default = pca_full(x, 2, bias=True, maxiters=10, va_init=1000.0)
r_custom = pca_full(x, 2, bias=True, maxiters=10, va_init=100.0)

# The RMS traces should differ when starting from different priors
rms_default = r_default["lc"]["rms"]
rms_custom = r_custom["lc"]["rms"]
assert rms_default != rms_custom
Loading