From 6352d27c3a0c969154c5b98cd67e620ef34fa2f9 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 20:01:01 +0000 Subject: [PATCH 1/3] =?UTF-8?q?evo2=20infer:=20add=20`generate`=20CLI=20mo?= =?UTF-8?q?de=20=E2=80=94=20steer=20from=20the=20command=20line?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fourth CLI mode alongside serve/encode/batch: `launch_inference.sh generate --prompt ATGC... --clamp FEATURE_ID[:STRENGTH]` runs steered generation from the command line (repeat --clamp for several features). Reuses the verified engine.generate; `_parse_clamps` turns the repeatable --clamp args into feature specs. Usage docs the encode->generate steering loop. Stacked on #1622 (the engine + server), so the verified core stays frozen. test_cli.py covers the clamp parsing (CPU); the generation path itself is the engine's (GPU-verified in #1622's test_steering). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 6 +- .../recipes/evo2/src/evo2_sae/cli.py | 57 ++++++++++++++++++- .../recipes/evo2/tests/test_cli.py | 33 +++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 266bbc0669..26768a4c46 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Launch the Evo2 SAE inference engine. One engine, three modes: +# Launch the Evo2 SAE inference engine. One engine, four modes: # # ./launch_inference.sh serve # live HTTP server on :8001 (viz backend) # ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features # ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet +# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA +# +# Steering loop: `encode` a sequence to find an active feature id, then +# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp). # # Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): # FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py index b68a06b22d..98185dd51a 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Evo2 SAE inference CLI — one engine, three modes. +"""Evo2 SAE inference CLI — one engine, four modes. serve : start the FastAPI server (one sequence at a time, interactive) encode : annotate ONE sequence -> top features (stdout JSON) batch : run a FASTA of MANY sequences -> parquet of per-sequence top features + generate: generate DNA, optionally steering SAE features (stdout JSON) -All three build the same `Evo2SAE` engine; config comes from flags or env +They all build the same `Evo2SAE` engine; config comes from flags or env (EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER). """ @@ -73,9 +74,21 @@ def _engine(args): ) +def _parse_clamps(clamps: list[str]) -> list[dict]: + """Parse repeated ``--clamp FEATURE_ID[:STRENGTH]`` args into [{feature_id, strength}]. + + Strength defaults to 1.0 if omitted (e.g. ``--clamp 29244:300`` or ``--clamp 29244``). + """ + specs = [] + for c in clamps: + fid, sep, strength = c.partition(":") + specs.append({"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0}) + return specs + + def main(): """Parse args and dispatch to the serve / encode / batch subcommand.""" - ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)") + ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)") sub = ap.add_subparsers(dest="cmd", required=True) ps = sub.add_parser("serve", help="start the FastAPI inference server") @@ -96,6 +109,23 @@ def main(): pb.add_argument("--top-k", type=int, default=16) pb.add_argument("--batch-size", type=int, default=8) + pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features") + _add_common(pg) + pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation") + pg.add_argument("--organism", default="None (raw DNA)") + pg.add_argument( + "--clamp", + action="append", + default=[], + metavar="FEATURE_ID[:STRENGTH]", + help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). " + "Find feature ids with `encode`.", + ) + pg.add_argument("--n-tokens", type=int, default=120) + pg.add_argument("--temperature", type=float, default=1.0) + pg.add_argument("--top-k", type=int, default=0) + pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison") + args = ap.parse_args() if args.cmd == "serve": @@ -141,6 +171,27 @@ def main(): df.to_parquet(args.out, index=False) print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}") + elif args.cmd == "generate": + out = eng.generate( + prompt=args.prompt, + organism=args.organism, + features=_parse_clamps(args.clamp), + n_tokens=args.n_tokens, + temperature=args.temperature, + top_k=args.top_k, + compare_baseline=args.compare_baseline, + ) + result = { + "prompt": out["prompt"], + "organism": out["organism"], + "steered": out["steered"], + "features": out["features"], + "sequence": out["generation"]["sequence"], + } + if out.get("baseline"): + result["baseline_sequence"] = out["baseline"]["sequence"] + print(json.dumps(result, indent=2)) + if __name__ == "__main__": main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py new file mode 100644 index 0000000000..e3381af409 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU test for the generate CLI's --clamp parsing (no model).""" + +from evo2_sae.cli import _parse_clamps + + +def test_parse_clamps_id_and_strength(): + assert _parse_clamps(["29244:300", "88:1.5"]) == [ + {"feature_id": 29244, "strength": 300.0}, + {"feature_id": 88, "strength": 1.5}, + ] + + +def test_parse_clamps_default_strength(): + assert _parse_clamps(["29244"]) == [{"feature_id": 29244, "strength": 1.0}] + + +def test_parse_clamps_empty(): + assert _parse_clamps([]) == [] From dde5ea89e55dd07a2a48e3feb87782b566245462 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 21:16:36 +0000 Subject: [PATCH 2/3] evo2 SAE steering: shared clamp primitive + dose-response harness (on #1622) Consolidates all steering onto the serve engine (#1622), where its consumers live: - sae/steering.py: model-agnostic delta-clamp hook + steer() context manager (the intervention sibling of sae.eval.probing), with CPU tests. - recipes/evo2/scripts/steer.py: dose-response / selectivity harness using the shared hook on the Evo2 decoder layer; reconciled to the evo2_sae package name. Replaces #1631 (was stacked on the probing base) and the steering half of #1629. Follow-up: dedup Evo2SAE._clamp_hook (decode-only) onto sae.steering via a decode_only flag. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/steer.py | 121 ++++++++++++++++++ .../sae/src/sae/steering.py | 77 +++++++++++ .../sae/tests/test_steering.py | 64 +++++++++ 3 files changed, 262 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py new file mode 100644 index 0000000000..8d54dee633 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation. + +Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer +the SAE was trained on. Workflow: encode a sequence to find its active features, then for a +**target** feature sweep the clamp strength (dose-response) and for **control** features apply +the same clamp (selectivity), each time comparing the steered continuation to the baseline. + +GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test. + + python steer.py --evo2-ckpt-dir --sae-checkpoint --layer 26 \ + --sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200 + +Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers +the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in +``evo2_sae.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a +``decode_only`` flag) is a planned follow-up. +""" + +from __future__ import annotations + +import argparse +import sys +from contextlib import nullcontext +from pathlib import Path + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent / "src")) # recipes/evo2/src -> evo2_sae package +sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) + +from sae.steering import steer # noqa: E402 + + +def _divergence(a: str, b: str): + """Return (first differing index, fraction of differing chars) over the shared prefix length.""" + n = min(len(a), len(b)) + first = next((i for i in range(n) if a[i] != b[i]), n) + diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n) + return first, diff + + +def main(): + """Encode a sequence, then steer a target feature (dose-response) + control features (selectivity).""" + p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).") + p.add_argument("--evo2-ckpt-dir", required=True) + p.add_argument("--sae-checkpoint", required=True) + p.add_argument("--layer", type=int, required=True) + p.add_argument("--sequence", required=True) + p.add_argument("--organism", default="None (raw DNA)") + p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).") + p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).") + p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.") + p.add_argument("--n-tokens", type=int, default=60) + p.add_argument("--device", default="cuda") + a = p.parse_args() + + from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100 + from evo2_sae.core import Evo2SAE, clean_dna # noqa: E402, RUF100 + from megatron.core.utils import unwrap_model # noqa: E402, RUF100 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + + # 1. Encode -> the sequence's most-active features (pick a target if not given). + codes = eng.encode(a.sequence) + vals, ids = codes.max(0).values.topk(10) + print(f"top features on {a.sequence[:24]}...:") + target = a.feature + for v, i in zip(vals.tolist(), ids.tolist()): + lab = eng.labels.get(int(i)) + print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}") + if target is None and lab: + target = int(i) + controls = [int(c) for c in a.controls.split(",") if c.strip()] + strengths = [float(s) for s in a.strengths.split(",")] + + # 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt. + comp = eng._ensure_engine() + prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence) + layer_mod = unwrap_model(comp.model).decoder.layers[a.layer] + + def gen(clamps): + ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext() + with ctx: + out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + + base = gen({}) + print(f"\nbaseline: {base[:60]}") + print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===") + for s in strengths: + steered = gen({target: s}) + first, diff = _divergence(base, steered) + print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}") + + if controls: + s = strengths[-1] + print(f"\n=== selectivity: control features clamped to {s} ===") + for c in controls: + steered = gen({c: s}) + first, diff = _divergence(base, steered) + print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py new file mode 100644 index 0000000000..c061e38533 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. + +A forward hook on the layer the SAE was trained on: it re-encodes the layer output through +the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the +activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the +SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves +the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) +and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model +with vs. without the hook. +""" + +from contextlib import contextmanager +from typing import Dict + +import torch + + +def clamp_hook(sae, clamps: Dict[int, float]): + """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. + + The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's + output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative + value reverses its decoder direction. Works whether the module returns a tensor or a tuple + whose first element is the hidden state. + + Args: + sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, + and ``top_k``. + clamps: Map of feature index -> absolute code value to force at every position. + + Returns: + A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. + """ + items = [(int(f), float(v)) for f, v in clamps.items()] + + def hook(module, inputs, output): + h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + dtype, shape = h.dtype, h.shape + h_flat = h.reshape(-1, h.shape[-1]).float() + with torch.no_grad(): + pre_act, info = sae.encode_pre_act(h_flat) + codes = torch.relu(pre_act) + kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) + codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) + codes_clamped = codes_orig.clone() + for f, v in items: + codes_clamped[:, f] = v + delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) + h_out = (h_flat + delta).to(dtype).reshape(shape) + return (h_out, *rest) if rest is not None else h_out + + return hook + + +@contextmanager +def steer(module, sae, clamps: Dict[int, float]): + """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" + handle = module.register_forward_hook(clamp_hook(sae, clamps)) + try: + yield + finally: + handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py new file mode 100644 index 0000000000..5ef9d15746 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.steering: the delta-clamp adds exactly decode(clamped) - decode(orig).""" + +import torch +from sae.architectures import TopKSAE +from sae.steering import clamp_hook, steer +from torch import nn + + +def _sae(): + torch.manual_seed(0) + return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + + +def test_delta_clamp_is_exact_and_cancels_recon(): + """No-op clamp leaves the activation unchanged (recon error cancels); a real clamp shifts + it by exactly decode(clamped) - decode(orig) — the two halves of the delta-clamp contract.""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + + # No-op: decode(orig) != x, but the added delta is 0, so the output is unchanged. + with steer(m, sae, {}): + assert torch.allclose(m(x), x, atol=1e-5) + + # Real clamp: output == x + (decode(clamped) - decode(orig)), recon error cancelled. + with torch.no_grad(): + pre, info = sae.encode_pre_act(x.float()) + codes = torch.relu(pre) + kv, ki = torch.topk(codes, sae.top_k, dim=-1) + co = torch.zeros_like(codes).scatter(-1, ki, kv) + cc = co.clone() + cc[:, 3] = 5.0 + expected = x + (sae.decode(cc, info) - sae.decode(co, info)) + with steer(m, sae, {3: 5.0}): + assert torch.allclose(m(x), expected, atol=1e-4) + + +def test_tuple_output_steers_only_hidden_state(): + """When the hooked module returns a tuple, only element 0 is steered; the rest passes through.""" + + class M(nn.Module): + def forward(self, x): + return (x, "meta") + + sae, x = _sae(), torch.randn(3, 8) + m = M() + handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) + out = m(x) + handle.remove() + assert isinstance(out, tuple) and out[1] == "meta" + assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it From d385efc5a286f99f4fc65c26b8734b77285ac13a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 21:48:44 +0000 Subject: [PATCH 3/3] evo2 steering: unify all clamping onto sae.steering (B), drop core's variant Always clamp with the faithful encode->topk->decode delta. core.py's generate had its own decode-only, no-topk linear inject (A) that ignored top-K gating; replace it with sae.steering.clamp_hook (B) + a new decode_only flag (skips prompt prefill, seq dim != 1). Now ONE clamp implementation serves every surface: the /generate endpoint, the generate CLI, the dashboard steering tab (all via core.generate), and the steer.py harness. - sae/steering.py: add decode_only (continuation-only steering on [S,B,H] decode steps). - core.py: generate() registers sae.steering.clamp_hook(decode_only=True); removed Evo2SAE._clamp_hook + per-feature spec building. - test_steering.py: +decode_only test (prefill untouched, decode-step steered). NOTE: this changes the production steering math (A->B: top-K gating now applies), so it needs a 7B/L26 GPU smoke before this PR leaves draft. Signed-off-by: Polina Binder --- .../recipes/evo2/src/evo2_sae/core.py | 49 ++++--------------- .../sae/src/sae/steering.py | 12 +++-- .../sae/tests/test_steering.py | 12 +++++ 3 files changed, 30 insertions(+), 43 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index ad3a5fb869..444748b591 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -294,30 +294,6 @@ def top_features(self, codes: torch.Tensor, tag_len: int = 0, k: int = 8) -> lis ] # ------------------------------------------------------------------ generate - def _clamp_hook(self, specs, pre_bias): - """Forward hook that clamps SAE features on the residual during DECODE steps only. - - A decode step processes a single new token (sequence dim == 1); the prompt prefill - (sequence dim > 1) is left untouched, giving continuation-only steering through - `infer.generate`: h <- h + Σ_f (t_f - a_f(h)) · d_f - `specs` = list of (enc_f [H], b_f float, dec_f [H], target float). - """ - - def hook(_module, _inp, output): - hs = output[0] if isinstance(output, tuple) else output # [S, B, H] - if hs.shape[0] != 1: # prefill (whole prompt) — leave untouched - return output - x = hs.float() - xc = x - pre_bias - add = torch.zeros_like(x) - for enc_f, b_f, dec_f, target in specs: - a = torch.relu(torch.matmul(xc, enc_f) + b_f) - add = add + (target - a).unsqueeze(-1) * dec_f - new = (x + add).to(hs.dtype) - return (new, *output[1:]) if isinstance(output, tuple) else new - - return hook - def generate( self, prompt="", @@ -354,23 +330,16 @@ def generate( with self._lock: comp = self._ensure_engine() hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] - pre_bias = self.sae.pre_bias.detach().float().to(self.device) - specs, feat_meta = [], [] - for f in features: - fid = int(f["feature_id"]) - specs.append( - ( - self.sae.encoder.weight[fid].detach().float().to(self.device), - float(self.sae.latent_bias[fid].detach()), - self.sae.decoder.weight[:, fid].detach().float().to(self.device), - float(f.get("strength", 1.0)), - ) - ) - feat_meta.append({"id": fid, "label": self.labels.get(fid), "strength": float(f.get("strength", 1.0))}) + from sae.steering import clamp_hook + + clamps = {int(f["feature_id"]): float(f.get("strength", 1.0)) for f in features} + feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] def _run(steer: bool) -> str: handle = ( - hook_layer.register_forward_hook(self._clamp_hook(specs, pre_bias)) if (steer and specs) else None + hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True)) + if (steer and clamps) + else None ) try: out = INF.generate( @@ -382,7 +351,7 @@ def _run(steer: bool) -> str: handle.remove() main_dna = _run(steer=True) - base_dna = _run(steer=False) if (compare_baseline and specs) else None + base_dna = _run(steer=False) if (compare_baseline and clamps) else None resp = { "prompt": dna, @@ -391,7 +360,7 @@ def _run(steer: bool) -> str: "tag_len": len(resolved_tag), "n_tokens": n_tokens, "features": feat_meta, - "steered": bool(specs), + "steered": bool(clamps), "generation": {"sequence": main_dna, "activations": self.feature_tracks(main_dna, fids)}, "baseline": None, } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py index c061e38533..d390c9066b 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -30,7 +30,7 @@ import torch -def clamp_hook(sae, clamps: Dict[int, float]): +def clamp_hook(sae, clamps: Dict[int, float], decode_only: bool = False): """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's @@ -42,6 +42,10 @@ def clamp_hook(sae, clamps: Dict[int, float]): sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, and ``top_k``. clamps: Map of feature index -> absolute code value to force at every position. + decode_only: If True, steer only autoregressive *decode* steps and leave the prompt + prefill untouched (continuation-only steering). Assumes a ``(sequence, batch, hidden)`` + layout — the convention for Evo2/megatron decoder layers — and applies the clamp only + when the sequence dimension is 1 (a single new token). Returns: A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. @@ -50,6 +54,8 @@ def clamp_hook(sae, clamps: Dict[int, float]): def hook(module, inputs, output): h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + if decode_only and h.shape[0] != 1: # prefill (seq dim > 1) — leave untouched + return output dtype, shape = h.dtype, h.shape h_flat = h.reshape(-1, h.shape[-1]).float() with torch.no_grad(): @@ -68,9 +74,9 @@ def hook(module, inputs, output): @contextmanager -def steer(module, sae, clamps: Dict[int, float]): +def steer(module, sae, clamps: Dict[int, float], decode_only: bool = False): """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" - handle = module.register_forward_hook(clamp_hook(sae, clamps)) + handle = module.register_forward_hook(clamp_hook(sae, clamps, decode_only=decode_only)) try: yield finally: diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py index 5ef9d15746..0b28517c9d 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -62,3 +62,15 @@ def forward(self, x): handle.remove() assert isinstance(out, tuple) and out[1] == "meta" assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it + + +def test_decode_only_skips_prefill(): + """decode_only steers single-token decode steps ([1,B,H]) but leaves multi-token prefill alone.""" + sae, m = _sae(), nn.Identity() + prefill = torch.randn(5, 2, 8) # [S=5, B, H] — prompt prefill, must pass through + decode = torch.randn(1, 2, 8) # [S=1, B, H] — a single new token, must be steered + handle = m.register_forward_hook(clamp_hook(sae, {3: 5.0}, decode_only=True)) + out_prefill, out_decode = m(prefill), m(decode) + handle.remove() + assert torch.allclose(out_prefill, prefill, atol=1e-5) # prefill untouched + assert not torch.allclose(out_decode, decode) # decode step steered