Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation.

Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer
the SAE was trained on. Workflow: encode a sequence to find its active features, then for a
**target** feature sweep the clamp strength (dose-response) and for **control** features apply
the same clamp (selectivity), each time comparing the steered continuation to the baseline.

GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test.

python steer.py --evo2-ckpt-dir <mbridge> --sae-checkpoint <sae.pt> --layer 26 \
--sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200

Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers
the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in
``evo2_sae_infer.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a
``decode_only`` flag) is a planned follow-up.
"""

from __future__ import annotations

import argparse
import sys
from contextlib import nullcontext
from pathlib import Path


_HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(_HERE))
sys.path.insert(0, str(_HERE.parent))
sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src"))

from sae.steering import steer # noqa: E402


def _divergence(a: str, b: str):
"""Return (first differing index, fraction of differing chars) over the shared prefix length."""
n = min(len(a), len(b))
first = next((i for i in range(n) if a[i] != b[i]), n)
diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n)
return first, diff


def main():
"""Encode a sequence, then steer a target feature (dose-response) + control features (selectivity)."""
p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).")
p.add_argument("--evo2-ckpt-dir", required=True)
p.add_argument("--sae-checkpoint", required=True)
p.add_argument("--layer", type=int, required=True)
p.add_argument("--sequence", required=True)
p.add_argument("--organism", default="None (raw DNA)")
p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).")
p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).")
p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.")
p.add_argument("--n-tokens", type=int, default=60)
p.add_argument("--device", default="cuda")
a = p.parse_args()

from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100
from evo2_sae_infer.core import Evo2SAE, clean_dna # noqa: E402, RUF100
from megatron.core.utils import unwrap_model # noqa: E402, RUF100

eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load()

# 1. Encode -> the sequence's most-active features (pick a target if not given).
codes = eng.encode(a.sequence)
vals, ids = codes.max(0).values.topk(10)
print(f"top features on {a.sequence[:24]}...:")
target = a.feature
for v, i in zip(vals.tolist(), ids.tolist()):
lab = eng.labels.get(int(i))
print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}")
if target is None and lab:
target = int(i)
controls = [int(c) for c in a.controls.split(",") if c.strip()]
strengths = [float(s) for s in a.strengths.split(",")]

# 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt.
comp = eng._ensure_engine()
prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence)
layer_mod = unwrap_model(comp.model).decoder.layers[a.layer]

def gen(clamps):
ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext()
with ctx:
out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1)
return clean_dna(INF._unwrap_result(out[0]).generated_text)

base = gen({})
print(f"\nbaseline: {base[:60]}")
print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===")
for s in strengths:
steered = gen({target: s})
first, diff = _divergence(base, steered)
print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}")

if controls:
s = strengths[-1]
print(f"\n=== selectivity: control features clamped to {s} ===")
for c in controls:
steered = gen({c: s})
first, diff = _divergence(base, steered)
print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta.

A forward hook on the layer the SAE was trained on: it re-encodes the layer output through
the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the
activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the
SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves
the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``)
and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model
with vs. without the hook.
"""

from contextlib import contextmanager
from typing import Dict

import torch


def clamp_hook(sae, clamps: Dict[int, float]):
"""Build a forward hook that clamps ``{feature_idx: value}`` via the delta method.

The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's
output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative
value reverses its decoder direction. Works whether the module returns a tensor or a tuple
whose first element is the hidden state.

Args:
sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``,
and ``top_k``.
clamps: Map of feature index -> absolute code value to force at every position.

Returns:
A ``register_forward_hook``-compatible ``hook(module, inputs, output)``.
"""
items = [(int(f), float(v)) for f, v in clamps.items()]

def hook(module, inputs, output):
h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None)
dtype, shape = h.dtype, h.shape
h_flat = h.reshape(-1, h.shape[-1]).float()
with torch.no_grad():
pre_act, info = sae.encode_pre_act(h_flat)
codes = torch.relu(pre_act)
kvals, kidx = torch.topk(codes, sae.top_k, dim=-1)
codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals)
codes_clamped = codes_orig.clone()
for f, v in items:
codes_clamped[:, f] = v
delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info)
h_out = (h_flat + delta).to(dtype).reshape(shape)
return (h_out, *rest) if rest is not None else h_out

return hook


@contextmanager
def steer(module, sae, clamps: Dict[int, float]):
"""Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it."""
handle = module.register_forward_hook(clamp_hook(sae, clamps))
try:
yield
finally:
handle.remove()
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CPU tests for sae.steering: the delta-clamp hook adds exactly decode(clamped) - decode(orig)."""

import torch
from sae.architectures import TopKSAE
from sae.steering import clamp_hook, steer
from torch import nn


def _sae():
torch.manual_seed(0)
return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False)


def test_no_clamp_is_a_noop():
"""An empty clamp map leaves the activation unchanged."""
sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8)
with steer(m, sae, {}):
out = m(x)
assert torch.allclose(out, x, atol=1e-5)


def test_clamp_adds_decoder_delta():
"""Clamping a feature shifts the activation by exactly decode(clamped) - decode(orig)."""
sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8)
with torch.no_grad():
pre, info = sae.encode_pre_act(x.float())
codes = torch.relu(pre)
kv, ki = torch.topk(codes, sae.top_k, dim=-1)
co = torch.zeros_like(codes).scatter(-1, ki, kv)
cc = co.clone()
cc[:, 3] = 5.0
expected = x + (sae.decode(cc, info) - sae.decode(co, info))
with steer(m, sae, {3: 5.0}):
out = m(x)
assert torch.allclose(out, expected, atol=1e-4)


def test_tuple_output_first_element_steered_rest_preserved():
"""When the module returns a tuple, only the hidden state (elem 0) is steered."""

class M(nn.Module):
def forward(self, x):
return (x, "meta")

sae, x = _sae(), torch.randn(3, 8)
m = M()
handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0}))
out = m(x)
handle.remove()
assert isinstance(out, tuple)
assert out[1] == "meta"
assert out[0].shape == x.shape
assert not torch.allclose(out[0], x) # the clamp moved it
Loading