Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ format:

# Type-check library code (strict mode).
typecheck:
uv run mypy --strict src
uv run --extra dev --extra plot mypy --strict src

# Run the test suite quietly.
test:
Expand Down
276 changes: 245 additions & 31 deletions src/vbpca_py/_converge.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,64 @@ def _plateau_stop(
return None


def _relative_elbo_stop(
cost: np.ndarray,
threshold: float | None,
) -> str | None:
"""Return a message if relative ELBO decrease is below *threshold*.

Checks ``|ELBO[t] - ELBO[t-1]| / |ELBO[t]| < threshold``. This is
scale-invariant and more robust than an absolute plateau check.

Returns:
A human-readable stop message, or ``None``.
"""
if threshold is None or threshold <= 0 or cost.size < 2:
return None

curr, prev = cost[-1], cost[-2]
if not (np.isfinite(curr) and np.isfinite(prev)):
return None

rel_change = abs(curr - prev) / (abs(curr) + np.finfo(float).eps)
if rel_change < threshold:
return (
f"Stop: relative ELBO change {rel_change:.3e} "
f"is below cfstop_rel = {threshold:.3e}."
)
return None


def _elbo_curvature_stop(
cost: np.ndarray,
threshold: float | None,
) -> str | None:
"""Return a message if ELBO curvature (2nd difference) is below *threshold*.

Checks ``|ΔELBO[t] - ΔELBO[t-1]| < threshold``, i.e. whether the
*rate of improvement* has itself stabilised.

Returns:
A human-readable stop message, or ``None``.
"""
if threshold is None or threshold <= 0 or cost.size < 3:
return None

d1 = cost[-1] - cost[-2]
d0 = cost[-2] - cost[-3]

if not (np.isfinite(d1) and np.isfinite(d0)):
return None

curvature = abs(d1 - d0)
if curvature < threshold:
return (
f"Stop: ELBO curvature {curvature:.3e} "
f"is below cfstop_curv = {threshold:.3e}."
)
return None


def _slowing_down_message(sd_iter: int | None) -> str | None:
"""Return a slowing-down message if sd_iter hits the threshold."""
if sd_iter is not None and sd_iter == 40:
Expand All @@ -174,6 +232,160 @@ def _slowing_down_message(sd_iter: int | None) -> str | None:
return None


def _cost_criteria(
opts: Mapping[str, Any],
cost: np.ndarray,
) -> str | None:
"""Evaluate all cost/ELBO-based stopping criteria in priority order.

Returns:
The first triggered message, or ``None``.
"""
# Cost plateau
cfstop = opts.get("cfstop")
if cost.size >= 2 and cfstop is not None:
plateau_msg = _plateau_stop(cost, cfstop, "cost")
if plateau_msg:
return plateau_msg

# Relative ELBO decrease
cfstop_rel = opts.get("cfstop_rel")
if cfstop_rel is not None:
rel_msg = _relative_elbo_stop(cost, float(cfstop_rel))
if rel_msg:
return rel_msg

# ELBO curvature (2nd difference)
cfstop_curv = opts.get("cfstop_curv")
if cfstop_curv is not None:
curv_msg = _elbo_curvature_stop(cost, float(cfstop_curv))
if curv_msg:
return curv_msg

return None


def _check_sub_criterion(
key: str,
threshold: float,
angle_a: float,
rms: np.ndarray,
cost: np.ndarray,
) -> str | None:
"""Evaluate a single composite sub-criterion.

Returns:
A short summary string like ``"angle=1.2e-04<1.0e-03"`` when the
criterion is satisfied, or ``None`` when it is not met.

Raises:
ValueError: If *key* is not a recognised sub-criterion name.
"""
eps = np.finfo(float).eps

if key == "angle":
if np.isfinite(angle_a) and angle_a < threshold:
return f"angle={angle_a:.2e}<{threshold:.2e}"
return None

if key == "rms":
return _rel_change_check("rms_rel", rms, threshold, eps)

if key == "elbo_rel":
return _rel_change_check("elbo_rel", cost, threshold, eps)

msg = f"Unknown composite_stop key: {key!r}"
raise ValueError(msg)


def _rel_change_check(
label: str,
series: np.ndarray,
threshold: float,
eps: float,
) -> str | None:
"""Check relative change between last two elements of *series*.

Returns:
A summary string when change is below *threshold*, else ``None``.
"""
if series.size < 2:
return None
curr, prev = series[-1], series[-2]
if not (np.isfinite(curr) and np.isfinite(prev)):
return None
rel = abs(curr - prev) / (abs(curr) + eps)
if rel >= threshold:
return None
return f"{label}={rel:.2e}<{threshold:.2e}"


def _composite_stop(
composite_cfg: Mapping[str, float],
angle_a: float,
rms: np.ndarray,
cost: np.ndarray,
) -> str | None:
"""Check whether **all** sub-criteria in *composite_cfg* are satisfied.

Supported keys (all optional, but at least one must be present):

- ``"angle"``: subspace angle must be below this value.
- ``"rms"``: relative RMS change over the last two iterations
must be below this value.
- ``"elbo_rel"``: relative ELBO change must be below this value.

Returns:
A stop message listing which sub-criteria were satisfied, or
``None`` if any sub-criterion is **not** met.
"""
satisfied: list[str] = []
for key, threshold in composite_cfg.items():
result = _check_sub_criterion(key, threshold, angle_a, rms, cost)
if result is None:
return None
satisfied.append(result)

if not satisfied:
return None

detail = ", ".join(satisfied)
return f"Composite stop: all criteria met ({detail})."


def _apply_patience(
msg: str,
lc: Mapping[str, Sequence[float]],
patience: int,
) -> str:
"""Gate *msg* through a patience counter stored in ``lc["_patience"]``.

When a criterion fires (``msg`` is non-empty), the counter is
incremented. The message is only returned once the counter reaches
*patience*. If no criterion fires, the counter resets to zero.

Args:
msg: The candidate convergence message (may be empty).
lc: Learning-curve dict; ``lc["_patience"]`` is mutated in-place.
patience: Required number of consecutive satisfied iterations.

Returns:
The convergence message when patience is exhausted, otherwise
an empty string.
"""
# Obtain the mutable patience list from lc.
patience_list: list[float] = lc.get("_patience", [0]) # type: ignore[assignment]

if msg:
patience_list[0] = float(patience_list[0]) + 1
if int(patience_list[0]) >= patience:
return msg
return ""

patience_list[0] = 0.0
return ""


# ---------------------------------------------------------------------------
# Public convergence check
# ---------------------------------------------------------------------------
Expand All @@ -193,47 +405,49 @@ def convergence_check(
1. Subspace-angle stop (``minangle``).
2. Early stopping based on probe RMS (``earlystop``).
3. RMS plateau stop (``rmsstop = [window, abs_tol, rel_tol]``).
4. Cost plateau stop (``cfstop = [window, abs_tol, rel_tol]``).
5. “Slowing-down'' stop based on ``sd_iter`` (gradient backtracking).
4. Cost / ELBO criteria (``cfstop``, ``cfstop_rel``, ``cfstop_curv``).
5. Composite stop (``composite_stop``).
6. "Slowing-down'' stop based on ``sd_iter`` (gradient backtracking).

When ``patience`` is set (> 1), the winning criterion must fire for
that many **consecutive** iterations before the message is returned.

Returns:
A non-empty convergence message when a criterion triggers,
otherwise an empty string.
"""
# 1. Angle-based stop
angle_msg = _angle_stop_message(opts, angle_a)
if angle_msg:
return angle_msg

rms = np.asarray(lc.get("rms", []), dtype=float)
prms = np.asarray(lc.get("prms", []), dtype=float)
cost = np.asarray(lc.get("cost", []), dtype=float)

# 2. Early stopping on probe RMS
early_msg = _early_stop_message(opts, prms)
if early_msg:
return early_msg

# 3. RMS plateau
rmsstop = opts.get("rmsstop")
if rms.size >= 2 and rmsstop is not None:
plateau_msg = _plateau_stop(rms, rmsstop, "RMS")
if plateau_msg:
return plateau_msg

# 4. Cost plateau
cfstop = opts.get("cfstop")
if cost.size >= 2 and cfstop is not None:
plateau_msg = _plateau_stop(cost, cfstop, "cost")
if plateau_msg:
return plateau_msg

# 5. Slowing-down criterion
slow_msg = _slowing_down_message(sd_iter)
if slow_msg:
return slow_msg

return ""
composite_cfg = opts.get("composite_stop")

# Evaluate criteria in priority order; return first trigger.
checks: list[str | None] = [
# 1. Angle-based stop
_angle_stop_message(opts, angle_a),
# 2. Early stopping on probe RMS
_early_stop_message(opts, prms),
# 3. RMS plateau
_plateau_stop(rms, rmsstop, "RMS")
if rms.size >= 2 and rmsstop is not None
else None,
# 4. Cost / ELBO criteria
_cost_criteria(opts, cost),
# 5. Composite stop
_composite_stop(composite_cfg, angle_a, rms, cost) if composite_cfg else None,
# 6. Slowing-down criterion
_slowing_down_message(sd_iter),
]
candidate = next((msg for msg in checks if msg), "")

# Apply patience window if configured.
patience_val = opts.get("patience")
patience = int(patience_val) if patience_val is not None else 1
if patience <= 1:
return candidate
return _apply_patience(candidate, lc, patience)


# ---------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions src/vbpca_py/_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def _initial_monitoring(
"time": [0.0],
"cost": [float("nan")],
"angle": [float("nan")],
"_patience": [0.0],
"phase_scores_sec": [0.0],
"phase_loadings_sec": [0.0],
"phase_rms_sec": [0.0],
Expand Down
4 changes: 4 additions & 0 deletions src/vbpca_py/_pca_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,10 @@ def _build_options(kwargs: Mapping[str, object]) -> dict[str, object]:
"earlystop": False,
"rmsstop": np.array([100, 1e-4, 1e-3]),
"cfstop": np.array([]),
"cfstop_rel": None,
"cfstop_curv": None,
"composite_stop": None,
"patience": 1,
"verbose": 1,
"num_cpu": None,
"num_cpu_score_update": None,
Expand Down
Loading
Loading