From 538592f5757047fabbdcd6de08b11c64093398f1 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 16:44:42 +0000 Subject: [PATCH 01/20] evo2 SAE recipe: live inference engine + steering server + CLI (src/evo2_sae) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the live-serving half of the evo2 SAE recipe on top of the offline extract/train (#1621): src/evo2_sae/ with the Evo2SAE engine (core), the FastAPI server, and the serve/encode/batch CLI. This is what distinguishes evo2 from the esm2/codonfm recipes (static dashboards) — a live backend with interactive SAE-feature steering. - Model loading via predict.load_model_to_layer (#1618); generation via infer.setup_inference_engine + infer.generate (eager, cuda_graph_impl="none") so the decode-only residual-stream steering hook applies. SAE math from the shared sae package. - pyproject now discovers the package from src/ and adds the serving deps (fastapi, uvicorn, pandas). - Tests: test_server.py (mocked engine, CPU, locks the viz API contract) and test_steering.py (clamp math on CPU; GPU steering gated on EVO2_CKPT_DIR/SAE_CKPT_PATH). 7 CPU tests pass. The feature-explorer dashboard builds on this server in a follow-up. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 12 +- .../recipes/evo2/scripts/launch_inference.sh | 31 ++ .../recipes/evo2/src/evo2_sae/__init__.py | 21 + .../recipes/evo2/src/evo2_sae/cli.py | 155 +++++++ .../recipes/evo2/src/evo2_sae/core.py | 396 ++++++++++++++++++ .../recipes/evo2/src/evo2_sae/server.py | 174 ++++++++ .../recipes/evo2/tests/test_server.py | 105 +++++ .../recipes/evo2/tests/test_steering.py | 100 +++++ 8 files changed, 989 insertions(+), 5 deletions(-) create mode 100755 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f7b23d8eff..f4a6a32c40 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -13,13 +13,15 @@ dependencies = [ "torch>=2.0", "numpy>=1.20", "pyarrow>=23.0.0", + "fastapi>=0.110", + "uvicorn>=0.29", + "pandas>=1.5", ] -# No package code lives here yet — the recipe is just an entry-point for -# scripts/ that depends on the shared `sae` workspace package. Declare no -# packages so setuptools doesn't try to discover anything. -[tool.setuptools] -packages = [] +# The `evo2_sae` package (src/) holds the live inference engine + server + CLI; +# scripts/ (extract, train) are standalone entry points alongside it. +[tool.setuptools.packages.find] +where = ["src"] [tool.uv.sources] sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh new file mode 100755 index 0000000000..64f8d89f5d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Launch the Evo2 SAE inference engine. One engine, three modes: +# +# ./launch_inference.sh serve # live HTTP server on :8001 (viz backend) +# ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features +# ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet +# +# Config via env (sensible defaults below): EVO2_CKPT_DIR, SAE_CKPT_PATH, +# FEATURE_ANNOTATIONS, EMBEDDING_LAYER, DEVICE, PORT, CUDA_VISIBLE_DEVICES. +# +# Requires the evo2_megatron recipe venv (provides bionemo.evo2 + megatron). +set -euo pipefail + +HERE="$(cd "$(dirname "$0")" && pwd)" +RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports + +VENV="${VENV:-/data/pbinder/bionemo-framework/bionemo-recipes/recipes/evo2_megatron/.venv}" +export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:-/data/interp/evo2/checkpoints/evo2_1b_base_mbridge}" +export SAE_CKPT_PATH="${SAE_CKPT_PATH:-/data/interp/evo2/sae/v2_diverse/layer19_C13_nofilter/checkpoints/checkpoint_final.pt}" +export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-/data/interp/evo2/sae_eval/dashboard_data/l19_C13_nofilter/feature_metadata.parquet}" +export EMBEDDING_LAYER="${EMBEDDING_LAYER:-19}" + +if [[ ! -x "$VENV/bin/python" ]]; then + echo "ERROR: evo2_megatron venv not found at $VENV (build it with the recipe's .ci_build.sh)" >&2 + exit 1 +fi + +source "$VENV/bin/activate" +cd "$RECIPE_DIR" +export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" +exec python -m evo2_sae.cli "$@" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py new file mode 100644 index 0000000000..c79afb9063 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 + SAE inference engine — reused by the live server, the batch CLI, and the viz backend.""" + +from .core import DEFAULT_ORGANISM_TAGS, Evo2SAE, clean_dna + + +__all__ = ["DEFAULT_ORGANISM_TAGS", "Evo2SAE", "clean_dna"] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py new file mode 100644 index 0000000000..b4e7ce0acd --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 SAE inference CLI — one engine, three modes. + + serve : start the FastAPI server (one sequence at a time, interactive) + encode : annotate ONE sequence -> top features (stdout JSON) + batch : run a FASTA of MANY sequences -> parquet of per-sequence top features + +All three build the same `Evo2SAE` engine; config comes from flags or env +(EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER). +""" + +from __future__ import annotations + +import argparse +import gzip +import json +import os + + +def _add_common(p: argparse.ArgumentParser) -> None: + p.add_argument( + "--evo2-ckpt-dir", + default=os.environ.get("EVO2_CKPT_DIR", "/data/interp/evo2/checkpoints/evo2_1b_base_mbridge"), + ) + p.add_argument( + "--sae-ckpt-path", + default=os.environ.get( + "SAE_CKPT_PATH", "/data/interp/evo2/sae/v2_diverse/layer19_C13_nofilter/checkpoints/checkpoint_final.pt" + ), + ) + p.add_argument( + "--feature-annotations", + default=os.environ.get( + "FEATURE_ANNOTATIONS", + "/data/interp/evo2/sae_eval/dashboard_data/l19_C13_nofilter/feature_metadata.parquet", + ), + ) + p.add_argument("--layer", type=int, default=int(os.environ.get("EMBEDDING_LAYER", "19"))) + p.add_argument("--device", default=os.environ.get("DEVICE", "cuda")) + p.add_argument("--max-seq-len", type=int, default=int(os.environ.get("MAX_SEQ_LEN", "8192"))) + + +def _engine(args): + from .core import Evo2SAE + + return Evo2SAE( + evo2_ckpt_dir=args.evo2_ckpt_dir, + sae_ckpt_path=args.sae_ckpt_path, + layer=args.layer, + device=args.device, + max_seq_len=args.max_seq_len, + feature_annotations=args.feature_annotations, + ) + + +def _read_fasta(path: str): + seqs, ids = [], [] + name, parts = None, [] + opener = gzip.open if str(path).endswith(".gz") else open + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if name is not None: + seqs.append("".join(parts)) + ids.append(name) + name, parts = line[1:].split()[0] if len(line) > 1 else f"seq_{len(ids)}", [] + else: + parts.append(line) + if name is not None: + seqs.append("".join(parts)) + ids.append(name) + return ids, seqs + + +def main(): + """Parse args and dispatch to the serve / encode / batch subcommand.""" + ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)") + sub = ap.add_subparsers(dest="cmd", required=True) + + ps = sub.add_parser("serve", help="start the FastAPI inference server") + _add_common(ps) + ps.add_argument("--host", default="0.0.0.0") + ps.add_argument("--port", type=int, default=int(os.environ.get("PORT", "8001"))) + + pe = sub.add_parser("encode", help="annotate ONE sequence -> top features (JSON)") + _add_common(pe) + pe.add_argument("--sequence", required=True) + pe.add_argument("--organism", default="None (raw DNA)") + pe.add_argument("--top-k", type=int, default=8) + + pb = sub.add_parser("batch", help="MANY sequences (FASTA) -> parquet of per-sequence top features") + _add_common(pb) + pb.add_argument("--fasta", required=True) + pb.add_argument("--out", required=True) + pb.add_argument("--top-k", type=int, default=16) + pb.add_argument("--batch-size", type=int, default=8) + + args = ap.parse_args() + + if args.cmd == "serve": + import uvicorn + + from .server import build_app + + uvicorn.run(build_app(_engine(args)), host=args.host, port=args.port, log_level="info") + return + + from .core import clean_dna + + eng = _engine(args).load() + + if args.cmd == "encode": + tag = eng.resolve_tag(args.organism, None) or "" + dna = clean_dna(args.sequence) + codes = eng.encode(tag + dna) + tag_len = len(tag) if codes.shape[0] >= len(tag) else 0 + feats = eng.top_features(codes, tag_len=tag_len, k=args.top_k) + print( + json.dumps( + {"sequence": dna, "organism": args.organism, "bases": len(dna), "top_features": feats}, indent=2 + ) + ) + + elif args.cmd == "batch": + import pandas as pd + + ids, seqs = _read_fasta(args.fasta) + print(f"[batch] {len(seqs)} sequences from {args.fasta}; encoding (batch_size={args.batch_size})…") + codes_list = eng.encode_batch(seqs, batch_size=args.batch_size) + rows = [] + for sid, codes in zip(ids, codes_list): + for rank, ft in enumerate(eng.top_features(codes, k=args.top_k)): + rows.append({"sequence_id": sid, "bp": int(codes.shape[0]), "rank": rank, **ft}) + df = pd.DataFrame(rows) + df.to_parquet(args.out, index=False) + print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py new file mode 100644 index 0000000000..789bd9252f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 + SAE inference core — one importable engine for live and batch use. + +`Evo2SAE` loads a base Evo2 model and a trained SAE once, then exposes: + + encode(dna) -> codes [S, n_features] # ONE sequence (interactive) + encode_batch(seqs) -> list of codes [S_i, n_features] # MANY sequences (batched on GPU) + feature_tracks(dna, f) -> {feature_id: [per-base activation]} + generate(...) -> autoregressive DNA generation with optional additive + SAE-feature clamping on the generated continuation + +It has NO web dependency: the FastAPI server (`server.py`) and the batch CLI +(`cli.py`) are thin wrappers over this class, and the viz backend imports it too. + +The heavy Evo2 machinery is reused from the recipe: model loading via +`predict.load_model_to_layer` and generation via `infer.setup_inference_engine` / +`infer.generate` (run eager, `cuda_graph_impl="none"`, so the residual-stream steering +hook applies). This module only adds the SAE layer: encode, feature labels, and the +decode-only feature-clamp hook. +""" + +from __future__ import annotations + +import logging +import os +import re +import sys +import threading +from pathlib import Path +from typing import Optional + +import torch + + +logger = logging.getLogger("evo2_sae_infer") + +# Disable Inductor CUDA graphs before torch initializes inductor — graph capture +# conflicts with the residual-stream forward hook (which replaces the layer output) +# and with re-feeding a growing sequence each decode step. +os.environ.setdefault("TORCHINDUCTOR_CUDAGRAPHS", "0") + +# Make the local `sae` package importable (sparse_autoencoders/sae/src). +_SAE_SRC = os.environ.get("SAE_SRC", str(Path(__file__).resolve().parents[4] / "sae" / "src")) +if _SAE_SRC not in sys.path: + sys.path.insert(0, _SAE_SRC) + +_VALID_BASES = re.compile(r"[^ACGTN]") + +# Phylogenetic-tag prefixes per organism (Evo2 was trained with lineage tags). +DEFAULT_ORGANISM_TAGS = { + "None (raw DNA)": "", + "Human": "|d__Eukaryota;p__Chordata;c__Mammalia;o__Primates;f__Hominidae;g__Homo;s__Homo sapiens|", + "E. coli": "|d__Bacteria;p__Pseudomonadota;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Escherichia;s__Escherichia coli|", + "S. cerevisiae": "|d__Eukaryota;p__Ascomycota;c__Saccharomycetes;o__Saccharomycetales;f__Saccharomycetaceae;g__Saccharomyces;s__Saccharomyces cerevisiae|", +} + + +def clean_dna(seq: str) -> str: + """Uppercase and strip everything that isn't a nucleotide.""" + return _VALID_BASES.sub("", (seq or "").upper()) + + +def _init_single_process_distributed() -> None: + """Set the env vars Megatron's distributed init expects for a 1-GPU process.""" + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29577") + + +class Evo2SAE: + """Persistent Evo2 + SAE inference engine (single-sequence and batched).""" + + def __init__( + self, + evo2_ckpt_dir: str, + sae_ckpt_path: str, + layer: int, + device: str = "cuda", + max_seq_len: int = 8192, + feature_annotations: Optional[str] = None, + organism_tags: Optional[dict] = None, + ): + """Record config; call .load() to actually load the model + SAE onto the GPU.""" + self.evo2_ckpt_dir = evo2_ckpt_dir + self.sae_ckpt_path = sae_ckpt_path + self.layer = int(layer) + self.device = device + self.max_seq_len = int(max_seq_len) + self.feature_annotations = feature_annotations + self.organism_tags = dict(organism_tags) if organism_tags else dict(DEFAULT_ORGANISM_TAGS) + + self.model = None # truncated model (post_process=False) for activations + self.gen_components = None # recipe inference engine (full model) for generation, lazy + self.tokenizer = None + self.sae = None + self.n_features = None + self.labels: dict[int, str] = {} + self.peaks: dict[int, float] = {} + self._lock = threading.Lock() # serialize GPU access (Megatron isn't thread-safe) + self.ready = False + + # ------------------------------------------------------------------ loading + def load(self) -> "Evo2SAE": + """Load the truncated Evo2 model + SAE + feature labels (one-time, ~1 min).""" + from bionemo.evo2.run import predict as P + + _init_single_process_distributed() + self.model, self.tokenizer = P.load_model_to_layer(self.evo2_ckpt_dir, self.layer, full=False) + self.sae, self.n_features = self._load_sae() + self.labels, self.peaks = self._load_feature_meta() + self.ready = True + logger.info("Evo2SAE ready: layer=%d n_features=%d n_labels=%d", self.layer, self.n_features, len(self.labels)) + return self + + def _ensure_engine(self): + """Lazily build the recipe's inference engine (eager/hookable) for generation. + + cuda_graph_impl="none" keeps decode eager so the residual-stream steering hook + takes effect (a CUDA-graph-captured decode would replay frozen ops and ignore it). + """ + if self.gen_components is None: + from bionemo.evo2.run import infer as INF + + self.gen_components = INF.setup_inference_engine( + Path(self.evo2_ckpt_dir), max_seq_length=self.max_seq_len, cuda_graph_impl="none" + ) + return self.gen_components + + def _load_sae(self): + ckpt = torch.load(self.sae_ckpt_path, map_location="cpu", weights_only=False) + cfg = dict(ckpt["model_config"]) + state = ckpt["model_state_dict"] + if any(k.startswith("module.") for k in state): + state = {k.removeprefix("module."): v for k, v in state.items()} + train_cfg = ckpt.get("config") + arch = ( + str(train_cfg.get("architecture", train_cfg.get("arch", ""))).lower() + if isinstance(train_cfg, dict) + else "" + ) + hint = (arch + " " + self.sae_ckpt_path).lower() + from sae.architectures import ReLUSAE, TopKSAE + + cls = ReLUSAE if "relu" in hint else TopKSAE + sae = cls(**cfg) + sae.load_state_dict(state) + sae.eval().to(self.device) + logger.info("SAE loaded: %s input_dim=%d n_features=%d", cls.__name__, cfg["input_dim"], cfg["hidden_dim"]) + return sae, int(cfg["hidden_dim"]) + + def _load_feature_meta(self): + """feature_id -> (label, natural peak) from the annotation parquet/tsv/csv/json.""" + labels: dict[int, str] = {} + peaks: dict[int, float] = {} + if not self.feature_annotations: + return labels, peaks + path = Path(self.feature_annotations) + if not path.exists(): + logger.warning("Feature annotations %s not found — features unlabeled", path) + return labels, peaks + if path.suffix.lower() == ".parquet": + import pyarrow.parquet as pq + + tbl = pq.read_table(path).to_pydict() + ids = tbl.get("feature_id", []) + names = tbl.get("label", tbl.get("annotation", [None] * len(ids))) + pk = tbl.get("max_activation", [None] * len(ids)) + for i, n, p in zip(ids, names, pk): + if n is not None: + labels[int(i)] = str(n) + if p is not None: + peaks[int(i)] = float(p) + logger.info("Loaded %d labels from %s", len(labels), path) + return labels, peaks + + # ------------------------------------------------------------------ tokenize + def tokenize(self, text: str) -> list[int]: + """Tokenize text to token ids, truncated to max_seq_len.""" + tok = self.tokenizer + ids = tok.tokenize(text) if hasattr(tok, "tokenize") else tok.text_to_ids(text) + return ids[: self.max_seq_len] + + def resolve_tag(self, organism: str, tag: Optional[str]) -> Optional[str]: + """Explicit custom `tag` wins; else look up the organism preset.""" + if tag is not None: + return tag + return self.organism_tags.get(organism) + + # ------------------------------------------------------------------ encode + @torch.no_grad() + def encode(self, dna: str) -> torch.Tensor: + """ONE sequence -> SAE codes [seq_len, n_features] on CPU. No phylo tag.""" + ids = self.tokenize(dna) + if not ids: + return torch.empty(0, self.n_features) + with self._lock: + hidden = self._forward_hidden([ids])[0] # [S, H] + return self.sae.encode(hidden.to(self.device)).detach().cpu() + + @torch.no_grad() + def encode_batch(self, seqs: list[str], batch_size: int = 8) -> list[torch.Tensor]: + """MANY sequences -> list of SAE codes [S_i, n_features], batched on the GPU. + + Sequences are padded to the longest in each micro-batch; padding is masked + out before SAE-encoding so each result has the true per-base length. + """ + out: list[torch.Tensor] = [None] * len(seqs) # type: ignore + order = [(i, self.tokenize(s)) for i, s in enumerate(seqs)] + with self._lock: + for start in range(0, len(order), batch_size): + chunk = order[start : start + batch_size] + id_lists = [ids for _, ids in chunk] + hiddens = self._forward_hidden(id_lists) # list of [S_i, H] + for (orig_i, ids), h in zip(chunk, hiddens): + out[orig_i] = ( + self.sae.encode(h.to(self.device)).detach().cpu() + if h.shape[0] > 0 + else torch.empty(0, self.n_features) + ) + return out + + @torch.no_grad() + def _forward_hidden(self, id_lists: list[list[int]]) -> list[torch.Tensor]: + """Run the truncated model on a (padded) batch of token-id lists. + + Returns the unpadded layer-`layer` hidden states [S_i, H] per sequence. + """ + from bionemo.evo2.run import predict as P + + lens = [len(ids) for ids in id_lists] + maxlen = max(lens) if lens else 0 + if maxlen == 0: + return [torch.empty(0, 0) for _ in id_lists] + b = len(id_lists) + tokens = torch.zeros(b, maxlen, dtype=torch.long, device=self.device) + loss_mask = torch.zeros(b, maxlen, dtype=torch.long, device=self.device) + for i, ids in enumerate(id_lists): + if ids: + tokens[i, : len(ids)] = torch.tensor(ids, dtype=torch.long, device=self.device) + loss_mask[i, : len(ids)] = 1 + batch = { + "tokens": tokens, + "position_ids": torch.arange(maxlen, dtype=torch.long, device=self.device).unsqueeze(0).expand(b, -1), + "loss_mask": loss_mask, + "seq_idx": torch.arange(b, dtype=torch.long, device=self.device), + } + result = P._predict_step(model=self.model, batch=batch, output_embeddings=True) + hidden = result["hidden_embeddings"] # [B, S, H] + return [hidden[i, : lens[i]].float() for i in range(b)] + + def feature_tracks(self, dna: str, fids: list[int]) -> dict: + """Per-base activation of several features on `dna`. {fid: [..]} (encoded once).""" + if not dna: + return {int(f): [] for f in fids} + codes = self.encode(dna) + return {int(f): [round(float(v), 4) for v in codes[:, int(f)].tolist()] for f in fids} + + def top_features(self, codes: torch.Tensor, tag_len: int = 0, k: int = 8) -> list[dict]: + """Top-k features by per-base max activation over the DNA region (excluding the tag). + + `codes` is [S, n_features] from `encode`/`encode_batch`; `tag_len` skips the leading + phylo-tag tokens (ignored if it would drop the whole sequence). Returns the strictly + positive features as [{feature_id, label, max_activation}], used by the CLI and server. + """ + if codes.shape[0] == 0: + return [] + region = codes[tag_len:] if codes.shape[0] > tag_len else codes + per = region.max(dim=0).values + idx = per.topk(min(int(k), per.numel())).indices.tolist() + return [ + {"feature_id": int(i), "label": self.labels.get(int(i)), "max_activation": round(float(per[i]), 4)} + for i in idx + if per[i].item() > 0 + ] + + # ------------------------------------------------------------------ generate + def _clamp_hook(self, specs, pre_bias): + """Forward hook that clamps SAE features on the residual during DECODE steps only. + + A decode step processes a single new token (sequence dim == 1); the prompt prefill + (sequence dim > 1) is left untouched, giving continuation-only steering through + `infer.generate`: h <- h + Σ_f (t_f - a_f(h)) · d_f + `specs` = list of (enc_f [H], b_f float, dec_f [H], target float). + """ + + def hook(_module, _inp, output): + hs = output[0] if isinstance(output, tuple) else output # [S, B, H] + if hs.shape[0] != 1: # prefill (whole prompt) — leave untouched + return output + x = hs.float() + xc = x - pre_bias + add = torch.zeros_like(x) + for enc_f, b_f, dec_f, target in specs: + a = torch.relu(torch.matmul(xc, enc_f) + b_f) + add = add + (target - a).unsqueeze(-1) * dec_f + new = (x + add).to(hs.dtype) + return (new, *output[1:]) if isinstance(output, tuple) else new + + return hook + + def generate( + self, + prompt="", + organism="None (raw DNA)", + tag=None, + features=None, + n_tokens=120, + temperature=1.0, + top_k=0, + compare_baseline=False, + ) -> dict: + """Autoregressively generate DNA, optionally clamping features on the continuation. + + `features` = list of {"feature_id": int, "strength": float} (or []). Generation runs + through the recipe's inference engine (`infer.generate`, eager so the hook applies); + steering is a decode-only forward hook on layer `layer`. Returns + {generation:{sequence,activations}, baseline:..|None, features, steered}. + """ + from megatron.core.utils import unwrap_model + + from bionemo.evo2.run import infer as INF + + features = features or [] + resolved_tag = self.resolve_tag(organism, tag) + if resolved_tag is None: + raise ValueError(f"Unknown organism '{organism}' and no custom tag") + dna = clean_dna(prompt) + full_prompt = resolved_tag + dna + if not full_prompt: + raise ValueError("Provide a prompt or pick an organism (need >=1 token to seed)") + n_tokens = max(1, min(int(n_tokens), 400)) + fids = [int(f["feature_id"]) for f in features] + + with self._lock: + comp = self._ensure_engine() + hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] + pre_bias = self.sae.pre_bias.detach().float().to(self.device) + specs, feat_meta = [], [] + for f in features: + fid = int(f["feature_id"]) + specs.append( + ( + self.sae.encoder.weight[fid].detach().float().to(self.device), + float(self.sae.latent_bias[fid].detach()), + self.sae.decoder.weight[:, fid].detach().float().to(self.device), + float(f.get("strength", 1.0)), + ) + ) + feat_meta.append({"id": fid, "label": self.labels.get(fid), "strength": float(f.get("strength", 1.0))}) + + def _run(steer: bool) -> str: + handle = ( + hook_layer.register_forward_hook(self._clamp_hook(specs, pre_bias)) if (steer and specs) else None + ) + try: + out = INF.generate( + comp, [full_prompt], max_new_tokens=n_tokens, temperature=temperature, top_k=top_k + ) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + finally: + if handle is not None: + handle.remove() + + main_dna = _run(steer=True) + base_dna = _run(steer=False) if (compare_baseline and specs) else None + + resp = { + "prompt": dna, + "organism": organism, + "tag": resolved_tag, + "tag_len": len(resolved_tag), + "n_tokens": n_tokens, + "features": feat_meta, + "steered": bool(specs), + "generation": {"sequence": main_dna, "activations": self.feature_tracks(main_dna, fids)}, + "baseline": None, + } + if base_dna is not None: + resp["baseline"] = {"sequence": base_dna, "activations": self.feature_tracks(base_dna, fids)} + return resp diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py new file mode 100644 index 0000000000..1834a10cd6 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FastAPI server over the Evo2SAE engine — the live backend the viz talks to. + +Endpoints: /health, /features, /annotate (per-base activations for a pasted +sequence), /generate (autoregressive generation + optional SAE-feature clamp). +This is a thin layer; all model work lives in `core.Evo2SAE`. +""" + +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from typing import Optional + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from .core import Evo2SAE, clean_dna + + +logger = logging.getLogger("evo2_sae_infer.server") + + +class AnnotateRequest(BaseModel): + """Request body for /annotate (top-k feature scan or an explicit feature pick).""" + + sequence: str + organism: str = "None (raw DNA)" + tag: Optional[str] = None + mode: str = "topk" # "topk" | "pick" + k: int = 8 + feature_ids: Optional[list[int]] = None + feature_id: Optional[int] = None + + +class FeatureClamp(BaseModel): + """A single SAE-feature steering clamp (feature id + target strength).""" + + feature_id: int + strength: float = 1.0 + + +class GenerateRequest(BaseModel): + """Request body for /generate (autoregressive generation + optional SAE-feature clamps).""" + + prompt: str = "" + organism: str = "None (raw DNA)" + tag: Optional[str] = None + features: list[FeatureClamp] = [] + n_tokens: int = 120 + temperature: float = 1.0 + top_k: int = 0 + compare_baseline: bool = False + + +def build_app(engine: Evo2SAE) -> FastAPI: + """Build the FastAPI app; the engine is loaded once in the lifespan handler.""" + + @asynccontextmanager + async def lifespan(app: FastAPI): + try: + engine.load() + logger.info("engine ready") + except Exception: + logger.exception("engine startup failed — /health stays not-ready") + yield + + app = FastAPI(title="Evo2 SAE inference", lifespan=lifespan) + app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) + + @app.get("/health") + def health(): + return { + "ready": bool(engine.ready), + "layer": engine.layer, + "n_features": engine.n_features, + "n_labels": len(engine.labels), + "sae_path": engine.sae_ckpt_path, + "organisms": list(engine.organism_tags.keys()), + "organism_tags": engine.organism_tags, + "device": engine.device, + } + + @app.get("/features") + def features(): + if not engine.ready: + raise HTTPException(503, "Backend not ready") + rows = [ + {"id": int(f), "label": lab, "natural_peak": engine.peaks.get(int(f))} for f, lab in engine.labels.items() + ] + rows.sort(key=lambda r: r["id"]) + return rows + + @app.post("/annotate") + def annotate(req: AnnotateRequest): + if not engine.ready: + raise HTTPException(503, "Backend not ready") + dna = clean_dna(req.sequence) + if not dna: + raise HTTPException(400, "No valid nucleotides in sequence") + tag = engine.resolve_tag(req.organism, req.tag) + if tag is None: + raise HTTPException(400, f"Unknown organism '{req.organism}' and no custom tag") + full = tag + dna + tag_len = len(tag) + codes = engine.encode(full) # [S, n_features], lock held inside + if codes.shape[0] < tag_len: + tag_len = 0 + if req.mode == "pick": + ids = req.feature_ids or ([req.feature_id] if req.feature_id is not None else []) + if not ids: + raise HTTPException(400, "mode='pick' requires feature_ids") + chosen = [int(i) for i in ids] + else: + k = max(1, min(int(req.k), 64)) + chosen = [ft["feature_id"] for ft in engine.top_features(codes, tag_len=tag_len, k=k)] + feats = [] + for fid in chosen: + col = codes[:, fid] + feats.append( + { + "feature_id": fid, + "label": engine.labels.get(fid), + "max_activation": float(col[tag_len:].max().item()) + if codes.shape[0] > tag_len + else float(col.max().item()), + "activations": [round(float(v), 4) for v in col.tolist()], + } + ) + return { + "sequence": dna, + "organism": req.organism, + "tag": tag, + "tag_len": tag_len, + "bases": list(full), + "n_tokens": codes.shape[0], + "layer": engine.layer, + "features": feats, + } + + @app.post("/generate") + def generate(req: GenerateRequest): + if not engine.ready: + raise HTTPException(503, "Backend not ready") + try: + return engine.generate( + prompt=req.prompt, + organism=req.organism, + tag=req.tag, + features=[f.model_dump() for f in req.features], + n_tokens=req.n_tokens, + temperature=req.temperature, + top_k=req.top_k, + compare_baseline=req.compare_baseline, + ) + except ValueError as e: + raise HTTPException(400, str(e)) + + return app diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py new file mode 100644 index 0000000000..a09a9553d4 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server contract tests — the API the feature-explorer viz consumes. + +A mocked engine (no model, CPU-only) drives the FastAPI app so these run in CI and lock the +response shapes + error codes the dashboard depends on: /health, /features, /annotate (per-base +activations), /generate. Real model inference is covered by test_steering.py. +""" + +import pytest +import torch +from evo2_sae.server import build_app +from fastapi.testclient import TestClient + + +class FakeEngine: + """Minimal stand-in for Evo2SAE exposing only what the server endpoints touch.""" + + def __init__(self): + self.ready = True + self.layer = 19 + self.n_features = 4 + self.labels = {0: "feat0", 1: "feat1"} + self.peaks = {0: 0.5} + self.organism_tags = {"None (raw DNA)": "", "Human": "|tag|"} + self.device = "cpu" + self.sae_ckpt_path = "fake.pt" + + def load(self): + self.ready = True + + def resolve_tag(self, organism, tag): + return tag if tag is not None else self.organism_tags.get(organism) + + def encode(self, full): + codes = torch.zeros(len(full), self.n_features) + codes[:, 0] = 1.0 # feature 0 fires everywhere + return codes + + def top_features(self, codes, tag_len=0, k=8): + return [{"feature_id": 0, "label": self.labels.get(0), "max_activation": 1.0}] + + def generate(self, **kw): + if not kw.get("prompt") and kw.get("organism") == "None (raw DNA)" and not kw.get("tag"): + raise ValueError("need a seed") + return { + "generation": {"sequence": "ACGT", "activations": {0: [1.0, 1.0, 1.0, 1.0]}}, + "baseline": None, + "features": [], + "steered": False, + } + + +@pytest.fixture +def client(): + with TestClient(build_app(FakeEngine())) as c: + yield c + + +def test_health(client): + b = client.get("/health").json() + assert b["ready"] is True and b["layer"] == 19 + assert "None (raw DNA)" in b["organisms"] + + +def test_features(client): + rows = client.get("/features").json() + assert {"id", "label", "natural_peak"} <= set(rows[0]) + + +def test_annotate_returns_per_base_activations(client): + b = client.post("/annotate", json={"sequence": "ACGTACGT", "organism": "None (raw DNA)"}).json() + assert {"sequence", "features", "bases", "tag_len", "layer", "n_tokens"} <= set(b) + assert b["features"][0]["activations"] # the per-base track the viz plots + + +def test_annotate_rejects_non_dna(client): + assert client.post("/annotate", json={"sequence": "ZZZZ"}).status_code == 400 + + +def test_generate_returns_sequence(client): + b = client.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).json() + assert b["generation"]["sequence"] + + +def test_endpoints_503_until_ready(): + eng = FakeEngine() + eng.ready = False + eng.load = lambda: None # startup leaves it not-ready + with TestClient(build_app(eng)) as c: + assert c.get("/features").status_code == 503 + assert c.post("/annotate", json={"sequence": "ACGT"}).status_code == 503 diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py new file mode 100644 index 0000000000..0b93206364 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SAE feature steering during generation. + +`test_clamp_math` (CPU) checks the steering arithmetic the forward hook applies; it needs +no model and runs in CI. `test_steering_*` (GPU, slow, checkpoint-gated) drive the real +`generate()` through the recipe inference engine and assert the steering hook fires on the +continuation only and changes the output. Gated by EVO2_CKPT_DIR + SAE_CKPT_PATH. +""" + +import os + +import pytest +import torch + + +# --------------------------------------------------------------------- CPU: clamp math +def test_clamp_math(): + """The decode-only hook applies h <- h + Σ_f (target_f - relu((h-pre_bias)@enc_f + b_f))·dec_f.""" + from evo2_sae.core import Evo2SAE + + H, F = 4, 2 + torch.manual_seed(0) + pre_bias = torch.randn(H) + enc = torch.randn(F, H) + bias = torch.randn(F) + dec = torch.randn(H, F) + specs = [(enc[f], float(bias[f]), dec[:, f], float(f + 1)) for f in range(F)] + + eng = Evo2SAE.__new__(Evo2SAE) # bare instance — exercise the hook only + hook = eng._clamp_hook(specs, pre_bias) + + h = torch.randn(1, 1, H) # one decode token: [S=1, B=1, H] + out = hook(None, None, h) + + xc = h - pre_bias + expected = h.clone() + for enc_f, b_f, dec_f, target in specs: + a = torch.relu(xc @ enc_f + b_f) + expected = expected + (target - a).unsqueeze(-1) * dec_f + torch.testing.assert_close(out, expected) + + # prefill (S>1) must be left untouched (continuation-only steering) + prefill = torch.randn(5, 1, H) + assert torch.equal(hook(None, None, prefill), prefill) + + +# --------------------------------------------------------------------- GPU: real generation +_CKPT = os.environ.get("EVO2_CKPT_DIR") +_SAE = os.environ.get("SAE_CKPT_PATH") +_LAYER = int(os.environ.get("EMBEDDING_LAYER", "19")) +_PROMPT = "ACGTACGTACGTACGTACGT" + + +@pytest.fixture(scope="module") +def engine(): + """Load the Evo2 + SAE engine once (skips unless CUDA + checkpoints are available).""" + if not torch.cuda.is_available(): + pytest.skip("steering tests require CUDA") + if not (_CKPT and _SAE): + pytest.skip("set EVO2_CKPT_DIR and SAE_CKPT_PATH to run the steering tests") + from evo2_sae import Evo2SAE + + return Evo2SAE(evo2_ckpt_dir=_CKPT, sae_ckpt_path=_SAE, layer=_LAYER).load() + + +def _gen(engine, features): + torch.manual_seed(0) + return engine.generate( + prompt=_PROMPT, organism="None (raw DNA)", features=features, n_tokens=48, temperature=0.0, top_k=1 + ) + + +@pytest.mark.slow +def test_unsteered_is_dna(engine): + """Unsteered generation yields a non-empty ACGT string (Evo2 stays in-distribution).""" + seq = _gen(engine, [])["generation"]["sequence"] + assert seq and set(seq) <= set("ACGTN") + + +@pytest.mark.slow +def test_steering_changes_output(engine): + """A strong clamp changes the generated continuation; an empty clamp is a deterministic no-op.""" + base = _gen(engine, [])["generation"]["sequence"] + steered = _gen(engine, [{"feature_id": 0, "strength": 10.0}])["generation"]["sequence"] + assert steered != base + assert _gen(engine, [])["generation"]["sequence"] == base # determinism / no-op From d6158e510c8469647446044143f0593516a9a207 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 18:12:24 +0000 Subject: [PATCH 02/20] evo2 infer: default launch_inference.sh to the 7B/layer-26 model + its annotations Point the script's defaults at the production config: evo2_7b_mbridge, the layer-26 7B SAE (layer26_7B_ablate_normalize_input), EMBEDDING_LAYER=26, and the l26_7B_normalize feature annotations. Still overridable via env. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 64f8d89f5d..00925c4773 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -15,10 +15,10 @@ HERE="$(cd "$(dirname "$0")" && pwd)" RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports VENV="${VENV:-/data/pbinder/bionemo-framework/bionemo-recipes/recipes/evo2_megatron/.venv}" -export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:-/data/interp/evo2/checkpoints/evo2_1b_base_mbridge}" -export SAE_CKPT_PATH="${SAE_CKPT_PATH:-/data/interp/evo2/sae/v2_diverse/layer19_C13_nofilter/checkpoints/checkpoint_final.pt}" -export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-/data/interp/evo2/sae_eval/dashboard_data/l19_C13_nofilter/feature_metadata.parquet}" -export EMBEDDING_LAYER="${EMBEDDING_LAYER:-19}" +export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:-/data/interp/evo2/checkpoints/evo2_7b_mbridge}" +export SAE_CKPT_PATH="${SAE_CKPT_PATH:-/data/interp/evo2/sae/v2_diverse/layer26_7B_ablate_normalize_input/checkpoints/checkpoint_final.pt}" +export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-/data/interp/evo2/sae_eval/dashboard_data/l26_7B_normalize/feature_metadata.parquet}" +export EMBEDDING_LAYER="${EMBEDDING_LAYER:-26}" if [[ ! -x "$VENV/bin/python" ]]; then echo "ERROR: evo2_megatron venv not found at $VENV (build it with the recipe's .ci_build.sh)" >&2 From 2212289f0bef9b0dc250460460e09ec4a32b066c Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 20:16:17 +0000 Subject: [PATCH 03/20] evo2 infer: wrap encode forward in bf16 autocast (TransformerEngine dtype fix) _forward_hidden calls predict._predict_step directly; on a current-main venv the stricter TransformerEngine asserts param/input dtypes match unless inside an autocast region ("Found input dtype: torch.float32 and 'layer_norm_weight' dtype: torch.bfloat16"). Wrap the call in torch.autocast(bf16). Verified: 7B/layer-26 encode now returns labeled features (e.g. motif_ATG fires on an ATG-led sequence). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 789bd9252f..ad3a5fb869 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -260,7 +260,11 @@ def _forward_hidden(self, id_lists: list[list[int]]) -> list[torch.Tensor]: "loss_mask": loss_mask, "seq_idx": torch.arange(b, dtype=torch.long, device=self.device), } - result = P._predict_step(model=self.model, batch=batch, output_embeddings=True) + # Evo2 runs bf16; TransformerEngine asserts param/input dtypes match unless inside an + # autocast region. predict()'s loop relies on this too, so wrap the direct _predict_step. + device_type = "cuda" if self.device.startswith("cuda") else "cpu" + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): + result = P._predict_step(model=self.model, batch=batch, output_embeddings=True) hidden = result["hidden_embeddings"] # [B, S, H] return [hidden[i, : lens[i]].float() for i in range(b)] From 92b7c859dc6485a8f52c09c576cf5447b6367c63 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 20:49:18 +0000 Subject: [PATCH 04/20] evo2 infer: robust steering test (discovered active feature) + encode smoke MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old test_steering_changes_output clamped an arbitrary feature_id=0 (often dead) and asserted tokens != base — too weak: greedy decoding is unchanged when the clamped feature has no signal, so it failed even though steering works. Replace with test_steering_changes_continuation, which discovers the most-active feature on the prompt (SAE-agnostic), clamps it to 3x its peak, and asserts the continuation changes + empty-clamp determinism. Add test_encode_smoke to guard the bf16 encode forward (the TransformerEngine dtype path). Verified 4/4 on 7B/layer-26 (H100): clamp_math, encode_smoke, unsteered_is_dna, steering_changes_continuation. Diagnostic confirmed the hook fires once per decode token (48), prefill (seq dim 192) skipped, decode (seq dim 1) steered — continuation-only holds; a hard clamp collapses the output (e.g. -> repetitive C). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/tests/test_steering.py | 53 +++++++++++++------ 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py index 0b93206364..cebceb0f52 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for SAE feature steering during generation. - -`test_clamp_math` (CPU) checks the steering arithmetic the forward hook applies; it needs -no model and runs in CI. `test_steering_*` (GPU, slow, checkpoint-gated) drive the real -`generate()` through the recipe inference engine and assert the steering hook fires on the -continuation only and changes the output. Gated by EVO2_CKPT_DIR + SAE_CKPT_PATH. +"""GPU inference tests: SAE encode + feature steering. + +`test_clamp_math` (CPU) checks the steering arithmetic the forward hook applies; it needs no +model and runs in CI. The GPU tests (slow, checkpoint-gated) drive the real engine: +`test_encode_smoke` guards the bf16 encode forward, and `test_unsteered_is_dna` / +`test_steering_changes_continuation` drive `generate()` through the recipe inference engine +and assert steering (on a discovered active feature) changes the continuation only. Gated by +EVO2_CKPT_DIR + SAE_CKPT_PATH. """ import os @@ -62,7 +64,8 @@ def test_clamp_math(): _CKPT = os.environ.get("EVO2_CKPT_DIR") _SAE = os.environ.get("SAE_CKPT_PATH") _LAYER = int(os.environ.get("EMBEDDING_LAYER", "19")) -_PROMPT = "ACGTACGTACGTACGTACGT" +_PROMPT = "ATGGCCGAATTCGGCACGAGGACGTGCTGAAAGCTAGCTAGGCTAACCGGTTACGTGCAT" +_ORG = "Human" @pytest.fixture(scope="module") @@ -79,9 +82,23 @@ def engine(): def _gen(engine, features): torch.manual_seed(0) - return engine.generate( - prompt=_PROMPT, organism="None (raw DNA)", features=features, n_tokens=48, temperature=0.0, top_k=1 - ) + return engine.generate(prompt=_PROMPT, organism=_ORG, features=features, n_tokens=48, temperature=0.0, top_k=1) + + +def _tag(engine): + return engine.resolve_tag(_ORG, None) or "" + + +@pytest.mark.slow +def test_encode_smoke(engine): + """encode runs the truncated bf16 forward and returns finite per-feature codes (>=1 firing). + + Guards the TransformerEngine bf16/fp32 autocast path: a dtype mismatch would crash here. + """ + codes = engine.encode(_tag(engine) + _PROMPT) + assert codes.ndim == 2 and codes.shape[1] == engine.n_features + assert torch.isfinite(codes).all() + assert (codes > 0).any() @pytest.mark.slow @@ -92,9 +109,15 @@ def test_unsteered_is_dna(engine): @pytest.mark.slow -def test_steering_changes_output(engine): - """A strong clamp changes the generated continuation; an empty clamp is a deterministic no-op.""" +def test_steering_changes_continuation(engine): + """Clamping a KNOWN-ACTIVE feature hard changes the continuation; empty clamp is a no-op. + + Discovers the most-active feature on the prompt (SAE-agnostic) so the clamp has real signal — + an arbitrary/dead feature would leave greedy decoding unchanged and make this test useless. + """ + per = engine.encode(_tag(engine) + _PROMPT).max(dim=0).values + fid, peak = int(per.argmax()), float(per.max()) base = _gen(engine, [])["generation"]["sequence"] - steered = _gen(engine, [{"feature_id": 0, "strength": 10.0}])["generation"]["sequence"] - assert steered != base - assert _gen(engine, [])["generation"]["sequence"] == base # determinism / no-op + steered = _gen(engine, [{"feature_id": fid, "strength": max(peak * 3.0, 50.0)}])["generation"]["sequence"] + assert steered != base # the clamp on an active feature changed the continuation + assert _gen(engine, [])["generation"]["sequence"] == base # determinism + empty-clamp no-op From c5e4d783404c28d1678034aeef0ea951927dc6e7 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:06:02 +0000 Subject: [PATCH 05/20] evo2 SAE recipe: add recipe README (run guide) Recipe-level entry point tying together the offline extract/train (#1621), the live inference engine + server + CLI (this PR), and the dashboard (#1623): venv setup, 7B/L26 config, CLI encode/batch, serve, dashboard launch, and the CPU/GPU test commands. codonfm/esm2 have one; evo2 didn't. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/README.md | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md new file mode 100644 index 0000000000..ebefaf9043 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -0,0 +1,86 @@ +# Evo2 SAE recipe + +Sparse Autoencoders for the [Evo2](../../../../recipes/evo2_megatron) DNA language model: +offline activation extraction + SAE training, a **live inference engine** (encode / steered +generation) with an HTTP server, and a feature-explorer **dashboard**. + +``` +recipes/evo2/ +├── scripts/ +│ ├── extract.py · train.py · chunk_fasta.py # offline: activations -> SAE +│ ├── launch_inference.sh # live: serve / encode / batch +│ └── launch_dashboard.py # serve the dashboard on provided data +├── src/evo2_sae/ core.py (engine) · server.py · cli.py +├── feature_explorer/ React/Vite dashboard (4 panels) +└── tests/ +``` + +The recipe reuses the evo2_megatron recipe for the model (`predict.load_model_to_layer`, +`infer.generate`) and the shared `sae` package for the autoencoder; `src/evo2_sae/` is only the +SAE layer (encode, the decode-only feature-clamp hook, feature labels, the serve object). + +## 0. Environment + +The recipe runs inside the **evo2_megatron venv** (provides `bionemo.evo2` + megatron + TE): + +```bash +cd ../../../recipes/evo2_megatron && bash .ci_build.sh # builds ./.venv (~15–30 min) +export VENV=$PWD/.venv +``` + +> The venv must include `predict.load_model_to_layer` and `infer.setup_inference_engine` +> (current `main`). `launch_inference.sh` reads `$VENV`; point it at the venv you built. + +## 1. Config (defaults target 7B / layer 26) + +```bash +export EVO2_CKPT_DIR=/data/interp/evo2/checkpoints/evo2_7b_mbridge +export SAE_CKPT_PATH=.../sae/v2_diverse/layer26_7B_ablate_normalize_input/checkpoints/checkpoint_final.pt +export EMBEDDING_LAYER=26 +export FEATURE_ANNOTATIONS=.../sae_eval/dashboard_data/l26_7B_normalize/feature_metadata.parquet +export CUDA_VISIBLE_DEVICES=0 +``` + +## 2. CLI inference + +```bash +cd scripts +./launch_inference.sh encode --sequence ATGGCC...GTGCAT --organism "Human" --top-k 8 # one seq -> JSON +./launch_inference.sh batch --fasta in.fa --out out.parquet # FASTA -> parquet +``` + +## 3. Inference server (dashboard backend) + +```bash +./launch_inference.sh serve # FastAPI on :8001 — /health /features /annotate /generate +curl localhost:8001/health +``` + +## 4. Dashboard + +The dashboard reads atlas parquets **you provide** (it does not generate them): + +```bash +cd .. # recipes/evo2 +# DIR must hold features_atlas.parquet, feature_metadata.parquet, feature_examples.parquet +"$VENV/bin/python" scripts/launch_dashboard.py --data-dir /path/to/dashboard_data +``` + +The **Feature atlas** tab is static (served from those parquets); **Sequence inspector** and +**Generative steering** call the server from step 3. See `feature_explorer/README.md`. + +## 5. Tests + +```bash +# CPU (no model): +PYTHONPATH=src "$VENV/bin/python" -m pytest tests/test_server.py tests/test_launch_dashboard.py \ + tests/test_steering.py::test_clamp_math -q +# GPU (slow, gated by the step-1 env vars): encode + steering on the real model +PYTHONPATH=src "$VENV/bin/python" -m pytest tests/test_steering.py -q +``` + +## Notes + +- **Two-model design:** a truncated model (encode/inspect, loaded eagerly) + the full inference + engine (generate, lazy). Keeps inspect cheap; the engine loads on first `serve`/`generate`. +- Generating the dashboard atlas parquets is a separate offline step (not yet in-recipe). From d0450ffd0d7e5b46cc031f5d28105c97f4ea4bdf Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:13:33 +0000 Subject: [PATCH 06/20] evo2 README: clarify dashboard launch (--data-dir optional; atlas vs live tabs) Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md index ebefaf9043..d8aed131ca 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -58,16 +58,18 @@ curl localhost:8001/health ## 4. Dashboard -The dashboard reads atlas parquets **you provide** (it does not generate them): - ```bash cd .. # recipes/evo2 -# DIR must hold features_atlas.parquet, feature_metadata.parquet, feature_examples.parquet +# inspector + steering tabs only (use the live server from step 3 — no atlas data needed): +"$VENV/bin/python" scripts/launch_dashboard.py +# + Feature-atlas tab: pass a dir holding features_atlas/feature_metadata/feature_examples.parquet: "$VENV/bin/python" scripts/launch_dashboard.py --data-dir /path/to/dashboard_data ``` -The **Feature atlas** tab is static (served from those parquets); **Sequence inspector** and -**Generative steering** call the server from step 3. See `feature_explorer/README.md`. +**Sequence inspector** and **Generative steering** call the server from step 3, so they work +with no atlas data. The **Feature atlas** tab needs the three parquets via `--data-dir`; +producing them from a corpus is a separate offline step (not yet in-recipe). See +`feature_explorer/README.md`. ## 5. Tests From de81106e0bc861e7235ba9a60f8a649d86613734 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 23:40:07 +0000 Subject: [PATCH 07/20] evo2: drop premature recipe README from the inference PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recipe-level README documented the dashboard (#1623) and the atlas generator (deferred), neither of which is in this PR — premature here. The inference run instructions live in launch_inference.sh's header. A complete recipe README lands once the dashboard + generator exist. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/README.md | 88 ------------------- 1 file changed, 88 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md deleted file mode 100644 index d8aed131ca..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md +++ /dev/null @@ -1,88 +0,0 @@ -# Evo2 SAE recipe - -Sparse Autoencoders for the [Evo2](../../../../recipes/evo2_megatron) DNA language model: -offline activation extraction + SAE training, a **live inference engine** (encode / steered -generation) with an HTTP server, and a feature-explorer **dashboard**. - -``` -recipes/evo2/ -├── scripts/ -│ ├── extract.py · train.py · chunk_fasta.py # offline: activations -> SAE -│ ├── launch_inference.sh # live: serve / encode / batch -│ └── launch_dashboard.py # serve the dashboard on provided data -├── src/evo2_sae/ core.py (engine) · server.py · cli.py -├── feature_explorer/ React/Vite dashboard (4 panels) -└── tests/ -``` - -The recipe reuses the evo2_megatron recipe for the model (`predict.load_model_to_layer`, -`infer.generate`) and the shared `sae` package for the autoencoder; `src/evo2_sae/` is only the -SAE layer (encode, the decode-only feature-clamp hook, feature labels, the serve object). - -## 0. Environment - -The recipe runs inside the **evo2_megatron venv** (provides `bionemo.evo2` + megatron + TE): - -```bash -cd ../../../recipes/evo2_megatron && bash .ci_build.sh # builds ./.venv (~15–30 min) -export VENV=$PWD/.venv -``` - -> The venv must include `predict.load_model_to_layer` and `infer.setup_inference_engine` -> (current `main`). `launch_inference.sh` reads `$VENV`; point it at the venv you built. - -## 1. Config (defaults target 7B / layer 26) - -```bash -export EVO2_CKPT_DIR=/data/interp/evo2/checkpoints/evo2_7b_mbridge -export SAE_CKPT_PATH=.../sae/v2_diverse/layer26_7B_ablate_normalize_input/checkpoints/checkpoint_final.pt -export EMBEDDING_LAYER=26 -export FEATURE_ANNOTATIONS=.../sae_eval/dashboard_data/l26_7B_normalize/feature_metadata.parquet -export CUDA_VISIBLE_DEVICES=0 -``` - -## 2. CLI inference - -```bash -cd scripts -./launch_inference.sh encode --sequence ATGGCC...GTGCAT --organism "Human" --top-k 8 # one seq -> JSON -./launch_inference.sh batch --fasta in.fa --out out.parquet # FASTA -> parquet -``` - -## 3. Inference server (dashboard backend) - -```bash -./launch_inference.sh serve # FastAPI on :8001 — /health /features /annotate /generate -curl localhost:8001/health -``` - -## 4. Dashboard - -```bash -cd .. # recipes/evo2 -# inspector + steering tabs only (use the live server from step 3 — no atlas data needed): -"$VENV/bin/python" scripts/launch_dashboard.py -# + Feature-atlas tab: pass a dir holding features_atlas/feature_metadata/feature_examples.parquet: -"$VENV/bin/python" scripts/launch_dashboard.py --data-dir /path/to/dashboard_data -``` - -**Sequence inspector** and **Generative steering** call the server from step 3, so they work -with no atlas data. The **Feature atlas** tab needs the three parquets via `--data-dir`; -producing them from a corpus is a separate offline step (not yet in-recipe). See -`feature_explorer/README.md`. - -## 5. Tests - -```bash -# CPU (no model): -PYTHONPATH=src "$VENV/bin/python" -m pytest tests/test_server.py tests/test_launch_dashboard.py \ - tests/test_steering.py::test_clamp_math -q -# GPU (slow, gated by the step-1 env vars): encode + steering on the real model -PYTHONPATH=src "$VENV/bin/python" -m pytest tests/test_steering.py -q -``` - -## Notes - -- **Two-model design:** a truncated model (encode/inspect, loaded eagerly) + the full inference - engine (generate, lazy). Keeps inspect cheap; the engine loads on first `serve`/`generate`. -- Generating the dashboard atlas parquets is a separate offline step (not yet in-recipe). From 374233f534a999b5b07ca3260892301aa5cac82d Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 18:40:46 +0000 Subject: [PATCH 08/20] Address review: required env config, header edge case, mode validation - launch_inference.sh / cli.py: drop hardcoded dev paths; checkpoints now required via env or flag (EVO2_CKPT_DIR / SAE_CKPT_PATH), layer defaults to 26 - cli.py _read_fasta: handle headers with no token after '>' ("> ") instead of IndexError; add docstrings on _add_common / _engine / _read_fasta - server.py: CORS origins configurable via CORS_ORIGINS env (default "*"); reject invalid /annotate mode with 400 instead of silently treating it as topk - test_server.py: assert /generate returns 503 before the engine is ready Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 10 ++-- .../recipes/evo2/src/evo2_sae/cli.py | 55 ++++++++++++------- .../recipes/evo2/src/evo2_sae/server.py | 6 +- .../recipes/evo2/tests/test_server.py | 1 + 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 00925c4773..8d4ee16ee9 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -14,10 +14,12 @@ set -euo pipefail HERE="$(cd "$(dirname "$0")" && pwd)" RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports -VENV="${VENV:-/data/pbinder/bionemo-framework/bionemo-recipes/recipes/evo2_megatron/.venv}" -export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:-/data/interp/evo2/checkpoints/evo2_7b_mbridge}" -export SAE_CKPT_PATH="${SAE_CKPT_PATH:-/data/interp/evo2/sae/v2_diverse/layer26_7B_ablate_normalize_input/checkpoints/checkpoint_final.pt}" -export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-/data/interp/evo2/sae_eval/dashboard_data/l26_7B_normalize/feature_metadata.parquet}" +# Required (no hardcoded defaults — supply your own paths via env): +VENV="${VENV:?Set VENV to the evo2_megatron recipe .venv (provides bionemo.evo2 + megatron)}" +export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:?Set EVO2_CKPT_DIR to an Evo2 MBridge checkpoint directory}" +export SAE_CKPT_PATH="${SAE_CKPT_PATH:?Set SAE_CKPT_PATH to a trained SAE checkpoint (.pt)}" +# Optional: feature-label parquet (empty = features are unlabeled). Layer defaults to 26. +export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-}" export EMBEDDING_LAYER="${EMBEDDING_LAYER:-26}" if [[ ! -x "$VENV/bin/python" ]]; then diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py index b4e7ce0acd..375ef8d386 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -32,29 +32,36 @@ def _add_common(p: argparse.ArgumentParser) -> None: - p.add_argument( - "--evo2-ckpt-dir", - default=os.environ.get("EVO2_CKPT_DIR", "/data/interp/evo2/checkpoints/evo2_1b_base_mbridge"), - ) - p.add_argument( - "--sae-ckpt-path", - default=os.environ.get( - "SAE_CKPT_PATH", "/data/interp/evo2/sae/v2_diverse/layer19_C13_nofilter/checkpoints/checkpoint_final.pt" - ), - ) - p.add_argument( - "--feature-annotations", - default=os.environ.get( - "FEATURE_ANNOTATIONS", - "/data/interp/evo2/sae_eval/dashboard_data/l19_C13_nofilter/feature_metadata.parquet", - ), - ) - p.add_argument("--layer", type=int, default=int(os.environ.get("EMBEDDING_LAYER", "19"))) + """Register the shared inference arguments (checkpoints, layer, device) on a parser. + + Defaults come from env vars (``EVO2_CKPT_DIR``, ``SAE_CKPT_PATH``, ``FEATURE_ANNOTATIONS``, + ``EMBEDDING_LAYER``, ``DEVICE``, ``MAX_SEQ_LEN``); pass the flags to override. No hardcoded + paths — the checkpoints must be supplied via flag or env. + + Args: + p: The argparse parser (or subparser) to add the shared arguments to. + + Returns: + None. Mutates ``p`` in place. + """ + p.add_argument("--evo2-ckpt-dir", default=os.environ.get("EVO2_CKPT_DIR")) + p.add_argument("--sae-ckpt-path", default=os.environ.get("SAE_CKPT_PATH")) + p.add_argument("--feature-annotations", default=os.environ.get("FEATURE_ANNOTATIONS")) + p.add_argument("--layer", type=int, default=int(os.environ.get("EMBEDDING_LAYER", "26"))) p.add_argument("--device", default=os.environ.get("DEVICE", "cuda")) p.add_argument("--max-seq-len", type=int, default=int(os.environ.get("MAX_SEQ_LEN", "8192"))) def _engine(args): + """Construct an Evo2SAE engine from parsed CLI args. + + Args: + args: Parsed argparse namespace with ``evo2_ckpt_dir``, ``sae_ckpt_path``, ``layer``, + ``device``, ``max_seq_len``, ``feature_annotations``. + + Returns: + An (unloaded) ``Evo2SAE`` instance — call ``.load()`` before use. + """ from .core import Evo2SAE return Evo2SAE( @@ -68,6 +75,15 @@ def _engine(args): def _read_fasta(path: str): + """Read a FASTA file (plain or gzipped) into parallel id/sequence lists. + + Args: + path: Path to a FASTA file; a ``.gz`` suffix is read transparently. + + Returns: + (ids, seqs): the header names and their concatenated sequences. A header with no + token after ``>`` (e.g. ``">"`` or ``"> "``) gets a generated ``seq_`` id. + """ seqs, ids = [], [] name, parts = None, [] opener = gzip.open if str(path).endswith(".gz") else open @@ -78,7 +94,8 @@ def _read_fasta(path: str): if name is not None: seqs.append("".join(parts)) ids.append(name) - name, parts = line[1:].split()[0] if len(line) > 1 else f"seq_{len(ids)}", [] + header = line[1:].strip().split() + name, parts = (header[0] if header else f"seq_{len(ids)}"), [] else: parts.append(line) if name is not None: diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py index 1834a10cd6..bb7e3b391d 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py @@ -23,6 +23,7 @@ from __future__ import annotations import logging +import os from contextlib import asynccontextmanager from typing import Optional @@ -81,7 +82,8 @@ async def lifespan(app: FastAPI): yield app = FastAPI(title="Evo2 SAE inference", lifespan=lifespan) - app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) + allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",") # comma-separated; "*" by default (local backend) + app.add_middleware(CORSMiddleware, allow_origins=allowed_origins, allow_methods=["*"], allow_headers=["*"]) @app.get("/health") def health(): @@ -121,6 +123,8 @@ def annotate(req: AnnotateRequest): codes = engine.encode(full) # [S, n_features], lock held inside if codes.shape[0] < tag_len: tag_len = 0 + if req.mode not in ("pick", "topk"): + raise HTTPException(400, f"Invalid mode {req.mode!r}: must be 'pick' or 'topk'") if req.mode == "pick": ids = req.feature_ids or ([req.feature_id] if req.feature_id is not None else []) if not ids: diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py index a09a9553d4..82b5b0726b 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py @@ -103,3 +103,4 @@ def test_endpoints_503_until_ready(): with TestClient(build_app(eng)) as c: assert c.get("/features").status_code == 503 assert c.post("/annotate", json={"sequence": "ACGT"}).status_code == 503 + assert c.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).status_code == 503 From bb38064f4fa462b2427ce01a5064fe3758edc0cf Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 18:43:44 +0000 Subject: [PATCH 09/20] launch_inference.sh: correct usage header (EVO2_CKPT_DIR/SAE_CKPT_PATH required, no defaults) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 8d4ee16ee9..266bbc0669 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -5,8 +5,8 @@ # ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features # ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet # -# Config via env (sensible defaults below): EVO2_CKPT_DIR, SAE_CKPT_PATH, -# FEATURE_ANNOTATIONS, EMBEDDING_LAYER, DEVICE, PORT, CUDA_VISIBLE_DEVICES. +# Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): +# FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. # # Requires the evo2_megatron recipe venv (provides bionemo.evo2 + megatron). set -euo pipefail From 4a0de592fc7ab0f0ad70081b9db848a5d628405e Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 18:54:47 +0000 Subject: [PATCH 10/20] Dedupe FASTA parsing into shared evo2_sae.fasta.read_fasta MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cli.py (_read_fasta) and scripts/chunk_fasta.py (parse_fasta) had two copies of the same header/.gz/concat logic; the chunk_fasta copy still IndexError'd on a token-less header ('>' / '> '). Replace both with one streaming read_fasta() in a new stdlib-only evo2_sae/fasta.py (the fixed edge case included). Make evo2_sae/__init__.py lazy (PEP 562 __getattr__) so importing the package — and the lightweight fasta submodule from the standalone chunk_fasta script — no longer pulls in torch via .core; 'from evo2_sae import Evo2SAE' still works. Add tests/test_fasta.py (CPU): multiline, token-less header, .gz, empty file. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/chunk_fasta.py | 21 +------ .../recipes/evo2/src/evo2_sae/__init__.py | 19 ++++++- .../recipes/evo2/src/evo2_sae/cli.py | 38 ++----------- .../recipes/evo2/src/evo2_sae/fasta.py | 56 +++++++++++++++++++ .../recipes/evo2/tests/test_fasta.py | 49 ++++++++++++++++ 5 files changed, 131 insertions(+), 52 deletions(-) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py index 797034da7b..05d4333408 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -24,26 +24,9 @@ """ import argparse -import gzip from pathlib import Path - -def parse_fasta(path: Path): - """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" - opener = gzip.open if path.suffix == ".gz" else open - seq_id, parts = None, [] - with opener(path, "rt") as f: - for line in f: - line = line.rstrip() - if line.startswith(">"): - if seq_id is not None: - yield seq_id, "".join(parts) - seq_id = line[1:].split()[0] - parts = [] - else: - parts.append(line) - if seq_id is not None: - yield seq_id, "".join(parts) +from evo2_sae.fasta import read_fasta def main(): @@ -61,7 +44,7 @@ def main(): n_in = n_out = bp_out = 0 args.output.parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as out: - for seq_id, seq in parse_fasta(args.input): + for seq_id, seq in read_fasta(args.input): n_in += 1 for start in range(0, len(seq), args.window): end = min(start + args.window, len(seq)) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py index c79afb9063..ac7255d908 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py @@ -15,7 +15,24 @@ """Evo2 + SAE inference engine — reused by the live server, the batch CLI, and the viz backend.""" -from .core import DEFAULT_ORGANISM_TAGS, Evo2SAE, clean_dna +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: # for type checkers / ruff F822 — runtime access goes through __getattr__ below + from .core import DEFAULT_ORGANISM_TAGS, Evo2SAE, clean_dna __all__ = ["DEFAULT_ORGANISM_TAGS", "Evo2SAE", "clean_dna"] + + +def __getattr__(name: str): + """Lazily pull the heavy engine symbols from ``.core`` (importing ``.core`` loads torch). + + Keeps ``import evo2_sae`` (and lightweight submodules like ``evo2_sae.fasta``) cheap so + stdlib-only callers don't drag in torch, while ``from evo2_sae import Evo2SAE`` still works. + """ + if name in __all__: + from . import core + + return getattr(core, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py index 375ef8d386..b68a06b22d 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -26,7 +26,6 @@ from __future__ import annotations import argparse -import gzip import json import os @@ -74,36 +73,6 @@ def _engine(args): ) -def _read_fasta(path: str): - """Read a FASTA file (plain or gzipped) into parallel id/sequence lists. - - Args: - path: Path to a FASTA file; a ``.gz`` suffix is read transparently. - - Returns: - (ids, seqs): the header names and their concatenated sequences. A header with no - token after ``>`` (e.g. ``">"`` or ``"> "``) gets a generated ``seq_`` id. - """ - seqs, ids = [], [] - name, parts = None, [] - opener = gzip.open if str(path).endswith(".gz") else open - with opener(path, "rt") as f: - for line in f: - line = line.rstrip() - if line.startswith(">"): - if name is not None: - seqs.append("".join(parts)) - ids.append(name) - header = line[1:].strip().split() - name, parts = (header[0] if header else f"seq_{len(ids)}"), [] - else: - parts.append(line) - if name is not None: - seqs.append("".join(parts)) - ids.append(name) - return ids, seqs - - def main(): """Parse args and dispatch to the serve / encode / batch subcommand.""" ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)") @@ -156,7 +125,12 @@ def main(): elif args.cmd == "batch": import pandas as pd - ids, seqs = _read_fasta(args.fasta) + from .fasta import read_fasta + + ids, seqs = [], [] + for sid, seq in read_fasta(args.fasta): + ids.append(sid) + seqs.append(seq) print(f"[batch] {len(seqs)} sequences from {args.fasta}; encoding (batch_size={args.batch_size})…") codes_list = eng.encode_batch(seqs, batch_size=args.batch_size) rows = [] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py new file mode 100644 index 0000000000..46db93066c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared FASTA reader for the evo2 SAE recipe (stdlib-only; no torch import). + +One streaming parser reused by the batch CLI (`cli.py`) and the FASTA chunker +(`scripts/chunk_fasta.py`) so the header/`.gz`/concat logic lives in one place. +""" + +from __future__ import annotations + +import gzip +from collections.abc import Iterator +from pathlib import Path + + +def read_fasta(path: str | Path) -> Iterator[tuple[str, str]]: + """Yield ``(seq_id, sequence)`` for each record in a FASTA file. + + Args: + path: Path to a FASTA file; a ``.gz`` suffix is decompressed transparently. + + Yields: + ``(seq_id, sequence)``: the first whitespace-delimited token of the header, + or a generated ``seq_`` when the header carries no token (e.g. ``">"`` or + ``"> "``), paired with the record's concatenated sequence lines. + """ + opener = gzip.open if str(path).endswith(".gz") else open + seq_id: str | None = None + parts: list[str] = [] + n = 0 # records yielded so far — used to name token-less headers + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + n += 1 + header = line[1:].strip().split() + seq_id, parts = (header[0] if header else f"seq_{n}"), [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py new file mode 100644 index 0000000000..a7700e49ab --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the shared FASTA reader (no torch / no GPU).""" + +import gzip + +from evo2_sae.fasta import read_fasta + + +def test_basic_multiline_and_header_token(tmp_path): + """Header keeps only its first token; sequence lines are concatenated.""" + fa = tmp_path / "x.fa" + fa.write_text(">chr1 some description\nACGT\nACGT\n>chr2\nTTTT\n") + assert list(read_fasta(fa)) == [("chr1", "ACGTACGT"), ("chr2", "TTTT")] + + +def test_tokenless_header_gets_generated_id(tmp_path): + """A bare ``>`` / ``"> "`` header must not IndexError — it gets a ``seq_`` id.""" + fa = tmp_path / "x.fa" + fa.write_text(">good\nAAAA\n> \nCCCC\n>\nGGGG\n") + assert list(read_fasta(fa)) == [("good", "AAAA"), ("seq_1", "CCCC"), ("seq_2", "GGGG")] + + +def test_gzip_transparent(tmp_path): + """A ``.gz`` path is decompressed transparently.""" + fa = tmp_path / "x.fa.gz" + with gzip.open(fa, "wt") as f: + f.write(">a\nACGT\n") + assert list(read_fasta(fa)) == [("a", "ACGT")] + + +def test_empty_file(tmp_path): + """An empty file yields nothing (no trailing phantom record).""" + fa = tmp_path / "empty.fa" + fa.write_text("") + assert list(read_fasta(fa)) == [] From 13ff76d01a61cdca364ca00b7c899309a294edb6 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 20:20:08 +0000 Subject: [PATCH 11/20] evo2 infer: drop test_empty_file from test_fasta Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/tests/test_fasta.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py index a7700e49ab..fb72cbe130 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py @@ -40,10 +40,3 @@ def test_gzip_transparent(tmp_path): with gzip.open(fa, "wt") as f: f.write(">a\nACGT\n") assert list(read_fasta(fa)) == [("a", "ACGT")] - - -def test_empty_file(tmp_path): - """An empty file yields nothing (no trailing phantom record).""" - fa = tmp_path / "empty.fa" - fa.write_text("") - assert list(read_fasta(fa)) == [] From 46dae0d16d82440de76f2c2c5025a605ee3411f3 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 22:11:01 +0000 Subject: [PATCH 12/20] =?UTF-8?q?evo2=20serve:=20clamp=20with=20sae.steeri?= =?UTF-8?q?ng=20(B)=20from=20the=20start=20=E2=80=94=20no=20throwaway=20ho?= =?UTF-8?q?ok?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The engine's steered generate() now uses the faithful sae.steering delta-clamp (encode->topk->decode) directly, instead of an in-core per-feature linear inject that a downstream PR then deleted. Adds sae/steering.py (shared primitive, with decode_only for continuation-only steering) + its CPU tests; generate() registers clamp_hook(self.sae, clamps, decode_only=True). One clamp, written once, in the base. Signed-off-by: Polina Binder --- .../recipes/evo2/src/evo2_sae/core.py | 49 ++--------- .../sae/src/sae/steering.py | 83 +++++++++++++++++++ .../sae/tests/test_steering.py | 76 +++++++++++++++++ 3 files changed, 168 insertions(+), 40 deletions(-) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index ad3a5fb869..444748b591 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -294,30 +294,6 @@ def top_features(self, codes: torch.Tensor, tag_len: int = 0, k: int = 8) -> lis ] # ------------------------------------------------------------------ generate - def _clamp_hook(self, specs, pre_bias): - """Forward hook that clamps SAE features on the residual during DECODE steps only. - - A decode step processes a single new token (sequence dim == 1); the prompt prefill - (sequence dim > 1) is left untouched, giving continuation-only steering through - `infer.generate`: h <- h + Σ_f (t_f - a_f(h)) · d_f - `specs` = list of (enc_f [H], b_f float, dec_f [H], target float). - """ - - def hook(_module, _inp, output): - hs = output[0] if isinstance(output, tuple) else output # [S, B, H] - if hs.shape[0] != 1: # prefill (whole prompt) — leave untouched - return output - x = hs.float() - xc = x - pre_bias - add = torch.zeros_like(x) - for enc_f, b_f, dec_f, target in specs: - a = torch.relu(torch.matmul(xc, enc_f) + b_f) - add = add + (target - a).unsqueeze(-1) * dec_f - new = (x + add).to(hs.dtype) - return (new, *output[1:]) if isinstance(output, tuple) else new - - return hook - def generate( self, prompt="", @@ -354,23 +330,16 @@ def generate( with self._lock: comp = self._ensure_engine() hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] - pre_bias = self.sae.pre_bias.detach().float().to(self.device) - specs, feat_meta = [], [] - for f in features: - fid = int(f["feature_id"]) - specs.append( - ( - self.sae.encoder.weight[fid].detach().float().to(self.device), - float(self.sae.latent_bias[fid].detach()), - self.sae.decoder.weight[:, fid].detach().float().to(self.device), - float(f.get("strength", 1.0)), - ) - ) - feat_meta.append({"id": fid, "label": self.labels.get(fid), "strength": float(f.get("strength", 1.0))}) + from sae.steering import clamp_hook + + clamps = {int(f["feature_id"]): float(f.get("strength", 1.0)) for f in features} + feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] def _run(steer: bool) -> str: handle = ( - hook_layer.register_forward_hook(self._clamp_hook(specs, pre_bias)) if (steer and specs) else None + hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True)) + if (steer and clamps) + else None ) try: out = INF.generate( @@ -382,7 +351,7 @@ def _run(steer: bool) -> str: handle.remove() main_dna = _run(steer=True) - base_dna = _run(steer=False) if (compare_baseline and specs) else None + base_dna = _run(steer=False) if (compare_baseline and clamps) else None resp = { "prompt": dna, @@ -391,7 +360,7 @@ def _run(steer: bool) -> str: "tag_len": len(resolved_tag), "n_tokens": n_tokens, "features": feat_meta, - "steered": bool(specs), + "steered": bool(clamps), "generation": {"sequence": main_dna, "activations": self.feature_tracks(main_dna, fids)}, "baseline": None, } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py new file mode 100644 index 0000000000..d390c9066b --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. + +A forward hook on the layer the SAE was trained on: it re-encodes the layer output through +the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the +activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the +SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves +the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) +and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model +with vs. without the hook. +""" + +from contextlib import contextmanager +from typing import Dict + +import torch + + +def clamp_hook(sae, clamps: Dict[int, float], decode_only: bool = False): + """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. + + The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's + output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative + value reverses its decoder direction. Works whether the module returns a tensor or a tuple + whose first element is the hidden state. + + Args: + sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, + and ``top_k``. + clamps: Map of feature index -> absolute code value to force at every position. + decode_only: If True, steer only autoregressive *decode* steps and leave the prompt + prefill untouched (continuation-only steering). Assumes a ``(sequence, batch, hidden)`` + layout — the convention for Evo2/megatron decoder layers — and applies the clamp only + when the sequence dimension is 1 (a single new token). + + Returns: + A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. + """ + items = [(int(f), float(v)) for f, v in clamps.items()] + + def hook(module, inputs, output): + h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + if decode_only and h.shape[0] != 1: # prefill (seq dim > 1) — leave untouched + return output + dtype, shape = h.dtype, h.shape + h_flat = h.reshape(-1, h.shape[-1]).float() + with torch.no_grad(): + pre_act, info = sae.encode_pre_act(h_flat) + codes = torch.relu(pre_act) + kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) + codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) + codes_clamped = codes_orig.clone() + for f, v in items: + codes_clamped[:, f] = v + delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) + h_out = (h_flat + delta).to(dtype).reshape(shape) + return (h_out, *rest) if rest is not None else h_out + + return hook + + +@contextmanager +def steer(module, sae, clamps: Dict[int, float], decode_only: bool = False): + """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" + handle = module.register_forward_hook(clamp_hook(sae, clamps, decode_only=decode_only)) + try: + yield + finally: + handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py new file mode 100644 index 0000000000..0b28517c9d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.steering: the delta-clamp adds exactly decode(clamped) - decode(orig).""" + +import torch +from sae.architectures import TopKSAE +from sae.steering import clamp_hook, steer +from torch import nn + + +def _sae(): + torch.manual_seed(0) + return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + + +def test_delta_clamp_is_exact_and_cancels_recon(): + """No-op clamp leaves the activation unchanged (recon error cancels); a real clamp shifts + it by exactly decode(clamped) - decode(orig) — the two halves of the delta-clamp contract.""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + + # No-op: decode(orig) != x, but the added delta is 0, so the output is unchanged. + with steer(m, sae, {}): + assert torch.allclose(m(x), x, atol=1e-5) + + # Real clamp: output == x + (decode(clamped) - decode(orig)), recon error cancelled. + with torch.no_grad(): + pre, info = sae.encode_pre_act(x.float()) + codes = torch.relu(pre) + kv, ki = torch.topk(codes, sae.top_k, dim=-1) + co = torch.zeros_like(codes).scatter(-1, ki, kv) + cc = co.clone() + cc[:, 3] = 5.0 + expected = x + (sae.decode(cc, info) - sae.decode(co, info)) + with steer(m, sae, {3: 5.0}): + assert torch.allclose(m(x), expected, atol=1e-4) + + +def test_tuple_output_steers_only_hidden_state(): + """When the hooked module returns a tuple, only element 0 is steered; the rest passes through.""" + + class M(nn.Module): + def forward(self, x): + return (x, "meta") + + sae, x = _sae(), torch.randn(3, 8) + m = M() + handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) + out = m(x) + handle.remove() + assert isinstance(out, tuple) and out[1] == "meta" + assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it + + +def test_decode_only_skips_prefill(): + """decode_only steers single-token decode steps ([1,B,H]) but leaves multi-token prefill alone.""" + sae, m = _sae(), nn.Identity() + prefill = torch.randn(5, 2, 8) # [S=5, B, H] — prompt prefill, must pass through + decode = torch.randn(1, 2, 8) # [S=1, B, H] — a single new token, must be steered + handle = m.register_forward_hook(clamp_hook(sae, {3: 5.0}, decode_only=True)) + out_prefill, out_decode = m(prefill), m(decode) + handle.remove() + assert torch.allclose(out_prefill, prefill, atol=1e-5) # prefill untouched + assert not torch.allclose(out_decode, decode) # decode step steered From 89168b46df4a91ea8127070f222fab3580af06fd Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 22:29:54 +0000 Subject: [PATCH 13/20] evo2 serve: fold in the rest of steering (generate CLI + steer.py harness) #1622 now holds the full steering surface end-to-end: engine + generate (clamp B via sae.steering) + /generate + the generate CLI (--clamp) + steer.py dose-response harness. One PR for all of inspect + generate + steer; supersedes #1634. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 6 +- .../recipes/evo2/scripts/steer.py | 121 ++++++++++++++++++ .../recipes/evo2/src/evo2_sae/cli.py | 57 ++++++++- .../recipes/evo2/tests/test_cli.py | 33 +++++ 4 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 266bbc0669..26768a4c46 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Launch the Evo2 SAE inference engine. One engine, three modes: +# Launch the Evo2 SAE inference engine. One engine, four modes: # # ./launch_inference.sh serve # live HTTP server on :8001 (viz backend) # ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features # ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet +# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA +# +# Steering loop: `encode` a sequence to find an active feature id, then +# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp). # # Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): # FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py new file mode 100644 index 0000000000..8d54dee633 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation. + +Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer +the SAE was trained on. Workflow: encode a sequence to find its active features, then for a +**target** feature sweep the clamp strength (dose-response) and for **control** features apply +the same clamp (selectivity), each time comparing the steered continuation to the baseline. + +GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test. + + python steer.py --evo2-ckpt-dir --sae-checkpoint --layer 26 \ + --sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200 + +Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers +the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in +``evo2_sae.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a +``decode_only`` flag) is a planned follow-up. +""" + +from __future__ import annotations + +import argparse +import sys +from contextlib import nullcontext +from pathlib import Path + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent / "src")) # recipes/evo2/src -> evo2_sae package +sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) + +from sae.steering import steer # noqa: E402 + + +def _divergence(a: str, b: str): + """Return (first differing index, fraction of differing chars) over the shared prefix length.""" + n = min(len(a), len(b)) + first = next((i for i in range(n) if a[i] != b[i]), n) + diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n) + return first, diff + + +def main(): + """Encode a sequence, then steer a target feature (dose-response) + control features (selectivity).""" + p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).") + p.add_argument("--evo2-ckpt-dir", required=True) + p.add_argument("--sae-checkpoint", required=True) + p.add_argument("--layer", type=int, required=True) + p.add_argument("--sequence", required=True) + p.add_argument("--organism", default="None (raw DNA)") + p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).") + p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).") + p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.") + p.add_argument("--n-tokens", type=int, default=60) + p.add_argument("--device", default="cuda") + a = p.parse_args() + + from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100 + from evo2_sae.core import Evo2SAE, clean_dna # noqa: E402, RUF100 + from megatron.core.utils import unwrap_model # noqa: E402, RUF100 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + + # 1. Encode -> the sequence's most-active features (pick a target if not given). + codes = eng.encode(a.sequence) + vals, ids = codes.max(0).values.topk(10) + print(f"top features on {a.sequence[:24]}...:") + target = a.feature + for v, i in zip(vals.tolist(), ids.tolist()): + lab = eng.labels.get(int(i)) + print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}") + if target is None and lab: + target = int(i) + controls = [int(c) for c in a.controls.split(",") if c.strip()] + strengths = [float(s) for s in a.strengths.split(",")] + + # 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt. + comp = eng._ensure_engine() + prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence) + layer_mod = unwrap_model(comp.model).decoder.layers[a.layer] + + def gen(clamps): + ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext() + with ctx: + out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + + base = gen({}) + print(f"\nbaseline: {base[:60]}") + print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===") + for s in strengths: + steered = gen({target: s}) + first, diff = _divergence(base, steered) + print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}") + + if controls: + s = strengths[-1] + print(f"\n=== selectivity: control features clamped to {s} ===") + for c in controls: + steered = gen({c: s}) + first, diff = _divergence(base, steered) + print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py index b68a06b22d..98185dd51a 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Evo2 SAE inference CLI — one engine, three modes. +"""Evo2 SAE inference CLI — one engine, four modes. serve : start the FastAPI server (one sequence at a time, interactive) encode : annotate ONE sequence -> top features (stdout JSON) batch : run a FASTA of MANY sequences -> parquet of per-sequence top features + generate: generate DNA, optionally steering SAE features (stdout JSON) -All three build the same `Evo2SAE` engine; config comes from flags or env +They all build the same `Evo2SAE` engine; config comes from flags or env (EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER). """ @@ -73,9 +74,21 @@ def _engine(args): ) +def _parse_clamps(clamps: list[str]) -> list[dict]: + """Parse repeated ``--clamp FEATURE_ID[:STRENGTH]`` args into [{feature_id, strength}]. + + Strength defaults to 1.0 if omitted (e.g. ``--clamp 29244:300`` or ``--clamp 29244``). + """ + specs = [] + for c in clamps: + fid, sep, strength = c.partition(":") + specs.append({"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0}) + return specs + + def main(): """Parse args and dispatch to the serve / encode / batch subcommand.""" - ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)") + ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)") sub = ap.add_subparsers(dest="cmd", required=True) ps = sub.add_parser("serve", help="start the FastAPI inference server") @@ -96,6 +109,23 @@ def main(): pb.add_argument("--top-k", type=int, default=16) pb.add_argument("--batch-size", type=int, default=8) + pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features") + _add_common(pg) + pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation") + pg.add_argument("--organism", default="None (raw DNA)") + pg.add_argument( + "--clamp", + action="append", + default=[], + metavar="FEATURE_ID[:STRENGTH]", + help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). " + "Find feature ids with `encode`.", + ) + pg.add_argument("--n-tokens", type=int, default=120) + pg.add_argument("--temperature", type=float, default=1.0) + pg.add_argument("--top-k", type=int, default=0) + pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison") + args = ap.parse_args() if args.cmd == "serve": @@ -141,6 +171,27 @@ def main(): df.to_parquet(args.out, index=False) print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}") + elif args.cmd == "generate": + out = eng.generate( + prompt=args.prompt, + organism=args.organism, + features=_parse_clamps(args.clamp), + n_tokens=args.n_tokens, + temperature=args.temperature, + top_k=args.top_k, + compare_baseline=args.compare_baseline, + ) + result = { + "prompt": out["prompt"], + "organism": out["organism"], + "steered": out["steered"], + "features": out["features"], + "sequence": out["generation"]["sequence"], + } + if out.get("baseline"): + result["baseline_sequence"] = out["baseline"]["sequence"] + print(json.dumps(result, indent=2)) + if __name__ == "__main__": main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py new file mode 100644 index 0000000000..e3381af409 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU test for the generate CLI's --clamp parsing (no model).""" + +from evo2_sae.cli import _parse_clamps + + +def test_parse_clamps_id_and_strength(): + assert _parse_clamps(["29244:300", "88:1.5"]) == [ + {"feature_id": 29244, "strength": 300.0}, + {"feature_id": 88, "strength": 1.5}, + ] + + +def test_parse_clamps_default_strength(): + assert _parse_clamps(["29244"]) == [{"feature_id": 29244, "strength": 1.0}] + + +def test_parse_clamps_empty(): + assert _parse_clamps([]) == [] From b37d33431b45b83f72874f2d231b2c85d08e6d22 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 22:36:30 +0000 Subject: [PATCH 14/20] evo2 serve: consolidate steer.py onto Evo2SAE.generate; drop fasta/cli tests - steer.py: the dose-response/selectivity harness now calls eng.generate(features=...) instead of re-deriving _ensure_engine + unwrap_model + steer() ctx + INF.generate + _unwrap_result. ~25 fewer lines and the harness now measures the exact production decode-only sae.steering path (no separate every-forward variant). - remove test_fasta.py and test_cli.py (low-value tests of a trivial reader / arg parsing). Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/steer.py | 37 ++++++---------- .../recipes/evo2/tests/test_cli.py | 33 --------------- .../recipes/evo2/tests/test_fasta.py | 42 ------------------- 3 files changed, 13 insertions(+), 99 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py index 8d54dee633..d0af6c9615 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py @@ -15,27 +15,22 @@ r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation. -Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer -the SAE was trained on. Workflow: encode a sequence to find its active features, then for a -**target** feature sweep the clamp strength (dose-response) and for **control** features apply -the same clamp (selectivity), each time comparing the steered continuation to the baseline. +Reuses ``Evo2SAE.generate`` — the same production path the server/CLI use — which clamps via +the shared decode-only ``sae.steering`` hook. Workflow: encode a sequence to find its active +features, then for a **target** feature sweep the clamp strength (dose-response) and for +**control** features apply the same clamp (selectivity), comparing each steered continuation to +the baseline. So the harness measures exactly the steering the product ships. GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test. python steer.py --evo2-ckpt-dir --sae-checkpoint --layer 26 \ --sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200 - -Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers -the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in -``evo2_sae.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a -``decode_only`` flag) is a planned follow-up. """ from __future__ import annotations import argparse import sys -from contextlib import nullcontext from pathlib import Path @@ -44,8 +39,6 @@ sys.path.insert(0, str(_HERE.parent / "src")) # recipes/evo2/src -> evo2_sae package sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) -from sae.steering import steer # noqa: E402 - def _divergence(a: str, b: str): """Return (first differing index, fraction of differing chars) over the shared prefix length.""" @@ -70,9 +63,7 @@ def main(): p.add_argument("--device", default="cuda") a = p.parse_args() - from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100 - from evo2_sae.core import Evo2SAE, clean_dna # noqa: E402, RUF100 - from megatron.core.utils import unwrap_model # noqa: E402, RUF100 + from evo2_sae.core import Evo2SAE # noqa: E402, RUF100 eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() @@ -89,16 +80,14 @@ def main(): controls = [int(c) for c in a.controls.split(",") if c.strip()] strengths = [float(s) for s in a.strengths.split(",")] - # 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt. - comp = eng._ensure_engine() - prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence) - layer_mod = unwrap_model(comp.model).decoder.layers[a.layer] - + # 2. Steered generation reuses the production path: Evo2SAE.generate clamps the same + # decode-only sae.steering hook the server/CLI use, so the harness measures the real thing. def gen(clamps): - ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext() - with ctx: - out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1) - return clean_dna(INF._unwrap_result(out[0]).generated_text) + feats = [{"feature_id": f, "strength": v} for f, v in clamps.items()] + out = eng.generate( + prompt=a.sequence, organism=a.organism, features=feats, n_tokens=a.n_tokens, temperature=0.0, top_k=1 + ) + return out["generation"]["sequence"] base = gen({}) print(f"\nbaseline: {base[:60]}") diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py deleted file mode 100644 index e3381af409..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CPU test for the generate CLI's --clamp parsing (no model).""" - -from evo2_sae.cli import _parse_clamps - - -def test_parse_clamps_id_and_strength(): - assert _parse_clamps(["29244:300", "88:1.5"]) == [ - {"feature_id": 29244, "strength": 300.0}, - {"feature_id": 88, "strength": 1.5}, - ] - - -def test_parse_clamps_default_strength(): - assert _parse_clamps(["29244"]) == [{"feature_id": 29244, "strength": 1.0}] - - -def test_parse_clamps_empty(): - assert _parse_clamps([]) == [] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py deleted file mode 100644 index fb72cbe130..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_fasta.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CPU unit tests for the shared FASTA reader (no torch / no GPU).""" - -import gzip - -from evo2_sae.fasta import read_fasta - - -def test_basic_multiline_and_header_token(tmp_path): - """Header keeps only its first token; sequence lines are concatenated.""" - fa = tmp_path / "x.fa" - fa.write_text(">chr1 some description\nACGT\nACGT\n>chr2\nTTTT\n") - assert list(read_fasta(fa)) == [("chr1", "ACGTACGT"), ("chr2", "TTTT")] - - -def test_tokenless_header_gets_generated_id(tmp_path): - """A bare ``>`` / ``"> "`` header must not IndexError — it gets a ``seq_`` id.""" - fa = tmp_path / "x.fa" - fa.write_text(">good\nAAAA\n> \nCCCC\n>\nGGGG\n") - assert list(read_fasta(fa)) == [("good", "AAAA"), ("seq_1", "CCCC"), ("seq_2", "GGGG")] - - -def test_gzip_transparent(tmp_path): - """A ``.gz`` path is decompressed transparently.""" - fa = tmp_path / "x.fa.gz" - with gzip.open(fa, "wt") as f: - f.write(">a\nACGT\n") - assert list(read_fasta(fa)) == [("a", "ACGT")] From d1f888d285b58a378c222a44e2c95c461cd357b6 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 03:33:22 +0000 Subject: [PATCH 15/20] evo2 serve: move steer.py to its own steering-analysis PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #1622 keeps the steering *capability* (sae.steering clamp + core.generate + /generate + CLI); the dose-response/selectivity *harness* moves to a dedicated steering-analysis PR stacked on this one, alongside tested effect-metrics. The dashboard (on #1622) is unaffected — it uses /generate, not steer.py. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/steer.py | 110 ------------------ 1 file changed, 110 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py deleted file mode 100644 index d0af6c9615..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation. - -Reuses ``Evo2SAE.generate`` — the same production path the server/CLI use — which clamps via -the shared decode-only ``sae.steering`` hook. Workflow: encode a sequence to find its active -features, then for a **target** feature sweep the clamp strength (dose-response) and for -**control** features apply the same clamp (selectivity), comparing each steered continuation to -the baseline. So the harness measures exactly the steering the product ships. - -GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test. - - python steer.py --evo2-ckpt-dir --sae-checkpoint --layer 26 \ - --sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200 -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - - -_HERE = Path(__file__).resolve().parent -sys.path.insert(0, str(_HERE)) -sys.path.insert(0, str(_HERE.parent / "src")) # recipes/evo2/src -> evo2_sae package -sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) - - -def _divergence(a: str, b: str): - """Return (first differing index, fraction of differing chars) over the shared prefix length.""" - n = min(len(a), len(b)) - first = next((i for i in range(n) if a[i] != b[i]), n) - diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n) - return first, diff - - -def main(): - """Encode a sequence, then steer a target feature (dose-response) + control features (selectivity).""" - p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).") - p.add_argument("--evo2-ckpt-dir", required=True) - p.add_argument("--sae-checkpoint", required=True) - p.add_argument("--layer", type=int, required=True) - p.add_argument("--sequence", required=True) - p.add_argument("--organism", default="None (raw DNA)") - p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).") - p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).") - p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.") - p.add_argument("--n-tokens", type=int, default=60) - p.add_argument("--device", default="cuda") - a = p.parse_args() - - from evo2_sae.core import Evo2SAE # noqa: E402, RUF100 - - eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - - # 1. Encode -> the sequence's most-active features (pick a target if not given). - codes = eng.encode(a.sequence) - vals, ids = codes.max(0).values.topk(10) - print(f"top features on {a.sequence[:24]}...:") - target = a.feature - for v, i in zip(vals.tolist(), ids.tolist()): - lab = eng.labels.get(int(i)) - print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}") - if target is None and lab: - target = int(i) - controls = [int(c) for c in a.controls.split(",") if c.strip()] - strengths = [float(s) for s in a.strengths.split(",")] - - # 2. Steered generation reuses the production path: Evo2SAE.generate clamps the same - # decode-only sae.steering hook the server/CLI use, so the harness measures the real thing. - def gen(clamps): - feats = [{"feature_id": f, "strength": v} for f, v in clamps.items()] - out = eng.generate( - prompt=a.sequence, organism=a.organism, features=feats, n_tokens=a.n_tokens, temperature=0.0, top_k=1 - ) - return out["generation"]["sequence"] - - base = gen({}) - print(f"\nbaseline: {base[:60]}") - print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===") - for s in strengths: - steered = gen({target: s}) - first, diff = _divergence(base, steered) - print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}") - - if controls: - s = strengths[-1] - print(f"\n=== selectivity: control features clamped to {s} ===") - for c in controls: - steered = gen({c: s}) - first, diff = _divergence(base, steered) - print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed") - - -if __name__ == "__main__": - main() From 1489955ffa213f88547c42caf67bf0fa37f4a437 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 03:46:50 +0000 Subject: [PATCH 16/20] evo2 serve: split server/CLI out -> inference engine only This PR is now the importable engine: core.py (Evo2SAE) + sae.steering clamp + fasta + tests. The FastAPI server, the serve/encode/batch/generate CLI, and launch_inference.sh move to a server PR stacked here. Engine is the shared dep (steer.py, eval, dashboard all build on core); the server is only for the dashboard. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 37 ---- .../recipes/evo2/src/evo2_sae/cli.py | 197 ------------------ .../recipes/evo2/src/evo2_sae/server.py | 178 ---------------- .../recipes/evo2/tests/test_server.py | 106 ---------- 4 files changed, 518 deletions(-) delete mode 100755 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh deleted file mode 100755 index 26768a4c46..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# Launch the Evo2 SAE inference engine. One engine, four modes: -# -# ./launch_inference.sh serve # live HTTP server on :8001 (viz backend) -# ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features -# ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet -# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA -# -# Steering loop: `encode` a sequence to find an active feature id, then -# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp). -# -# Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): -# FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. -# -# Requires the evo2_megatron recipe venv (provides bionemo.evo2 + megatron). -set -euo pipefail - -HERE="$(cd "$(dirname "$0")" && pwd)" -RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports - -# Required (no hardcoded defaults — supply your own paths via env): -VENV="${VENV:?Set VENV to the evo2_megatron recipe .venv (provides bionemo.evo2 + megatron)}" -export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:?Set EVO2_CKPT_DIR to an Evo2 MBridge checkpoint directory}" -export SAE_CKPT_PATH="${SAE_CKPT_PATH:?Set SAE_CKPT_PATH to a trained SAE checkpoint (.pt)}" -# Optional: feature-label parquet (empty = features are unlabeled). Layer defaults to 26. -export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-}" -export EMBEDDING_LAYER="${EMBEDDING_LAYER:-26}" - -if [[ ! -x "$VENV/bin/python" ]]; then - echo "ERROR: evo2_megatron venv not found at $VENV (build it with the recipe's .ci_build.sh)" >&2 - exit 1 -fi - -source "$VENV/bin/activate" -cd "$RECIPE_DIR" -export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" -exec python -m evo2_sae.cli "$@" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py deleted file mode 100644 index 98185dd51a..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ /dev/null @@ -1,197 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Evo2 SAE inference CLI — one engine, four modes. - - serve : start the FastAPI server (one sequence at a time, interactive) - encode : annotate ONE sequence -> top features (stdout JSON) - batch : run a FASTA of MANY sequences -> parquet of per-sequence top features - generate: generate DNA, optionally steering SAE features (stdout JSON) - -They all build the same `Evo2SAE` engine; config comes from flags or env -(EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER). -""" - -from __future__ import annotations - -import argparse -import json -import os - - -def _add_common(p: argparse.ArgumentParser) -> None: - """Register the shared inference arguments (checkpoints, layer, device) on a parser. - - Defaults come from env vars (``EVO2_CKPT_DIR``, ``SAE_CKPT_PATH``, ``FEATURE_ANNOTATIONS``, - ``EMBEDDING_LAYER``, ``DEVICE``, ``MAX_SEQ_LEN``); pass the flags to override. No hardcoded - paths — the checkpoints must be supplied via flag or env. - - Args: - p: The argparse parser (or subparser) to add the shared arguments to. - - Returns: - None. Mutates ``p`` in place. - """ - p.add_argument("--evo2-ckpt-dir", default=os.environ.get("EVO2_CKPT_DIR")) - p.add_argument("--sae-ckpt-path", default=os.environ.get("SAE_CKPT_PATH")) - p.add_argument("--feature-annotations", default=os.environ.get("FEATURE_ANNOTATIONS")) - p.add_argument("--layer", type=int, default=int(os.environ.get("EMBEDDING_LAYER", "26"))) - p.add_argument("--device", default=os.environ.get("DEVICE", "cuda")) - p.add_argument("--max-seq-len", type=int, default=int(os.environ.get("MAX_SEQ_LEN", "8192"))) - - -def _engine(args): - """Construct an Evo2SAE engine from parsed CLI args. - - Args: - args: Parsed argparse namespace with ``evo2_ckpt_dir``, ``sae_ckpt_path``, ``layer``, - ``device``, ``max_seq_len``, ``feature_annotations``. - - Returns: - An (unloaded) ``Evo2SAE`` instance — call ``.load()`` before use. - """ - from .core import Evo2SAE - - return Evo2SAE( - evo2_ckpt_dir=args.evo2_ckpt_dir, - sae_ckpt_path=args.sae_ckpt_path, - layer=args.layer, - device=args.device, - max_seq_len=args.max_seq_len, - feature_annotations=args.feature_annotations, - ) - - -def _parse_clamps(clamps: list[str]) -> list[dict]: - """Parse repeated ``--clamp FEATURE_ID[:STRENGTH]`` args into [{feature_id, strength}]. - - Strength defaults to 1.0 if omitted (e.g. ``--clamp 29244:300`` or ``--clamp 29244``). - """ - specs = [] - for c in clamps: - fid, sep, strength = c.partition(":") - specs.append({"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0}) - return specs - - -def main(): - """Parse args and dispatch to the serve / encode / batch subcommand.""" - ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)") - sub = ap.add_subparsers(dest="cmd", required=True) - - ps = sub.add_parser("serve", help="start the FastAPI inference server") - _add_common(ps) - ps.add_argument("--host", default="0.0.0.0") - ps.add_argument("--port", type=int, default=int(os.environ.get("PORT", "8001"))) - - pe = sub.add_parser("encode", help="annotate ONE sequence -> top features (JSON)") - _add_common(pe) - pe.add_argument("--sequence", required=True) - pe.add_argument("--organism", default="None (raw DNA)") - pe.add_argument("--top-k", type=int, default=8) - - pb = sub.add_parser("batch", help="MANY sequences (FASTA) -> parquet of per-sequence top features") - _add_common(pb) - pb.add_argument("--fasta", required=True) - pb.add_argument("--out", required=True) - pb.add_argument("--top-k", type=int, default=16) - pb.add_argument("--batch-size", type=int, default=8) - - pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features") - _add_common(pg) - pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation") - pg.add_argument("--organism", default="None (raw DNA)") - pg.add_argument( - "--clamp", - action="append", - default=[], - metavar="FEATURE_ID[:STRENGTH]", - help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). " - "Find feature ids with `encode`.", - ) - pg.add_argument("--n-tokens", type=int, default=120) - pg.add_argument("--temperature", type=float, default=1.0) - pg.add_argument("--top-k", type=int, default=0) - pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison") - - args = ap.parse_args() - - if args.cmd == "serve": - import uvicorn - - from .server import build_app - - uvicorn.run(build_app(_engine(args)), host=args.host, port=args.port, log_level="info") - return - - from .core import clean_dna - - eng = _engine(args).load() - - if args.cmd == "encode": - tag = eng.resolve_tag(args.organism, None) or "" - dna = clean_dna(args.sequence) - codes = eng.encode(tag + dna) - tag_len = len(tag) if codes.shape[0] >= len(tag) else 0 - feats = eng.top_features(codes, tag_len=tag_len, k=args.top_k) - print( - json.dumps( - {"sequence": dna, "organism": args.organism, "bases": len(dna), "top_features": feats}, indent=2 - ) - ) - - elif args.cmd == "batch": - import pandas as pd - - from .fasta import read_fasta - - ids, seqs = [], [] - for sid, seq in read_fasta(args.fasta): - ids.append(sid) - seqs.append(seq) - print(f"[batch] {len(seqs)} sequences from {args.fasta}; encoding (batch_size={args.batch_size})…") - codes_list = eng.encode_batch(seqs, batch_size=args.batch_size) - rows = [] - for sid, codes in zip(ids, codes_list): - for rank, ft in enumerate(eng.top_features(codes, k=args.top_k)): - rows.append({"sequence_id": sid, "bp": int(codes.shape[0]), "rank": rank, **ft}) - df = pd.DataFrame(rows) - df.to_parquet(args.out, index=False) - print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}") - - elif args.cmd == "generate": - out = eng.generate( - prompt=args.prompt, - organism=args.organism, - features=_parse_clamps(args.clamp), - n_tokens=args.n_tokens, - temperature=args.temperature, - top_k=args.top_k, - compare_baseline=args.compare_baseline, - ) - result = { - "prompt": out["prompt"], - "organism": out["organism"], - "steered": out["steered"], - "features": out["features"], - "sequence": out["generation"]["sequence"], - } - if out.get("baseline"): - result["baseline_sequence"] = out["baseline"]["sequence"] - print(json.dumps(result, indent=2)) - - -if __name__ == "__main__": - main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py deleted file mode 100644 index bb7e3b391d..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""FastAPI server over the Evo2SAE engine — the live backend the viz talks to. - -Endpoints: /health, /features, /annotate (per-base activations for a pasted -sequence), /generate (autoregressive generation + optional SAE-feature clamp). -This is a thin layer; all model work lives in `core.Evo2SAE`. -""" - -from __future__ import annotations - -import logging -import os -from contextlib import asynccontextmanager -from typing import Optional - -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel - -from .core import Evo2SAE, clean_dna - - -logger = logging.getLogger("evo2_sae_infer.server") - - -class AnnotateRequest(BaseModel): - """Request body for /annotate (top-k feature scan or an explicit feature pick).""" - - sequence: str - organism: str = "None (raw DNA)" - tag: Optional[str] = None - mode: str = "topk" # "topk" | "pick" - k: int = 8 - feature_ids: Optional[list[int]] = None - feature_id: Optional[int] = None - - -class FeatureClamp(BaseModel): - """A single SAE-feature steering clamp (feature id + target strength).""" - - feature_id: int - strength: float = 1.0 - - -class GenerateRequest(BaseModel): - """Request body for /generate (autoregressive generation + optional SAE-feature clamps).""" - - prompt: str = "" - organism: str = "None (raw DNA)" - tag: Optional[str] = None - features: list[FeatureClamp] = [] - n_tokens: int = 120 - temperature: float = 1.0 - top_k: int = 0 - compare_baseline: bool = False - - -def build_app(engine: Evo2SAE) -> FastAPI: - """Build the FastAPI app; the engine is loaded once in the lifespan handler.""" - - @asynccontextmanager - async def lifespan(app: FastAPI): - try: - engine.load() - logger.info("engine ready") - except Exception: - logger.exception("engine startup failed — /health stays not-ready") - yield - - app = FastAPI(title="Evo2 SAE inference", lifespan=lifespan) - allowed_origins = os.getenv("CORS_ORIGINS", "*").split(",") # comma-separated; "*" by default (local backend) - app.add_middleware(CORSMiddleware, allow_origins=allowed_origins, allow_methods=["*"], allow_headers=["*"]) - - @app.get("/health") - def health(): - return { - "ready": bool(engine.ready), - "layer": engine.layer, - "n_features": engine.n_features, - "n_labels": len(engine.labels), - "sae_path": engine.sae_ckpt_path, - "organisms": list(engine.organism_tags.keys()), - "organism_tags": engine.organism_tags, - "device": engine.device, - } - - @app.get("/features") - def features(): - if not engine.ready: - raise HTTPException(503, "Backend not ready") - rows = [ - {"id": int(f), "label": lab, "natural_peak": engine.peaks.get(int(f))} for f, lab in engine.labels.items() - ] - rows.sort(key=lambda r: r["id"]) - return rows - - @app.post("/annotate") - def annotate(req: AnnotateRequest): - if not engine.ready: - raise HTTPException(503, "Backend not ready") - dna = clean_dna(req.sequence) - if not dna: - raise HTTPException(400, "No valid nucleotides in sequence") - tag = engine.resolve_tag(req.organism, req.tag) - if tag is None: - raise HTTPException(400, f"Unknown organism '{req.organism}' and no custom tag") - full = tag + dna - tag_len = len(tag) - codes = engine.encode(full) # [S, n_features], lock held inside - if codes.shape[0] < tag_len: - tag_len = 0 - if req.mode not in ("pick", "topk"): - raise HTTPException(400, f"Invalid mode {req.mode!r}: must be 'pick' or 'topk'") - if req.mode == "pick": - ids = req.feature_ids or ([req.feature_id] if req.feature_id is not None else []) - if not ids: - raise HTTPException(400, "mode='pick' requires feature_ids") - chosen = [int(i) for i in ids] - else: - k = max(1, min(int(req.k), 64)) - chosen = [ft["feature_id"] for ft in engine.top_features(codes, tag_len=tag_len, k=k)] - feats = [] - for fid in chosen: - col = codes[:, fid] - feats.append( - { - "feature_id": fid, - "label": engine.labels.get(fid), - "max_activation": float(col[tag_len:].max().item()) - if codes.shape[0] > tag_len - else float(col.max().item()), - "activations": [round(float(v), 4) for v in col.tolist()], - } - ) - return { - "sequence": dna, - "organism": req.organism, - "tag": tag, - "tag_len": tag_len, - "bases": list(full), - "n_tokens": codes.shape[0], - "layer": engine.layer, - "features": feats, - } - - @app.post("/generate") - def generate(req: GenerateRequest): - if not engine.ready: - raise HTTPException(503, "Backend not ready") - try: - return engine.generate( - prompt=req.prompt, - organism=req.organism, - tag=req.tag, - features=[f.model_dump() for f in req.features], - n_tokens=req.n_tokens, - temperature=req.temperature, - top_k=req.top_k, - compare_baseline=req.compare_baseline, - ) - except ValueError as e: - raise HTTPException(400, str(e)) - - return app diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py deleted file mode 100644 index 82b5b0726b..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Server contract tests — the API the feature-explorer viz consumes. - -A mocked engine (no model, CPU-only) drives the FastAPI app so these run in CI and lock the -response shapes + error codes the dashboard depends on: /health, /features, /annotate (per-base -activations), /generate. Real model inference is covered by test_steering.py. -""" - -import pytest -import torch -from evo2_sae.server import build_app -from fastapi.testclient import TestClient - - -class FakeEngine: - """Minimal stand-in for Evo2SAE exposing only what the server endpoints touch.""" - - def __init__(self): - self.ready = True - self.layer = 19 - self.n_features = 4 - self.labels = {0: "feat0", 1: "feat1"} - self.peaks = {0: 0.5} - self.organism_tags = {"None (raw DNA)": "", "Human": "|tag|"} - self.device = "cpu" - self.sae_ckpt_path = "fake.pt" - - def load(self): - self.ready = True - - def resolve_tag(self, organism, tag): - return tag if tag is not None else self.organism_tags.get(organism) - - def encode(self, full): - codes = torch.zeros(len(full), self.n_features) - codes[:, 0] = 1.0 # feature 0 fires everywhere - return codes - - def top_features(self, codes, tag_len=0, k=8): - return [{"feature_id": 0, "label": self.labels.get(0), "max_activation": 1.0}] - - def generate(self, **kw): - if not kw.get("prompt") and kw.get("organism") == "None (raw DNA)" and not kw.get("tag"): - raise ValueError("need a seed") - return { - "generation": {"sequence": "ACGT", "activations": {0: [1.0, 1.0, 1.0, 1.0]}}, - "baseline": None, - "features": [], - "steered": False, - } - - -@pytest.fixture -def client(): - with TestClient(build_app(FakeEngine())) as c: - yield c - - -def test_health(client): - b = client.get("/health").json() - assert b["ready"] is True and b["layer"] == 19 - assert "None (raw DNA)" in b["organisms"] - - -def test_features(client): - rows = client.get("/features").json() - assert {"id", "label", "natural_peak"} <= set(rows[0]) - - -def test_annotate_returns_per_base_activations(client): - b = client.post("/annotate", json={"sequence": "ACGTACGT", "organism": "None (raw DNA)"}).json() - assert {"sequence", "features", "bases", "tag_len", "layer", "n_tokens"} <= set(b) - assert b["features"][0]["activations"] # the per-base track the viz plots - - -def test_annotate_rejects_non_dna(client): - assert client.post("/annotate", json={"sequence": "ZZZZ"}).status_code == 400 - - -def test_generate_returns_sequence(client): - b = client.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).json() - assert b["generation"]["sequence"] - - -def test_endpoints_503_until_ready(): - eng = FakeEngine() - eng.ready = False - eng.load = lambda: None # startup leaves it not-ready - with TestClient(build_app(eng)) as c: - assert c.get("/features").status_code == 503 - assert c.post("/annotate", json={"sequence": "ACGT"}).status_code == 503 - assert c.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).status_code == 503 From f31028973dcd3def09d26ff02ea32a5609518def Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 03:53:10 +0000 Subject: [PATCH 17/20] =?UTF-8?q?evo2=20serve:=20steering=20safety=20?= =?UTF-8?q?=E2=80=94=20fix=20gen=20double-init=20+=20guard=20bad=20clamps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _ensure_engine: destroy_num_microbatches_calculator() before setup_inference_engine (load_model_to_layer already initialized it; the re-init asserted -> /generate 500). - generate(): reject feature ids outside [0, n_features) (server -> 400) and cap clamp magnitude to MAX_CLAMP_STRENGTH. Both previously triggered CUDA device-side asserts (out-of-range index; NaN logits under sampling) that corrupt the context and wedge the server until restart. Verified on 7B/L26: bad id -> 400, huge strength -> capped 200, no wedge. Co-Authored-By: Claude Opus 4.8 --- .../recipes/evo2/src/evo2_sae/core.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 444748b591..00035d9b7f 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -60,6 +60,11 @@ _VALID_BASES = re.compile(r"[^ACGTN]") +# Steering clamp targets are absolute SAE-code values; this SAE's features peak ~100-300. +# Cap the magnitude so an extreme target can't blow the logits to NaN (which device-asserts +# and wedges the process). Generous headroom for amplification, bounded against runaway. +MAX_CLAMP_STRENGTH = float(os.environ.get("MAX_CLAMP_STRENGTH", "300")) + # Phylogenetic-tag prefixes per organism (Evo2 was trained with lineage tags). DEFAULT_ORGANISM_TAGS = { "None (raw DNA)": "", @@ -137,6 +142,16 @@ def _ensure_engine(self): if self.gen_components is None: from bionemo.evo2.run import infer as INF + # load_model_to_layer (the encode model) already initialized the global + # num-microbatches calculator; setup_inference_engine re-inits it and asserts + # unless we tear the singleton down first. + try: + from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator + + destroy_num_microbatches_calculator() + except Exception: + pass + self.gen_components = INF.setup_inference_engine( Path(self.evo2_ckpt_dir), max_seq_length=self.max_seq_len, cuda_graph_impl="none" ) @@ -206,12 +221,7 @@ def resolve_tag(self, organism: str, tag: Optional[str]) -> Optional[str]: @torch.no_grad() def encode(self, dna: str) -> torch.Tensor: """ONE sequence -> SAE codes [seq_len, n_features] on CPU. No phylo tag.""" - ids = self.tokenize(dna) - if not ids: - return torch.empty(0, self.n_features) - with self._lock: - hidden = self._forward_hidden([ids])[0] # [S, H] - return self.sae.encode(hidden.to(self.device)).detach().cpu() + return self.encode_batch([dna])[0] @torch.no_grad() def encode_batch(self, seqs: list[str], batch_size: int = 8) -> list[torch.Tensor]: @@ -325,14 +335,24 @@ def generate( if not full_prompt: raise ValueError("Provide a prompt or pick an organism (need >=1 token to seed)") n_tokens = max(1, min(int(n_tokens), 400)) - fids = [int(f["feature_id"]) for f in features] + # Guard rails — a bad clamp can wedge the whole process (a CUDA device-side assert + # corrupts the context, so every later request 500s until restart): + # * a feature id outside [0, n_features) indexes off the SAE codes -> assert; + # * an extreme target overflows the logits -> NaN under sampling -> assert. + # Reject out-of-range ids (the server maps ValueError -> 400) and cap the magnitude. + bad = sorted({int(f["feature_id"]) for f in features if not (0 <= int(f["feature_id"]) < self.n_features)}) + if bad: + raise ValueError(f"feature_id(s) {bad} out of range [0, {self.n_features})") + clamps = { + int(f["feature_id"]): max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, float(f.get("strength", 1.0)))) + for f in features + } + fids = list(clamps) with self._lock: comp = self._ensure_engine() hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] from sae.steering import clamp_hook - - clamps = {int(f["feature_id"]): float(f.get("strength", 1.0)) for f in features} feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] def _run(steer: bool) -> str: From 65a243962f37a4bbb7deaab1dc0ddbbcb1667ec6 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 05:36:42 +0000 Subject: [PATCH 18/20] style(core): ruff-format blank line after in-function import (fix pre-commit CI) Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 00035d9b7f..6cc72cc80b 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -353,6 +353,7 @@ def generate( comp = self._ensure_engine() hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] from sae.steering import clamp_hook + feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] def _run(steer: bool) -> str: From a1f1d545ba3331a494f5a881b5c7fe997dd19c1e Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 06:25:59 +0000 Subject: [PATCH 19/20] evo2 serve: fix temperature=0 NaN + factor steering guards into a tested helper temperature<=0 made the recipe sampler divide logits by temperature -> NaN -> multinomial device-side assert (wedged the server). The steering tab defaults to temp 0, so it hit this immediately. Coerce temp<=0 to greedy top-1 (the deterministic path that skips the division). Factored the input guards (id-range reject, clamp-magnitude cap, temp-0 coercion) into a pure _sanitize_steering() and added 4 CPU tests for them (no GPU): out-of-range id -> ValueError, extreme strength -> capped, temp 0 -> top_k 1, valid -> passthrough. Co-Authored-By: Claude Opus 4.8 --- .../recipes/evo2/src/evo2_sae/core.py | 43 +++++++++++++------ .../recipes/evo2/tests/test_steering.py | 40 +++++++++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 6cc72cc80b..f653cc383a 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -79,6 +79,33 @@ def clean_dna(seq: str) -> str: return _VALID_BASES.sub("", (seq or "").upper()) +def _sanitize_steering(features, n_features, temperature, top_k): + """Validate/normalize steering inputs (pure, no GPU); raise on bad input. + + Each guard prevents a CUDA device-side assert that would corrupt the context and wedge + the server until restart: + + * feature id outside [0, n_features) indexes off the SAE codes -> ValueError (server -> 400); + * |strength| beyond MAX_CLAMP_STRENGTH blows the logits to inf/NaN -> capped; + * temperature <= 0 makes the recipe's sampler divide logits by temperature (NaN under + multinomial) -> coerce to greedy top-1, which is deterministic and skips that path. + + Returns ``(clamps: dict[int, float], fids: list[int], temperature: float, top_k: int)``. + """ + bad = sorted({int(f["feature_id"]) for f in features if not (0 <= int(f["feature_id"]) < n_features)}) + if bad: + raise ValueError(f"feature_id(s) {bad} out of range [0, {n_features})") + clamps = { + int(f["feature_id"]): max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, float(f.get("strength", 1.0)))) + for f in features + } + temperature = float(temperature) + top_k = int(top_k) + if temperature <= 0: # greedy — avoid the sampler's logits/temperature division (NaN) + top_k = max(top_k, 1) + return clamps, list(clamps), temperature, top_k + + def _init_single_process_distributed() -> None: """Set the env vars Megatron's distributed init expects for a 1-GPU process.""" os.environ.setdefault("RANK", "0") @@ -335,19 +362,9 @@ def generate( if not full_prompt: raise ValueError("Provide a prompt or pick an organism (need >=1 token to seed)") n_tokens = max(1, min(int(n_tokens), 400)) - # Guard rails — a bad clamp can wedge the whole process (a CUDA device-side assert - # corrupts the context, so every later request 500s until restart): - # * a feature id outside [0, n_features) indexes off the SAE codes -> assert; - # * an extreme target overflows the logits -> NaN under sampling -> assert. - # Reject out-of-range ids (the server maps ValueError -> 400) and cap the magnitude. - bad = sorted({int(f["feature_id"]) for f in features if not (0 <= int(f["feature_id"]) < self.n_features)}) - if bad: - raise ValueError(f"feature_id(s) {bad} out of range [0, {self.n_features})") - clamps = { - int(f["feature_id"]): max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, float(f.get("strength", 1.0)))) - for f in features - } - fids = list(clamps) + # Validate/normalize steering inputs — out-of-range ids, extreme clamps, and temperature + # 0 each trigger CUDA device-side asserts that wedge the server (see _sanitize_steering). + clamps, fids, temperature, top_k = _sanitize_steering(features, self.n_features, temperature, top_k) with self._lock: comp = self._ensure_engine() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py index cebceb0f52..f4368ebbbf 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -60,6 +60,46 @@ def test_clamp_math(): assert torch.equal(hook(None, None, prefill), prefill) +# --------------------------------------------------- CPU: steering input guards (no model) +# Each guards against a CUDA device-side assert that would corrupt the context and wedge the +# server: out-of-range id (indexes off the SAE codes), extreme clamp (NaN logits), and +# temperature 0 (sampler divides logits by temperature -> NaN under multinomial). +def test_sanitize_rejects_out_of_range_id(): + """feature_id outside [0, n_features) -> ValueError (server maps to 400, not a wedge).""" + from evo2_sae.core import _sanitize_steering + + with pytest.raises(ValueError): + _sanitize_steering([{"feature_id": 70000, "strength": 1.0}], n_features=65536, temperature=1.0, top_k=0) + with pytest.raises(ValueError): + _sanitize_steering([{"feature_id": -1, "strength": 1.0}], n_features=65536, temperature=1.0, top_k=0) + + +def test_sanitize_caps_clamp_magnitude(): + """An extreme target is capped to ±MAX_CLAMP_STRENGTH (prevents NaN logits).""" + from evo2_sae.core import MAX_CLAMP_STRENGTH, _sanitize_steering + + clamps, _, _, _ = _sanitize_steering([{"feature_id": 5, "strength": 99999.0}], 65536, 1.0, 0) + assert clamps[5] == MAX_CLAMP_STRENGTH + clamps, _, _, _ = _sanitize_steering([{"feature_id": 5, "strength": -99999.0}], 65536, 1.0, 0) + assert clamps[5] == -MAX_CLAMP_STRENGTH + + +def test_sanitize_temperature_zero_forces_greedy_topk(): + """temperature<=0 (divide-by-zero in the sampler) is coerced to greedy top-1.""" + from evo2_sae.core import _sanitize_steering + + _, _, temp, top_k = _sanitize_steering([{"feature_id": 5, "strength": 1.0}], 65536, 0.0, 0) + assert top_k == 1 # bumped from 0 so the logits/temperature path is skipped + + +def test_sanitize_passthrough(): + """Valid inputs pass through unchanged — no spurious capping or top_k change.""" + from evo2_sae.core import _sanitize_steering + + clamps, fids, temp, top_k = _sanitize_steering([{"feature_id": 5, "strength": 2.0}], 65536, 0.8, 0) + assert clamps == {5: 2.0} and fids == [5] and temp == 0.8 and top_k == 0 + + # --------------------------------------------------------------------- GPU: real generation _CKPT = os.environ.get("EVO2_CKPT_DIR") _SAE = os.environ.get("SAE_CKPT_PATH") From baf6ddf3f888e1f9fa8bf2a4fc9f64904b37911c Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 06:28:01 +0000 Subject: [PATCH 20/20] evo2 serve: drop unused unpack var in steering guard test (lint) --- .../sparse_autoencoders/recipes/evo2/tests/test_steering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py index f4368ebbbf..51f48485f1 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -88,7 +88,7 @@ def test_sanitize_temperature_zero_forces_greedy_topk(): """temperature<=0 (divide-by-zero in the sampler) is coerced to greedy top-1.""" from evo2_sae.core import _sanitize_steering - _, _, temp, top_k = _sanitize_steering([{"feature_id": 5, "strength": 1.0}], 65536, 0.0, 0) + _, _, _, top_k = _sanitize_steering([{"feature_id": 5, "strength": 1.0}], 65536, 0.0, 0) assert top_k == 1 # bumped from 0 so the logits/temperature path is skipped