From ed1c901791933c68091f95be516994d86ec71911 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 19:45:21 +0000 Subject: [PATCH 1/5] sae: shared interpretability primitives (probing metrics + steering hook) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Model-agnostic SAE eval/intervention primitives, factored out of the evo2 eval (#1624) and steering (#1626) PRs so any recipe (evo2/codonfm/esm2) can reuse them and the evo2-specific harnesses stack on top: - sae.eval.probing: ActivationBuffer + scoring lenses (AUROC, linear/softmax decode, instance-level domain_f1) — pure functions of codes + labels. - sae.steering: delta-clamp forward hook + steer() context manager. CPU-testable, no torch-CUDA / no model. Tests: test_probing, test_steering (7). Signed-off-by: Polina Binder --- .../sae/src/sae/eval/__init__.py | 24 ++ .../sae/src/sae/eval/probing.py | 241 ++++++++++++++++++ .../sae/src/sae/steering.py | 77 ++++++ .../sae/tests/test_probing.py | 67 +++++ .../sae/tests/test_steering.py | 68 +++++ 5 files changed, 477 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py index 1039045c8a..80af59e5cd 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py @@ -22,6 +22,19 @@ compute_loss_recovered, evaluate_loss_recovered, ) +from .probing import ( + ActivationBuffer, + auroc_all, + auroc_vec, + best_single_train_test, + decode_eval, + domain_f1, + fit_logreg, + fit_softmax, + macro_auroc, + split_indices, + standardize, +) from .reconstruction import ( ReconstructionMetrics, compute_reconstruction_metrics, @@ -31,16 +44,27 @@ __all__ = [ + "ActivationBuffer", "DeadLatentStats", "DeadLatentTracker", "EvalResults", "LossRecoveredResult", "ReconstructionMetrics", "SparsityMetrics", + "auroc_all", + "auroc_vec", + "best_single_train_test", "compute_loss_recovered", "compute_reconstruction_metrics", + "decode_eval", + "domain_f1", "evaluate_loss_recovered", "evaluate_reconstruction", "evaluate_sae", "evaluate_sparsity", + "fit_logreg", + "fit_softmax", + "macro_auroc", + "split_indices", + "standardize", ] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py new file mode 100644 index 0000000000..ab34c381dc --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-agnostic SAE feature-probing metrics + the activation-buffer artifact. + +Everything here is a pure function of a probing buffer (per-token feature codes, +an optional dense-residual twin, per-token labels, optional instance IDs). Recipe +drivers (e.g. Evo2) only produce the buffer; all scoring lives here so it is shared +and reusable. Companions in this package: loss_recovered (fidelity), reconstruction, +sparsity, dead_latents. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +import numpy as np +import torch + + +# ───────────────────────────────────────────────────────────── artifact +@dataclass +class ActivationBuffer: + """A probing buffer: SAE codes (+ optional dense twin), per-token labels, instance IDs.""" + + codes: np.ndarray # [N, F] float16 SAE feature activations + labels: np.ndarray # [N, L] bool + label_names: list + dense: Optional[np.ndarray] = None # [N, H] float16 raw layer residual (dense twin) + instances: Optional[Dict[str, np.ndarray]] = None # {concept: [N] int32, -1 outside} + + def save(self, path: str) -> None: + """Write codes, labels, names (+ optional dense twin / instance ids) to an .npz.""" + d = {"codes": self.codes, "labels": self.labels, "label_names": np.array(self.label_names)} + if self.dense is not None: + d["dense"] = self.dense + for k, v in (self.instances or {}).items(): + d[f"inst_{k}"] = v + np.savez(path, **d) + + @classmethod + def load(cls, path: str) -> "ActivationBuffer": + """Load an ActivationBuffer from an .npz written by save().""" + z = np.load(path, allow_pickle=True) + inst = {k[5:]: z[k] for k in z.files if k.startswith("inst_")} + return cls( + codes=z["codes"], + labels=z["labels"], + label_names=list(z["label_names"]), + dense=z["dense"] if "dense" in z.files else None, + instances=inst or None, + ) + + @property + def name_idx(self): + """Map each label name to its column index in ``labels``.""" + return {n: i for i, n in enumerate(self.label_names)} + + +def split_indices(n, test_frac=0.4, seed=0): + """Deterministic train/test split of ``range(n)``; returns (train_idx, test_idx).""" + perm = torch.randperm(n, generator=torch.Generator().manual_seed(seed)) + nte = int(n * test_frac) + return perm[nte:], perm[:nte] # train, test + + +def standardize(X, tr): + """Return (mean, std) of ``X`` over the train rows ``tr`` (std floored by 1e-6).""" + mu, sd = X[tr].mean(0), X[tr].std(0) + 1e-6 + return mu, sd + + +# ───────────────────────────────────────────────────────────── AUROC +@torch.no_grad() +def auroc_all(X, Y, chunk=1024): + """X [N,F], Y [N,L] bool -> AUROC [F,L] via vectorized rank statistic.""" + N, F = X.shape + L = Y.shape[1] + y = Y.float() + npos = y.sum(0) + nneg = N - npos + valid = (npos > 0) & (nneg > 0) + denom = (npos * nneg).clamp_min(1.0) + half = npos * (npos + 1) / 2.0 + out = torch.full((F, L), 0.5, device=X.device) + for c0 in range(0, F, chunk): + c1 = min(c0 + chunk, F) + ranks = X[:, c0:c1].float().argsort(0).argsort(0).float() + 1.0 + au = (y.t() @ ranks - half[:, None]) / denom[:, None] + out[c0:c1] = au.t() + out[:, ~valid] = 0.5 + return out + + +@torch.no_grad() +def auroc_vec(scores, y): + """AUROC of a single score vector against boolean labels ``y`` (0.5 if degenerate).""" + n = scores.numel() + npos = int(y.sum()) + nneg = n - npos + if npos == 0 or nneg == 0: + return 0.5 + ranks = scores.argsort().argsort().float() + 1.0 + return float((ranks[y].sum() - npos * (npos + 1) / 2) / (npos * nneg)) + + +@torch.no_grad() +def best_single_train_test(Xtr, ytr, Xte, yte, chunk=2048): + """Pick the best single dim on TRAIN, report ITS AUROC on TEST (no winner's curse).""" + + def per_feat(X, y): + n = X.shape[0] + npos = int(y.sum()) + nneg = n - npos + if npos == 0 or nneg == 0: + return None + yf = y.float() + F = X.shape[1] + out = torch.empty(F, device=X.device) + for c0 in range(0, F, chunk): + ranks = X[:, c0 : c0 + chunk].float().argsort(0).argsort(0).float() + 1.0 + out[c0 : c0 + chunk] = (yf @ ranks - npos * (npos + 1) / 2) / (npos * nneg) + return out + + a_tr = per_feat(Xtr, ytr) + if a_tr is None: + return float("nan") + f = int(torch.maximum(a_tr, 1 - a_tr).argmax()) + flip = bool(a_tr[f] < 0.5) + a_te = auroc_vec(Xte[:, f].float(), yte) + return float(1 - a_te if flip else a_te) + + +# ───────────────────────────────────────────────────────────── linear probes +def fit_logreg(Xtr, ytr, steps=400, lr=0.05, wd=1e-2): + """Fit a logistic-regression probe (Adam + BCE-with-logits); returns (w, b).""" + w = torch.zeros(Xtr.shape[1], device=Xtr.device, requires_grad=True) + b = torch.zeros(1, device=Xtr.device, requires_grad=True) + opt = torch.optim.Adam([w, b], lr=lr, weight_decay=wd) + lossf = torch.nn.BCEWithLogitsLoss() + with torch.enable_grad(): + for _ in range(steps): + opt.zero_grad() + lossf(Xtr @ w + b, ytr).backward() + opt.step() + return w.detach(), b.detach() + + +def fit_softmax(Xtr, ytr, nclass, steps=400, lr=0.05, wd=1e-2): + """Fit a multinomial-softmax probe (Adam + cross-entropy); returns (W, b).""" + W = torch.zeros(Xtr.shape[1], nclass, device=Xtr.device, requires_grad=True) + b = torch.zeros(nclass, device=Xtr.device, requires_grad=True) + opt = torch.optim.Adam([W, b], lr=lr, weight_decay=wd) + lossf = torch.nn.CrossEntropyLoss() + with torch.enable_grad(): + for _ in range(steps): + opt.zero_grad() + lossf(Xtr @ W + b, ytr).backward() + opt.step() + return W.detach(), b.detach() + + +@torch.no_grad() +def macro_auroc(logits, y, nclass): + """Macro-averaged one-vs-rest AUROC over ``nclass``; returns (mean_auroc, n_classes_scored).""" + aucs = [] + for c in range(nclass): + yc = y == c + npos = int(yc.sum()) + if npos == 0 or npos == len(y): + continue + ranks = logits[:, c].argsort().argsort().float() + 1.0 + aucs.append(float((ranks[yc].sum() - npos * (npos + 1) / 2) / (npos * (len(y) - npos)))) + return (sum(aucs) / max(1, len(aucs))), len(aucs) + + +def decode_eval(Xtr, ytr, Xte, yte, nclass, **kw): + """Fit a softmax probe on train; return (accuracy, macro_auroc, n_classes) on test.""" + W, b = fit_softmax(Xtr, ytr, nclass, **kw) + logits = Xte @ W + b + acc = float((logits.argmax(1) == yte).float().mean()) + mauc, ncls = macro_auroc(logits, yte, nclass) + return acc, mauc, ncls + + +# ───────────────────────────────────────────────────────────── domain-adjusted F1 +@torch.no_grad() +def domain_f1(codes, fmax, concept_mask, inst_ids, thresholds=(0.15, 0.3, 0.5, 0.6, 0.8), chunk=1024): + """InterPLM domain-adjusted F1 per feature: precision-per-position, recall-per-instance. + + codes [P,F] (>=0), fmax [F], concept_mask [P] bool, inst_ids [P] int (-1 outside). + Returns (best_f1[F], best_threshold[F]) over the threshold sweep. + """ + _, F = codes.shape + dev = codes.device + valid = inst_ids >= 0 + uniq = torch.unique(inst_ids[valid]) + n_inst = len(uniq) + if n_inst == 0: + return torch.zeros(F, device=dev), torch.zeros(F, device=dev) + remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long) + remap[uniq.long()] = torch.arange(n_inst, device=dev) + inst_c = torch.where(valid, remap[inst_ids.long()], torch.full_like(inst_ids, -1, dtype=torch.long)) + best_f1 = torch.zeros(F, device=dev) + best_t = torch.zeros(F, device=dev) + for c0 in range(0, F, chunk): + c1 = min(c0 + chunk, F) + cn = codes[:, c0:c1] / fmax[c0:c1].clamp_min(1e-6) + C = c1 - c0 + cb = torch.zeros(C, device=dev) + ct = torch.zeros(C, device=dev) + for t in thresholds: + fire = cn > t + firing = fire.sum(0).float() + prec = torch.where( + firing > 0, (fire & concept_mask[:, None]).sum(0).float() / firing, torch.zeros(C, device=dev) + ) + bucket = torch.zeros(n_inst, C, device=dev) + vm = inst_c >= 0 + bucket.index_reduce_(0, inst_c[vm], fire[vm].float(), "amax", include_self=False) + recall = (bucket > 0).sum(0).float() / n_inst + f1 = torch.where((prec + recall) > 0, 2 * prec * recall / (prec + recall), torch.zeros(C, device=dev)) + upd = f1 > cb + cb = torch.where(upd, f1, cb) + ct = torch.where(upd, torch.full_like(ct, t), ct) + best_f1[c0:c1] = cb + best_t[c0:c1] = ct + return best_f1, best_t diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py new file mode 100644 index 0000000000..c061e38533 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. + +A forward hook on the layer the SAE was trained on: it re-encodes the layer output through +the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the +activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the +SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves +the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) +and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model +with vs. without the hook. +""" + +from contextlib import contextmanager +from typing import Dict + +import torch + + +def clamp_hook(sae, clamps: Dict[int, float]): + """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. + + The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's + output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative + value reverses its decoder direction. Works whether the module returns a tensor or a tuple + whose first element is the hidden state. + + Args: + sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, + and ``top_k``. + clamps: Map of feature index -> absolute code value to force at every position. + + Returns: + A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. + """ + items = [(int(f), float(v)) for f, v in clamps.items()] + + def hook(module, inputs, output): + h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + dtype, shape = h.dtype, h.shape + h_flat = h.reshape(-1, h.shape[-1]).float() + with torch.no_grad(): + pre_act, info = sae.encode_pre_act(h_flat) + codes = torch.relu(pre_act) + kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) + codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) + codes_clamped = codes_orig.clone() + for f, v in items: + codes_clamped[:, f] = v + delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) + h_out = (h_flat + delta).to(dtype).reshape(shape) + return (h_out, *rest) if rest is not None else h_out + + return hook + + +@contextmanager +def steer(module, sae, clamps: Dict[int, float]): + """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" + handle = module.register_forward_hook(clamp_hook(sae, clamps)) + try: + yield + finally: + handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py new file mode 100644 index 0000000000..6a83113d72 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.eval.probing: the metrics reproduce on synthetic data (no model/GPU).""" + +import numpy as np +import torch +from sae.eval.probing import ActivationBuffer, auroc_all, split_indices, standardize + + +def test_auroc_separates_predictive_from_noise(): + """A feature equal to the label scores ~1.0 AUROC; a random feature scores near chance.""" + torch.manual_seed(0) + n = 400 + y = (torch.arange(n) % 2).float() + predictive = y + torch.randn(n) * 0.01 # near-perfect detector + noise = torch.randn(n) # uninformative + x = torch.stack([predictive, noise], dim=1) # [N, 2 features] + au = auroc_all(x, y.unsqueeze(1)) # [2 features, 1 label] + assert float(au[0, 0]) > 0.99 + assert float(au[1, 0]) < 0.7 + + +def test_split_indices_disjoint_and_complete(): + """Train/test indices partition range(n) with the requested test fraction.""" + tr, te = split_indices(100, test_frac=0.4, seed=0) + tr_s, te_s = set(tr.tolist()), set(te.tolist()) + assert tr_s.isdisjoint(te_s) + assert tr_s | te_s == set(range(100)) + assert 0.35 < len(te_s) / 100 < 0.45 + + +def test_standardize_zero_means_train_split(): + """standardize() returns train-split mean/std that center the train rows.""" + torch.manual_seed(0) + x = torch.randn(100, 5) * 3 + 7 + tr = torch.arange(80) + mu, sd = standardize(x, tr) + z = (x[tr] - mu) / sd + assert torch.allclose(z.mean(0), torch.zeros(5), atol=1e-4) + + +def test_activation_buffer_roundtrip(tmp_path): + """ActivationBuffer save/load preserves codes, labels, names (+ name_idx mapping).""" + rng = np.random.default_rng(0) + codes = rng.random((10, 4)).astype(np.float16) + labels = np.tile(np.array([True, False]), (10, 1)) + buf = ActivationBuffer(codes=codes, labels=labels, label_names=["motif_atg", "is_prok"]) + path = str(tmp_path / "buf.npz") + buf.save(path) + + loaded = ActivationBuffer.load(path) + assert np.array_equal(loaded.codes, codes) + assert [str(n) for n in loaded.label_names] == ["motif_atg", "is_prok"] + assert loaded.name_idx["is_prok"] == 1 diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py new file mode 100644 index 0000000000..a3ac2c5d3e --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.steering: the delta-clamp hook adds exactly decode(clamped) - decode(orig).""" + +import torch +from sae.architectures import TopKSAE +from sae.steering import clamp_hook, steer +from torch import nn + + +def _sae(): + torch.manual_seed(0) + return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + + +def test_no_clamp_is_a_noop(): + """An empty clamp map leaves the activation unchanged.""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + with steer(m, sae, {}): + out = m(x) + assert torch.allclose(out, x, atol=1e-5) + + +def test_clamp_adds_decoder_delta(): + """Clamping a feature shifts the activation by exactly decode(clamped) - decode(orig).""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + with torch.no_grad(): + pre, info = sae.encode_pre_act(x.float()) + codes = torch.relu(pre) + kv, ki = torch.topk(codes, sae.top_k, dim=-1) + co = torch.zeros_like(codes).scatter(-1, ki, kv) + cc = co.clone() + cc[:, 3] = 5.0 + expected = x + (sae.decode(cc, info) - sae.decode(co, info)) + with steer(m, sae, {3: 5.0}): + out = m(x) + assert torch.allclose(out, expected, atol=1e-4) + + +def test_tuple_output_first_element_steered_rest_preserved(): + """When the module returns a tuple, only the hidden state (elem 0) is steered.""" + + class M(nn.Module): + def forward(self, x): + return (x, "meta") + + sae, x = _sae(), torch.randn(3, 8) + m = M() + handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) + out = m(x) + handle.remove() + assert isinstance(out, tuple) + assert out[1] == "meta" + assert out[0].shape == x.shape + assert not torch.allclose(out[0], x) # the clamp moved it From 2e13d19b367021f508c0312d8bf2298a5cc2a97a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 21:08:35 +0000 Subject: [PATCH 2/5] sae tests: leaner, stronger probing/steering tests (oracle + reference) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace weak sanity-bound asserts with one strong correctness check per non-trivial metric, and drop trivia: - auroc_all: validated against the pairwise AUROC definition (P(s+>s-)), an oracle independent of the argsort rank-sum impl (no new dependency). - domain_f1, best_single_train_test, decode_eval: previously UNtested — now each has a hand-computed-reference / constructed-flip / separability test. - dropped the standalone standardize + weak auroc-sanity tests (trivia / subsumed); split_indices folded into the buffer roundtrip. - steering: merged the no-op identity into the exact-delta test (recon-cancellation), kept the tuple-output contract. 3 -> 2. Net 7 -> 7 CPU tests, every one now a real correctness check. Signed-off-by: Polina Binder --- .../sae/tests/test_probing.py | 133 +++++++++++++----- .../sae/tests/test_steering.py | 30 ++-- 2 files changed, 110 insertions(+), 53 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py index 6a83113d72..346a796fcd 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -13,55 +13,116 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""CPU tests for sae.eval.probing: the metrics reproduce on synthetic data (no model/GPU).""" +"""CPU correctness tests for sae.eval.probing (no model / no GPU). + +One strong test per non-trivial metric: each checks the result against an independent +reference (a definitional oracle or a hand-computed value) rather than a loose sanity bound. +The trivial standardize helper is exercised transitively (decode_eval test); split_indices +folds into the buffer roundtrip. +""" import numpy as np import torch -from sae.eval.probing import ActivationBuffer, auroc_all, split_indices, standardize +from sae.eval.probing import ( + ActivationBuffer, + auroc_all, + best_single_train_test, + decode_eval, + domain_f1, + split_indices, +) + + +def _auroc_ref(scores: torch.Tensor, y: torch.Tensor) -> float: + """Definitional AUROC oracle: P(score+ > score-) over all positive/negative pairs. + + Computed by brute-force pair comparison — independent of the argsort rank-sum used by + auroc_all, so agreement validates that implementation (randn inputs => no ties). + """ + pos, neg = scores[y], scores[~y] + return float((pos[:, None] > neg[None, :]).float().mean()) -def test_auroc_separates_predictive_from_noise(): - """A feature equal to the label scores ~1.0 AUROC; a random feature scores near chance.""" +def test_auroc_all_matches_definition(): + """auroc_all matches the pairwise-definition AUROC for every (feature, label).""" torch.manual_seed(0) - n = 400 - y = (torch.arange(n) % 2).float() - predictive = y + torch.randn(n) * 0.01 # near-perfect detector - noise = torch.randn(n) # uninformative - x = torch.stack([predictive, noise], dim=1) # [N, 2 features] - au = auroc_all(x, y.unsqueeze(1)) # [2 features, 1 label] - assert float(au[0, 0]) > 0.99 - assert float(au[1, 0]) < 0.7 - - -def test_split_indices_disjoint_and_complete(): - """Train/test indices partition range(n) with the requested test fraction.""" - tr, te = split_indices(100, test_frac=0.4, seed=0) - tr_s, te_s = set(tr.tolist()), set(te.tolist()) - assert tr_s.isdisjoint(te_s) - assert tr_s | te_s == set(range(100)) - assert 0.35 < len(te_s) / 100 < 0.45 + n, f, ell = 200, 6, 3 + x = torch.randn(n, f) + y = torch.randn(n, ell) > 0 + au = auroc_all(x, y) # [F, L] + for fi in range(f): + for li in range(ell): + assert abs(float(au[fi, li]) - _auroc_ref(x[:, fi], y[:, li])) < 1e-6 + + +def test_best_single_reports_flipped_test_auroc(): + """best_single picks the most-separating TRAIN feature and reports ITS test AUROC, + flipping a feature that separates by firing on the negatives (no winner's curse).""" + torch.manual_seed(0) + y = torch.cat([torch.zeros(10), torch.ones(10)]).bool() + # 'anti' fires on the y=0 class (train AUROC ~0 -> selected via 1-AUROC, flip=True); + # it stays anti-correlated on test, so the reported (flipped) test AUROC is ~1. + anti_tr = torch.cat([torch.ones(10), torch.zeros(10)]) + torch.randn(20) * 0.01 + anti_te = torch.cat([torch.ones(10), torch.zeros(10)]) + torch.randn(20) * 0.01 + xtr = torch.stack([anti_tr, torch.randn(20)], 1) # 2nd feature is noise + xte = torch.stack([anti_te, torch.randn(20)], 1) + assert best_single_train_test(xtr, y, xte, y.clone()) > 0.9 -def test_standardize_zero_means_train_split(): - """standardize() returns train-split mean/std that center the train rows.""" +def test_domain_f1_matches_hand_computed(): + """domain_f1 = precision-per-position, recall-per-instance, best over the threshold sweep. + + Two binary features over 6 positions, 2 annotation instances ({0,1} and {4}): + feat0 fires at an extra non-concept position -> prec 3/4, recall 2/2 -> F1 = 6/7 + feat1 fires exactly on concept positions -> prec 1, recall 2/2 -> F1 = 1 + """ + codes = torch.tensor([[1, 1], [1, 1], [1, 0], [0, 0], [1, 1], [0, 0]], dtype=torch.float) + fmax = codes.max(0).values + concept_mask = torch.tensor([1, 1, 0, 0, 1, 0], dtype=torch.bool) + inst_ids = torch.tensor([0, 0, -1, -1, 1, -1]) + f1, _ = domain_f1(codes, fmax, concept_mask, inst_ids) + assert abs(float(f1[0]) - 6 / 7) < 1e-4 + assert abs(float(f1[1]) - 1.0) < 1e-4 + + +def test_decode_eval_recovers_separable_classes(): + """The softmax decoder (fit_softmax + macro_auroc) separates separable classes and not noise.""" torch.manual_seed(0) - x = torch.randn(100, 5) * 3 + 7 - tr = torch.arange(80) - mu, sd = standardize(x, tr) - z = (x[tr] - mu) / sd - assert torch.allclose(z.mean(0), torch.zeros(5), atol=1e-4) + dim, nclass = 8, 3 + centers = torch.eye(nclass, dim) * 6.0 + + def make(per): + ys = torch.arange(nclass).repeat_interleave(per) + return centers[ys] + torch.randn(len(ys), dim), ys + xtr, ytr = make(40) + xte, yte = make(20) + acc, mauc, ncls = decode_eval(xtr, ytr, xte, yte, nclass, steps=400, lr=0.1) + assert acc > 0.9 and mauc > 0.9 and ncls == 3 -def test_activation_buffer_roundtrip(tmp_path): - """ActivationBuffer save/load preserves codes, labels, names (+ name_idx mapping).""" + # random features/labels -> no better than chance (1/3) + xr, yr = torch.randn(120, dim), torch.randint(0, nclass, (120,)) + acc_rand, _, _ = decode_eval(xr[:90], yr[:90], xr[90:], yr[90:], nclass, steps=400, lr=0.1) + assert acc_rand < 0.6 + + +def test_buffer_roundtrip_and_split(tmp_path): + """ActivationBuffer save/load preserves codes/labels/names/dense/instances; split is a partition.""" rng = np.random.default_rng(0) codes = rng.random((10, 4)).astype(np.float16) - labels = np.tile(np.array([True, False]), (10, 1)) - buf = ActivationBuffer(codes=codes, labels=labels, label_names=["motif_atg", "is_prok"]) + labels = np.tile(np.array([True, False, True]), (10, 1)) + dense = rng.random((10, 8)).astype(np.float16) + instances = {"exon": np.array([0, 0, -1, 1, 1, -1, 2, 2, 2, -1], np.int32)} + buf = ActivationBuffer(codes, labels, ["a", "b", "c"], dense=dense, instances=instances) path = str(tmp_path / "buf.npz") buf.save(path) - loaded = ActivationBuffer.load(path) - assert np.array_equal(loaded.codes, codes) - assert [str(n) for n in loaded.label_names] == ["motif_atg", "is_prok"] - assert loaded.name_idx["is_prok"] == 1 + lo = ActivationBuffer.load(path) + assert np.array_equal(lo.codes, codes) + assert np.array_equal(lo.dense, dense) + assert np.array_equal(lo.instances["exon"], instances["exon"]) + assert lo.name_idx["c"] == 2 + + tr, te = split_indices(100, test_frac=0.4, seed=0) + s_tr, s_te = set(tr.tolist()), set(te.tolist()) + assert s_tr.isdisjoint(s_te) and (s_tr | s_te) == set(range(100)) and len(s_te) == 40 diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py index a3ac2c5d3e..5ef9d15746 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""CPU tests for sae.steering: the delta-clamp hook adds exactly decode(clamped) - decode(orig).""" +"""CPU tests for sae.steering: the delta-clamp adds exactly decode(clamped) - decode(orig).""" import torch from sae.architectures import TopKSAE @@ -26,17 +26,16 @@ def _sae(): return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) -def test_no_clamp_is_a_noop(): - """An empty clamp map leaves the activation unchanged.""" +def test_delta_clamp_is_exact_and_cancels_recon(): + """No-op clamp leaves the activation unchanged (recon error cancels); a real clamp shifts + it by exactly decode(clamped) - decode(orig) — the two halves of the delta-clamp contract.""" sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) - with steer(m, sae, {}): - out = m(x) - assert torch.allclose(out, x, atol=1e-5) + # No-op: decode(orig) != x, but the added delta is 0, so the output is unchanged. + with steer(m, sae, {}): + assert torch.allclose(m(x), x, atol=1e-5) -def test_clamp_adds_decoder_delta(): - """Clamping a feature shifts the activation by exactly decode(clamped) - decode(orig).""" - sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + # Real clamp: output == x + (decode(clamped) - decode(orig)), recon error cancelled. with torch.no_grad(): pre, info = sae.encode_pre_act(x.float()) codes = torch.relu(pre) @@ -46,12 +45,11 @@ def test_clamp_adds_decoder_delta(): cc[:, 3] = 5.0 expected = x + (sae.decode(cc, info) - sae.decode(co, info)) with steer(m, sae, {3: 5.0}): - out = m(x) - assert torch.allclose(out, expected, atol=1e-4) + assert torch.allclose(m(x), expected, atol=1e-4) -def test_tuple_output_first_element_steered_rest_preserved(): - """When the module returns a tuple, only the hidden state (elem 0) is steered.""" +def test_tuple_output_steers_only_hidden_state(): + """When the hooked module returns a tuple, only element 0 is steered; the rest passes through.""" class M(nn.Module): def forward(self, x): @@ -62,7 +60,5 @@ def forward(self, x): handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) out = m(x) handle.remove() - assert isinstance(out, tuple) - assert out[1] == "meta" - assert out[0].shape == x.shape - assert not torch.allclose(out[0], x) # the clamp moved it + assert isinstance(out, tuple) and out[1] == "meta" + assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it From 6521a246f44fdd4604e285fb7930845f41f1ba3d Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 21:14:51 +0000 Subject: [PATCH 3/5] sae: drop steering from the probing base (moves to a steering PR on #1622) Steering's only consumers (the live engine's clamp hook + the steer.py harness) both live in the evo2 serve recipe (#1622), and the harness imports Evo2SAE from it. So the steering primitive + harness move to a dedicated PR stacked on #1622, where the core clamp-hook dedup can happen in-place. This base is now the probing library only. Signed-off-by: Polina Binder --- .../sae/src/sae/steering.py | 77 ------------------- .../sae/tests/test_steering.py | 64 --------------- 2 files changed, 141 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py deleted file mode 100644 index c061e38533..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. - -A forward hook on the layer the SAE was trained on: it re-encodes the layer output through -the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the -activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the -SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves -the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) -and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model -with vs. without the hook. -""" - -from contextlib import contextmanager -from typing import Dict - -import torch - - -def clamp_hook(sae, clamps: Dict[int, float]): - """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. - - The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's - output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative - value reverses its decoder direction. Works whether the module returns a tensor or a tuple - whose first element is the hidden state. - - Args: - sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, - and ``top_k``. - clamps: Map of feature index -> absolute code value to force at every position. - - Returns: - A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. - """ - items = [(int(f), float(v)) for f, v in clamps.items()] - - def hook(module, inputs, output): - h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) - dtype, shape = h.dtype, h.shape - h_flat = h.reshape(-1, h.shape[-1]).float() - with torch.no_grad(): - pre_act, info = sae.encode_pre_act(h_flat) - codes = torch.relu(pre_act) - kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) - codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) - codes_clamped = codes_orig.clone() - for f, v in items: - codes_clamped[:, f] = v - delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) - h_out = (h_flat + delta).to(dtype).reshape(shape) - return (h_out, *rest) if rest is not None else h_out - - return hook - - -@contextmanager -def steer(module, sae, clamps: Dict[int, float]): - """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" - handle = module.register_forward_hook(clamp_hook(sae, clamps)) - try: - yield - finally: - handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py deleted file mode 100644 index 5ef9d15746..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CPU tests for sae.steering: the delta-clamp adds exactly decode(clamped) - decode(orig).""" - -import torch -from sae.architectures import TopKSAE -from sae.steering import clamp_hook, steer -from torch import nn - - -def _sae(): - torch.manual_seed(0) - return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) - - -def test_delta_clamp_is_exact_and_cancels_recon(): - """No-op clamp leaves the activation unchanged (recon error cancels); a real clamp shifts - it by exactly decode(clamped) - decode(orig) — the two halves of the delta-clamp contract.""" - sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) - - # No-op: decode(orig) != x, but the added delta is 0, so the output is unchanged. - with steer(m, sae, {}): - assert torch.allclose(m(x), x, atol=1e-5) - - # Real clamp: output == x + (decode(clamped) - decode(orig)), recon error cancelled. - with torch.no_grad(): - pre, info = sae.encode_pre_act(x.float()) - codes = torch.relu(pre) - kv, ki = torch.topk(codes, sae.top_k, dim=-1) - co = torch.zeros_like(codes).scatter(-1, ki, kv) - cc = co.clone() - cc[:, 3] = 5.0 - expected = x + (sae.decode(cc, info) - sae.decode(co, info)) - with steer(m, sae, {3: 5.0}): - assert torch.allclose(m(x), expected, atol=1e-4) - - -def test_tuple_output_steers_only_hidden_state(): - """When the hooked module returns a tuple, only element 0 is steered; the rest passes through.""" - - class M(nn.Module): - def forward(self, x): - return (x, "meta") - - sae, x = _sae(), torch.randn(3, 8) - m = M() - handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) - out = m(x) - handle.remove() - assert isinstance(out, tuple) and out[1] == "meta" - assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it From 79df727f424df1b4250ec90ff175573cf0bb0503 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 00:31:42 +0000 Subject: [PATCH 4/5] sae.eval.probing: add annotate_features (per-feature best concept by AUROC) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The persistence half of probing that was missing: turns a buffer (codes + concept labels) into a feature->label table — for each feature, the concept it best separates (highest AUROC), kept only above min_auroc. Model-agnostic; the recipe CLI just loads a buffer, calls this, and writes the annotations parquet. + CPU test. Signed-off-by: Polina Binder --- .../sae/src/sae/eval/__init__.py | 2 ++ .../sae/src/sae/eval/probing.py | 29 +++++++++++++++++++ .../sae/tests/test_probing.py | 14 +++++++++ 3 files changed, 45 insertions(+) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py index 80af59e5cd..208c9ef3cd 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py @@ -24,6 +24,7 @@ ) from .probing import ( ActivationBuffer, + annotate_features, auroc_all, auroc_vec, best_single_train_test, @@ -51,6 +52,7 @@ "LossRecoveredResult", "ReconstructionMetrics", "SparsityMetrics", + "annotate_features", "auroc_all", "auroc_vec", "best_single_train_test", diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py index ab34c381dc..55ff09ff09 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py @@ -144,6 +144,35 @@ def per_feat(X, y): return float(1 - a_te if flip else a_te) +@torch.no_grad() +def annotate_features(codes, labels, label_names, min_auroc: float = 0.8, chunk: int = 1024): + """Assign each feature the concept it best separates (by AUROC) -> the feature->label table. + + The persistence half of probing: turns a buffer (codes + concept labels) into per-feature + annotations. For each feature, takes the concept with the highest AUROC and keeps it only if + that AUROC >= ``min_auroc`` (unconfident features stay unlabeled). + + Args: + codes: [N, F] feature activations. + labels: [N, L] bool concept masks. + label_names: length-L concept names. + min_auroc: keep a feature's annotation only if its best AUROC clears this. + chunk: feature chunk size for ``auroc_all``. + + Returns: + ``[{"feature_id": int, "label": str, "auroc": float}]`` sorted by feature_id. + """ + au = auroc_all(codes, labels, chunk=chunk) # [F, L] + best = au.max(dim=1) + names = list(label_names) + out = [] + for f in range(au.shape[0]): + score = float(best.values[f]) + if score >= min_auroc: + out.append({"feature_id": int(f), "label": str(names[int(best.indices[f])]), "auroc": round(score, 4)}) + return out + + # ───────────────────────────────────────────────────────────── linear probes def fit_logreg(Xtr, ytr, steps=400, lr=0.05, wd=1e-2): """Fit a logistic-regression probe (Adam + BCE-with-logits); returns (w, b).""" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py index 346a796fcd..d38868a962 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -25,6 +25,7 @@ import torch from sae.eval.probing import ( ActivationBuffer, + annotate_features, auroc_all, best_single_train_test, decode_eval, @@ -106,6 +107,19 @@ def make(per): assert acc_rand < 0.6 +def test_annotate_features_assigns_best_concept_above_threshold(): + """Each feature gets the concept it best separates; unconfident features stay unlabeled.""" + torch.manual_seed(0) + n = 200 + labels = torch.stack([torch.arange(n) % 2 == 0, torch.arange(n) < n // 2], 1) # [N, 2]: 'even', 'first_half' + detector = labels[:, 0].float() + torch.randn(n) * 0.01 # cleanly tracks 'even' + noise = torch.randn(n) # tracks nothing + codes = torch.stack([detector, noise], 1) # [N, 2 features] + ann = annotate_features(codes, labels, ["even", "first_half"], min_auroc=0.9) + assert {a["feature_id"]: a["label"] for a in ann} == {0: "even"} # feature 1 (noise) excluded + assert ann[0]["auroc"] > 0.99 + + def test_buffer_roundtrip_and_split(tmp_path): """ActivationBuffer save/load preserves codes/labels/names/dense/instances; split is a partition.""" rng = np.random.default_rng(0) From 57837ec785217d85ad854f9fa2ef5c91b71edde9 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 05:25:44 +0000 Subject: [PATCH 5/5] =?UTF-8?q?docs(probing):=20CodeRabbit=20nitpicks=20?= =?UTF-8?q?=E2=80=94=20document=20allow=5Fpickle=20+=20remap=20+2=20sizing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/sae/src/sae/eval/probing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py index 55ff09ff09..ec36394451 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py @@ -53,7 +53,12 @@ def save(self, path: str) -> None: @classmethod def load(cls, path: str) -> "ActivationBuffer": - """Load an ActivationBuffer from an .npz written by save().""" + """Load an ActivationBuffer from an .npz written by save(). + + Warning: + Uses ``allow_pickle=True`` (the per-concept instance dict is an object array); + only load buffers from trusted sources. + """ z = np.load(path, allow_pickle=True) inst = {k[5:]: z[k] for k in z.files if k.startswith("inst_")} return cls( @@ -240,6 +245,8 @@ def domain_f1(codes, fmax, concept_mask, inst_ids, thresholds=(0.15, 0.3, 0.5, 0 n_inst = len(uniq) if n_inst == 0: return torch.zeros(F, device=dev), torch.zeros(F, device=dev) + # size = max instance id + 2: +1 to index by the max id itself, +1 headroom so a -1 + # sentinel never indexes out of bounds when remapped. remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long) remap[uniq.long()] = torch.arange(n_inst, device=dev) inst_c = torch.where(valid, remap[inst_ids.long()], torch.full_like(inst_ids, -1, dtype=torch.long))