Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/vbpca_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from importlib.metadata import version

from vbpca_py._missing import make_xprobe_mask
from vbpca_py.estimators import VBPCA
from vbpca_py.model_selection import (
CVConfig,
Expand Down Expand Up @@ -37,6 +38,7 @@
"SelectionConfig",
"check_data",
"cross_validate_components",
"make_xprobe_mask",
"select_n_components",
]

Expand Down
98 changes: 98 additions & 0 deletions src/vbpca_py/_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging

import numpy as np
import scipy.sparse as sp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,3 +105,100 @@ def _missing_patterns(
)

return n_patterns, pattern_columns, column_pattern_index


def _xprobe_sparse(
x: sp.csr_matrix,
fraction: float,
rng: np.random.Generator,
) -> tuple[sp.csr_matrix, sp.csr_matrix]:
"""Hold out probe entries from a sparse matrix.

Returns:
Tuple of (x_masked, xprobe) as CSR matrices.
"""
x_csr = sp.csr_matrix(x, copy=True)
n_probe = max(1, round(x_csr.nnz * fraction))
probe_idx = rng.choice(x_csr.nnz, size=n_probe, replace=False)

rows, cols = x_csr.nonzero()
sp_rows = rows[probe_idx]
sp_cols = cols[probe_idx]
sp_vals = np.array(x_csr[sp_rows, sp_cols]).ravel()

xprobe_sp = sp.lil_matrix(x_csr.shape, dtype=float)
for r, c, v in zip(sp_rows, sp_cols, sp_vals, strict=True):
xprobe_sp[r, c] = v
xprobe = sp.csr_matrix(xprobe_sp)

for r, c in zip(sp_rows, sp_cols, strict=True):
x_csr[r, c] = 0.0
x_csr.eliminate_zeros()
x_csr.sort_indices()
return x_csr, xprobe


def _xprobe_dense(
x: np.ndarray,
fraction: float,
rng: np.random.Generator,
) -> tuple[np.ndarray, np.ndarray]:
"""Hold out probe entries from a dense matrix.

Returns:
Tuple of (x_masked, xprobe) as dense arrays.
"""
x_dense = np.array(x, dtype=float, copy=True)
obs_rows, obs_cols = np.nonzero(~np.isnan(x_dense))
n_probe = max(1, round(len(obs_rows) * fraction))
probe_idx = rng.choice(len(obs_rows), size=n_probe, replace=False)

probe_rows = obs_rows[probe_idx]
probe_cols = obs_cols[probe_idx]

xprobe = np.full(x_dense.shape, np.nan, dtype=float)
xprobe[probe_rows, probe_cols] = x_dense[probe_rows, probe_cols]
x_dense[probe_rows, probe_cols] = np.nan
return x_dense, xprobe


def make_xprobe_mask(
x: np.ndarray | sp.csr_matrix,
fraction: float = 0.10,
rng: np.random.Generator | None = None,
) -> tuple[np.ndarray | sp.csr_matrix, np.ndarray | sp.csr_matrix]:
"""Hold out a fraction of observed entries as probe data.

Selects a random subset of observed entries (non-NaN for dense,
structurally non-zero for sparse) and returns a modified data matrix
with those entries masked out plus a probe matrix containing only the
held-out values.

Args:
x: Data matrix of shape ``(n_features, n_samples)``.
Dense or sparse (CSR).
fraction: Fraction of observed entries to hold out (default 0.10).
Must be in ``(0, 1)``.
rng: NumPy random generator. If ``None``, a new default generator
is created.

Returns:
x_masked: Copy of *x* with probe entries set to NaN (dense) or
removed (sparse).
xprobe: Matrix of the same shape as *x* containing only the
held-out probe values; all other entries are NaN (dense) or
absent (sparse).

Raises:
ValueError: If *fraction* is not in ``(0, 1)``.
"""
if not 0.0 < fraction < 1.0:
msg = f"fraction must be in (0, 1), got {fraction}"
raise ValueError(msg)

if rng is None:
rng = np.random.default_rng()

if sp.issparse(x):
return _xprobe_sparse(sp.csr_matrix(x), fraction, rng)
return _xprobe_dense(np.asarray(x), fraction, rng)
10 changes: 10 additions & 0 deletions src/vbpca_py/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
format_bytes,
resolve_max_dense_bytes,
)
from vbpca_py._missing import make_xprobe_mask
from vbpca_py._pca_full import Matrix, _build_options, pca_full
from vbpca_py.model_selection import SelectionConfig, select_n_components

Expand All @@ -34,6 +35,7 @@ def __init__( # noqa: PLR0913
hp_v: float | None = None,
niter_broadprior: int | None = None,
va_init: float | None = None,
xprobe_fraction: float = 0.0,
**opts: object,
) -> None:
"""
Expand All @@ -52,6 +54,10 @@ def __init__( # noqa: PLR0913
prior before convergence checks activate (default 100).
va_init: Initial broad prior value for loadings and bias
variances (default 1000).
xprobe_fraction: Fraction of observed entries to hold out as
probe data (default 0.0, disabled). When positive and no
explicit *xprobe* is passed to :meth:`fit`, a random probe
set is generated automatically.
**opts: Additional options passed to the underlying PCA_FULL implementation.
"""
self.n_components = n_components
Expand All @@ -64,6 +70,7 @@ def __init__( # noqa: PLR0913
self.hp_v = hp_v
self.niter_broadprior = niter_broadprior
self.va_init = va_init
self.xprobe_fraction = xprobe_fraction
self.opts = opts
self.components_: np.ndarray | None = None
self.scores_: np.ndarray | None = None
Expand Down Expand Up @@ -126,6 +133,9 @@ def fit( # noqa: C901, PLR0912, PLR0914, PLR0915
opts.update(self.opts)
if xprobe is not None:
opts["xprobe"] = xprobe
elif self.xprobe_fraction > 0.0:
x, xprobe_gen = make_xprobe_mask(x, fraction=self.xprobe_fraction)
opts["xprobe"] = xprobe_gen

max_dense_bytes = resolve_max_dense_bytes(
opts.get("max_dense_bytes", 2_000_000_000)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,50 @@ def test_va_init_affects_initial_prior() -> None:
rms_default = r_default["lc"]["rms"]
rms_custom = r_custom["lc"]["rms"]
assert rms_default != rms_custom


# ---------------------------------------------------------------------------
# xprobe_fraction parameter
# ---------------------------------------------------------------------------


def test_xprobe_fraction_default_is_zero() -> None:
"""Default xprobe_fraction is 0.0 (disabled)."""
model = VBPCA(n_components=2)
assert model.xprobe_fraction == pytest.approx(0.0)


def test_xprobe_fraction_generates_probe() -> None:
"""When xprobe_fraction > 0, fit() auto-generates a probe set."""
rng = np.random.default_rng(42)
x = rng.standard_normal((8, 20))
model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.10)
model.fit(x)
# prms should be a real number (not NaN) when probe is active
assert model.prms_ is not None
assert np.isfinite(model.prms_)


def test_xprobe_fraction_explicit_xprobe_takes_precedence() -> None:
"""An explicit xprobe passed to fit() overrides xprobe_fraction."""
rng = np.random.default_rng(42)
x = rng.standard_normal((8, 20))
xprobe = np.full(x.shape, np.nan, dtype=float)
xprobe[0, 0] = x[0, 0]

model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.10)
model.fit(x, xprobe=xprobe)
# Should still produce a finite prms from the explicit probe
assert model.prms_ is not None
assert np.isfinite(model.prms_)


def test_xprobe_fraction_no_probe_when_zero() -> None:
"""With xprobe_fraction=0.0 and no xprobe, prms should be NaN."""
rng = np.random.default_rng(42)
x = rng.standard_normal((8, 20))
model = VBPCA(n_components=2, maxiters=5, xprobe_fraction=0.0)
model.fit(x)
# No probe -> prms is NaN
assert model.prms_ is not None
assert np.isnan(model.prms_)
71 changes: 70 additions & 1 deletion tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import numpy as np
import pytest
import scipy.sparse as sp

import vbpca_py._missing as missing_mod
from vbpca_py._missing import _missing_patterns
from vbpca_py._missing import _missing_patterns, make_xprobe_mask

# ---------------------------------------------------------------------------
# Happy paths
Expand Down Expand Up @@ -157,3 +158,71 @@ def test_missing_patterns_bad_dim_raises(bad_mask: np.ndarray) -> None:
"""_missing_patterns should raise when mask is not 2D."""
with pytest.raises(ValueError, match=r"mask must be a 2D array."):
_missing_patterns(bad_mask)


# ---------------------------------------------------------------------------
# make_xprobe_mask
# ---------------------------------------------------------------------------


def test_make_xprobe_mask_dense_basic() -> None:
"""Dense: probe entries are held out and data is masked."""
rng = np.random.default_rng(0)
x = rng.standard_normal((6, 10))
x_masked, xprobe = make_xprobe_mask(x, fraction=0.2, rng=np.random.default_rng(0))

# xprobe has some non-NaN entries
probe_obs = ~np.isnan(xprobe)
assert probe_obs.any()

# Those entries are NaN in x_masked
assert np.all(np.isnan(x_masked[probe_obs]))

# Non-probe entries are unchanged
non_probe = ~probe_obs
np.testing.assert_array_equal(x_masked[non_probe], x[non_probe])

# Probe values match original
np.testing.assert_array_equal(xprobe[probe_obs], x[probe_obs])


def test_make_xprobe_mask_dense_with_existing_nans() -> None:
"""Dense: existing NaN entries are not selected as probes."""
rng = np.random.default_rng(1)
x = rng.standard_normal((6, 10))
x[0, 0] = np.nan
x[2, 3] = np.nan

x_masked, xprobe = make_xprobe_mask(x, fraction=0.1, rng=np.random.default_rng(1))

# Original NaN positions remain NaN in both
assert np.isnan(x_masked[0, 0])
assert np.isnan(xprobe[0, 0])


def test_make_xprobe_mask_sparse() -> None:
"""Sparse CSR: probe entries are removed from data."""
rng = np.random.default_rng(2)
dense = rng.standard_normal((6, 10))
dense[dense < 0.3] = 0.0
x = sp.csr_matrix(dense)
nnz_before = x.nnz

x_masked, xprobe = make_xprobe_mask(x, fraction=0.2, rng=np.random.default_rng(2))

assert sp.issparse(x_masked)
assert sp.issparse(xprobe)
assert xprobe.nnz > 0
assert x_masked.nnz < nnz_before
assert x_masked.nnz + xprobe.nnz == nnz_before


def test_make_xprobe_mask_bad_fraction_raises() -> None:
"""Fraction outside (0, 1) should raise ValueError."""
x = np.ones((3, 3))
with pytest.raises(ValueError, match="fraction must be in"):
make_xprobe_mask(x, fraction=0.0)
with pytest.raises(ValueError, match="fraction must be in"):
make_xprobe_mask(x, fraction=1.0)
with pytest.raises(ValueError, match="fraction must be in"):
make_xprobe_mask(x, fraction=-0.1)
Loading