diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f7b23d8eff..f4a6a32c40 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -13,13 +13,15 @@ dependencies = [ "torch>=2.0", "numpy>=1.20", "pyarrow>=23.0.0", + "fastapi>=0.110", + "uvicorn>=0.29", + "pandas>=1.5", ] -# No package code lives here yet — the recipe is just an entry-point for -# scripts/ that depends on the shared `sae` workspace package. Declare no -# packages so setuptools doesn't try to discover anything. -[tool.setuptools] -packages = [] +# The `evo2_sae` package (src/) holds the live inference engine + server + CLI; +# scripts/ (extract, train) are standalone entry points alongside it. +[tool.setuptools.packages.find] +where = ["src"] [tool.uv.sources] sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py index 797034da7b..05d4333408 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -24,26 +24,9 @@ """ import argparse -import gzip from pathlib import Path - -def parse_fasta(path: Path): - """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" - opener = gzip.open if path.suffix == ".gz" else open - seq_id, parts = None, [] - with opener(path, "rt") as f: - for line in f: - line = line.rstrip() - if line.startswith(">"): - if seq_id is not None: - yield seq_id, "".join(parts) - seq_id = line[1:].split()[0] - parts = [] - else: - parts.append(line) - if seq_id is not None: - yield seq_id, "".join(parts) +from evo2_sae.fasta import read_fasta def main(): @@ -61,7 +44,7 @@ def main(): n_in = n_out = bp_out = 0 args.output.parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as out: - for seq_id, seq in parse_fasta(args.input): + for seq_id, seq in read_fasta(args.input): n_in += 1 for start in range(0, len(seq), args.window): end = min(start + args.window, len(seq)) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py new file mode 100644 index 0000000000..ac7255d908 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 + SAE inference engine — reused by the live server, the batch CLI, and the viz backend.""" + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: # for type checkers / ruff F822 — runtime access goes through __getattr__ below + from .core import DEFAULT_ORGANISM_TAGS, Evo2SAE, clean_dna + + +__all__ = ["DEFAULT_ORGANISM_TAGS", "Evo2SAE", "clean_dna"] + + +def __getattr__(name: str): + """Lazily pull the heavy engine symbols from ``.core`` (importing ``.core`` loads torch). + + Keeps ``import evo2_sae`` (and lightweight submodules like ``evo2_sae.fasta``) cheap so + stdlib-only callers don't drag in torch, while ``from evo2_sae import Evo2SAE`` still works. + """ + if name in __all__: + from . import core + + return getattr(core, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py new file mode 100644 index 0000000000..f653cc383a --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -0,0 +1,407 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 + SAE inference core — one importable engine for live and batch use. + +`Evo2SAE` loads a base Evo2 model and a trained SAE once, then exposes: + + encode(dna) -> codes [S, n_features] # ONE sequence (interactive) + encode_batch(seqs) -> list of codes [S_i, n_features] # MANY sequences (batched on GPU) + feature_tracks(dna, f) -> {feature_id: [per-base activation]} + generate(...) -> autoregressive DNA generation with optional additive + SAE-feature clamping on the generated continuation + +It has NO web dependency: the FastAPI server (`server.py`) and the batch CLI +(`cli.py`) are thin wrappers over this class, and the viz backend imports it too. + +The heavy Evo2 machinery is reused from the recipe: model loading via +`predict.load_model_to_layer` and generation via `infer.setup_inference_engine` / +`infer.generate` (run eager, `cuda_graph_impl="none"`, so the residual-stream steering +hook applies). This module only adds the SAE layer: encode, feature labels, and the +decode-only feature-clamp hook. +""" + +from __future__ import annotations + +import logging +import os +import re +import sys +import threading +from pathlib import Path +from typing import Optional + +import torch + + +logger = logging.getLogger("evo2_sae_infer") + +# Disable Inductor CUDA graphs before torch initializes inductor — graph capture +# conflicts with the residual-stream forward hook (which replaces the layer output) +# and with re-feeding a growing sequence each decode step. +os.environ.setdefault("TORCHINDUCTOR_CUDAGRAPHS", "0") + +# Make the local `sae` package importable (sparse_autoencoders/sae/src). +_SAE_SRC = os.environ.get("SAE_SRC", str(Path(__file__).resolve().parents[4] / "sae" / "src")) +if _SAE_SRC not in sys.path: + sys.path.insert(0, _SAE_SRC) + +_VALID_BASES = re.compile(r"[^ACGTN]") + +# Steering clamp targets are absolute SAE-code values; this SAE's features peak ~100-300. +# Cap the magnitude so an extreme target can't blow the logits to NaN (which device-asserts +# and wedges the process). Generous headroom for amplification, bounded against runaway. +MAX_CLAMP_STRENGTH = float(os.environ.get("MAX_CLAMP_STRENGTH", "300")) + +# Phylogenetic-tag prefixes per organism (Evo2 was trained with lineage tags). +DEFAULT_ORGANISM_TAGS = { + "None (raw DNA)": "", + "Human": "|d__Eukaryota;p__Chordata;c__Mammalia;o__Primates;f__Hominidae;g__Homo;s__Homo sapiens|", + "E. coli": "|d__Bacteria;p__Pseudomonadota;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Escherichia;s__Escherichia coli|", + "S. cerevisiae": "|d__Eukaryota;p__Ascomycota;c__Saccharomycetes;o__Saccharomycetales;f__Saccharomycetaceae;g__Saccharomyces;s__Saccharomyces cerevisiae|", +} + + +def clean_dna(seq: str) -> str: + """Uppercase and strip everything that isn't a nucleotide.""" + return _VALID_BASES.sub("", (seq or "").upper()) + + +def _sanitize_steering(features, n_features, temperature, top_k): + """Validate/normalize steering inputs (pure, no GPU); raise on bad input. + + Each guard prevents a CUDA device-side assert that would corrupt the context and wedge + the server until restart: + + * feature id outside [0, n_features) indexes off the SAE codes -> ValueError (server -> 400); + * |strength| beyond MAX_CLAMP_STRENGTH blows the logits to inf/NaN -> capped; + * temperature <= 0 makes the recipe's sampler divide logits by temperature (NaN under + multinomial) -> coerce to greedy top-1, which is deterministic and skips that path. + + Returns ``(clamps: dict[int, float], fids: list[int], temperature: float, top_k: int)``. + """ + bad = sorted({int(f["feature_id"]) for f in features if not (0 <= int(f["feature_id"]) < n_features)}) + if bad: + raise ValueError(f"feature_id(s) {bad} out of range [0, {n_features})") + clamps = { + int(f["feature_id"]): max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, float(f.get("strength", 1.0)))) + for f in features + } + temperature = float(temperature) + top_k = int(top_k) + if temperature <= 0: # greedy — avoid the sampler's logits/temperature division (NaN) + top_k = max(top_k, 1) + return clamps, list(clamps), temperature, top_k + + +def _init_single_process_distributed() -> None: + """Set the env vars Megatron's distributed init expects for a 1-GPU process.""" + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29577") + + +class Evo2SAE: + """Persistent Evo2 + SAE inference engine (single-sequence and batched).""" + + def __init__( + self, + evo2_ckpt_dir: str, + sae_ckpt_path: str, + layer: int, + device: str = "cuda", + max_seq_len: int = 8192, + feature_annotations: Optional[str] = None, + organism_tags: Optional[dict] = None, + ): + """Record config; call .load() to actually load the model + SAE onto the GPU.""" + self.evo2_ckpt_dir = evo2_ckpt_dir + self.sae_ckpt_path = sae_ckpt_path + self.layer = int(layer) + self.device = device + self.max_seq_len = int(max_seq_len) + self.feature_annotations = feature_annotations + self.organism_tags = dict(organism_tags) if organism_tags else dict(DEFAULT_ORGANISM_TAGS) + + self.model = None # truncated model (post_process=False) for activations + self.gen_components = None # recipe inference engine (full model) for generation, lazy + self.tokenizer = None + self.sae = None + self.n_features = None + self.labels: dict[int, str] = {} + self.peaks: dict[int, float] = {} + self._lock = threading.Lock() # serialize GPU access (Megatron isn't thread-safe) + self.ready = False + + # ------------------------------------------------------------------ loading + def load(self) -> "Evo2SAE": + """Load the truncated Evo2 model + SAE + feature labels (one-time, ~1 min).""" + from bionemo.evo2.run import predict as P + + _init_single_process_distributed() + self.model, self.tokenizer = P.load_model_to_layer(self.evo2_ckpt_dir, self.layer, full=False) + self.sae, self.n_features = self._load_sae() + self.labels, self.peaks = self._load_feature_meta() + self.ready = True + logger.info("Evo2SAE ready: layer=%d n_features=%d n_labels=%d", self.layer, self.n_features, len(self.labels)) + return self + + def _ensure_engine(self): + """Lazily build the recipe's inference engine (eager/hookable) for generation. + + cuda_graph_impl="none" keeps decode eager so the residual-stream steering hook + takes effect (a CUDA-graph-captured decode would replay frozen ops and ignore it). + """ + if self.gen_components is None: + from bionemo.evo2.run import infer as INF + + # load_model_to_layer (the encode model) already initialized the global + # num-microbatches calculator; setup_inference_engine re-inits it and asserts + # unless we tear the singleton down first. + try: + from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator + + destroy_num_microbatches_calculator() + except Exception: + pass + + self.gen_components = INF.setup_inference_engine( + Path(self.evo2_ckpt_dir), max_seq_length=self.max_seq_len, cuda_graph_impl="none" + ) + return self.gen_components + + def _load_sae(self): + ckpt = torch.load(self.sae_ckpt_path, map_location="cpu", weights_only=False) + cfg = dict(ckpt["model_config"]) + state = ckpt["model_state_dict"] + if any(k.startswith("module.") for k in state): + state = {k.removeprefix("module."): v for k, v in state.items()} + train_cfg = ckpt.get("config") + arch = ( + str(train_cfg.get("architecture", train_cfg.get("arch", ""))).lower() + if isinstance(train_cfg, dict) + else "" + ) + hint = (arch + " " + self.sae_ckpt_path).lower() + from sae.architectures import ReLUSAE, TopKSAE + + cls = ReLUSAE if "relu" in hint else TopKSAE + sae = cls(**cfg) + sae.load_state_dict(state) + sae.eval().to(self.device) + logger.info("SAE loaded: %s input_dim=%d n_features=%d", cls.__name__, cfg["input_dim"], cfg["hidden_dim"]) + return sae, int(cfg["hidden_dim"]) + + def _load_feature_meta(self): + """feature_id -> (label, natural peak) from the annotation parquet/tsv/csv/json.""" + labels: dict[int, str] = {} + peaks: dict[int, float] = {} + if not self.feature_annotations: + return labels, peaks + path = Path(self.feature_annotations) + if not path.exists(): + logger.warning("Feature annotations %s not found — features unlabeled", path) + return labels, peaks + if path.suffix.lower() == ".parquet": + import pyarrow.parquet as pq + + tbl = pq.read_table(path).to_pydict() + ids = tbl.get("feature_id", []) + names = tbl.get("label", tbl.get("annotation", [None] * len(ids))) + pk = tbl.get("max_activation", [None] * len(ids)) + for i, n, p in zip(ids, names, pk): + if n is not None: + labels[int(i)] = str(n) + if p is not None: + peaks[int(i)] = float(p) + logger.info("Loaded %d labels from %s", len(labels), path) + return labels, peaks + + # ------------------------------------------------------------------ tokenize + def tokenize(self, text: str) -> list[int]: + """Tokenize text to token ids, truncated to max_seq_len.""" + tok = self.tokenizer + ids = tok.tokenize(text) if hasattr(tok, "tokenize") else tok.text_to_ids(text) + return ids[: self.max_seq_len] + + def resolve_tag(self, organism: str, tag: Optional[str]) -> Optional[str]: + """Explicit custom `tag` wins; else look up the organism preset.""" + if tag is not None: + return tag + return self.organism_tags.get(organism) + + # ------------------------------------------------------------------ encode + @torch.no_grad() + def encode(self, dna: str) -> torch.Tensor: + """ONE sequence -> SAE codes [seq_len, n_features] on CPU. No phylo tag.""" + return self.encode_batch([dna])[0] + + @torch.no_grad() + def encode_batch(self, seqs: list[str], batch_size: int = 8) -> list[torch.Tensor]: + """MANY sequences -> list of SAE codes [S_i, n_features], batched on the GPU. + + Sequences are padded to the longest in each micro-batch; padding is masked + out before SAE-encoding so each result has the true per-base length. + """ + out: list[torch.Tensor] = [None] * len(seqs) # type: ignore + order = [(i, self.tokenize(s)) for i, s in enumerate(seqs)] + with self._lock: + for start in range(0, len(order), batch_size): + chunk = order[start : start + batch_size] + id_lists = [ids for _, ids in chunk] + hiddens = self._forward_hidden(id_lists) # list of [S_i, H] + for (orig_i, ids), h in zip(chunk, hiddens): + out[orig_i] = ( + self.sae.encode(h.to(self.device)).detach().cpu() + if h.shape[0] > 0 + else torch.empty(0, self.n_features) + ) + return out + + @torch.no_grad() + def _forward_hidden(self, id_lists: list[list[int]]) -> list[torch.Tensor]: + """Run the truncated model on a (padded) batch of token-id lists. + + Returns the unpadded layer-`layer` hidden states [S_i, H] per sequence. + """ + from bionemo.evo2.run import predict as P + + lens = [len(ids) for ids in id_lists] + maxlen = max(lens) if lens else 0 + if maxlen == 0: + return [torch.empty(0, 0) for _ in id_lists] + b = len(id_lists) + tokens = torch.zeros(b, maxlen, dtype=torch.long, device=self.device) + loss_mask = torch.zeros(b, maxlen, dtype=torch.long, device=self.device) + for i, ids in enumerate(id_lists): + if ids: + tokens[i, : len(ids)] = torch.tensor(ids, dtype=torch.long, device=self.device) + loss_mask[i, : len(ids)] = 1 + batch = { + "tokens": tokens, + "position_ids": torch.arange(maxlen, dtype=torch.long, device=self.device).unsqueeze(0).expand(b, -1), + "loss_mask": loss_mask, + "seq_idx": torch.arange(b, dtype=torch.long, device=self.device), + } + # Evo2 runs bf16; TransformerEngine asserts param/input dtypes match unless inside an + # autocast region. predict()'s loop relies on this too, so wrap the direct _predict_step. + device_type = "cuda" if self.device.startswith("cuda") else "cpu" + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): + result = P._predict_step(model=self.model, batch=batch, output_embeddings=True) + hidden = result["hidden_embeddings"] # [B, S, H] + return [hidden[i, : lens[i]].float() for i in range(b)] + + def feature_tracks(self, dna: str, fids: list[int]) -> dict: + """Per-base activation of several features on `dna`. {fid: [..]} (encoded once).""" + if not dna: + return {int(f): [] for f in fids} + codes = self.encode(dna) + return {int(f): [round(float(v), 4) for v in codes[:, int(f)].tolist()] for f in fids} + + def top_features(self, codes: torch.Tensor, tag_len: int = 0, k: int = 8) -> list[dict]: + """Top-k features by per-base max activation over the DNA region (excluding the tag). + + `codes` is [S, n_features] from `encode`/`encode_batch`; `tag_len` skips the leading + phylo-tag tokens (ignored if it would drop the whole sequence). Returns the strictly + positive features as [{feature_id, label, max_activation}], used by the CLI and server. + """ + if codes.shape[0] == 0: + return [] + region = codes[tag_len:] if codes.shape[0] > tag_len else codes + per = region.max(dim=0).values + idx = per.topk(min(int(k), per.numel())).indices.tolist() + return [ + {"feature_id": int(i), "label": self.labels.get(int(i)), "max_activation": round(float(per[i]), 4)} + for i in idx + if per[i].item() > 0 + ] + + # ------------------------------------------------------------------ generate + def generate( + self, + prompt="", + organism="None (raw DNA)", + tag=None, + features=None, + n_tokens=120, + temperature=1.0, + top_k=0, + compare_baseline=False, + ) -> dict: + """Autoregressively generate DNA, optionally clamping features on the continuation. + + `features` = list of {"feature_id": int, "strength": float} (or []). Generation runs + through the recipe's inference engine (`infer.generate`, eager so the hook applies); + steering is a decode-only forward hook on layer `layer`. Returns + {generation:{sequence,activations}, baseline:..|None, features, steered}. + """ + from megatron.core.utils import unwrap_model + + from bionemo.evo2.run import infer as INF + + features = features or [] + resolved_tag = self.resolve_tag(organism, tag) + if resolved_tag is None: + raise ValueError(f"Unknown organism '{organism}' and no custom tag") + dna = clean_dna(prompt) + full_prompt = resolved_tag + dna + if not full_prompt: + raise ValueError("Provide a prompt or pick an organism (need >=1 token to seed)") + n_tokens = max(1, min(int(n_tokens), 400)) + # Validate/normalize steering inputs — out-of-range ids, extreme clamps, and temperature + # 0 each trigger CUDA device-side asserts that wedge the server (see _sanitize_steering). + clamps, fids, temperature, top_k = _sanitize_steering(features, self.n_features, temperature, top_k) + + with self._lock: + comp = self._ensure_engine() + hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] + from sae.steering import clamp_hook + + feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] + + def _run(steer: bool) -> str: + handle = ( + hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True)) + if (steer and clamps) + else None + ) + try: + out = INF.generate( + comp, [full_prompt], max_new_tokens=n_tokens, temperature=temperature, top_k=top_k + ) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + finally: + if handle is not None: + handle.remove() + + main_dna = _run(steer=True) + base_dna = _run(steer=False) if (compare_baseline and clamps) else None + + resp = { + "prompt": dna, + "organism": organism, + "tag": resolved_tag, + "tag_len": len(resolved_tag), + "n_tokens": n_tokens, + "features": feat_meta, + "steered": bool(clamps), + "generation": {"sequence": main_dna, "activations": self.feature_tracks(main_dna, fids)}, + "baseline": None, + } + if base_dna is not None: + resp["baseline"] = {"sequence": base_dna, "activations": self.feature_tracks(base_dna, fids)} + return resp diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py new file mode 100644 index 0000000000..46db93066c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/fasta.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared FASTA reader for the evo2 SAE recipe (stdlib-only; no torch import). + +One streaming parser reused by the batch CLI (`cli.py`) and the FASTA chunker +(`scripts/chunk_fasta.py`) so the header/`.gz`/concat logic lives in one place. +""" + +from __future__ import annotations + +import gzip +from collections.abc import Iterator +from pathlib import Path + + +def read_fasta(path: str | Path) -> Iterator[tuple[str, str]]: + """Yield ``(seq_id, sequence)`` for each record in a FASTA file. + + Args: + path: Path to a FASTA file; a ``.gz`` suffix is decompressed transparently. + + Yields: + ``(seq_id, sequence)``: the first whitespace-delimited token of the header, + or a generated ``seq_`` when the header carries no token (e.g. ``">"`` or + ``"> "``), paired with the record's concatenated sequence lines. + """ + opener = gzip.open if str(path).endswith(".gz") else open + seq_id: str | None = None + parts: list[str] = [] + n = 0 # records yielded so far — used to name token-less headers + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + n += 1 + header = line[1:].strip().split() + seq_id, parts = (header[0] if header else f"seq_{n}"), [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py new file mode 100644 index 0000000000..51f48485f1 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU inference tests: SAE encode + feature steering. + +`test_clamp_math` (CPU) checks the steering arithmetic the forward hook applies; it needs no +model and runs in CI. The GPU tests (slow, checkpoint-gated) drive the real engine: +`test_encode_smoke` guards the bf16 encode forward, and `test_unsteered_is_dna` / +`test_steering_changes_continuation` drive `generate()` through the recipe inference engine +and assert steering (on a discovered active feature) changes the continuation only. Gated by +EVO2_CKPT_DIR + SAE_CKPT_PATH. +""" + +import os + +import pytest +import torch + + +# --------------------------------------------------------------------- CPU: clamp math +def test_clamp_math(): + """The decode-only hook applies h <- h + Σ_f (target_f - relu((h-pre_bias)@enc_f + b_f))·dec_f.""" + from evo2_sae.core import Evo2SAE + + H, F = 4, 2 + torch.manual_seed(0) + pre_bias = torch.randn(H) + enc = torch.randn(F, H) + bias = torch.randn(F) + dec = torch.randn(H, F) + specs = [(enc[f], float(bias[f]), dec[:, f], float(f + 1)) for f in range(F)] + + eng = Evo2SAE.__new__(Evo2SAE) # bare instance — exercise the hook only + hook = eng._clamp_hook(specs, pre_bias) + + h = torch.randn(1, 1, H) # one decode token: [S=1, B=1, H] + out = hook(None, None, h) + + xc = h - pre_bias + expected = h.clone() + for enc_f, b_f, dec_f, target in specs: + a = torch.relu(xc @ enc_f + b_f) + expected = expected + (target - a).unsqueeze(-1) * dec_f + torch.testing.assert_close(out, expected) + + # prefill (S>1) must be left untouched (continuation-only steering) + prefill = torch.randn(5, 1, H) + assert torch.equal(hook(None, None, prefill), prefill) + + +# --------------------------------------------------- CPU: steering input guards (no model) +# Each guards against a CUDA device-side assert that would corrupt the context and wedge the +# server: out-of-range id (indexes off the SAE codes), extreme clamp (NaN logits), and +# temperature 0 (sampler divides logits by temperature -> NaN under multinomial). +def test_sanitize_rejects_out_of_range_id(): + """feature_id outside [0, n_features) -> ValueError (server maps to 400, not a wedge).""" + from evo2_sae.core import _sanitize_steering + + with pytest.raises(ValueError): + _sanitize_steering([{"feature_id": 70000, "strength": 1.0}], n_features=65536, temperature=1.0, top_k=0) + with pytest.raises(ValueError): + _sanitize_steering([{"feature_id": -1, "strength": 1.0}], n_features=65536, temperature=1.0, top_k=0) + + +def test_sanitize_caps_clamp_magnitude(): + """An extreme target is capped to ±MAX_CLAMP_STRENGTH (prevents NaN logits).""" + from evo2_sae.core import MAX_CLAMP_STRENGTH, _sanitize_steering + + clamps, _, _, _ = _sanitize_steering([{"feature_id": 5, "strength": 99999.0}], 65536, 1.0, 0) + assert clamps[5] == MAX_CLAMP_STRENGTH + clamps, _, _, _ = _sanitize_steering([{"feature_id": 5, "strength": -99999.0}], 65536, 1.0, 0) + assert clamps[5] == -MAX_CLAMP_STRENGTH + + +def test_sanitize_temperature_zero_forces_greedy_topk(): + """temperature<=0 (divide-by-zero in the sampler) is coerced to greedy top-1.""" + from evo2_sae.core import _sanitize_steering + + _, _, _, top_k = _sanitize_steering([{"feature_id": 5, "strength": 1.0}], 65536, 0.0, 0) + assert top_k == 1 # bumped from 0 so the logits/temperature path is skipped + + +def test_sanitize_passthrough(): + """Valid inputs pass through unchanged — no spurious capping or top_k change.""" + from evo2_sae.core import _sanitize_steering + + clamps, fids, temp, top_k = _sanitize_steering([{"feature_id": 5, "strength": 2.0}], 65536, 0.8, 0) + assert clamps == {5: 2.0} and fids == [5] and temp == 0.8 and top_k == 0 + + +# --------------------------------------------------------------------- GPU: real generation +_CKPT = os.environ.get("EVO2_CKPT_DIR") +_SAE = os.environ.get("SAE_CKPT_PATH") +_LAYER = int(os.environ.get("EMBEDDING_LAYER", "19")) +_PROMPT = "ATGGCCGAATTCGGCACGAGGACGTGCTGAAAGCTAGCTAGGCTAACCGGTTACGTGCAT" +_ORG = "Human" + + +@pytest.fixture(scope="module") +def engine(): + """Load the Evo2 + SAE engine once (skips unless CUDA + checkpoints are available).""" + if not torch.cuda.is_available(): + pytest.skip("steering tests require CUDA") + if not (_CKPT and _SAE): + pytest.skip("set EVO2_CKPT_DIR and SAE_CKPT_PATH to run the steering tests") + from evo2_sae import Evo2SAE + + return Evo2SAE(evo2_ckpt_dir=_CKPT, sae_ckpt_path=_SAE, layer=_LAYER).load() + + +def _gen(engine, features): + torch.manual_seed(0) + return engine.generate(prompt=_PROMPT, organism=_ORG, features=features, n_tokens=48, temperature=0.0, top_k=1) + + +def _tag(engine): + return engine.resolve_tag(_ORG, None) or "" + + +@pytest.mark.slow +def test_encode_smoke(engine): + """encode runs the truncated bf16 forward and returns finite per-feature codes (>=1 firing). + + Guards the TransformerEngine bf16/fp32 autocast path: a dtype mismatch would crash here. + """ + codes = engine.encode(_tag(engine) + _PROMPT) + assert codes.ndim == 2 and codes.shape[1] == engine.n_features + assert torch.isfinite(codes).all() + assert (codes > 0).any() + + +@pytest.mark.slow +def test_unsteered_is_dna(engine): + """Unsteered generation yields a non-empty ACGT string (Evo2 stays in-distribution).""" + seq = _gen(engine, [])["generation"]["sequence"] + assert seq and set(seq) <= set("ACGTN") + + +@pytest.mark.slow +def test_steering_changes_continuation(engine): + """Clamping a KNOWN-ACTIVE feature hard changes the continuation; empty clamp is a no-op. + + Discovers the most-active feature on the prompt (SAE-agnostic) so the clamp has real signal — + an arbitrary/dead feature would leave greedy decoding unchanged and make this test useless. + """ + per = engine.encode(_tag(engine) + _PROMPT).max(dim=0).values + fid, peak = int(per.argmax()), float(per.max()) + base = _gen(engine, [])["generation"]["sequence"] + steered = _gen(engine, [{"feature_id": fid, "strength": max(peak * 3.0, 50.0)}])["generation"]["sequence"] + assert steered != base # the clamp on an active feature changed the continuation + assert _gen(engine, [])["generation"]["sequence"] == base # determinism + empty-clamp no-op diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py new file mode 100644 index 0000000000..d390c9066b --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. + +A forward hook on the layer the SAE was trained on: it re-encodes the layer output through +the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the +activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the +SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves +the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) +and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model +with vs. without the hook. +""" + +from contextlib import contextmanager +from typing import Dict + +import torch + + +def clamp_hook(sae, clamps: Dict[int, float], decode_only: bool = False): + """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. + + The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's + output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative + value reverses its decoder direction. Works whether the module returns a tensor or a tuple + whose first element is the hidden state. + + Args: + sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, + and ``top_k``. + clamps: Map of feature index -> absolute code value to force at every position. + decode_only: If True, steer only autoregressive *decode* steps and leave the prompt + prefill untouched (continuation-only steering). Assumes a ``(sequence, batch, hidden)`` + layout — the convention for Evo2/megatron decoder layers — and applies the clamp only + when the sequence dimension is 1 (a single new token). + + Returns: + A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. + """ + items = [(int(f), float(v)) for f, v in clamps.items()] + + def hook(module, inputs, output): + h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + if decode_only and h.shape[0] != 1: # prefill (seq dim > 1) — leave untouched + return output + dtype, shape = h.dtype, h.shape + h_flat = h.reshape(-1, h.shape[-1]).float() + with torch.no_grad(): + pre_act, info = sae.encode_pre_act(h_flat) + codes = torch.relu(pre_act) + kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) + codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) + codes_clamped = codes_orig.clone() + for f, v in items: + codes_clamped[:, f] = v + delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) + h_out = (h_flat + delta).to(dtype).reshape(shape) + return (h_out, *rest) if rest is not None else h_out + + return hook + + +@contextmanager +def steer(module, sae, clamps: Dict[int, float], decode_only: bool = False): + """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" + handle = module.register_forward_hook(clamp_hook(sae, clamps, decode_only=decode_only)) + try: + yield + finally: + handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py new file mode 100644 index 0000000000..0b28517c9d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.steering: the delta-clamp adds exactly decode(clamped) - decode(orig).""" + +import torch +from sae.architectures import TopKSAE +from sae.steering import clamp_hook, steer +from torch import nn + + +def _sae(): + torch.manual_seed(0) + return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + + +def test_delta_clamp_is_exact_and_cancels_recon(): + """No-op clamp leaves the activation unchanged (recon error cancels); a real clamp shifts + it by exactly decode(clamped) - decode(orig) — the two halves of the delta-clamp contract.""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + + # No-op: decode(orig) != x, but the added delta is 0, so the output is unchanged. + with steer(m, sae, {}): + assert torch.allclose(m(x), x, atol=1e-5) + + # Real clamp: output == x + (decode(clamped) - decode(orig)), recon error cancelled. + with torch.no_grad(): + pre, info = sae.encode_pre_act(x.float()) + codes = torch.relu(pre) + kv, ki = torch.topk(codes, sae.top_k, dim=-1) + co = torch.zeros_like(codes).scatter(-1, ki, kv) + cc = co.clone() + cc[:, 3] = 5.0 + expected = x + (sae.decode(cc, info) - sae.decode(co, info)) + with steer(m, sae, {3: 5.0}): + assert torch.allclose(m(x), expected, atol=1e-4) + + +def test_tuple_output_steers_only_hidden_state(): + """When the hooked module returns a tuple, only element 0 is steered; the rest passes through.""" + + class M(nn.Module): + def forward(self, x): + return (x, "meta") + + sae, x = _sae(), torch.randn(3, 8) + m = M() + handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) + out = m(x) + handle.remove() + assert isinstance(out, tuple) and out[1] == "meta" + assert out[0].shape == x.shape and not torch.allclose(out[0], x) # clamp moved it + + +def test_decode_only_skips_prefill(): + """decode_only steers single-token decode steps ([1,B,H]) but leaves multi-token prefill alone.""" + sae, m = _sae(), nn.Identity() + prefill = torch.randn(5, 2, 8) # [S=5, B, H] — prompt prefill, must pass through + decode = torch.randn(1, 2, 8) # [S=1, B, H] — a single new token, must be steered + handle = m.register_forward_hook(clamp_hook(sae, {3: 5.0}, decode_only=True)) + out_prefill, out_decode = m(prefill), m(decode) + handle.remove() + assert torch.allclose(out_prefill, prefill, atol=1e-5) # prefill untouched + assert not torch.allclose(out_decode, decode) # decode step steered