diff --git a/src/vbpca_py/__init__.py b/src/vbpca_py/__init__.py index c2ec9f2..e325d86 100644 --- a/src/vbpca_py/__init__.py +++ b/src/vbpca_py/__init__.py @@ -2,6 +2,7 @@ from importlib.metadata import version +from vbpca_py._missing import make_xprobe_mask from vbpca_py.estimators import VBPCA from vbpca_py.model_selection import ( CVConfig, @@ -37,6 +38,7 @@ "SelectionConfig", "check_data", "cross_validate_components", + "make_xprobe_mask", "select_n_components", ] diff --git a/src/vbpca_py/_missing.py b/src/vbpca_py/_missing.py index 16e5134..7150364 100644 --- a/src/vbpca_py/_missing.py +++ b/src/vbpca_py/_missing.py @@ -13,6 +13,7 @@ import logging import numpy as np +import scipy.sparse as sp logger = logging.getLogger(__name__) @@ -104,3 +105,100 @@ def _missing_patterns( ) return n_patterns, pattern_columns, column_pattern_index + + +def _xprobe_sparse( + x: sp.csr_matrix, + fraction: float, + rng: np.random.Generator, +) -> tuple[sp.csr_matrix, sp.csr_matrix]: + """Hold out probe entries from a sparse matrix. + + Returns: + Tuple of (x_masked, xprobe) as CSR matrices. + """ + x_csr = sp.csr_matrix(x, copy=True) + n_probe = max(1, round(x_csr.nnz * fraction)) + probe_idx = rng.choice(x_csr.nnz, size=n_probe, replace=False) + + rows, cols = x_csr.nonzero() + sp_rows = rows[probe_idx] + sp_cols = cols[probe_idx] + sp_vals = np.array(x_csr[sp_rows, sp_cols]).ravel() + + xprobe_sp = sp.lil_matrix(x_csr.shape, dtype=float) + for r, c, v in zip(sp_rows, sp_cols, sp_vals, strict=True): + xprobe_sp[r, c] = v + xprobe = sp.csr_matrix(xprobe_sp) + + for r, c in zip(sp_rows, sp_cols, strict=True): + x_csr[r, c] = 0.0 + x_csr.eliminate_zeros() + x_csr.sort_indices() + return x_csr, xprobe + + +def _xprobe_dense( + x: np.ndarray, + fraction: float, + rng: np.random.Generator, +) -> tuple[np.ndarray, np.ndarray]: + """Hold out probe entries from a dense matrix. + + Returns: + Tuple of (x_masked, xprobe) as dense arrays. + """ + x_dense = np.array(x, dtype=float, copy=True) + obs_rows, obs_cols = np.nonzero(~np.isnan(x_dense)) + n_probe = max(1, round(len(obs_rows) * fraction)) + probe_idx = rng.choice(len(obs_rows), size=n_probe, replace=False) + + probe_rows = obs_rows[probe_idx] + probe_cols = obs_cols[probe_idx] + + xprobe = np.full(x_dense.shape, np.nan, dtype=float) + xprobe[probe_rows, probe_cols] = x_dense[probe_rows, probe_cols] + x_dense[probe_rows, probe_cols] = np.nan + return x_dense, xprobe + + +def make_xprobe_mask( + x: np.ndarray | sp.csr_matrix, + fraction: float = 0.10, + rng: np.random.Generator | None = None, +) -> tuple[np.ndarray | sp.csr_matrix, np.ndarray | sp.csr_matrix]: + """Hold out a fraction of observed entries as probe data. + + Selects a random subset of observed entries (non-NaN for dense, + structurally non-zero for sparse) and returns a modified data matrix + with those entries masked out plus a probe matrix containing only the + held-out values. + + Args: + x: Data matrix of shape ``(n_features, n_samples)``. + Dense or sparse (CSR). + fraction: Fraction of observed entries to hold out (default 0.10). + Must be in ``(0, 1)``. + rng: NumPy random generator. If ``None``, a new default generator + is created. + + Returns: + x_masked: Copy of *x* with probe entries set to NaN (dense) or + removed (sparse). + xprobe: Matrix of the same shape as *x* containing only the + held-out probe values; all other entries are NaN (dense) or + absent (sparse). + + Raises: + ValueError: If *fraction* is not in ``(0, 1)``. + """ + if not 0.0 < fraction < 1.0: + msg = f"fraction must be in (0, 1), got {fraction}" + raise ValueError(msg) + + if rng is None: + rng = np.random.default_rng() + + if sp.issparse(x): + return _xprobe_sparse(sp.csr_matrix(x), fraction, rng) + return _xprobe_dense(np.asarray(x), fraction, rng) diff --git a/src/vbpca_py/estimators.py b/src/vbpca_py/estimators.py index 9ad4138..fee2c86 100644 --- a/src/vbpca_py/estimators.py +++ b/src/vbpca_py/estimators.py @@ -12,6 +12,7 @@ format_bytes, resolve_max_dense_bytes, ) +from vbpca_py._missing import make_xprobe_mask from vbpca_py._pca_full import Matrix, _build_options, pca_full from vbpca_py.model_selection import SelectionConfig, select_n_components @@ -34,6 +35,7 @@ def __init__( # noqa: PLR0913 hp_v: float | None = None, niter_broadprior: int | None = None, va_init: float | None = None, + xprobe_fraction: float = 0.0, **opts: object, ) -> None: """ @@ -52,6 +54,10 @@ def __init__( # noqa: PLR0913 prior before convergence checks activate (default 100). va_init: Initial broad prior value for loadings and bias variances (default 1000). + xprobe_fraction: Fraction of observed entries to hold out as + probe data (default 0.0, disabled). When positive and no + explicit *xprobe* is passed to :meth:`fit`, a random probe + set is generated automatically. **opts: Additional options passed to the underlying PCA_FULL implementation. """ self.n_components = n_components @@ -64,6 +70,7 @@ def __init__( # noqa: PLR0913 self.hp_v = hp_v self.niter_broadprior = niter_broadprior self.va_init = va_init + self.xprobe_fraction = xprobe_fraction self.opts = opts self.components_: np.ndarray | None = None self.scores_: np.ndarray | None = None @@ -126,6 +133,9 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915 opts.update(self.opts) if xprobe is not None: opts["xprobe"] = xprobe + elif self.xprobe_fraction > 0.0: + x, xprobe_gen = make_xprobe_mask(x, fraction=self.xprobe_fraction) + opts["xprobe"] = xprobe_gen max_dense_bytes = resolve_max_dense_bytes( opts.get("max_dense_bytes", 2_000_000_000) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 436cda1..4a3eae9 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -251,3 +251,50 @@ def test_va_init_affects_initial_prior() -> None: rms_default = r_default["lc"]["rms"] rms_custom = r_custom["lc"]["rms"] assert rms_default != rms_custom + + +# --------------------------------------------------------------------------- +# xprobe_fraction parameter +# --------------------------------------------------------------------------- + + +def test_xprobe_fraction_default_is_zero() -> None: + """Default xprobe_fraction is 0.0 (disabled).""" + model = VBPCA(n_components=2) + assert model.xprobe_fraction == pytest.approx(0.0) + + +def test_xprobe_fraction_generates_probe() -> None: + """When xprobe_fraction > 0, fit() auto-generates a probe set.""" + rng = np.random.default_rng(42) + x = rng.standard_normal((8, 20)) + model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.10) + model.fit(x) + # prms should be a real number (not NaN) when probe is active + assert model.prms_ is not None + assert np.isfinite(model.prms_) + + +def test_xprobe_fraction_explicit_xprobe_takes_precedence() -> None: + """An explicit xprobe passed to fit() overrides xprobe_fraction.""" + rng = np.random.default_rng(42) + x = rng.standard_normal((8, 20)) + xprobe = np.full(x.shape, np.nan, dtype=float) + xprobe[0, 0] = x[0, 0] + + model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.10) + model.fit(x, xprobe=xprobe) + # Should still produce a finite prms from the explicit probe + assert model.prms_ is not None + assert np.isfinite(model.prms_) + + +def test_xprobe_fraction_no_probe_when_zero() -> None: + """With xprobe_fraction=0.0 and no xprobe, prms should be NaN.""" + rng = np.random.default_rng(42) + x = rng.standard_normal((8, 20)) + model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.0) + model.fit(x) + # No probe -> prms is NaN + assert model.prms_ is not None + assert np.isnan(model.prms_) diff --git a/tests/test_missing.py b/tests/test_missing.py index cb02cb0..893fdf5 100644 --- a/tests/test_missing.py +++ b/tests/test_missing.py @@ -3,9 +3,10 @@ import numpy as np import pytest +import scipy.sparse as sp import vbpca_py._missing as missing_mod -from vbpca_py._missing import _missing_patterns +from vbpca_py._missing import _missing_patterns, make_xprobe_mask # --------------------------------------------------------------------------- # Happy paths @@ -157,3 +158,71 @@ def test_missing_patterns_bad_dim_raises(bad_mask: np.ndarray) -> None: """_missing_patterns should raise when mask is not 2D.""" with pytest.raises(ValueError, match=r"mask must be a 2D array."): _missing_patterns(bad_mask) + + +# --------------------------------------------------------------------------- +# make_xprobe_mask +# --------------------------------------------------------------------------- + + +def test_make_xprobe_mask_dense_basic() -> None: + """Dense: probe entries are held out and data is masked.""" + rng = np.random.default_rng(0) + x = rng.standard_normal((6, 10)) + x_masked, xprobe = make_xprobe_mask(x, fraction=0.2, rng=np.random.default_rng(0)) + + # xprobe has some non-NaN entries + probe_obs = ~np.isnan(xprobe) + assert probe_obs.any() + + # Those entries are NaN in x_masked + assert np.all(np.isnan(x_masked[probe_obs])) + + # Non-probe entries are unchanged + non_probe = ~probe_obs + np.testing.assert_array_equal(x_masked[non_probe], x[non_probe]) + + # Probe values match original + np.testing.assert_array_equal(xprobe[probe_obs], x[probe_obs]) + + +def test_make_xprobe_mask_dense_with_existing_nans() -> None: + """Dense: existing NaN entries are not selected as probes.""" + rng = np.random.default_rng(1) + x = rng.standard_normal((6, 10)) + x[0, 0] = np.nan + x[2, 3] = np.nan + + x_masked, xprobe = make_xprobe_mask(x, fraction=0.1, rng=np.random.default_rng(1)) + + # Original NaN positions remain NaN in both + assert np.isnan(x_masked[0, 0]) + assert np.isnan(xprobe[0, 0]) + + +def test_make_xprobe_mask_sparse() -> None: + """Sparse CSR: probe entries are removed from data.""" + rng = np.random.default_rng(2) + dense = rng.standard_normal((6, 10)) + dense[dense < 0.3] = 0.0 + x = sp.csr_matrix(dense) + nnz_before = x.nnz + + x_masked, xprobe = make_xprobe_mask(x, fraction=0.2, rng=np.random.default_rng(2)) + + assert sp.issparse(x_masked) + assert sp.issparse(xprobe) + assert xprobe.nnz > 0 + assert x_masked.nnz < nnz_before + assert x_masked.nnz + xprobe.nnz == nnz_before + + +def test_make_xprobe_mask_bad_fraction_raises() -> None: + """Fraction outside (0, 1) should raise ValueError.""" + x = np.ones((3, 3)) + with pytest.raises(ValueError, match="fraction must be in"): + make_xprobe_mask(x, fraction=0.0) + with pytest.raises(ValueError, match="fraction must be in"): + make_xprobe_mask(x, fraction=1.0) + with pytest.raises(ValueError, match="fraction must be in"): + make_xprobe_mask(x, fraction=-0.1)