From c27f744ec0d437527284ebfbe2cb25211d908251 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 19:45:47 +0000 Subject: [PATCH 01/14] evo2 SAE eval: feature-probing harness + label producers The evo2-specific half of the probing eval, stacked on the shared sae.eval.probing primitives (base PR). Produces labels/buffers and drives the shared scoring lenses; all model-agnostic metrics live in sae.eval.probing. - labelers.py: per-token biological labelers (motifs, ORFs, codons, GC, ...) - euk_windows.py: GFF3 -> instance-labeled exon/intron/CDS windows (domain_f1) - evo2_buffer.py: Evo2 engine -> ActivationBuffer (the only model-touching code) - probe.py: the probe CLI (extract/auroc/linear/codon-aa/context/euk/loss-recovered) - probe_loss_recovered.py: Evo2 adapter to sae.eval.loss_recovered Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/euk_windows.py | 248 +++++++++++ .../recipes/evo2/scripts/evo2_buffer.py | 139 ++++++ .../recipes/evo2/scripts/labelers.py | 420 ++++++++++++++++++ .../recipes/evo2/scripts/probe.py | 295 ++++++++++++ .../evo2/scripts/probe_loss_recovered.py | 150 +++++++ 5 files changed, 1252 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py new file mode 100644 index 0000000000..335922d000 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Build instance-level exon/intron/CDS-labeled windows from a genome FASTA + GFF3. + +Eukaryotic gene-structure annotation for SAE feature probing. Unlike the +sequence-derived labelers, these labels come from real gene models, and crucially +carry *instance IDs* (which exon / which intron / which gene each position belongs +to) so domain-adjusted F1 can compute recall PER ANNOTATION INSTANCE (a feature +"recalls" an exon if it fires anywhere inside it), not per position. + +For each protein-coding gene we take a representative transcript (longest by total +exon length), tile its span ± flank into windows, and label every position: + exon / intron / cds / utr / intergenic (+ per-position instance IDs for + exon, intron, gene) + +`python euk_windows.py --fasta chr21.fa --gff chr21.gff3 --dry-run` prints +coverage stats without building sequences. +""" + +from __future__ import annotations + +import argparse +import random +from collections import defaultdict + +import numpy as np + + +def _attrs(s): + return dict(kv.split("=", 1) for kv in s.strip().split(";") if "=" in kv) + + +def parse_gff(gff_path): + """Return {gene_id: {strand, tx: {tx_id: {'exon': [(s,e)], 'cds': [(s,e)]}}}} (protein_coding).""" + gene_strand, gene_biotype = {}, {} + tx_gene, tx_biotype = {}, {} + tx_exon = defaultdict(list) + tx_cds = defaultdict(list) + with open(gff_path) as fh: + for line in fh: + if line.startswith("#"): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 9: + continue + typ, s, e, strand, attr = f[2], int(f[3]), int(f[4]), f[6], f[8] + a = _attrs(attr) + if typ == "gene": + gid = a.get("ID", "").replace("gene:", "") + gene_strand[gid] = strand + gene_biotype[gid] = a.get("biotype", "") + elif typ in ("mRNA", "transcript"): + tid = a.get("ID", "").replace("transcript:", "") + tx_gene[tid] = a.get("Parent", "").replace("gene:", "") + tx_biotype[tid] = a.get("biotype", "") + elif typ == "exon": + tid = a.get("Parent", "").replace("transcript:", "") + tx_exon[tid].append((s, e)) + elif typ == "CDS": + tid = a.get("Parent", "").replace("transcript:", "") + tx_cds[tid].append((s, e)) + genes = {} + for tid, gid in tx_gene.items(): + if gene_biotype.get(gid) != "protein_coding" or tx_biotype.get(tid) != "protein_coding": + continue + if not tx_exon.get(tid): + continue + genes.setdefault(gid, {"strand": gene_strand.get(gid, "+"), "tx": {}}) + genes[gid]["tx"][tid] = {"exon": sorted(tx_exon[tid]), "cds": sorted(tx_cds.get(tid, []))} + return genes + + +def representative_tx(gene): + """Longest transcript by total exon length.""" + best, best_len = None, -1 + for tid, t in gene["tx"].items(): + ln = sum(e - s + 1 for s, e in t["exon"]) + if ln > best_len: + best, best_len = tid, ln + return best, gene["tx"][best] + + +def _label_window(chrom, w0, w1, gm, N): + """Label a window [w0,w1) using one gene model's intervals (central-gene approx).""" + L = w1 - w0 + pos = np.arange(w0, w1) + lab = {k: np.zeros(L, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + inst = {k: np.full(L, -1, np.int32) for k in ("exon", "intron", "gene")} + g_start, g_end = gm["span"] + in_tx = (pos >= g_start - 1) & (pos < g_end) + lab["intergenic"][~in_tx] = True + inst["gene"][in_tx] = gm["gi"] + for (s, e), iid in zip(gm["exons"], gm["exon_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["exon"][m] = True + inst["exon"][m] = iid + for (s, e), iid in zip(gm["introns"], gm["intron_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["intron"][m] = True + inst["intron"][m] = iid + for s, e in gm["cds"]: + lab["cds"][(pos >= s - 1) & (pos < e)] = True + lab["utr"] = lab["exon"] & ~lab["cds"] + return {"dna": chrom[w0:w1], "labels": lab, "instances": inst} + + +def build_windows( # noqa: D103 + fasta, gff, seq_len=1024, max_tokens=300_000, flank=300, seed=0, intergenic_frac=0.12, dry_run=False +): + seqs = [] + with open(fasta) as fh: + for line in fh: + if not line.startswith(">"): + seqs.append(line.strip()) + chrom = "".join(seqs).upper() + N = len(chrom) + genes = parse_gff(gff) + + exon_id, intron_id, gene_id = {}, {}, {} + stats = defaultdict(int) + gene_models, gene_spans = [], [] + for gid, gene in genes.items(): + tid, tx = representative_tx(gene) + exons, cds = tx["exon"], tx["cds"] + if not exons: + continue + g_start, g_end = exons[0][0], exons[-1][1] + introns = [ + (exons[i][1] + 1, exons[i + 1][0] - 1) + for i in range(len(exons) - 1) + if exons[i + 1][0] - 1 >= exons[i][1] + 1 + ] + gi = gene_id.setdefault(gid, len(gene_id)) + eids = [exon_id.setdefault((tid, i), len(exon_id)) for i in range(len(exons))] + iids = [intron_id.setdefault((tid, i), len(intron_id)) for i in range(len(introns))] + gene_models.append( + { + "exons": exons, + "introns": introns, + "cds": cds, + "gi": gi, + "exon_ids": eids, + "intron_ids": iids, + "span": (g_start, g_end), + } + ) + gene_spans.append((g_start, g_end)) + stats["genes"] += 1 + stats["exons"] += len(exons) + stats["introns"] += len(introns) + stats["exon_bp"] += sum(e - s + 1 for s, e in exons) + stats["intron_bp"] += sum(e - s + 1 for s, e in introns) + stats["cds_bp"] += sum(e - s + 1 for s, e in cds) + if dry_run: + return [], dict(stats), 0, N + + rng = random.Random(seed) + # exon-centered windows sampled across ALL genes' exons (diverse + exon/intron balanced) + exon_refs = [(gi, ei) for gi, gm in enumerate(gene_models) for ei in range(len(gm["exons"]))] + rng.shuffle(exon_refs) + windows, tot = [], 0 + budget_genic = int(max_tokens * (1 - intergenic_frac)) + for gi, ei in exon_refs: + if tot >= budget_genic: + break + gm = gene_models[gi] + s, e = gm["exons"][ei] + center = (s - 1 + e) // 2 + w0 = max(0, center - seq_len // 2) + w1 = min(N, w0 + seq_len) + if w1 - w0 < 60: + continue + win = _label_window(chrom, w0, w1, gm, N) + if win["dna"].count("N") > 0.5 * len(win["dna"]): + continue + windows.append(win) + tot += w1 - w0 + # intergenic windows: random spots clear of any gene span (+flank) + spans = sorted(gene_spans) + tries = 0 + while tot < max_tokens and tries < 20000: + tries += 1 + w0 = rng.randint(0, N - seq_len) + w1 = w0 + seq_len + if any(not (w1 < gs - flank or w0 > ge + flank) for gs, ge in spans): + continue + dna = chrom[w0:w1] + if dna.count("N") > 0.5 * seq_len: + continue + lab = {k: np.zeros(seq_len, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + lab["intergenic"][:] = True + inst = {k: np.full(seq_len, -1, np.int32) for k in ("exon", "intron", "gene")} + windows.append({"dna": dna, "labels": lab, "instances": inst}) + tot += seq_len + return windows, dict(stats), tot, N + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser() + ap.add_argument("--fasta", required=True) + ap.add_argument("--gff", required=True) + ap.add_argument("--seq-len", type=int, default=1024) + ap.add_argument("--max-tokens", type=int, default=300_000) + ap.add_argument("--flank", type=int, default=300) + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + windows, stats, tot, N = build_windows( + args.fasta, args.gff, args.seq_len, args.max_tokens, args.flank, dry_run=args.dry_run + ) + print(f"chromosome length: {N:,} bp") + print(f"protein-coding genes used: {stats.get('genes', 0):,}") + print(f"exons: {stats.get('exons', 0):,} introns: {stats.get('introns', 0):,}") + if args.dry_run: + print( + f"exon bp: {stats.get('exon_bp', 0):,} intron bp: {stats.get('intron_bp', 0):,} cds bp: {stats.get('cds_bp', 0):,}" + ) + return + print(f"windows built: {len(windows):,} total tokens: {tot:,}") + # coverage over built windows + cov = defaultdict(int) + ninst = {k: set() for k in ("exon", "intron", "gene")} + for w in windows: + for k, m in w["labels"].items(): + cov[k] += int(m.sum()) + for k in ninst: + ids = w["instances"][k] + ninst[k].update(int(x) for x in np.unique(ids) if x >= 0) + print("per-position coverage (of built windows):") + for k in ("exon", "intron", "cds", "utr", "intergenic"): + print(f" {k:11s} {cov[k]:>9,} ({100 * cov[k] / max(1, tot):5.1f}%)") + print(f"instances: exons={len(ninst['exon']):,} introns={len(ninst['intron']):,} genes={len(ninst['gene']):,}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py new file mode 100644 index 0000000000..10a8529b41 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2-specific bit: turn DNA sequences into a probing ActivationBuffer. + +The only model-touching code in the probing pipeline. Streams sequences through +the Evo2SAE engine (Evo2 -> layer-L residual -> SAE.encode), keeps the dense +residual twin, and computes per-token labels (+ instance IDs) from labelers.py. +All scoring is done elsewhere by the model-agnostic sae.eval.probing metrics. +""" + +from __future__ import annotations + +import random + +import labelers as L +import numpy as np +import torch +from sae.eval.probing import ActivationBuffer + + +KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} + + +def read_fasta(path): # noqa: D103 + header, chunks = None, [] + with open(path) as fh: + for line in fh: + line = line.rstrip("\n") + if line.startswith(">"): + if header is not None: + yield header, "".join(chunks) + header, chunks = line[1:], [] + else: + chunks.append(line) + if header is not None: + yield header, "".join(chunks) + + +def sample_sequences(fasta, max_tokens, seq_len, kingdoms=("prok", "euk"), seed=0): # noqa: D103 + from evo2_sae_infer.core import clean_dna + + kingdoms = list(kingdoms) + pools = {k: [] for k in kingdoms} + need = max_tokens // seq_len + 50 + for header, seq in read_fasta(fasta): + kg = "prok" if header.lower().startswith("prok") else "euk" + if kg not in pools: + continue + dna = clean_dna(seq)[:seq_len] + if len(dna) < 60: + continue + pools[kg].append(dna) + if all(len(pools[k]) >= need for k in kingdoms): + break + rng = random.Random(seed) + for k in kingdoms: + rng.shuffle(pools[k]) + out, tok, i = [], 0, 0 + maxlen = max((len(pools[k]) for k in kingdoms), default=0) + while tok < max_tokens and i < maxlen: + for k in kingdoms: + if i < len(pools[k]): + out.append((k, pools[k][i])) + tok += len(pools[k][i]) + len(KINGDOM_TAGS[k]) + i += 1 + rng.shuffle(out) + return out + + +@torch.no_grad() +def build_buffer(engine, seqs, label_names, *, subsample, auroc_device, annotate_cds=False, batch_size=8, log=print): + """Stream seqs through engine -> ActivationBuffer (codes + dense + labels [+ cds instances]).""" + F = engine.n_features + Hd = engine.sae.pre_bias.shape[0] + dev = engine.device + S = subsample + code_buf = torch.zeros(S, F, dtype=torch.float16, device=auroc_device) + dense_buf = torch.zeros(S, Hd, dtype=torch.float16, device=auroc_device) + lab_buf = torch.zeros(S, len(label_names), dtype=torch.bool, device=auroc_device) + filled = 0 + for start in range(0, len(seqs), batch_size): + if filled >= S: + break + batch = seqs[start : start + batch_size] + id_lists, metas = [], [] + for kg, dna in batch: + tag = KINGDOM_TAGS[kg] + tids = engine.tokenize(tag) + id_lists.append(tids + engine.tokenize(dna)) + metas.append((tag, len(tids), kg, dna)) + with engine._lock: + hiddens = engine._forward_hidden(id_lists) + for h, (tag, tlen, kg, dna) in zip(hiddens, metas): + if h.shape[0] == 0 or filled >= S: + continue + hd = h.to(dev) + codes = engine.sae.encode(hd) + norm = h.float().norm(dim=-1).cpu().numpy() + T = codes.shape[0] + cds_mask = cds_frame = gene_starts = None + if annotate_cds and kg == "prok": + cds_mask, cds_frame, gene_starts = L.predict_cds(dna) + ctx = L.SeqContext( + text=(tag + dna)[:T], + tag_len=tlen, + dna=dna, + kingdom=kg, + hidden_norm=norm[:T], + cds_mask=cds_mask, + cds_frame=cds_frame, + gene_starts=gene_starts, + ) + lab = np.stack([L.LABELERS[n](ctx)[:T] for n in label_names], axis=1) + take = min(T, S - filled) + code_buf[filled : filled + take] = codes[:take].to(torch.float16).to(auroc_device) + dense_buf[filled : filled + take] = hd[:take].to(torch.float16).to(auroc_device) + lab_buf[filled : filled + take] = torch.from_numpy(lab[:take]).to(auroc_device) + filled += take + if (start // batch_size) % 10 == 0: + log(f" {start + len(batch)}/{len(seqs)} seqs | buf {filled}/{S}") + return ActivationBuffer( + codes=code_buf[:filled].cpu().numpy(), + dense=dense_buf[:filled].cpu().numpy(), + labels=lab_buf[:filled].cpu().numpy(), + label_names=list(label_names), + ) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py new file mode 100644 index 0000000000..e54e8c05ae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible per-token biological labelers for SAE feature probing. + +Each labeler maps a `SeqContext` (one tokenized sequence) to a per-token boolean +mask of length `T`. The per-feature AUROC probe (`probe_features.py`) asks, for +every label and every SAE feature, how well the feature's activation separates +positive from negative tokens. + +Adding a feature is just writing a function and decorating it: + + @labeler("my_concept") + def _my(ctx): + return some_bool_array_len_T + +`complex=True` flags labelers that are proxies or need real external annotation +(e.g. true gene models) and should be refined later — they're the natural home +for the "more complicated features" we want to add at the end. + +Conventions +----------- +* Tokens 0..tag_len-1 are the phylogenetic-tag prefix; sequence-derived motif / + positional labels are False there (use `_dna_mask`). Sequence-level labels + (`is_prok`) and norm-based labels (`is_sink_token`) may mark tag tokens. +* Byte-level Evo2 tokenization is 1 char = 1 token, so token i in the DNA region + corresponds to base `ctx.dna[i - tag_len]`. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + +import numpy as np + + +# name -> fn(ctx) -> np.ndarray[bool] of length T +LABELERS: dict[str, callable] = {} +# labelers that are proxies / need real annotations (documented, refine later) +COMPLEX_LABELERS: set[str] = set() + +# Sink-token norm threshold (residual L2). Set by the driver from the data +# (Evo2 7B layer-26 sinks sit ~1638 vs a ~21 median, so this cleanly isolates them). +SINK_NORM_THRESHOLD: float = 100.0 + + +def labeler(name: str, complex: bool = False): + """Register a per-token labeler under `name`.""" + + def deco(fn): + LABELERS[name] = fn + if complex: + COMPLEX_LABELERS.add(name) + return fn + + return deco + + +@dataclass +class SeqContext: + """Everything a labeler needs about one tokenized sequence.""" + + text: str # tag + dna (1 char == 1 token) + tag_len: int # number of leading phylo-tag tokens + dna: str # the DNA region (uppercase ACGTN), len == T - tag_len + kingdom: str # 'prok' | 'euk' + hidden_norm: np.ndarray # [T] residual L2 norm per token + # Gene-structure annotation over DNA positions (filled by a gene caller; None if absent). + cds_mask: Optional[np.ndarray] = None # bool[len(dna)] — within a predicted CDS (either strand) + cds_frame: Optional[np.ndarray] = None # int8[len(dna)] — codon position 0/1/2 within CDS, -1 if not + gene_starts: Optional[np.ndarray] = None # bool[len(dna)] — predicted translation start positions + + @property + def T(self) -> int: # noqa: D102 + return self.tag_len + len(self.dna) + + +def _dna_mask(ctx: SeqContext, dna_bool: np.ndarray) -> np.ndarray: + """Lift a per-DNA-position bool array to a per-token mask (tag tokens False).""" + out = np.zeros(ctx.T, dtype=bool) + out[ctx.tag_len : ctx.tag_len + len(dna_bool)] = dna_bool + return out + + +def _bytes(dna: str) -> np.ndarray: + return np.frombuffer(dna.encode("ascii", "replace"), dtype=np.uint8) + + +# --------------------------------------------------------------------- positional +@labeler("first_100bp") +def _first(ctx): + d = np.zeros(len(ctx.dna), bool) + d[:100] = True + return _dna_mask(ctx, d) + + +@labeler("last_100bp") +def _last(ctx): + d = np.zeros(len(ctx.dna), bool) + if len(d): + d[-100:] = True + return _dna_mask(ctx, d) + + +@labeler("codon_pos_1") +def _c1(ctx): + # frame-0 proxy (no CDS annotation): position 0 of each codon from seq start + return _dna_mask(ctx, np.arange(len(ctx.dna)) % 3 == 0) + + +@labeler("codon_pos_3") +def _c3(ctx): + return _dna_mask(ctx, np.arange(len(ctx.dna)) % 3 == 2) + + +# --------------------------------------------------------------------- composition +def _gc_window(dna: str, radius: int = 10) -> np.ndarray: + arr = _bytes(dna) + gc = ((arr == ord("G")) | (arr == ord("C"))).astype(np.float64) + csum = np.concatenate([[0.0], np.cumsum(gc)]) + n = len(gc) + idx = np.arange(n) + lo = np.maximum(0, idx - radius) + hi = np.minimum(n, idx + radius + 1) + return (csum[hi] - csum[lo]) / np.maximum(1, hi - lo) + + +@labeler("gc_high_window") +def _gch(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) >= 0.60) + + +@labeler("gc_low_window") +def _gcl(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) <= 0.30) + + +@labeler("homopolymer_window") +def _homo(ctx, k: int = 5): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n: + j = i + while j + 1 < n and d[j + 1] == d[i]: + j += 1 + if j - i + 1 >= k: + out[i : j + 1] = True + i = j + 1 + return _dna_mask(ctx, out) + + +@labeler("dinuc_repeat_window") +def _dinuc(ctx, min_reps: int = 3): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n - 1: + if d[i] != d[i + 1]: + j = i + while j + 2 < n and d[j + 2] == d[j]: + j += 1 + span = j + 2 - i + if span >= 2 * min_reps: + out[i : j + 2] = True + i = max(j + 1, i + 1) + else: + i += 1 + return _dna_mask(ctx, out) + + +# --------------------------------------------------------------------- motifs +def _starts(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start()] = True + return out + + +def _spans(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start() : m.end()] = True + return out + + +@labeler("motif_ATG") +def _atg(ctx): + return _dna_mask(ctx, _starts(ctx.dna, r"ATG")) + + +@labeler("motif_stop") +def _stop(ctx): + return _dna_mask(ctx, _starts(ctx.dna, r"TAA|TAG|TGA")) + + +@labeler("motif_TATA") +def _tata(ctx): + return _dna_mask(ctx, _spans(ctx.dna, r"TATA[AT]A")) + + +@labeler("motif_RBS_SD") +def _rbs(ctx): + # Shine-Dalgarno ribosome-binding site + return _dna_mask(ctx, _spans(ctx.dna, r"AGGAGG")) + + +# --------------------------------------------------- complex / consensus (refine later) +@labeler("kozak_atg", complex=True) +def _kozak(ctx): + # Kozak: (A/G)xxATGG — mark the ATG start (match start + 3) + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[AG]..ATGG", ctx.dna): + out[m.start() + 3] = True + return _dna_mask(ctx, out) + + +@labeler("tss_proxy_tata", complex=True) +def _tss(ctx): + # stricter canonical TATA box as a transcription-start proxy + return _dna_mask(ctx, _spans(ctx.dna, r"TATAAA")) + + +@labeler("splice_donor", complex=True) +def _sd(ctx): + # 5' donor consensus GT(A/G)AGT — mark the GT + return _dna_mask(ctx, _starts(ctx.dna, r"GT[AG]AG")) + + +@labeler("splice_acceptor", complex=True) +def _sa(ctx): + # 3' acceptor: polypyrimidine tract then AG — mark the AG + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[CT]{6}[ACGT]?AG", ctx.dna): + out[m.end() - 2 : m.end()] = True + return _dna_mask(ctx, out) + + +def _orf(ctx, frame: int, win: int = 60): + d, n = ctx.dna, len(ctx.dna) + stops = {"TAA", "TAG", "TGA"} + out = np.zeros(n, bool) + for p in range(frame, n - 2, 3): + ok = True + for q in range(p, min(p + win, n - 2), 3): + if d[q : q + 3] in stops: + ok = False + break + if ok: + out[p] = True + return _dna_mask(ctx, out) + + +@labeler("orf_frame_0_60bp", complex=True) +def _orf0(ctx): + return _orf(ctx, 0) + + +@labeler("orf_frame_1_60bp", complex=True) +def _orf1(ctx): + return _orf(ctx, 1) + + +@labeler("orf_frame_2_60bp", complex=True) +def _orf2(ctx): + return _orf(ctx, 2) + + +# --------------------------------------------------------------- sequence / norm level +@labeler("is_prok") +def _prok(ctx): + return np.full(ctx.T, ctx.kingdom == "prok", dtype=bool) + + +@labeler("is_euk_genic", complex=True) +def _eukg(ctx): + # proxy: eukaryotic token. True "genic" needs a gene model — refine later. + return np.full(ctx.T, ctx.kingdom == "euk", dtype=bool) + + +@labeler("is_sink_token", complex=True) +def _sink(ctx): + return ctx.hidden_norm > SINK_NORM_THRESHOLD + + +# --------------------------------------------- gene structure (real annotation, prok) +# These read a CDS annotation attached to the context by a gene caller (see +# predict_cds, prokaryotes only). They are no-ops when the annotation is absent. +@labeler("cds_coding", complex=True) +def _cds(ctx): + if ctx.cds_mask is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_mask) + + +@labeler("cds_start", complex=True) +def _cds_start(ctx): + if ctx.gene_starts is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.gene_starts) + + +@labeler("cds_frame_1", complex=True) +def _cds_f1(ctx): + # codon position 1 within a REAL predicted CDS (not the frame-0-from-start proxy) + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 0) + + +@labeler("cds_frame_3", complex=True) +def _cds_f3(ctx): + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 2) + + +_GENE_FINDER = None + +# Standard genetic code (NCBI translation table 1), codons in TCAG x TCAG x TCAG order. +_BASES = "TCAG" +_AA1 = "FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG" +CODON_TABLE = { + a + b + c: _AA1[i] for i, (a, b, c) in enumerate((x, y, z) for x in _BASES for y in _BASES for z in _BASES) +} +CODON_LIST = sorted(CODON_TABLE) # 64 codons +CODON_TO_IDX = {c: i for i, c in enumerate(CODON_LIST)} +AA_LIST = sorted(set(CODON_TABLE.values())) # 20 aa + '*' (stop) +AA_TO_IDX = {a: i for i, a in enumerate(AA_LIST)} +_COMP = str.maketrans("ACGTN", "TGCAN") + + +def _revcomp(s): + return s.translate(_COMP)[::-1] + + +def predict_codons(dna: str): + """In-frame codon + amino-acid identity at strand-correct codon anchors (prok genes). + + Returns (codon_id[N], aa_id[N]) over forward DNA coordinates; the anchor is the + first translated base of each codon (low coord on +strand, high coord on -strand), + other positions are -1. codon_id in 0..63 (CODON_LIST), aa_id in 0..20 (AA_LIST). + """ + global _GENE_FINDER + n = len(dna) + codon_id = np.full(n, -1, dtype=np.int16) + aa_id = np.full(n, -1, dtype=np.int8) + if n < 60: + return codon_id, aa_id + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = max(0, g.begin - 1), min(n, g.end) + sub = dna[b:e] + coding = sub if g.strand == 1 else _revcomp(sub) + for i in range(len(coding) // 3): + cod = coding[3 * i : 3 * i + 3] + j = CODON_TO_IDX.get(cod) + if j is None: + continue + p = b + 3 * i if g.strand == 1 else (e - 1 - 3 * i) + if 0 <= p < n: + codon_id[p] = j + aa_id[p] = AA_TO_IDX[CODON_TABLE[cod]] + return codon_id, aa_id + + +def predict_cds(dna: str): + """Prokaryotic gene calling via pyrodigal (meta mode) on a single DNA chunk. + + Returns (cds_mask, cds_frame, gene_starts) over forward DNA coordinates: + cds_mask[i] True if position i lies within any predicted CDS (either strand) + cds_frame[i] codon position 0/1/2 relative to that gene's start (strand-aware), else -1 + gene_starts[i] True at predicted translation starts + """ + global _GENE_FINDER + n = len(dna) + cds_mask = np.zeros(n, dtype=bool) + cds_frame = np.full(n, -1, dtype=np.int8) + gene_starts = np.zeros(n, dtype=bool) + if n < 60: + return cds_mask, cds_frame, gene_starts + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = g.begin - 1, g.end # 0-based half-open, forward coords + b, e = max(0, b), min(n, e) + if e <= b: + continue + cds_mask[b:e] = True + idx = np.arange(b, e) + if g.strand == 1: + gene_starts[b] = True + cds_frame[b:e] = (idx - b) % 3 + else: # reverse strand: start codon sits at the (forward) end + gene_starts[e - 1] = True + cds_frame[b:e] = ((e - 1) - idx) % 3 + return cds_mask, cds_frame, gene_starts + + +# Default label set for the probe (order preserved in outputs). +DEFAULT_LABELS = list(LABELERS.keys()) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py new file mode 100644 index 0000000000..5e40af5831 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Unified Evo2 SAE probing CLI. All scoring is sae.eval.probing (model-agnostic); +this driver only knows how to build/load Evo2 buffers and pick label sets. + + probe.py extract --out BUF [...] build an ActivationBuffer (needs the model) + probe.py auroc --acts BUF --labels .. per-feature AUROC table + probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) + probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense + probe.py context --acts BUF biological-context vs string-match firing (feat 29244/33918) + probe.py loss-recovered [...] fidelity via sae.eval.loss_recovered (needs the model) +""" # noqa: D205 + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent)) +sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) # sparse_autoencoders/sae/src + +import labelers as L # noqa: E402 +from sae.eval.probing import ( # noqa: E402 + ActivationBuffer, + auroc_all, + auroc_vec, + best_single_train_test, + decode_eval, + fit_softmax, + split_indices, + standardize, +) + + +def _z(X, tr): + # Standardize X by the train-split mean/std (reuses sae.eval.probing.standardize). + mu, sd = standardize(X, tr) + return (X - mu) / sd + + +# ───────────────────────────────────────── buffer-only subcommands (no model) +def cmd_auroc(a): # noqa: D103 + buf = ActivationBuffer.load(a.acts) + dev = a.device + X = torch.from_numpy(buf.codes).to(dev).float() + names = [t for t in a.labels.split(",") if t in buf.name_idx] + Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) + au = auroc_all(X, Y).cpu().numpy() + print(f"{'label':18s} {'%pos':>6s} {'best AUROC':>10s} {'feature':>8s}") + for i, n in enumerate(names): + print( + f"{n:18s} {buf.labels[:, buf.name_idx[n]].mean():6.1%} {au[:, i].max():10.3f} {int(au[:, i].argmax()):8d}" + ) + + +def _eval_matrix(mat, buf, names, tr, te, dev, steps, wd): + X = torch.from_numpy(mat).to(dev).float() + Xz = _z(X, tr) + out = {} + from sae.eval.probing import fit_logreg + + for n in names: + ytr = torch.from_numpy(buf.labels[tr.numpy(), buf.name_idx[n]]).to(dev).float() + yte = torch.from_numpy(buf.labels[te.numpy(), buf.name_idx[n]]).to(dev) + if ytr.sum() in (0, len(ytr)) or yte.sum() == 0: + out[n] = (float("nan"), float("nan")) + continue + w, b = fit_logreg(Xz[tr], ytr, steps=steps, wd=wd) + out[n] = (best_single_train_test(Xz[tr], ytr, Xz[te], yte), auroc_vec((Xz[te] @ w + b).float(), yte)) + del X, Xz + torch.cuda.empty_cache() + return out + + +def cmd_linear(a): # noqa: D103 + buf = ActivationBuffer.load(a.acts) + dev = a.device + names = [t for t in a.labels.split(",") if t in buf.name_idx] + tr, te = split_indices(buf.codes.shape[0], a.test_frac, a.seed) + sae = _eval_matrix(buf.codes, buf, names, tr, te, dev, a.steps, a.weight_decay) + den = _eval_matrix(buf.dense, buf, names, tr, te, dev, a.steps, a.weight_decay) if buf.dense is not None else None + h = f"{'label':18s} {'%pos':>6s} | {'SAE single':>10s} {'SAE multi':>9s}" + if den: + h += f" | {'dense single':>12s} {'dense multi':>11s} | {'Δ':>7s}" + print(h) + for n in names: + pos = buf.labels[:, buf.name_idx[n]].mean() + ss, sm = sae[n] + row = f"{n:18s} {pos:6.1%} | {ss:10.3f} {sm:9.3f}" + if den: + ds, dm = den[n] + row += f" | {ds:12.3f} {dm:11.3f} | {ss - ds:+7.3f}" + print(row) + + +def cmd_context(a): # noqa: D103 + z = np.load(a.acts, allow_pickle=True) + codes, labels = z["codes"], z["labels"] + idx = {n: i for i, n in enumerate(z["label_names"])} + P = codes.shape[0] + + def lab(n): + return labels[:, idx[n]].astype(bool) + + def rate(feat, m): + return (float((codes[:, feat][m] > 0).mean()), int(m.sum())) if m.sum() else (float("nan"), 0) + + ATG, STOP, START, INF = lab("motif_ATG"), lab("motif_stop"), lab("cds_start"), lab("cds_frame_1") + print(f"baseline: 29244={rate(29244, np.ones(P, bool))[0]:.3f} 33918={rate(33918, np.ones(P, bool))[0]:.3f}") + for nm, f, motif, ctx, cl in [ + ("29244 ATG", 29244, ATG, START, "real start"), + ("33918 STOP", 33918, STOP, INF, "in-frame"), + ]: + ra, na = rate(f, motif & ctx) + rb, nb = rate(f, motif & ~ctx) + print(f"{nm}: {cl} {ra:.3f}(n={na}) | other {rb:.3f}(n={nb}) | ratio {ra / max(rb, 1e-9):.2f}") + + +def cmd_codon_aa(a): # noqa: D103 + z = np.load(a.acts) + dev = a.device + codon = torch.from_numpy(z["codon"].astype(np.int64)).to(dev) + aa = torch.from_numpy(z["aa"].astype(np.int64)).to(dev) + codon_np = z["codon"].astype(np.int64) + ncod, naa = len(L.CODON_LIST), len(L.AA_LIST) + held = {"L": ["TTA", "TTG"], "S": ["AGT", "AGC"], "R": ["AGA", "AGG"]} + hidx = [L.CODON_TO_IDX[c] for v in held.values() for c in v] + print(f"{'matrix':6s} {'codon mAUROC':>12s} {'AA mAUROC':>10s} | family-disjoint recall L/S/R (chance)") + for nm in ("sae", "dense"): + if nm not in z.files: + continue + X = torch.from_numpy(z[nm]).to(dev).float() + Xz = (X - X.mean(0)) / (X.std(0) + 1e-6) + tr, te = split_indices(X.shape[0], a.test_frac, a.seed) + _, ca, _ = decode_eval(Xz[tr], codon[tr], Xz[te], codon[te], ncod, steps=a.steps, wd=a.weight_decay) + _, aaa, _ = decode_eval(Xz[tr], aa[tr], Xz[te], aa[te], naa, steps=a.steps, wd=a.weight_decay) + trn = torch.from_numpy(np.nonzero(~np.isin(codon_np, hidx))[0]).to(dev) + W, b = fit_softmax(Xz[trn], aa[trn], naa, steps=a.steps, wd=a.weight_decay) + rec = [] + for A, cods in held.items(): + m = np.isin(codon_np, [L.CODON_TO_IDX[c] for c in cods]) + pred = (Xz[torch.from_numpy(np.nonzero(m)[0]).to(dev)] @ W + b).argmax(1).cpu().numpy() + rec.append( + f"{A}={float((pred == L.AA_TO_IDX[A]).mean()):.2f}({float((aa == L.AA_TO_IDX[A]).float().mean()):.2f})" + ) + del X, Xz + torch.cuda.empty_cache() + print(f"{nm:6s} {ca:12.3f} {aaa:10.3f} | {' '.join(rec)}") + + +# ───────────────────────────────────────── model subcommands (need Evo2) +def cmd_euk(a): + """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" + from euk_windows import build_windows + from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from sae.eval.probing import domain_f1 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + windows, stats, tot, _ = build_windows(a.fasta, a.gff, a.seq_len, a.max_tokens, seed=a.seed) + print( + f"windows={len(windows)} tokens={tot} genes={stats['genes']} exons={stats['exons']} introns={stats['introns']}" + ) + F, adev = eng.n_features, a.auroc_device + tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) + tlen = len(tag_ids) + concepts = {"exon": "exon", "intron": "intron", "cds": "gene"} + code_buf = torch.zeros(tot, F, dtype=torch.float16, device=adev) + lab = {k: torch.zeros(tot, dtype=torch.bool, device=adev) for k in ("exon", "intron", "cds")} + inst = {k: torch.full((tot,), -1, dtype=torch.long, device=adev) for k in ("exon", "intron", "gene")} + filled = 0 + for s0 in range(0, len(windows), a.batch_size): + batch = windows[s0 : s0 + a.batch_size] + with eng._lock: + for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): + if h.shape[0] == 0: + continue + codes = eng.sae.encode(h.to(a.device)) + take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) + if take <= 0: + continue + code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) + for k in lab: + lab[k][filled : filled + take] = torch.from_numpy(w["labels"][k][:take]).to(adev) + for k in inst: + inst[k][filled : filled + take] = torch.from_numpy(w["instances"][k][:take].astype(np.int64)).to( + adev + ) + filled += take + code_buf = code_buf[:filled] + for d in (lab, inst): + for k in d: + d[k] = d[k][:filled] + fmax = code_buf.max(0).values.float() + g = torch.Generator(device=adev).manual_seed(a.seed) + print(f"encoded {filled} positions\n{'concept':8s} {'domF1':>6s} {'null':>6s} {'ratio':>6s} {'%pos':>6s}") + for c, ic in concepts.items(): + f1, _ = domain_f1(code_buf, fmax, lab[c], inst[ic]) + order = torch.randperm(filled, generator=g, device=adev) + f1n, _ = domain_f1(code_buf, fmax, lab[c][order], inst[ic][order]) + bf, nl = float(f1.max()), float(f1n.max()) + print(f"{c:8s} {bf:6.3f} {nl:6.3f} {bf / max(nl, 1e-9):6.2f} {float(lab[c].float().mean()):6.1%}") + + +def cmd_extract(a): # noqa: D103 + from evo2_buffer import build_buffer, sample_sequences + from evo2_sae_infer.core import Evo2SAE + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + label_names = list(L.LABELERS.keys()) + kingdoms = [k for k in a.kingdoms.split(",") if k] + seqs = sample_sequences(a.fasta, a.max_tokens, a.seq_len, kingdoms=kingdoms, seed=a.seed) + print(f"probe set: {len(seqs)} seqs (kingdoms={kingdoms})") + buf = build_buffer( + eng, + seqs, + label_names, + subsample=a.subsample, + auroc_device=a.auroc_device, + annotate_cds=a.annotate_cds, + batch_size=a.batch_size, + log=print, + ) + buf.save(a.out) + print(f"saved buffer -> {a.out} ({buf.codes.shape[0]} x {buf.codes.shape[1]}, dense {buf.dense.shape[1]})") + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + sub = ap.add_subparsers(dest="cmd", required=True) + common = argparse.ArgumentParser(add_help=False) + common.add_argument("--device", default="cuda:0") + common.add_argument("--seed", type=int, default=0) + common.add_argument("--steps", type=int, default=400) + common.add_argument("--weight-decay", type=float, default=1e-2) + common.add_argument("--test-frac", type=float, default=0.4) + for name, fn, needs_labels in [ + ("auroc", cmd_auroc, True), + ("linear", cmd_linear, True), + ("context", cmd_context, False), + ("codon-aa", cmd_codon_aa, False), + ]: + p = sub.add_parser(name, parents=[common]) + p.add_argument("--acts", required=True) + if needs_labels: + p.add_argument("--labels", required=True) + p.set_defaults(func=fn) + pe = sub.add_parser("extract", parents=[common]) + for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", "--out"]: + pe.add_argument(arg, required=True) + pe.add_argument("--layer", type=int, required=True) + pe.add_argument("--kingdoms", default="prok,euk") + pe.add_argument("--annotate-cds", action="store_true") + pe.add_argument("--max-tokens", type=int, default=200_000) + pe.add_argument("--subsample", type=int, default=50_000) + pe.add_argument("--seq-len", type=int, default=1024) + pe.add_argument("--batch-size", type=int, default=8) + pe.add_argument("--auroc-device", default="cuda:1") + pe.set_defaults(func=cmd_extract) + pk = sub.add_parser("euk-f1", parents=[common]) + for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", "--gff"]: + pk.add_argument(arg, required=True) + pk.add_argument("--layer", type=int, required=True) + pk.add_argument("--organism", default="Human") + pk.add_argument("--max-tokens", type=int, default=160_000) + pk.add_argument("--seq-len", type=int, default=1024) + pk.add_argument("--batch-size", type=int, default=8) + pk.add_argument("--auroc-device", default="cuda:1") + pk.set_defaults(func=cmd_euk) + args = ap.parse_args() + torch.set_grad_enabled(False) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py new file mode 100644 index 0000000000..490cdb9c26 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Loss recovered (fidelity) for the Evo2 SAE — reuses sae.eval.loss_recovered (Jared Wilber). + + loss_recovered = 1 - (CE_sae - CE_clean) / (CE_zero - CE_clean) + +We just provide Evo2-specific callables to his generic evaluator: + - get_hiddens(batch): capture the layer-`L` residual via a forward hook + - compute_ce(batch, override): full-model next-token CE, optionally patching the + layer-`L` output with `override` (zero-ablation or SAE reconstruction) +The SAE reconstruction is DENORMALIZED per token (normalize_input) so it is in the +raw residual space the layer actually emits. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as Fn + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent)) + +from evo2_buffer import sample_sequences # noqa: E402 +from evo2_sae_infer.core import Evo2SAE # noqa: E402 +from sae.eval.loss_recovered import evaluate_loss_recovered # noqa: E402 (Jared's code) + + +KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} + + +class SAEWrap(nn.Module): + """sae.forward(x[N,H]) -> (recon, codes) in RAW residual space (denormalized).""" + + def __init__(self, sae): # noqa: D107 + super().__init__() + self.sae = sae + + def forward(self, x): # noqa: D102 + s = self.sae + codes = s.encode(x) # encode normalizes internally if normalize_input + recon = s.decoder(codes) + s.pre_bias + if getattr(s, "normalize_input", False): + mu = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + 1e-8 + recon = recon * std + mu + return recon, codes + + +class L26Hook: # noqa: D101 + def __init__(self): # noqa: D107 + self.mode = "off" # off | capture | replace + self.override = None + self.captured = None + + def __call__(self, module, inp, output): # noqa: D102 + hs = output[0] if isinstance(output, tuple) else output + if self.mode == "replace" and self.override is not None: + new = self.override.to(hs.dtype) + return (new, *output[1:]) if isinstance(output, tuple) else new + if self.mode == "capture": + self.captured = hs.detach() + return output + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser() + ap.add_argument("--evo2-ckpt-dir", required=True) + ap.add_argument("--sae-checkpoint", required=True) + ap.add_argument("--layer", type=int, required=True) + ap.add_argument("--fasta", required=True) + ap.add_argument("--n-seqs", type=int, default=80) + ap.add_argument("--seq-len", type=int, default=1024) + ap.add_argument("--device", default="cuda:0") + ap.add_argument("--seed", type=int, default=0) + args = ap.parse_args() + torch.set_grad_enabled(False) + dev = args.device + + engine = Evo2SAE(args.evo2_ckpt_dir, args.sae_checkpoint, args.layer, device=dev).load() + from megatron.core.utils import unwrap_model + + gen = engine._ensure_gen_model() + layer = unwrap_model(gen).decoder.layers[args.layer] + hook = L26Hook() + layer.register_forward_hook(hook) + + pairs = sample_sequences( + args.fasta, args.n_seqs * args.seq_len, args.seq_len, kingdoms=["prok", "euk"], seed=args.seed + )[: args.n_seqs] + batches = [] + for kingdom, dna in pairs: + ids = engine.tokenize(KINGDOM_TAGS[kingdom] + dna) + if len(ids) > 4: + batches.append(torch.tensor([ids], dtype=torch.long, device=dev)) + + def fwd(ids): + return gen(input_ids=ids, position_ids=None, attention_mask=None, labels=None, runtime_gather_output=True) + + def get_hiddens(batch): + hook.mode = "capture" + fwd(batch) + hook.mode = "off" + return hook.captured # [S, 1, H] + + def compute_ce(batch, override): + if override is None: + hook.mode = "off" + else: + hook.mode = "replace" + hook.override = override + logits = fwd(batch) + hook.mode = "off" + hook.override = None + lg = logits[0, :-1].float() # [S-1, V] + tgt = batch[0, 1:] + ce = Fn.cross_entropy(lg, tgt, reduction="sum") + return float(ce), int(tgt.numel()) + + with engine._lock: + res = evaluate_loss_recovered(SAEWrap(engine.sae), batches, get_hiddens, compute_ce, device=dev) + print("\n==== Evo2 7B layer-%d SAE — loss recovered ====" % args.layer) + print(res) + print( + f"loss_recovered = {res.loss_recovered:.3f} " + f"(CE clean={res.ce_original:.3f}, SAE={res.ce_sae:.3f}, zero={res.ce_zero:.3f}, n_tok={res.n_tokens})" + ) + + +if __name__ == "__main__": + main() From b0c2330cd1bdc8f6cf5f6ee10d2217ce951fb1f1 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 19:56:01 +0000 Subject: [PATCH 02/14] evo2 probe: generic annotated-dataset domain-eval (F1 + AUROC vs any track) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User supplies an annotated dataset (FASTA + BED/GFF tracks); the SAE annotates it and we report, per concept, the best feature by instance-level domain-F1 (precision-per-nt, recall-per-annotation) and — threshold-free — by AUROC across the dictionary. Works for RefSeq/Rfam/JASPAR/ENCODE or any user-supplied intervals. - annot_tracks.py: generic interval-track loader (BED/GFF -> per-token mask + global instance IDs) + read_fasta_dict. The generic sibling of euk_windows.py; both feed the shared sae.eval.probing scorers (domain_f1 / auroc_all). - probe.py: 'domain-eval' subcommand (--track NAME=PATH[:GFF_FEATURE], repeatable). - tests/test_annot_tracks.py: CPU tests (BED/GFF parse, mask+instance, split-window id). Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/annot_tracks.py | 161 ++++++++++++++++++ .../recipes/evo2/scripts/probe.py | 96 +++++++++++ .../recipes/evo2/tests/test_annot_tracks.py | 76 +++++++++ 3 files changed, 333 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py new file mode 100644 index 0000000000..ca9c2de93f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Generic interval-track loader for the "user-supplied annotated dataset" eval. + +The user hands in an annotated dataset: a FASTA of sequences + one or more annotation +tracks (BED or GFF) naming intervals — RefSeq genes/exons, Rfam ncRNA, JASPAR TFBS, +ENCODE cCREs, etc. Each interval is one annotation **instance**. This module tiles the +sequences into windows and produces, per concept, a per-token boolean mask + per-token +**global** instance IDs (stable across the windows an interval spans) — exactly the +inputs `sae.eval.probing.domain_f1` (recall-per-instance) and `auroc_all` (per-feature) +consume. No model here; the SAE-encode step lives in the probe CLI (`probe.py domain-eval`). + +This is the generic sibling of `euk_windows.py` (which decomposes RefSeq gene models into +exon/intron/cds). Both feed the same shared scorers. +""" + +from __future__ import annotations + +import gzip +from collections import defaultdict + + +def read_fasta_dict(path: str) -> dict[str, str]: + """Read a (multi-record) FASTA into ``{seq_id: sequence}`` (``.gz`` transparent). + + ``seq_id`` is the first whitespace-delimited token of the header, so it matches the + chrom/seqid column of BED/GFF tracks. + """ + opener = gzip.open if str(path).endswith(".gz") else open + seqs: dict[str, str] = {} + name, parts = None, [] + with opener(path, "rt") as fh: + for line in fh: + line = line.rstrip() + if line.startswith(">"): + if name is not None: + seqs[name] = "".join(parts) + header = line[1:].strip().split() + name, parts = (header[0] if header else f"seq_{len(seqs)}"), [] + else: + parts.append(line) + if name is not None: + seqs[name] = "".join(parts) + return seqs + + +def _parse_bed(path): + """Yield (chrom, start0, end0) from a BED file (0-based, half-open — used as-is).""" + opener = gzip.open if str(path).endswith(".gz") else open + with opener(path, "rt") as fh: + for line in fh: + if not line.strip() or line.startswith(("#", "track", "browser")): + continue + f = line.split("\t") + if len(f) < 3: + continue + yield f[0], int(f[1]), int(f[2]) + + +def _parse_gff(path, feature_type=None): + """Yield (seqid, start0, end0) from GFF/GTF (1-based inclusive -> 0-based half-open).""" + opener = gzip.open if str(path).endswith(".gz") else open + with opener(path, "rt") as fh: + for line in fh: + if line.startswith("#") or not line.strip(): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 5: + continue + if feature_type and f[2] != feature_type: + continue + yield f[0], int(f[3]) - 1, int(f[4]) + + +def load_track(path: str, feature_type: str | None = None, fmt: str | None = None) -> dict[str, list[tuple[int, int]]]: + """Load one annotation track into ``{seqid: [(start0, end0), ...]}`` (0-based half-open). + + Args: + path: BED or GFF/GTF file (``.gz`` ok). + feature_type: GFF only — keep just this column-3 type (e.g. ``"exon"``, ``"ncRNA"``). + fmt: ``"bed"`` / ``"gff"``; inferred from the extension when omitted. + + Returns: + Intervals grouped by sequence id, each sorted. Every interval is one instance. + """ + fmt = fmt or ("gff" if str(path).endswith((".gff", ".gff3", ".gtf", ".gff.gz", ".gff3.gz")) else "bed") + rows = _parse_gff(path, feature_type) if fmt == "gff" else _parse_bed(path) + by_seq: dict[str, list[tuple[int, int]]] = defaultdict(list) + for chrom, s, e in rows: + if e > s: + by_seq[chrom].append((s, e)) + return {k: sorted(v) for k, v in by_seq.items()} + + +def label_windows(seqs, tracks, seq_len=1024, stride=None, max_tokens=None, min_n_frac=0.5): + """Tile sequences into windows, labeling each position per concept (mask + global instance id). + + Args: + seqs: ``{seqid: dna_str}``. + tracks: ``{concept: {seqid: [(start0, end0), ...]}}`` (e.g. from `load_track`). + seq_len: window length in bp. + stride: step between windows (defaults to non-overlapping = seq_len). + max_tokens: stop once this many positions are emitted (None = all). + min_n_frac: skip windows whose ``N`` fraction exceeds this. + + Returns: + (windows, stats). Each window is ``{"dna": str, "labels": {concept: bool[L]}, + "instances": {concept: int32[L]}}``. Each interval gets one global id, stable across + the windows it spans, so `domain_f1`'s recall-per-instance counts a split interval once. + """ + import numpy as np + + stride = stride or seq_len + concepts = list(tracks.keys()) + # assign a global instance id to every interval, per concept + concept_iv: dict[str, dict[str, list[tuple[int, int, int]]]] = {} + n_inst: dict[str, int] = {} + for concept in concepts: + gid = 0 + cc: dict[str, list[tuple[int, int, int]]] = {} + for seqid, ivs in tracks[concept].items(): + cc[seqid] = [(s, e, (gid := gid + 1) - 1) for (s, e) in ivs] + concept_iv[concept] = cc + n_inst[concept] = gid + + windows, tot = [], 0 + for seqid, dna in seqs.items(): + dna = dna.upper() + N = len(dna) + for w0 in range(0, max(1, N - seq_len + 1), stride): + w1 = min(N, w0 + seq_len) + sub = dna[w0:w1] + L = w1 - w0 + if L < 60 or sub.count("N") > min_n_frac * L: + continue + labels = {c: np.zeros(L, bool) for c in concepts} + inst = {c: np.full(L, -1, np.int32) for c in concepts} + for c in concepts: + for s, e, gid in concept_iv[c].get(seqid, []): + if e <= w0 or s >= w1: + continue + labels[c][max(s, w0) - w0 : min(e, w1) - w0] = True + inst[c][max(s, w0) - w0 : min(e, w1) - w0] = gid + windows.append({"dna": sub, "labels": labels, "instances": inst}) + tot += L + if max_tokens and tot >= max_tokens: + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index 5e40af5831..c1b66b2941 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -21,6 +21,9 @@ probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense probe.py context --acts BUF biological-context vs string-match firing (feat 29244/33918) + probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) + probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs + any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) probe.py loss-recovered [...] fidelity via sae.eval.loss_recovered (needs the model) """ # noqa: D205 @@ -221,6 +224,81 @@ def cmd_euk(a): print(f"{c:8s} {bf:6.3f} {nl:6.3f} {bf / max(nl, 1e-9):6.2f} {float(lab[c].float().mean()):6.1%}") +def _parse_track_spec(spec): + """Parse a ``--track NAME=PATH[:GFF_FEATURE]`` spec -> (name, path, feature_type|None).""" + name, rest = spec.split("=", 1) + ftype = None + if ":" in rest: + head, tail = rest.rsplit(":", 1) + if "/" not in tail and "." not in tail: # a GFF feature type, not part of a path + rest, ftype = head, tail + return name, rest, ftype + + +def cmd_domain_eval(a): + """User-supplied annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC. + + Each ``--track NAME=PATH[:GFF_FEATURE]`` is one concept; its BED/GFF intervals are the + annotation instances (RefSeq/Rfam/JASPAR/ENCODE, or anything the user supplies). The SAE + annotates the windows, then per concept we report the best feature by instance-level + domain-F1 (precision-per-nt, recall-per-annotation) and — threshold-free — by AUROC. + """ + from annot_tracks import label_windows, load_track, read_fasta_dict + from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from sae.eval.probing import auroc_all, domain_f1 + + tracks = {} + for spec in a.track: + name, path, ftype = _parse_track_spec(spec) + tracks[name] = load_track(path, feature_type=ftype) + seqs = read_fasta_dict(a.fasta) + windows, stats = label_windows(seqs, tracks, a.seq_len, max_tokens=a.max_tokens) + concepts = stats["concepts"] + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + F, adev = eng.n_features, a.auroc_device + tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) + tlen, tot = len(tag_ids), stats["tokens"] + code_buf = torch.zeros(tot, F, dtype=torch.float16, device=adev) + lab = {c: torch.zeros(tot, dtype=torch.bool, device=adev) for c in concepts} + inst = {c: torch.full((tot,), -1, dtype=torch.long, device=adev) for c in concepts} + filled = 0 + for s0 in range(0, len(windows), a.batch_size): + batch = windows[s0 : s0 + a.batch_size] + with eng._lock: + for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): + if h.shape[0] == 0: + continue + codes = eng.sae.encode(h.to(a.device)) + take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) + if take <= 0: + continue + code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) + for c in concepts: + lab[c][filled : filled + take] = torch.from_numpy(w["labels"][c][:take]).to(adev) + inst[c][filled : filled + take] = torch.from_numpy(w["instances"][c][:take].astype(np.int64)).to( + adev + ) + filled += take + code_buf = code_buf[:filled] + for c in concepts: + lab[c], inst[c] = lab[c][:filled], inst[c][:filled] + fmax = code_buf.max(0).values.float() + au = auroc_all(code_buf.float().to(a.device), torch.stack([lab[c] for c in concepts], 1).to(a.device)).cpu() + print(f"encoded {filled} positions across {len(concepts)} concept(s)") + print( + f"{'concept':14s} {'%pos':>6s} {'#inst':>6s} | " + f"{'domF1':>6s} {'@thr':>5s} {'feat':>7s} | {'AUROC':>6s} {'feat':>7s}" + ) + for i, c in enumerate(concepts): + f1, thr = domain_f1(code_buf, fmax, lab[c], inst[c]) + bi, ai = int(f1.argmax()), int(au[:, i].argmax()) + print( + f"{c:14s} {float(lab[c].float().mean()):6.1%} {stats['n_inst'][c]:6d} | " + f"{float(f1[bi]):6.3f} {float(thr[bi]):5.2f} {bi:7d} | {float(au[ai, i]):6.3f} {ai:7d}" + ) + + def cmd_extract(a): # noqa: D103 from evo2_buffer import build_buffer, sample_sequences from evo2_sae_infer.core import Evo2SAE @@ -286,6 +364,24 @@ def main(): # noqa: D103 pk.add_argument("--batch-size", type=int, default=8) pk.add_argument("--auroc-device", default="cuda:1") pk.set_defaults(func=cmd_euk) + pd = sub.add_parser("domain-eval", parents=[common]) + for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta"]: + pd.add_argument(arg, required=True) + pd.add_argument( + "--track", + action="append", + required=True, + metavar="NAME=PATH[:GFF_FEATURE]", + help="annotation track; BED or GFF intervals = instances of concept NAME. Repeatable " + "(e.g. --track exon=refseq.gff3:exon --track tfbs=jaspar.bed --track cCRE=encode.bed).", + ) + pd.add_argument("--layer", type=int, required=True) + pd.add_argument("--organism", default="Human") + pd.add_argument("--max-tokens", type=int, default=160_000) + pd.add_argument("--seq-len", type=int, default=1024) + pd.add_argument("--batch-size", type=int, default=8) + pd.add_argument("--auroc-device", default="cuda:1") + pd.set_defaults(func=cmd_domain_eval) args = ap.parse_args() torch.set_grad_enabled(False) args.func(args) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py new file mode 100644 index 0000000000..e927c9620d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the generic interval-track loader (no model / no torch-CUDA).""" + +import sys +from pathlib import Path + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from annot_tracks import label_windows, load_track, read_fasta_dict + + +def test_read_fasta_dict_uses_first_token(tmp_path): + """seq_id is the first header token (so it matches BED/GFF chrom).""" + fa = tmp_path / "g.fa" + fa.write_text(">chr1 Homo sapiens\nACGT\nACGT\n>chr2\nTTTT\n") + assert read_fasta_dict(fa) == {"chr1": "ACGTACGT", "chr2": "TTTT"} + + +def test_load_bed_is_half_open(tmp_path): + """BED is 0-based half-open and used as-is.""" + bed = tmp_path / "t.bed" + bed.write_text("chr1\t2\t5\tsiteA\nchr1\t10\t12\tsiteB\n") + assert load_track(str(bed)) == {"chr1": [(2, 5), (10, 12)]} + + +def test_load_gff_converts_to_half_open_and_filters_type(tmp_path): + """GFF 1-based inclusive -> 0-based half-open; feature_type filters column 3.""" + gff = tmp_path / "t.gff3" + gff.write_text( + "# comment\n" + "chr1\tsrc\texon\t3\t5\t.\t+\t.\tID=e1\n" # 1-based [3,5] -> [2,5) + "chr1\tsrc\tCDS\t3\t5\t.\t+\t.\tID=c1\n" + ) + assert load_track(str(gff), feature_type="exon") == {"chr1": [(2, 5)]} + + +def test_label_windows_mask_and_instance_ids(tmp_path): + """Each interval is one instance; mask + instance id line up with the window positions.""" + seqs = {"chr1": "ACGT" * 25} # 100 bp (above the 60 bp window floor) + tracks = {"site": {"chr1": [(2, 5), (10, 12)]}} # two instances + windows, stats = label_windows(seqs, tracks, seq_len=100) + assert len(windows) == 1 + w = windows[0] + mask, inst = w["labels"]["site"], w["instances"]["site"] + assert list(mask.nonzero()[0]) == [2, 3, 4, 10, 11] + assert set(inst[mask].tolist()) == {0, 1} # two distinct instances + assert (inst[~mask] == -1).all() + assert stats["n_inst"]["site"] == 2 + + +def test_instance_id_stable_across_split_windows(): + """An interval spanning a window boundary keeps ONE global instance id (recall counts it once).""" + seqs = {"chr1": "A" * 200} + tracks = {"big": {"chr1": [(90, 110)]}} # straddles the 0-100 / 100-200 boundary + windows, stats = label_windows(seqs, tracks, seq_len=100) + ids = set() + for w in windows: + inst = w["instances"]["big"] + ids.update(int(x) for x in inst[inst >= 0]) + assert ids == {0} # same id in both windows + assert stats["n_inst"]["big"] == 1 From e520186ba943be6cfafba82195ee7edbcdccd69e Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 20:17:02 +0000 Subject: [PATCH 03/14] evo2 probe: trim demo command + superseded proxy labelers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cut dead/superseded surface from the eval harness: - probe.py: drop the 'context' subcommand (a hardcoded-feature-id findings demo for feats 29244/33918 — belongs in a writeup, not the general eval). - labelers.py: drop 7 proxy labelers now covered by real annotations / pyrodigal: codon_pos_1/3 and orf_frame_0/1/2 (frame-from-start proxies -> use cds_frame_*), tss_proxy_tata (-> real TSS tracks via domain-eval), is_euk_genic (-> RefSeq gene track). Registry 26 -> 19 labelers; the real cds_frame_* / gene-structure labels stay. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/labelers.py | 53 ------------------- .../recipes/evo2/scripts/probe.py | 25 --------- 2 files changed, 78 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index e54e8c05ae..af9d27fe55 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -116,17 +116,6 @@ def _last(ctx): return _dna_mask(ctx, d) -@labeler("codon_pos_1") -def _c1(ctx): - # frame-0 proxy (no CDS annotation): position 0 of each codon from seq start - return _dna_mask(ctx, np.arange(len(ctx.dna)) % 3 == 0) - - -@labeler("codon_pos_3") -def _c3(ctx): - return _dna_mask(ctx, np.arange(len(ctx.dna)) % 3 == 2) - - # --------------------------------------------------------------------- composition def _gc_window(dna: str, radius: int = 10) -> np.ndarray: arr = _bytes(dna) @@ -229,12 +218,6 @@ def _kozak(ctx): return _dna_mask(ctx, out) -@labeler("tss_proxy_tata", complex=True) -def _tss(ctx): - # stricter canonical TATA box as a transcription-start proxy - return _dna_mask(ctx, _spans(ctx.dna, r"TATAAA")) - - @labeler("splice_donor", complex=True) def _sd(ctx): # 5' donor consensus GT(A/G)AGT — mark the GT @@ -250,48 +233,12 @@ def _sa(ctx): return _dna_mask(ctx, out) -def _orf(ctx, frame: int, win: int = 60): - d, n = ctx.dna, len(ctx.dna) - stops = {"TAA", "TAG", "TGA"} - out = np.zeros(n, bool) - for p in range(frame, n - 2, 3): - ok = True - for q in range(p, min(p + win, n - 2), 3): - if d[q : q + 3] in stops: - ok = False - break - if ok: - out[p] = True - return _dna_mask(ctx, out) - - -@labeler("orf_frame_0_60bp", complex=True) -def _orf0(ctx): - return _orf(ctx, 0) - - -@labeler("orf_frame_1_60bp", complex=True) -def _orf1(ctx): - return _orf(ctx, 1) - - -@labeler("orf_frame_2_60bp", complex=True) -def _orf2(ctx): - return _orf(ctx, 2) - - # --------------------------------------------------------------- sequence / norm level @labeler("is_prok") def _prok(ctx): return np.full(ctx.T, ctx.kingdom == "prok", dtype=bool) -@labeler("is_euk_genic", complex=True) -def _eukg(ctx): - # proxy: eukaryotic token. True "genic" needs a gene model — refine later. - return np.full(ctx.T, ctx.kingdom == "euk", dtype=bool) - - @labeler("is_sink_token", complex=True) def _sink(ctx): return ctx.hidden_norm > SINK_NORM_THRESHOLD diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index c1b66b2941..4a80a5d832 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -20,7 +20,6 @@ probe.py auroc --acts BUF --labels .. per-feature AUROC table probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense - probe.py context --acts BUF biological-context vs string-match firing (feat 29244/33918) probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) @@ -116,29 +115,6 @@ def cmd_linear(a): # noqa: D103 print(row) -def cmd_context(a): # noqa: D103 - z = np.load(a.acts, allow_pickle=True) - codes, labels = z["codes"], z["labels"] - idx = {n: i for i, n in enumerate(z["label_names"])} - P = codes.shape[0] - - def lab(n): - return labels[:, idx[n]].astype(bool) - - def rate(feat, m): - return (float((codes[:, feat][m] > 0).mean()), int(m.sum())) if m.sum() else (float("nan"), 0) - - ATG, STOP, START, INF = lab("motif_ATG"), lab("motif_stop"), lab("cds_start"), lab("cds_frame_1") - print(f"baseline: 29244={rate(29244, np.ones(P, bool))[0]:.3f} 33918={rate(33918, np.ones(P, bool))[0]:.3f}") - for nm, f, motif, ctx, cl in [ - ("29244 ATG", 29244, ATG, START, "real start"), - ("33918 STOP", 33918, STOP, INF, "in-frame"), - ]: - ra, na = rate(f, motif & ctx) - rb, nb = rate(f, motif & ~ctx) - print(f"{nm}: {cl} {ra:.3f}(n={na}) | other {rb:.3f}(n={nb}) | ratio {ra / max(rb, 1e-9):.2f}") - - def cmd_codon_aa(a): # noqa: D103 z = np.load(a.acts) dev = a.device @@ -334,7 +310,6 @@ def main(): # noqa: D103 for name, fn, needs_labels in [ ("auroc", cmd_auroc, True), ("linear", cmd_linear, True), - ("context", cmd_context, False), ("codon-aa", cmd_codon_aa, False), ]: p = sub.add_parser(name, parents=[common]) From e7d5a05905d31f0839c84e1975dc915f121f03cc Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 22:57:05 +0000 Subject: [PATCH 04/14] evo2 eval: use Biopython for the genetic code in labelers (drop hand-rolled table) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the cryptic hand-rolled NCBI table-1 string (_AA1) + codon comprehension and the str.translate revcomp with Bio.Data.CodonTable + Bio.Seq — verified byte-identical to the old 64-codon table and revcomp. Declares biopython (+ bioframe, used next for interval ops). Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 2 ++ .../recipes/evo2/scripts/labelers.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f7b23d8eff..f4c81cf3ff 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -13,6 +13,8 @@ dependencies = [ "torch>=2.0", "numpy>=1.20", "pyarrow>=23.0.0", + "biopython>=1.80", # genetic code / translation in labelers.py + "bioframe>=0.4", # genomic-interval ops (GFF/BED parse, intron/UTR/overlap) ] # No package code lives here yet — the recipe is just an entry-point for diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index af9d27fe55..cf89801c16 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -46,6 +46,8 @@ def _my(ctx): from typing import Optional import numpy as np +from Bio.Data import CodonTable +from Bio.Seq import Seq # name -> fn(ctx) -> np.ndarray[bool] of length T @@ -278,21 +280,17 @@ def _cds_f3(ctx): _GENE_FINDER = None -# Standard genetic code (NCBI translation table 1), codons in TCAG x TCAG x TCAG order. -_BASES = "TCAG" -_AA1 = "FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG" -CODON_TABLE = { - a + b + c: _AA1[i] for i, (a, b, c) in enumerate((x, y, z) for x in _BASES for y in _BASES for z in _BASES) -} +# Standard genetic code (NCBI translation table 1) via Biopython; codon -> amino acid ('*' = stop). +_STD_CODE = CodonTable.unambiguous_dna_by_id[1] +CODON_TABLE = {**_STD_CODE.forward_table, **dict.fromkeys(_STD_CODE.stop_codons, "*")} CODON_LIST = sorted(CODON_TABLE) # 64 codons CODON_TO_IDX = {c: i for i, c in enumerate(CODON_LIST)} AA_LIST = sorted(set(CODON_TABLE.values())) # 20 aa + '*' (stop) AA_TO_IDX = {a: i for i, a in enumerate(AA_LIST)} -_COMP = str.maketrans("ACGTN", "TGCAN") def _revcomp(s): - return s.translate(_COMP)[::-1] + return str(Seq(s).reverse_complement()) def predict_codons(dna: str): From 1d2e56faf8e02a5f3eec8176e04f7906eaf6bf6e Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 22:57:59 +0000 Subject: [PATCH 05/14] =?UTF-8?q?evo2=20eval:=20drop=20bioframe=20dep=20?= =?UTF-8?q?=E2=80=94=20investigation=20showed=20no=20real=20reduction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit read_table keeps GFF 1-based (coordinate handling stays ours), doesn't parse GFF attributes (gene-model assembly stays ours), and complement/subtract only trade short loops for comparable DataFrame code. Not worth a dependency against tested interval code; revisit if/when JASPAR/ENCODE/Rfam interval tracks are added. Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f4c81cf3ff..ba794c5e74 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ "numpy>=1.20", "pyarrow>=23.0.0", "biopython>=1.80", # genetic code / translation in labelers.py - "bioframe>=0.4", # genomic-interval ops (GFF/BED parse, intron/UTR/overlap) ] # No package code lives here yet — the recipe is just an entry-point for From efbbf11028f4fa79bee6218d72f84af15619dea9 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 23:33:23 +0000 Subject: [PATCH 06/14] evo2 probe: add end-to-end usage examples + factor shared subcommand args - docstring: concrete example flow (extract -> auroc/linear -> domain-eval -> loss-recovered) with real flags, so it's clear how the CLI is actually invoked. - main(): collapse the repeated model/encoding args (--evo2-ckpt-dir/--sae-checkpoint/ --fasta/--layer/--max-tokens/--seq-len/--batch-size/--auroc-device) across extract/ euk-f1/domain-eval into one _add_model_args helper. All flags verified intact via --help. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/probe.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index 4a80a5d832..dd19c0c33d 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -24,6 +24,24 @@ probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) probe.py loss-recovered [...] fidelity via sae.eval.loss_recovered (needs the model) + +Example end-to-end flow (7B / layer 26; $CKPT = MBridge dir, $SAE = trained SAE .pt): + + # 1. Build the probing buffer once: SAE codes + dense twin + per-token labels (needs the model) + python probe.py extract --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ + --fasta probe_set.fa --out buf.npz + + # 2-3. Score the buffer (no model): per-feature AUROC, then SAE-vs-dense linear probes + python probe.py auroc --acts buf.npz --labels motif_ATG,motif_stop,cds_coding,is_prok + python probe.py linear --acts buf.npz --labels cds_coding,is_prok + + # 4. User annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC, + # vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) + python probe.py domain-eval --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ + --fasta GRCh38_chr20.fa --track exon=refseq.gff3:exon --track cCRE=encode_ccre.bed + + # 5. SAE fidelity (loss recovered) — separate script, needs the model + python probe_loss_recovered.py --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 --fasta probe_set.fa """ # noqa: D205 from __future__ import annotations @@ -298,6 +316,17 @@ def cmd_extract(a): # noqa: D103 print(f"saved buffer -> {a.out} ({buf.codes.shape[0]} x {buf.codes.shape[1]}, dense {buf.dense.shape[1]})") +def _add_model_args(p, *, required=(), max_tokens=160_000): + """Shared model + encoding args for the model-backed subcommands (extract/euk-f1/domain-eval).""" + for arg in ("--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", *required): + p.add_argument(arg, required=True) + p.add_argument("--layer", type=int, required=True) + p.add_argument("--max-tokens", type=int, default=max_tokens) + p.add_argument("--seq-len", type=int, default=1024) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--auroc-device", default="cuda:1") + + def main(): # noqa: D103 ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) sub = ap.add_subparsers(dest="cmd", required=True) @@ -318,30 +347,17 @@ def main(): # noqa: D103 p.add_argument("--labels", required=True) p.set_defaults(func=fn) pe = sub.add_parser("extract", parents=[common]) - for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", "--out"]: - pe.add_argument(arg, required=True) - pe.add_argument("--layer", type=int, required=True) + _add_model_args(pe, required=("--out",), max_tokens=200_000) pe.add_argument("--kingdoms", default="prok,euk") pe.add_argument("--annotate-cds", action="store_true") - pe.add_argument("--max-tokens", type=int, default=200_000) pe.add_argument("--subsample", type=int, default=50_000) - pe.add_argument("--seq-len", type=int, default=1024) - pe.add_argument("--batch-size", type=int, default=8) - pe.add_argument("--auroc-device", default="cuda:1") pe.set_defaults(func=cmd_extract) pk = sub.add_parser("euk-f1", parents=[common]) - for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", "--gff"]: - pk.add_argument(arg, required=True) - pk.add_argument("--layer", type=int, required=True) + _add_model_args(pk, required=("--gff",)) pk.add_argument("--organism", default="Human") - pk.add_argument("--max-tokens", type=int, default=160_000) - pk.add_argument("--seq-len", type=int, default=1024) - pk.add_argument("--batch-size", type=int, default=8) - pk.add_argument("--auroc-device", default="cuda:1") pk.set_defaults(func=cmd_euk) pd = sub.add_parser("domain-eval", parents=[common]) - for arg in ["--evo2-ckpt-dir", "--sae-checkpoint", "--fasta"]: - pd.add_argument(arg, required=True) + _add_model_args(pd) pd.add_argument( "--track", action="append", @@ -350,12 +366,7 @@ def main(): # noqa: D103 help="annotation track; BED or GFF intervals = instances of concept NAME. Repeatable " "(e.g. --track exon=refseq.gff3:exon --track tfbs=jaspar.bed --track cCRE=encode.bed).", ) - pd.add_argument("--layer", type=int, required=True) pd.add_argument("--organism", default="Human") - pd.add_argument("--max-tokens", type=int, default=160_000) - pd.add_argument("--seq-len", type=int, default=1024) - pd.add_argument("--batch-size", type=int, default=8) - pd.add_argument("--auroc-device", default="cuda:1") pd.set_defaults(func=cmd_domain_eval) args = ap.parse_args() torch.set_grad_enabled(False) From 316f8d31e96294c3cdadbc883c7a0809530a03c1 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 23:38:04 +0000 Subject: [PATCH 07/14] evo2 probe: factor the shared window-encode loop (euk-f1 + domain-eval) cmd_euk and cmd_domain_eval each carried a ~22-line copy of the same 'stream windows through the SAE -> fill code_buf/lab/inst -> trim' loop, differing only in window source and the label/instance key sets. Extract it to _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a); both commands now call it. Net ~-22 lines, one encode path. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/probe.py | 80 ++++++++----------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index dd19c0c33d..c00b749e51 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -166,24 +166,17 @@ def cmd_codon_aa(a): # noqa: D103 # ───────────────────────────────────────── model subcommands (need Evo2) -def cmd_euk(a): - """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" - from euk_windows import build_windows - from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE - from sae.eval.probing import domain_f1 +def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): + """Stream tiled windows through the SAE -> (code_buf[filled,F], lab{k:bool}, inst{k:long}, fmax[F]). - eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - windows, stats, tot, _ = build_windows(a.fasta, a.gff, a.seq_len, a.max_tokens, seed=a.seed) - print( - f"windows={len(windows)} tokens={tot} genes={stats['genes']} exons={stats['exons']} introns={stats['introns']}" - ) - F, adev = eng.n_features, a.auroc_device - tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) - tlen = len(tag_ids) - concepts = {"exon": "exon", "intron": "intron", "cds": "gene"} - code_buf = torch.zeros(tot, F, dtype=torch.float16, device=adev) - lab = {k: torch.zeros(tot, dtype=torch.bool, device=adev) for k in ("exon", "intron", "cds")} - inst = {k: torch.full((tot,), -1, dtype=torch.long, device=adev) for k in ("exon", "intron", "gene")} + Shared by euk-f1 and domain-eval: encodes each window (skipping the phylo-tag prefix) and + fills per-concept label masks (lab_keys) + instance ids (inst_keys). Buffers are trimmed to + the number of positions actually filled. + """ + adev, tlen = a.auroc_device, len(tag_ids) + code_buf = torch.zeros(tot, eng.n_features, dtype=torch.float16, device=adev) + lab = {k: torch.zeros(tot, dtype=torch.bool, device=adev) for k in lab_keys} + inst = {k: torch.full((tot,), -1, dtype=torch.long, device=adev) for k in inst_keys} filled = 0 for s0 in range(0, len(windows), a.batch_size): batch = windows[s0 : s0 + a.batch_size] @@ -207,10 +200,29 @@ def cmd_euk(a): for d in (lab, inst): for k in d: d[k] = d[k][:filled] - fmax = code_buf.max(0).values.float() + fmax = code_buf.max(0).values.float() if filled else torch.zeros(eng.n_features, device=adev) + return code_buf, lab, inst, fmax + + +def cmd_euk(a): + """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" + from euk_windows import build_windows + from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from sae.eval.probing import domain_f1 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + windows, stats, tot, _ = build_windows(a.fasta, a.gff, a.seq_len, a.max_tokens, seed=a.seed) + print( + f"windows={len(windows)} tokens={tot} genes={stats['genes']} exons={stats['exons']} introns={stats['introns']}" + ) + tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) + code_buf, lab, inst, fmax = _encode_windows( + eng, windows, tag_ids, ("exon", "intron", "cds"), ("exon", "intron", "gene"), tot, a + ) + filled, adev = code_buf.shape[0], a.auroc_device g = torch.Generator(device=adev).manual_seed(a.seed) print(f"encoded {filled} positions\n{'concept':8s} {'domF1':>6s} {'null':>6s} {'ratio':>6s} {'%pos':>6s}") - for c, ic in concepts.items(): + for c, ic in {"exon": "exon", "intron": "intron", "cds": "gene"}.items(): f1, _ = domain_f1(code_buf, fmax, lab[c], inst[ic]) order = torch.randperm(filled, generator=g, device=adev) f1n, _ = domain_f1(code_buf, fmax, lab[c][order], inst[ic][order]) @@ -250,36 +262,10 @@ def cmd_domain_eval(a): concepts = stats["concepts"] eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - F, adev = eng.n_features, a.auroc_device tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) - tlen, tot = len(tag_ids), stats["tokens"] - code_buf = torch.zeros(tot, F, dtype=torch.float16, device=adev) - lab = {c: torch.zeros(tot, dtype=torch.bool, device=adev) for c in concepts} - inst = {c: torch.full((tot,), -1, dtype=torch.long, device=adev) for c in concepts} - filled = 0 - for s0 in range(0, len(windows), a.batch_size): - batch = windows[s0 : s0 + a.batch_size] - with eng._lock: - for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): - if h.shape[0] == 0: - continue - codes = eng.sae.encode(h.to(a.device)) - take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) - if take <= 0: - continue - code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) - for c in concepts: - lab[c][filled : filled + take] = torch.from_numpy(w["labels"][c][:take]).to(adev) - inst[c][filled : filled + take] = torch.from_numpy(w["instances"][c][:take].astype(np.int64)).to( - adev - ) - filled += take - code_buf = code_buf[:filled] - for c in concepts: - lab[c], inst[c] = lab[c][:filled], inst[c][:filled] - fmax = code_buf.max(0).values.float() + code_buf, lab, inst, fmax = _encode_windows(eng, windows, tag_ids, concepts, concepts, stats["tokens"], a) au = auroc_all(code_buf.float().to(a.device), torch.stack([lab[c] for c in concepts], 1).to(a.device)).cpu() - print(f"encoded {filled} positions across {len(concepts)} concept(s)") + print(f"encoded {code_buf.shape[0]} positions across {len(concepts)} concept(s)") print( f"{'concept':14s} {'%pos':>6s} {'#inst':>6s} | " f"{'domF1':>6s} {'@thr':>5s} {'feat':>7s} | {'AUROC':>6s} {'feat':>7s}" From 721760ed32c549e6e025d495f00f77d715b0d0de Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 00:13:09 +0000 Subject: [PATCH 08/14] evo2 eval: add single-base labelers (base_A/C/G/T) Evo2 is nucleotide-level (1 token = 1 base), so the most atomic feature is a single-base detector. Adds one labeler per nucleotide (fires at every position equal to it) so probing can surface/annotate features that monosemantically track a single base (e.g. base_G). Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/labelers.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index cf89801c16..83d5b6ed51 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -118,6 +118,18 @@ def _last(ctx): return _dna_mask(ctx, d) +# --------------------------------------------------------------------- single base +# Evo2 is nucleotide-level (1 token = 1 base), so the most atomic feature is a single-base +# detector. One labeler per nucleotide — fires at every position whose base equals it — +# so probing can surface features that monosemantically track a single base (e.g. base_G). +def _register_base_labelers(): + for base in "ACGT": + labeler(f"base_{base}")(lambda ctx, b=base: _dna_mask(ctx, _bytes(ctx.dna) == ord(b))) + + +_register_base_labelers() + + # --------------------------------------------------------------------- composition def _gc_window(dna: str, radius: int = 10) -> np.ndarray: arr = _bytes(dna) From dfcd85100d347f5272b35fa63dec16c251a24930 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 00:34:55 +0000 Subject: [PATCH 09/14] =?UTF-8?q?evo2=20probe:=20add=20'annotate'=20?= =?UTF-8?q?=E2=80=94=20persist=20per-feature=20best=20concept=20to=20the?= =?UTF-8?q?=20annotation=20parquet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit probe.py annotate --acts BUF --out P calls sae.eval.probing.annotate_features (best concept per feature by AUROC, above --min-auroc) and writes feature_metadata-style parquet {feature_id, label, auroc, activation_freq, max_activation} — the file the engine/dashboard load via --feature-annotations. Picks up base_A/C/G/T (and every labeler). CPU-tested e2e. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/probe.py | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index c00b749e51..b65b98d8ff 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -17,7 +17,8 @@ this driver only knows how to build/load Evo2 buffers and pick label sets. probe.py extract --out BUF [...] build an ActivationBuffer (needs the model) - probe.py auroc --acts BUF --labels .. per-feature AUROC table + probe.py auroc --acts BUF --labels .. per-feature AUROC table (prints) + probe.py annotate --acts BUF --out P assign each feature its best concept -> annotation parquet probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) @@ -31,10 +32,14 @@ python probe.py extract --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ --fasta probe_set.fa --out buf.npz - # 2-3. Score the buffer (no model): per-feature AUROC, then SAE-vs-dense linear probes + # 2. Score the buffer (no model): per-feature AUROC, then SAE-vs-dense linear probes python probe.py auroc --acts buf.npz --labels motif_ATG,motif_stop,cds_coding,is_prok python probe.py linear --acts buf.npz --labels cds_coding,is_prok + # 3. Persist annotations (no model): each feature's best concept (incl. base_A/C/G/T) -> + # the feature-annotation parquet the engine/dashboard load via --feature-annotations + python probe.py annotate --acts buf.npz --out feature_annotations.parquet --min-auroc 0.85 + # 4. User annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC, # vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) python probe.py domain-eval --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ @@ -165,6 +170,35 @@ def cmd_codon_aa(a): # noqa: D103 print(f"{nm:6s} {ca:12.3f} {aaa:10.3f} | {' '.join(rec)}") +def cmd_annotate(a): + """Buffer -> feature-annotation parquet: each feature's best concept by AUROC + activation stats. + + The persist step (uses sae.eval.probing.annotate_features). Writes a feature_metadata-style + parquet — {feature_id, label, auroc, activation_freq, max_activation} — the engine/dashboard + load via --feature-annotations. Concepts default to all labels in the buffer (incl. base_*). + """ + import pyarrow as pa + import pyarrow.parquet as pq + from sae.eval.probing import annotate_features + + buf = ActivationBuffer.load(a.acts) + dev = a.device + names = [t for t in (a.labels.split(",") if a.labels else list(buf.label_names)) if t in buf.name_idx] + X = torch.from_numpy(buf.codes).to(dev).float() + Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) + ann = annotate_features(X, Y, names, min_auroc=a.min_auroc) + cols = {"feature_id": [], "label": [], "auroc": [], "activation_freq": [], "max_activation": []} + for r in ann: + col = X[:, r["feature_id"]] + cols["feature_id"].append(r["feature_id"]) + cols["label"].append(r["label"]) + cols["auroc"].append(r["auroc"]) + cols["activation_freq"].append(round(float((col > 0).float().mean()), 6)) + cols["max_activation"].append(round(float(col.max()), 4)) + pq.write_table(pa.table(cols), a.out, compression="snappy") + print(f"[annotate] {len(ann)} features labeled (AUROC >= {a.min_auroc}) over {len(names)} concepts -> {a.out}") + + # ───────────────────────────────────────── model subcommands (need Evo2) def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): """Stream tiled windows through the SAE -> (code_buf[filled,F], lab{k:bool}, inst{k:long}, fmax[F]). @@ -332,6 +366,14 @@ def main(): # noqa: D103 if needs_labels: p.add_argument("--labels", required=True) p.set_defaults(func=fn) + pan = sub.add_parser("annotate", parents=[common]) + pan.add_argument("--acts", required=True) + pan.add_argument("--out", required=True) + pan.add_argument( + "--labels", default=None, help="comma-separated concept subset; default = all labels in the buffer" + ) + pan.add_argument("--min-auroc", type=float, default=0.8) + pan.set_defaults(func=cmd_annotate) pe = sub.add_parser("extract", parents=[common]) _add_model_args(pe, required=("--out",), max_tokens=200_000) pe.add_argument("--kingdoms", default="prok,euk") From d3636f4caf8ef569b32bb933a4d67a054df6601d Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 03:31:55 +0000 Subject: [PATCH 10/14] evo2 eval: reconcile engine import to evo2_sae (was evo2_sae_infer) The eval scripts imported a stale package name (evo2_sae_infer) that doesn't exist; the serve recipe's package is evo2_sae. Fix the 3 importers (probe.py, probe_loss_recovered.py, evo2_buffer.py) so the model-backed eval commands resolve against the real engine package. Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py | 2 +- .../sparse_autoencoders/recipes/evo2/scripts/probe.py | 6 +++--- .../recipes/evo2/scripts/probe_loss_recovered.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py index 10a8529b41..8bd14fe525 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py @@ -50,7 +50,7 @@ def read_fasta(path): # noqa: D103 def sample_sequences(fasta, max_tokens, seq_len, kingdoms=("prok", "euk"), seed=0): # noqa: D103 - from evo2_sae_infer.core import clean_dna + from evo2_sae.core import clean_dna kingdoms = list(kingdoms) pools = {k: [] for k in kingdoms} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index b65b98d8ff..e01a949142 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -241,7 +241,7 @@ def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): def cmd_euk(a): """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" from euk_windows import build_windows - from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE from sae.eval.probing import domain_f1 eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() @@ -284,7 +284,7 @@ def cmd_domain_eval(a): domain-F1 (precision-per-nt, recall-per-annotation) and — threshold-free — by AUROC. """ from annot_tracks import label_windows, load_track, read_fasta_dict - from evo2_sae_infer.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE from sae.eval.probing import auroc_all, domain_f1 tracks = {} @@ -315,7 +315,7 @@ def cmd_domain_eval(a): def cmd_extract(a): # noqa: D103 from evo2_buffer import build_buffer, sample_sequences - from evo2_sae_infer.core import Evo2SAE + from evo2_sae.core import Evo2SAE eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() label_names = list(L.LABELERS.keys()) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py index 490cdb9c26..1994b9c6b9 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py @@ -41,7 +41,7 @@ sys.path.insert(0, str(_HERE.parent)) from evo2_buffer import sample_sequences # noqa: E402 -from evo2_sae_infer.core import Evo2SAE # noqa: E402 +from evo2_sae.core import Evo2SAE # noqa: E402 from sae.eval.loss_recovered import evaluate_loss_recovered # noqa: E402 (Jared's code) From fe6bb146c93e01a8b36fbe288a1dc90c1893df53 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 03:44:29 +0000 Subject: [PATCH 11/14] evo2 eval: split harness out -> label producers only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is now the label-producer half: labelers (motifs/base/ORF/codon), euk_windows, annot_tracks + tests — the 'what to measure' concepts, mostly CPU-testable. The probing harness/CLI (probe.py, evo2_buffer, probe_loss_recovered) moves to a leaf PR stacked here. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/evo2_buffer.py | 139 ------ .../recipes/evo2/scripts/probe.py | 405 ------------------ .../evo2/scripts/probe_loss_recovered.py | 150 ------- 3 files changed, 694 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py deleted file mode 100644 index 8bd14fe525..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Evo2-specific bit: turn DNA sequences into a probing ActivationBuffer. - -The only model-touching code in the probing pipeline. Streams sequences through -the Evo2SAE engine (Evo2 -> layer-L residual -> SAE.encode), keeps the dense -residual twin, and computes per-token labels (+ instance IDs) from labelers.py. -All scoring is done elsewhere by the model-agnostic sae.eval.probing metrics. -""" - -from __future__ import annotations - -import random - -import labelers as L -import numpy as np -import torch -from sae.eval.probing import ActivationBuffer - - -KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} - - -def read_fasta(path): # noqa: D103 - header, chunks = None, [] - with open(path) as fh: - for line in fh: - line = line.rstrip("\n") - if line.startswith(">"): - if header is not None: - yield header, "".join(chunks) - header, chunks = line[1:], [] - else: - chunks.append(line) - if header is not None: - yield header, "".join(chunks) - - -def sample_sequences(fasta, max_tokens, seq_len, kingdoms=("prok", "euk"), seed=0): # noqa: D103 - from evo2_sae.core import clean_dna - - kingdoms = list(kingdoms) - pools = {k: [] for k in kingdoms} - need = max_tokens // seq_len + 50 - for header, seq in read_fasta(fasta): - kg = "prok" if header.lower().startswith("prok") else "euk" - if kg not in pools: - continue - dna = clean_dna(seq)[:seq_len] - if len(dna) < 60: - continue - pools[kg].append(dna) - if all(len(pools[k]) >= need for k in kingdoms): - break - rng = random.Random(seed) - for k in kingdoms: - rng.shuffle(pools[k]) - out, tok, i = [], 0, 0 - maxlen = max((len(pools[k]) for k in kingdoms), default=0) - while tok < max_tokens and i < maxlen: - for k in kingdoms: - if i < len(pools[k]): - out.append((k, pools[k][i])) - tok += len(pools[k][i]) + len(KINGDOM_TAGS[k]) - i += 1 - rng.shuffle(out) - return out - - -@torch.no_grad() -def build_buffer(engine, seqs, label_names, *, subsample, auroc_device, annotate_cds=False, batch_size=8, log=print): - """Stream seqs through engine -> ActivationBuffer (codes + dense + labels [+ cds instances]).""" - F = engine.n_features - Hd = engine.sae.pre_bias.shape[0] - dev = engine.device - S = subsample - code_buf = torch.zeros(S, F, dtype=torch.float16, device=auroc_device) - dense_buf = torch.zeros(S, Hd, dtype=torch.float16, device=auroc_device) - lab_buf = torch.zeros(S, len(label_names), dtype=torch.bool, device=auroc_device) - filled = 0 - for start in range(0, len(seqs), batch_size): - if filled >= S: - break - batch = seqs[start : start + batch_size] - id_lists, metas = [], [] - for kg, dna in batch: - tag = KINGDOM_TAGS[kg] - tids = engine.tokenize(tag) - id_lists.append(tids + engine.tokenize(dna)) - metas.append((tag, len(tids), kg, dna)) - with engine._lock: - hiddens = engine._forward_hidden(id_lists) - for h, (tag, tlen, kg, dna) in zip(hiddens, metas): - if h.shape[0] == 0 or filled >= S: - continue - hd = h.to(dev) - codes = engine.sae.encode(hd) - norm = h.float().norm(dim=-1).cpu().numpy() - T = codes.shape[0] - cds_mask = cds_frame = gene_starts = None - if annotate_cds and kg == "prok": - cds_mask, cds_frame, gene_starts = L.predict_cds(dna) - ctx = L.SeqContext( - text=(tag + dna)[:T], - tag_len=tlen, - dna=dna, - kingdom=kg, - hidden_norm=norm[:T], - cds_mask=cds_mask, - cds_frame=cds_frame, - gene_starts=gene_starts, - ) - lab = np.stack([L.LABELERS[n](ctx)[:T] for n in label_names], axis=1) - take = min(T, S - filled) - code_buf[filled : filled + take] = codes[:take].to(torch.float16).to(auroc_device) - dense_buf[filled : filled + take] = hd[:take].to(torch.float16).to(auroc_device) - lab_buf[filled : filled + take] = torch.from_numpy(lab[:take]).to(auroc_device) - filled += take - if (start // batch_size) % 10 == 0: - log(f" {start + len(batch)}/{len(seqs)} seqs | buf {filled}/{S}") - return ActivationBuffer( - codes=code_buf[:filled].cpu().numpy(), - dense=dense_buf[:filled].cpu().numpy(), - labels=lab_buf[:filled].cpu().numpy(), - label_names=list(label_names), - ) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py deleted file mode 100644 index e01a949142..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ /dev/null @@ -1,405 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Unified Evo2 SAE probing CLI. All scoring is sae.eval.probing (model-agnostic); -this driver only knows how to build/load Evo2 buffers and pick label sets. - - probe.py extract --out BUF [...] build an ActivationBuffer (needs the model) - probe.py auroc --acts BUF --labels .. per-feature AUROC table (prints) - probe.py annotate --acts BUF --out P assign each feature its best concept -> annotation parquet - probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) - probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense - probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) - probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs - any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) - probe.py loss-recovered [...] fidelity via sae.eval.loss_recovered (needs the model) - -Example end-to-end flow (7B / layer 26; $CKPT = MBridge dir, $SAE = trained SAE .pt): - - # 1. Build the probing buffer once: SAE codes + dense twin + per-token labels (needs the model) - python probe.py extract --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ - --fasta probe_set.fa --out buf.npz - - # 2. Score the buffer (no model): per-feature AUROC, then SAE-vs-dense linear probes - python probe.py auroc --acts buf.npz --labels motif_ATG,motif_stop,cds_coding,is_prok - python probe.py linear --acts buf.npz --labels cds_coding,is_prok - - # 3. Persist annotations (no model): each feature's best concept (incl. base_A/C/G/T) -> - # the feature-annotation parquet the engine/dashboard load via --feature-annotations - python probe.py annotate --acts buf.npz --out feature_annotations.parquet --min-auroc 0.85 - - # 4. User annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC, - # vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) - python probe.py domain-eval --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ - --fasta GRCh38_chr20.fa --track exon=refseq.gff3:exon --track cCRE=encode_ccre.bed - - # 5. SAE fidelity (loss recovered) — separate script, needs the model - python probe_loss_recovered.py --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 --fasta probe_set.fa -""" # noqa: D205 - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -import numpy as np -import torch - - -_HERE = Path(__file__).resolve().parent -sys.path.insert(0, str(_HERE)) -sys.path.insert(0, str(_HERE.parent)) -sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) # sparse_autoencoders/sae/src - -import labelers as L # noqa: E402 -from sae.eval.probing import ( # noqa: E402 - ActivationBuffer, - auroc_all, - auroc_vec, - best_single_train_test, - decode_eval, - fit_softmax, - split_indices, - standardize, -) - - -def _z(X, tr): - # Standardize X by the train-split mean/std (reuses sae.eval.probing.standardize). - mu, sd = standardize(X, tr) - return (X - mu) / sd - - -# ───────────────────────────────────────── buffer-only subcommands (no model) -def cmd_auroc(a): # noqa: D103 - buf = ActivationBuffer.load(a.acts) - dev = a.device - X = torch.from_numpy(buf.codes).to(dev).float() - names = [t for t in a.labels.split(",") if t in buf.name_idx] - Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) - au = auroc_all(X, Y).cpu().numpy() - print(f"{'label':18s} {'%pos':>6s} {'best AUROC':>10s} {'feature':>8s}") - for i, n in enumerate(names): - print( - f"{n:18s} {buf.labels[:, buf.name_idx[n]].mean():6.1%} {au[:, i].max():10.3f} {int(au[:, i].argmax()):8d}" - ) - - -def _eval_matrix(mat, buf, names, tr, te, dev, steps, wd): - X = torch.from_numpy(mat).to(dev).float() - Xz = _z(X, tr) - out = {} - from sae.eval.probing import fit_logreg - - for n in names: - ytr = torch.from_numpy(buf.labels[tr.numpy(), buf.name_idx[n]]).to(dev).float() - yte = torch.from_numpy(buf.labels[te.numpy(), buf.name_idx[n]]).to(dev) - if ytr.sum() in (0, len(ytr)) or yte.sum() == 0: - out[n] = (float("nan"), float("nan")) - continue - w, b = fit_logreg(Xz[tr], ytr, steps=steps, wd=wd) - out[n] = (best_single_train_test(Xz[tr], ytr, Xz[te], yte), auroc_vec((Xz[te] @ w + b).float(), yte)) - del X, Xz - torch.cuda.empty_cache() - return out - - -def cmd_linear(a): # noqa: D103 - buf = ActivationBuffer.load(a.acts) - dev = a.device - names = [t for t in a.labels.split(",") if t in buf.name_idx] - tr, te = split_indices(buf.codes.shape[0], a.test_frac, a.seed) - sae = _eval_matrix(buf.codes, buf, names, tr, te, dev, a.steps, a.weight_decay) - den = _eval_matrix(buf.dense, buf, names, tr, te, dev, a.steps, a.weight_decay) if buf.dense is not None else None - h = f"{'label':18s} {'%pos':>6s} | {'SAE single':>10s} {'SAE multi':>9s}" - if den: - h += f" | {'dense single':>12s} {'dense multi':>11s} | {'Δ':>7s}" - print(h) - for n in names: - pos = buf.labels[:, buf.name_idx[n]].mean() - ss, sm = sae[n] - row = f"{n:18s} {pos:6.1%} | {ss:10.3f} {sm:9.3f}" - if den: - ds, dm = den[n] - row += f" | {ds:12.3f} {dm:11.3f} | {ss - ds:+7.3f}" - print(row) - - -def cmd_codon_aa(a): # noqa: D103 - z = np.load(a.acts) - dev = a.device - codon = torch.from_numpy(z["codon"].astype(np.int64)).to(dev) - aa = torch.from_numpy(z["aa"].astype(np.int64)).to(dev) - codon_np = z["codon"].astype(np.int64) - ncod, naa = len(L.CODON_LIST), len(L.AA_LIST) - held = {"L": ["TTA", "TTG"], "S": ["AGT", "AGC"], "R": ["AGA", "AGG"]} - hidx = [L.CODON_TO_IDX[c] for v in held.values() for c in v] - print(f"{'matrix':6s} {'codon mAUROC':>12s} {'AA mAUROC':>10s} | family-disjoint recall L/S/R (chance)") - for nm in ("sae", "dense"): - if nm not in z.files: - continue - X = torch.from_numpy(z[nm]).to(dev).float() - Xz = (X - X.mean(0)) / (X.std(0) + 1e-6) - tr, te = split_indices(X.shape[0], a.test_frac, a.seed) - _, ca, _ = decode_eval(Xz[tr], codon[tr], Xz[te], codon[te], ncod, steps=a.steps, wd=a.weight_decay) - _, aaa, _ = decode_eval(Xz[tr], aa[tr], Xz[te], aa[te], naa, steps=a.steps, wd=a.weight_decay) - trn = torch.from_numpy(np.nonzero(~np.isin(codon_np, hidx))[0]).to(dev) - W, b = fit_softmax(Xz[trn], aa[trn], naa, steps=a.steps, wd=a.weight_decay) - rec = [] - for A, cods in held.items(): - m = np.isin(codon_np, [L.CODON_TO_IDX[c] for c in cods]) - pred = (Xz[torch.from_numpy(np.nonzero(m)[0]).to(dev)] @ W + b).argmax(1).cpu().numpy() - rec.append( - f"{A}={float((pred == L.AA_TO_IDX[A]).mean()):.2f}({float((aa == L.AA_TO_IDX[A]).float().mean()):.2f})" - ) - del X, Xz - torch.cuda.empty_cache() - print(f"{nm:6s} {ca:12.3f} {aaa:10.3f} | {' '.join(rec)}") - - -def cmd_annotate(a): - """Buffer -> feature-annotation parquet: each feature's best concept by AUROC + activation stats. - - The persist step (uses sae.eval.probing.annotate_features). Writes a feature_metadata-style - parquet — {feature_id, label, auroc, activation_freq, max_activation} — the engine/dashboard - load via --feature-annotations. Concepts default to all labels in the buffer (incl. base_*). - """ - import pyarrow as pa - import pyarrow.parquet as pq - from sae.eval.probing import annotate_features - - buf = ActivationBuffer.load(a.acts) - dev = a.device - names = [t for t in (a.labels.split(",") if a.labels else list(buf.label_names)) if t in buf.name_idx] - X = torch.from_numpy(buf.codes).to(dev).float() - Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) - ann = annotate_features(X, Y, names, min_auroc=a.min_auroc) - cols = {"feature_id": [], "label": [], "auroc": [], "activation_freq": [], "max_activation": []} - for r in ann: - col = X[:, r["feature_id"]] - cols["feature_id"].append(r["feature_id"]) - cols["label"].append(r["label"]) - cols["auroc"].append(r["auroc"]) - cols["activation_freq"].append(round(float((col > 0).float().mean()), 6)) - cols["max_activation"].append(round(float(col.max()), 4)) - pq.write_table(pa.table(cols), a.out, compression="snappy") - print(f"[annotate] {len(ann)} features labeled (AUROC >= {a.min_auroc}) over {len(names)} concepts -> {a.out}") - - -# ───────────────────────────────────────── model subcommands (need Evo2) -def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): - """Stream tiled windows through the SAE -> (code_buf[filled,F], lab{k:bool}, inst{k:long}, fmax[F]). - - Shared by euk-f1 and domain-eval: encodes each window (skipping the phylo-tag prefix) and - fills per-concept label masks (lab_keys) + instance ids (inst_keys). Buffers are trimmed to - the number of positions actually filled. - """ - adev, tlen = a.auroc_device, len(tag_ids) - code_buf = torch.zeros(tot, eng.n_features, dtype=torch.float16, device=adev) - lab = {k: torch.zeros(tot, dtype=torch.bool, device=adev) for k in lab_keys} - inst = {k: torch.full((tot,), -1, dtype=torch.long, device=adev) for k in inst_keys} - filled = 0 - for s0 in range(0, len(windows), a.batch_size): - batch = windows[s0 : s0 + a.batch_size] - with eng._lock: - for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): - if h.shape[0] == 0: - continue - codes = eng.sae.encode(h.to(a.device)) - take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) - if take <= 0: - continue - code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) - for k in lab: - lab[k][filled : filled + take] = torch.from_numpy(w["labels"][k][:take]).to(adev) - for k in inst: - inst[k][filled : filled + take] = torch.from_numpy(w["instances"][k][:take].astype(np.int64)).to( - adev - ) - filled += take - code_buf = code_buf[:filled] - for d in (lab, inst): - for k in d: - d[k] = d[k][:filled] - fmax = code_buf.max(0).values.float() if filled else torch.zeros(eng.n_features, device=adev) - return code_buf, lab, inst, fmax - - -def cmd_euk(a): - """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" - from euk_windows import build_windows - from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE - from sae.eval.probing import domain_f1 - - eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - windows, stats, tot, _ = build_windows(a.fasta, a.gff, a.seq_len, a.max_tokens, seed=a.seed) - print( - f"windows={len(windows)} tokens={tot} genes={stats['genes']} exons={stats['exons']} introns={stats['introns']}" - ) - tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) - code_buf, lab, inst, fmax = _encode_windows( - eng, windows, tag_ids, ("exon", "intron", "cds"), ("exon", "intron", "gene"), tot, a - ) - filled, adev = code_buf.shape[0], a.auroc_device - g = torch.Generator(device=adev).manual_seed(a.seed) - print(f"encoded {filled} positions\n{'concept':8s} {'domF1':>6s} {'null':>6s} {'ratio':>6s} {'%pos':>6s}") - for c, ic in {"exon": "exon", "intron": "intron", "cds": "gene"}.items(): - f1, _ = domain_f1(code_buf, fmax, lab[c], inst[ic]) - order = torch.randperm(filled, generator=g, device=adev) - f1n, _ = domain_f1(code_buf, fmax, lab[c][order], inst[ic][order]) - bf, nl = float(f1.max()), float(f1n.max()) - print(f"{c:8s} {bf:6.3f} {nl:6.3f} {bf / max(nl, 1e-9):6.2f} {float(lab[c].float().mean()):6.1%}") - - -def _parse_track_spec(spec): - """Parse a ``--track NAME=PATH[:GFF_FEATURE]`` spec -> (name, path, feature_type|None).""" - name, rest = spec.split("=", 1) - ftype = None - if ":" in rest: - head, tail = rest.rsplit(":", 1) - if "/" not in tail and "." not in tail: # a GFF feature type, not part of a path - rest, ftype = head, tail - return name, rest, ftype - - -def cmd_domain_eval(a): - """User-supplied annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC. - - Each ``--track NAME=PATH[:GFF_FEATURE]`` is one concept; its BED/GFF intervals are the - annotation instances (RefSeq/Rfam/JASPAR/ENCODE, or anything the user supplies). The SAE - annotates the windows, then per concept we report the best feature by instance-level - domain-F1 (precision-per-nt, recall-per-annotation) and — threshold-free — by AUROC. - """ - from annot_tracks import label_windows, load_track, read_fasta_dict - from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE - from sae.eval.probing import auroc_all, domain_f1 - - tracks = {} - for spec in a.track: - name, path, ftype = _parse_track_spec(spec) - tracks[name] = load_track(path, feature_type=ftype) - seqs = read_fasta_dict(a.fasta) - windows, stats = label_windows(seqs, tracks, a.seq_len, max_tokens=a.max_tokens) - concepts = stats["concepts"] - - eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) - code_buf, lab, inst, fmax = _encode_windows(eng, windows, tag_ids, concepts, concepts, stats["tokens"], a) - au = auroc_all(code_buf.float().to(a.device), torch.stack([lab[c] for c in concepts], 1).to(a.device)).cpu() - print(f"encoded {code_buf.shape[0]} positions across {len(concepts)} concept(s)") - print( - f"{'concept':14s} {'%pos':>6s} {'#inst':>6s} | " - f"{'domF1':>6s} {'@thr':>5s} {'feat':>7s} | {'AUROC':>6s} {'feat':>7s}" - ) - for i, c in enumerate(concepts): - f1, thr = domain_f1(code_buf, fmax, lab[c], inst[c]) - bi, ai = int(f1.argmax()), int(au[:, i].argmax()) - print( - f"{c:14s} {float(lab[c].float().mean()):6.1%} {stats['n_inst'][c]:6d} | " - f"{float(f1[bi]):6.3f} {float(thr[bi]):5.2f} {bi:7d} | {float(au[ai, i]):6.3f} {ai:7d}" - ) - - -def cmd_extract(a): # noqa: D103 - from evo2_buffer import build_buffer, sample_sequences - from evo2_sae.core import Evo2SAE - - eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() - label_names = list(L.LABELERS.keys()) - kingdoms = [k for k in a.kingdoms.split(",") if k] - seqs = sample_sequences(a.fasta, a.max_tokens, a.seq_len, kingdoms=kingdoms, seed=a.seed) - print(f"probe set: {len(seqs)} seqs (kingdoms={kingdoms})") - buf = build_buffer( - eng, - seqs, - label_names, - subsample=a.subsample, - auroc_device=a.auroc_device, - annotate_cds=a.annotate_cds, - batch_size=a.batch_size, - log=print, - ) - buf.save(a.out) - print(f"saved buffer -> {a.out} ({buf.codes.shape[0]} x {buf.codes.shape[1]}, dense {buf.dense.shape[1]})") - - -def _add_model_args(p, *, required=(), max_tokens=160_000): - """Shared model + encoding args for the model-backed subcommands (extract/euk-f1/domain-eval).""" - for arg in ("--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", *required): - p.add_argument(arg, required=True) - p.add_argument("--layer", type=int, required=True) - p.add_argument("--max-tokens", type=int, default=max_tokens) - p.add_argument("--seq-len", type=int, default=1024) - p.add_argument("--batch-size", type=int, default=8) - p.add_argument("--auroc-device", default="cuda:1") - - -def main(): # noqa: D103 - ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - sub = ap.add_subparsers(dest="cmd", required=True) - common = argparse.ArgumentParser(add_help=False) - common.add_argument("--device", default="cuda:0") - common.add_argument("--seed", type=int, default=0) - common.add_argument("--steps", type=int, default=400) - common.add_argument("--weight-decay", type=float, default=1e-2) - common.add_argument("--test-frac", type=float, default=0.4) - for name, fn, needs_labels in [ - ("auroc", cmd_auroc, True), - ("linear", cmd_linear, True), - ("codon-aa", cmd_codon_aa, False), - ]: - p = sub.add_parser(name, parents=[common]) - p.add_argument("--acts", required=True) - if needs_labels: - p.add_argument("--labels", required=True) - p.set_defaults(func=fn) - pan = sub.add_parser("annotate", parents=[common]) - pan.add_argument("--acts", required=True) - pan.add_argument("--out", required=True) - pan.add_argument( - "--labels", default=None, help="comma-separated concept subset; default = all labels in the buffer" - ) - pan.add_argument("--min-auroc", type=float, default=0.8) - pan.set_defaults(func=cmd_annotate) - pe = sub.add_parser("extract", parents=[common]) - _add_model_args(pe, required=("--out",), max_tokens=200_000) - pe.add_argument("--kingdoms", default="prok,euk") - pe.add_argument("--annotate-cds", action="store_true") - pe.add_argument("--subsample", type=int, default=50_000) - pe.set_defaults(func=cmd_extract) - pk = sub.add_parser("euk-f1", parents=[common]) - _add_model_args(pk, required=("--gff",)) - pk.add_argument("--organism", default="Human") - pk.set_defaults(func=cmd_euk) - pd = sub.add_parser("domain-eval", parents=[common]) - _add_model_args(pd) - pd.add_argument( - "--track", - action="append", - required=True, - metavar="NAME=PATH[:GFF_FEATURE]", - help="annotation track; BED or GFF intervals = instances of concept NAME. Repeatable " - "(e.g. --track exon=refseq.gff3:exon --track tfbs=jaspar.bed --track cCRE=encode.bed).", - ) - pd.add_argument("--organism", default="Human") - pd.set_defaults(func=cmd_domain_eval) - args = ap.parse_args() - torch.set_grad_enabled(False) - args.func(args) - - -if __name__ == "__main__": - main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py deleted file mode 100644 index 1994b9c6b9..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Loss recovered (fidelity) for the Evo2 SAE — reuses sae.eval.loss_recovered (Jared Wilber). - - loss_recovered = 1 - (CE_sae - CE_clean) / (CE_zero - CE_clean) - -We just provide Evo2-specific callables to his generic evaluator: - - get_hiddens(batch): capture the layer-`L` residual via a forward hook - - compute_ce(batch, override): full-model next-token CE, optionally patching the - layer-`L` output with `override` (zero-ablation or SAE reconstruction) -The SAE reconstruction is DENORMALIZED per token (normalize_input) so it is in the -raw residual space the layer actually emits. -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -import torch -import torch.nn as nn -import torch.nn.functional as Fn - - -_HERE = Path(__file__).resolve().parent -sys.path.insert(0, str(_HERE)) -sys.path.insert(0, str(_HERE.parent)) - -from evo2_buffer import sample_sequences # noqa: E402 -from evo2_sae.core import Evo2SAE # noqa: E402 -from sae.eval.loss_recovered import evaluate_loss_recovered # noqa: E402 (Jared's code) - - -KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} - - -class SAEWrap(nn.Module): - """sae.forward(x[N,H]) -> (recon, codes) in RAW residual space (denormalized).""" - - def __init__(self, sae): # noqa: D107 - super().__init__() - self.sae = sae - - def forward(self, x): # noqa: D102 - s = self.sae - codes = s.encode(x) # encode normalizes internally if normalize_input - recon = s.decoder(codes) + s.pre_bias - if getattr(s, "normalize_input", False): - mu = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) + 1e-8 - recon = recon * std + mu - return recon, codes - - -class L26Hook: # noqa: D101 - def __init__(self): # noqa: D107 - self.mode = "off" # off | capture | replace - self.override = None - self.captured = None - - def __call__(self, module, inp, output): # noqa: D102 - hs = output[0] if isinstance(output, tuple) else output - if self.mode == "replace" and self.override is not None: - new = self.override.to(hs.dtype) - return (new, *output[1:]) if isinstance(output, tuple) else new - if self.mode == "capture": - self.captured = hs.detach() - return output - - -def main(): # noqa: D103 - ap = argparse.ArgumentParser() - ap.add_argument("--evo2-ckpt-dir", required=True) - ap.add_argument("--sae-checkpoint", required=True) - ap.add_argument("--layer", type=int, required=True) - ap.add_argument("--fasta", required=True) - ap.add_argument("--n-seqs", type=int, default=80) - ap.add_argument("--seq-len", type=int, default=1024) - ap.add_argument("--device", default="cuda:0") - ap.add_argument("--seed", type=int, default=0) - args = ap.parse_args() - torch.set_grad_enabled(False) - dev = args.device - - engine = Evo2SAE(args.evo2_ckpt_dir, args.sae_checkpoint, args.layer, device=dev).load() - from megatron.core.utils import unwrap_model - - gen = engine._ensure_gen_model() - layer = unwrap_model(gen).decoder.layers[args.layer] - hook = L26Hook() - layer.register_forward_hook(hook) - - pairs = sample_sequences( - args.fasta, args.n_seqs * args.seq_len, args.seq_len, kingdoms=["prok", "euk"], seed=args.seed - )[: args.n_seqs] - batches = [] - for kingdom, dna in pairs: - ids = engine.tokenize(KINGDOM_TAGS[kingdom] + dna) - if len(ids) > 4: - batches.append(torch.tensor([ids], dtype=torch.long, device=dev)) - - def fwd(ids): - return gen(input_ids=ids, position_ids=None, attention_mask=None, labels=None, runtime_gather_output=True) - - def get_hiddens(batch): - hook.mode = "capture" - fwd(batch) - hook.mode = "off" - return hook.captured # [S, 1, H] - - def compute_ce(batch, override): - if override is None: - hook.mode = "off" - else: - hook.mode = "replace" - hook.override = override - logits = fwd(batch) - hook.mode = "off" - hook.override = None - lg = logits[0, :-1].float() # [S-1, V] - tgt = batch[0, 1:] - ce = Fn.cross_entropy(lg, tgt, reduction="sum") - return float(ce), int(tgt.numel()) - - with engine._lock: - res = evaluate_loss_recovered(SAEWrap(engine.sae), batches, get_hiddens, compute_ce, device=dev) - print("\n==== Evo2 7B layer-%d SAE — loss recovered ====" % args.layer) - print(res) - print( - f"loss_recovered = {res.loss_recovered:.3f} " - f"(CE clean={res.ce_original:.3f}, SAE={res.ce_sae:.3f}, zero={res.ce_zero:.3f}, n_tok={res.n_tokens})" - ) - - -if __name__ == "__main__": - main() From e57d17473fd7a051e9e2cba4df710678424e8710 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 04:13:59 +0000 Subject: [PATCH 12/14] evo2 eval: condense annot_tracks parsers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Shared _open() (gzip-or-open), and merge _parse_bed + _parse_gff into one _intervals() that picks column indices + the 0/1-based offset by format. read_fasta_dict tightened. Behavior unchanged — test_annot_tracks still green. 161 -> 146 lines. Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/annot_tracks.py | 69 ++++++++----------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py index ca9c2de93f..88994c2113 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py @@ -33,73 +33,58 @@ from collections import defaultdict +def _open(path): + """Open a path for text reading, transparently handling ``.gz``.""" + return (gzip.open if str(path).endswith(".gz") else open)(path, "rt") + + def read_fasta_dict(path: str) -> dict[str, str]: """Read a (multi-record) FASTA into ``{seq_id: sequence}`` (``.gz`` transparent). - ``seq_id`` is the first whitespace-delimited token of the header, so it matches the - chrom/seqid column of BED/GFF tracks. + ``seq_id`` is the first whitespace token of the header — matches the chrom/seqid of BED/GFF. """ - opener = gzip.open if str(path).endswith(".gz") else open seqs: dict[str, str] = {} name, parts = None, [] - with opener(path, "rt") as fh: + with _open(path) as fh: for line in fh: line = line.rstrip() if line.startswith(">"): if name is not None: seqs[name] = "".join(parts) - header = line[1:].strip().split() - name, parts = (header[0] if header else f"seq_{len(seqs)}"), [] - else: + tok = line[1:].split() + name, parts = (tok[0] if tok else f"seq_{len(seqs)}"), [] + elif line: parts.append(line) if name is not None: seqs[name] = "".join(parts) return seqs -def _parse_bed(path): - """Yield (chrom, start0, end0) from a BED file (0-based, half-open — used as-is).""" - opener = gzip.open if str(path).endswith(".gz") else open - with opener(path, "rt") as fh: - for line in fh: - if not line.strip() or line.startswith(("#", "track", "browser")): - continue - f = line.split("\t") - if len(f) < 3: - continue - yield f[0], int(f[1]), int(f[2]) - +def _intervals(path, fmt, feature_type=None): + """Yield (seqid, start0, end0) from BED (0-based) or GFF/GTF (1-based -> 0-based half-open). -def _parse_gff(path, feature_type=None): - """Yield (seqid, start0, end0) from GFF/GTF (1-based inclusive -> 0-based half-open).""" - opener = gzip.open if str(path).endswith(".gz") else open - with opener(path, "rt") as fh: + GFF rows are optionally filtered to a single column-3 ``feature_type`` (e.g. ``exon``). + """ + chrom_i, start_i, end_i, off = (0, 1, 2, 0) if fmt == "bed" else (0, 3, 4, 1) + with _open(path) as fh: for line in fh: - if line.startswith("#") or not line.strip(): - continue - f = line.rstrip("\n").split("\t") - if len(f) < 5: + if not line.strip() or line[0] == "#" or line.startswith(("track", "browser")): continue - if feature_type and f[2] != feature_type: + f = line.split("\t") + if len(f) <= end_i or (feature_type and fmt != "bed" and f[2] != feature_type): continue - yield f[0], int(f[3]) - 1, int(f[4]) + yield f[chrom_i], int(f[start_i]) - off, int(f[end_i]) -def load_track(path: str, feature_type: str | None = None, fmt: str | None = None) -> dict[str, list[tuple[int, int]]]: - """Load one annotation track into ``{seqid: [(start0, end0), ...]}`` (0-based half-open). +def load_track(path, feature_type=None, fmt=None): + """Load one annotation track into ``{seqid: [(start0, end0), ...]}`` (0-based half-open, sorted). - Args: - path: BED or GFF/GTF file (``.gz`` ok). - feature_type: GFF only — keep just this column-3 type (e.g. ``"exon"``, ``"ncRNA"``). - fmt: ``"bed"`` / ``"gff"``; inferred from the extension when omitted. - - Returns: - Intervals grouped by sequence id, each sorted. Every interval is one instance. + ``fmt`` (``bed``/``gff``) is inferred from the extension; ``feature_type`` filters GFF column 3. + Every interval is one annotation instance. """ - fmt = fmt or ("gff" if str(path).endswith((".gff", ".gff3", ".gtf", ".gff.gz", ".gff3.gz")) else "bed") - rows = _parse_gff(path, feature_type) if fmt == "gff" else _parse_bed(path) - by_seq: dict[str, list[tuple[int, int]]] = defaultdict(list) - for chrom, s, e in rows: + fmt = fmt or ("gff" if str(path).replace(".gz", "").endswith((".gff", ".gff3", ".gtf")) else "bed") + by_seq = defaultdict(list) + for chrom, s, e in _intervals(path, fmt, feature_type): if e > s: by_seq[chrom].append((s, e)) return {k: sorted(v) for k, v in by_seq.items()} From 0c76d3857b2d4ab246a06e3f3cc42b0d38d4e786 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 04:17:54 +0000 Subject: [PATCH 13/14] evo2 eval: table-drive consensus motifs + add labeler CPU tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 4 plain motifs (ATG/stop/TATA/RBS_SD) were near-identical _dna_mask(_starts|_spans(...)) functions -> one _MOTIFS table + loop. Adds test_labelers.py (motif positions, base_A/C/G/T, tag-prefix offset) — labelers had no unit test before. Registry unchanged (23 labelers). Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/labelers.py | 28 ++++------ .../recipes/evo2/tests/test_labelers.py | 53 +++++++++++++++++++ 2 files changed, 62 insertions(+), 19 deletions(-) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index 83d5b6ed51..2a92cb86dc 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -201,25 +201,15 @@ def _spans(dna: str, pattern: str) -> np.ndarray: return out -@labeler("motif_ATG") -def _atg(ctx): - return _dna_mask(ctx, _starts(ctx.dna, r"ATG")) - - -@labeler("motif_stop") -def _stop(ctx): - return _dna_mask(ctx, _starts(ctx.dna, r"TAA|TAG|TGA")) - - -@labeler("motif_TATA") -def _tata(ctx): - return _dna_mask(ctx, _spans(ctx.dna, r"TATA[AT]A")) - - -@labeler("motif_RBS_SD") -def _rbs(ctx): - # Shine-Dalgarno ribosome-binding site - return _dna_mask(ctx, _spans(ctx.dna, r"AGGAGG")) +# Consensus motifs: (name, matcher, regex) — `_starts` marks the match start, `_spans` the whole match. +_MOTIFS = [ + ("motif_ATG", _starts, r"ATG"), + ("motif_stop", _starts, r"TAA|TAG|TGA"), + ("motif_TATA", _spans, r"TATA[AT]A"), + ("motif_RBS_SD", _spans, r"AGGAGG"), # Shine-Dalgarno ribosome-binding site +] +for _name, _match, _pat in _MOTIFS: + labeler(_name)(lambda ctx, m=_match, p=_pat: _dna_mask(ctx, m(ctx.dna, p))) # --------------------------------------------------- complex / consensus (refine later) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py new file mode 100644 index 0000000000..e99cf236b2 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for the per-token labelers (pure masks, no model).""" + +import sys +from pathlib import Path + +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from labelers import LABELERS, SeqContext + + +def _ctx(dna, tag_len=0): + text = "X" * tag_len + dna + return SeqContext(text=text, tag_len=tag_len, dna=dna, kingdom="prok", hidden_norm=np.zeros(tag_len + len(dna))) + + +def test_consensus_motifs_fire_at_match_positions(): + """The table-driven motifs mark the right positions (ATG/stop = start, TATA = span).""" + ctx = _ctx("ATGTAACGT") # ATG @0 ; TAA (stop) @3 + assert list(LABELERS["motif_ATG"](ctx).nonzero()[0]) == [0] + assert list(LABELERS["motif_stop"](ctx).nonzero()[0]) == [3] + assert list(LABELERS["motif_TATA"](_ctx("TATAAA")).nonzero()[0]) == [0, 1, 2, 3, 4, 5] # spans the match + + +def test_base_labelers_fire_per_nucleotide(): + """base_A/C/G/T each fire exactly on their nucleotide.""" + ctx = _ctx("ACGTAA") + assert list(LABELERS["base_A"](ctx).nonzero()[0]) == [0, 4, 5] + assert list(LABELERS["base_G"](ctx).nonzero()[0]) == [2] + + +def test_tag_prefix_is_unlabeled(): + """Sequence-derived labels are False over the leading phylo-tag tokens.""" + ctx = _ctx("ATG", tag_len=2) # tokens: [tag, tag, A, T, G] + m = LABELERS["motif_ATG"](ctx) + assert len(m) == 5 and not m[:2].any() and m[2] # ATG starts at DNA pos 0 -> token 2 From f794c31e2bde7768ecc6e1502b63f210bf68d867 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 12 Jun 2026 04:51:37 +0000 Subject: [PATCH 14/14] =?UTF-8?q?fix(eval-labels):=20CodeRabbit=20review?= =?UTF-8?q?=20=E2=80=94=20splice=5Fdonor=20consensus=20+=20multi-parent=20?= =?UTF-8?q?GFF=20rows?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - splice_donor regex now requires the terminal T (GT[AG]AGT consensus), not GT[AG]AG - euk_windows: split comma-separated GFF Parent IDs so a shared exon/CDS attaches to every transcript Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/euk_windows.py | 10 ++++++---- .../recipes/evo2/scripts/labelers.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py index 335922d000..0f381df84b 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -67,11 +67,13 @@ def parse_gff(gff_path): tx_gene[tid] = a.get("Parent", "").replace("gene:", "") tx_biotype[tid] = a.get("biotype", "") elif typ == "exon": - tid = a.get("Parent", "").replace("transcript:", "") - tx_exon[tid].append((s, e)) + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_exon[tid].append((s, e)) elif typ == "CDS": - tid = a.get("Parent", "").replace("transcript:", "") - tx_cds[tid].append((s, e)) + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_cds[tid].append((s, e)) genes = {} for tid, gid in tx_gene.items(): if gene_biotype.get(gid) != "protein_coding" or tx_biotype.get(tid) != "protein_coding": diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index 2a92cb86dc..3243742d75 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -225,7 +225,7 @@ def _kozak(ctx): @labeler("splice_donor", complex=True) def _sd(ctx): # 5' donor consensus GT(A/G)AGT — mark the GT - return _dna_mask(ctx, _starts(ctx.dna, r"GT[AG]AG")) + return _dna_mask(ctx, _starts(ctx.dna, r"GT[AG]AGT")) @labeler("splice_acceptor", complex=True)