From 8bcbed59750d22ef7c06b04934a81478c863dbb8 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Thu, 18 Jun 2026 16:57:44 +0000 Subject: [PATCH 1/9] feat: add DeepSeek-V2 architecture adapter (MLA, V2/V2-Lite, complex RoPE) Closes #1400. DeepSeek-V2, V2-Lite, and Coder-V2 all use DeepseekV2ForCausalLM. This adds a bridge adapter covering three V2-specific differences from V3: 1. Complex-exponential RoPE: V2's rotary embedding returns freqs_cis (a complex tensor via torch.polar) rather than a (cos, sin) tuple. - RotaryEmbeddingBridge.forward() now passes complex tensors through without raising, leaving them for the attention bridge to consume. - MLAAttentionBridge.forward() detects complex position_embeddings and dispatches to a new _apply_rotary_complex() helper that mirrors DeepSeek-V2's apply_rotary_emb (view_as_complex, multiply, flatten). 2. Optional Q LoRA path: V2-Lite sets q_lora_rank=None, skipping q_a_proj/q_a_layernorm/q_b_proj and using q_proj directly instead. All three Q-path submodules are marked optional=True in the adapter; q_a_layernorm uses GeneralizedComponent (which already supports optional) rather than RMSNormalizationBridge. MLAAttentionBridge already branches on q_lora_rank at runtime. 3. Gate not hookable: DeepseekV2Moe.forward() routes via nn.functional.linear(..., self.gate.weight) rather than self.gate(hidden_states), so the gate module's forward() is never called and bridge hooks cannot fire. The gate is omitted from MoEBridge submodules; shared_experts uses __call__ and hooks fine. Files changed: - supported_architectures/deepseek_v2.py (new) - supported_architectures/__init__.py: register adapter - factories/architecture_adapter_factory.py: map DeepseekV2ForCausalLM - generalized_components/mla_attention.py: complex RoPE support - generalized_components/rotary_embedding.py: complex tensor pass-through - tests/integration/model_bridge/test_deepseek_v2_adapter.py (new, 17 tests) --- .../model_bridge/test_deepseek_v2_adapter.py | 190 ++++++++++++++++++ .../factories/architecture_adapter_factory.py | 2 + .../generalized_components/mla_attention.py | 50 ++++- .../rotary_embedding.py | 7 +- .../supported_architectures/__init__.py | 4 + .../supported_architectures/deepseek_v2.py | 127 ++++++++++++ 6 files changed, 374 insertions(+), 6 deletions(-) create mode 100644 tests/integration/model_bridge/test_deepseek_v2_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/deepseek_v2.py diff --git a/tests/integration/model_bridge/test_deepseek_v2_adapter.py b/tests/integration/model_bridge/test_deepseek_v2_adapter.py new file mode 100644 index 000000000..22fef4d7e --- /dev/null +++ b/tests/integration/model_bridge/test_deepseek_v2_adapter.py @@ -0,0 +1,190 @@ +"""Integration tests for DeepSeek V2 architecture adapter. + +Covers two distinct variants of DeepseekV2ForCausalLM: +- V2-full (q_lora_rank set): Q is compressed via two-stage LoRA projection. +- V2-Lite (q_lora_rank=None): Q uses a direct linear projection; no compression. +""" + +import tempfile + +import pytest +import torch +from transformers import AutoTokenizer, DeepseekV2Config, DeepseekV2ForCausalLM + +from transformer_lens.model_bridge.bridge import TransformerBridge + + +def _make_bridge(q_lora_rank): + """Build a tiny DeepseekV2 bridge with the given q_lora_rank (None = V2-Lite).""" + cfg = DeepseekV2Config( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=4, + num_attention_heads=8, + q_lora_rank=q_lora_rank, + kv_lora_rank=32, + qk_nope_head_dim=16, + qk_rope_head_dim=8, + v_head_dim=16, + vocab_size=1000, + first_k_dense_replace=1, + n_routed_experts=8, + n_shared_experts=2, + num_experts_per_tok=2, + max_position_embeddings=128, + moe_intermediate_size=256, + ) + hf_model = DeepseekV2ForCausalLM(cfg) + with tempfile.TemporaryDirectory() as tmpdir: + hf_model.save_pretrained(tmpdir) + tok = AutoTokenizer.from_pretrained("gpt2") + tok.save_pretrained(tmpdir) + return TransformerBridge.boot_transformers(tmpdir, device="cpu") + + +@pytest.fixture(scope="module") +def tiny_deepseek_v2_bridge(): + """V2-full: q_lora_rank=64 — two-stage Q compression (same as V3).""" + return _make_bridge(q_lora_rank=64) + + +@pytest.fixture(scope="module") +def tiny_deepseek_v2_lite_bridge(): + """V2-Lite: q_lora_rank=None — direct Q projection, no LoRA compression.""" + return _make_bridge(q_lora_rank=None) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _tokens(): + return torch.tensor([[1, 2, 3, 4]]) + + +# --------------------------------------------------------------------------- +# V2-full tests +# --------------------------------------------------------------------------- + +class TestDeepSeekV2BridgeCreation: + def test_block_count(self, tiny_deepseek_v2_bridge): + assert len(tiny_deepseek_v2_bridge.blocks) == 4 + + def test_has_embed_unembed_ln_final(self, tiny_deepseek_v2_bridge): + assert hasattr(tiny_deepseek_v2_bridge, "embed") + assert hasattr(tiny_deepseek_v2_bridge, "unembed") + assert hasattr(tiny_deepseek_v2_bridge, "ln_final") + + def test_attention_is_mla(self, tiny_deepseek_v2_bridge): + from transformer_lens.model_bridge.generalized_components.mla_attention import ( + MLAAttentionBridge, + ) + assert isinstance(tiny_deepseek_v2_bridge.blocks[0].attn, MLAAttentionBridge) + + +class TestDeepSeekV2ForwardPass: + def test_forward_returns_correct_shape(self, tiny_deepseek_v2_bridge): + tokens = _tokens() + with torch.no_grad(): + out = tiny_deepseek_v2_bridge(tokens) + assert out.shape == (1, 4, 1000) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_forward_matches_hf(self, tiny_deepseek_v2_bridge): + tokens = _tokens() + hf_model = tiny_deepseek_v2_bridge.original_model + with torch.no_grad(): + bridge_out = tiny_deepseek_v2_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 0.15, f"Bridge vs HF max diff = {max_diff}" + + +class TestDeepSeekV2DenseVsMoELayers: + def test_dense_layer_has_no_moe_hooks(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert not any("blocks.0.mlp.gate" in k for k in cache) + assert not any("blocks.0.mlp.shared_experts" in k for k in cache) + + def test_moe_layer_has_shared_expert_hooks(self, tiny_deepseek_v2_bridge): + # DeepseekV2Moe.forward() routes via nn.functional.linear(..., self.gate.weight) + # directly — not self.gate(hidden_states) — so the gate module's forward() is + # never called and its bridge hooks cannot fire. shared_experts IS called via + # __call__, so GatedMLPBridge hooks fire correctly. + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert not any("blocks.1.mlp.gate" in k for k in cache), ( + "gate hooks should not appear — gate is called via functional.linear, not forward()" + ) + assert any("blocks.1.mlp.shared_experts" in k for k in cache) + + def test_all_layers_have_mlp_hooks(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + for i in range(4): + assert f"blocks.{i}.mlp.hook_in" in cache + assert f"blocks.{i}.mlp.hook_out" in cache + assert not torch.isnan(cache[f"blocks.{i}.mlp.hook_out"]).any() + + +class TestDeepSeekV2AttentionHooks: + def test_attn_hooks_fire_all_layers(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + for i in range(4): + assert f"blocks.{i}.attn.hook_in" in cache + assert f"blocks.{i}.attn.hook_out" in cache + + def test_mla_latent_hooks_fire(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert any("hook_q_latent" in k for k in cache) + assert any("hook_kv_latent" in k for k in cache) + + +# --------------------------------------------------------------------------- +# V2-Lite tests (q_lora_rank=None — direct q_proj, no compression) +# --------------------------------------------------------------------------- + +class TestDeepSeekV2LiteBridgeCreation: + def test_block_count(self, tiny_deepseek_v2_lite_bridge): + assert len(tiny_deepseek_v2_lite_bridge.blocks) == 4 + + def test_attention_is_mla(self, tiny_deepseek_v2_lite_bridge): + from transformer_lens.model_bridge.generalized_components.mla_attention import ( + MLAAttentionBridge, + ) + assert isinstance(tiny_deepseek_v2_lite_bridge.blocks[0].attn, MLAAttentionBridge) + + +class TestDeepSeekV2LiteForwardPass: + def test_forward_returns_correct_shape(self, tiny_deepseek_v2_lite_bridge): + tokens = _tokens() + with torch.no_grad(): + out = tiny_deepseek_v2_lite_bridge(tokens) + assert out.shape == (1, 4, 1000) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_forward_matches_hf(self, tiny_deepseek_v2_lite_bridge): + tokens = _tokens() + hf_model = tiny_deepseek_v2_lite_bridge.original_model + with torch.no_grad(): + bridge_out = tiny_deepseek_v2_lite_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 0.15, f"V2-Lite bridge vs HF max diff = {max_diff}" + + +class TestDeepSeekV2LiteNoQLatentHook: + def test_hook_q_latent_absent_without_q_lora_rank(self, tiny_deepseek_v2_lite_bridge): + """V2-Lite skips Q compression — hook_q_latent should not fire.""" + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + assert not any("hook_q_latent" in k for k in cache) + + def test_hook_kv_latent_still_fires(self, tiny_deepseek_v2_lite_bridge): + """KV compression is always present regardless of q_lora_rank.""" + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + assert any("hook_kv_latent" in k for k in cache) + + def test_all_layers_produce_non_nan(self, tiny_deepseek_v2_lite_bridge): + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + for i in range(4): + assert not torch.isnan(cache[f"blocks.{i}.attn.hook_out"]).any() diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 49dd134f7..8869a7be1 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -16,6 +16,7 @@ BloomArchitectureAdapter, CodeGenArchitectureAdapter, CohereArchitectureAdapter, + DeepSeekV2ArchitectureAdapter, DeepSeekV3ArchitectureAdapter, FalconArchitectureAdapter, Gemma1ArchitectureAdapter, @@ -78,6 +79,7 @@ "BloomForCausalLM": BloomArchitectureAdapter, "CodeGenForCausalLM": CodeGenArchitectureAdapter, "CohereForCausalLM": CohereArchitectureAdapter, + "DeepseekV2ForCausalLM": DeepSeekV2ArchitectureAdapter, "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter, "FalconForCausalLM": FalconArchitectureAdapter, "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version diff --git a/transformer_lens/model_bridge/generalized_components/mla_attention.py b/transformer_lens/model_bridge/generalized_components/mla_attention.py index 18a770480..f472e1bd1 100644 --- a/transformer_lens/model_bridge/generalized_components/mla_attention.py +++ b/transformer_lens/model_bridge/generalized_components/mla_attention.py @@ -47,6 +47,31 @@ def _apply_rotary_pos_emb( return q_embed, k_embed +def _apply_rotary_complex( + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding via complex multiplication (DeepSeek-V2 style). + + DeepSeek-V2 uses ``freqs_cis = torch.polar(ones, freqs)`` (complex exponentials) + instead of the standard (cos, sin) pair. This matches the V2 HF implementation of + ``apply_rotary_emb``. + + Args: + q: Query rope portion [batch, heads, seq, rope_dim]. + k: Key rope portion [batch, 1, seq, rope_dim]. + freqs_cis: Complex rotary frequencies [batch, seq, rope_dim // 2]. + + Returns: + Tuple of rotated (q, k) tensors with same dtype and shape as inputs. + """ + freqs = freqs_cis.unsqueeze(1) # [batch, 1, seq, rope_dim // 2] + q_c = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) + k_c = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) + q_rot = torch.view_as_real(q_c * freqs.to(q_c.device)).flatten(3).type_as(q) + k_rot = torch.view_as_real(k_c * freqs.to(k_c.device)).flatten(3).type_as(k) + return q_rot, k_rot + + class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): """Bridge for DeepSeek's Multi-Head Latent Attention (MLA). @@ -176,20 +201,31 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim) # --- RoPE --- + # DeepSeek-V2 passes a complex freqs_cis tensor; V3 passes a (cos, sin) tuple. + # Detect the format and apply the appropriate rotation. + cos = sin = None if position_embeddings is not None: position_embeddings = self._apply_position_embedding_hooks(position_embeddings) - cos, sin = position_embeddings + if isinstance(position_embeddings, torch.Tensor) and position_embeddings.is_complex(): + # V2-style: complex exponential freqs_cis + q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, position_embeddings) + else: + cos, sin = position_embeddings + q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) elif self._rotary_emb is not None: # Fallback: compute from rotary_emb if position_embeddings not passed position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) - cos, sin = self._rotary_emb(hidden_states, position_ids) + emb = self._rotary_emb(hidden_states, position_ids) + if isinstance(emb, torch.Tensor) and emb.is_complex(): + q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, emb) + else: + cos, sin = emb + q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) else: raise ValueError( "MLAAttentionBridge requires position_embeddings or set_rotary_emb() " "to be called before forward." ) - - q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) q_rot = self.hook_rot_q(q_rot) k_rot = self.hook_rot_k(k_rot) @@ -209,7 +245,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: past_key_values = kwargs.pop("past_key_values", None) cache_position = kwargs.pop("cache_position", None) if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs: dict = {"cache_position": cache_position} + if cos is not None: + cache_kwargs["cos"] = cos + if sin is not None: + cache_kwargs["sin"] = sin key_states, value_states = past_key_values.update( key_states, value_states, hf_attn.layer_idx, cache_kwargs ) diff --git a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py index 3af922a04..3b3a97648 100644 --- a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py +++ b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py @@ -109,8 +109,13 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor # Call original component to get (cos, sin) tuple output = self.original_component(*args, **kwargs) - # Ensure output is a tuple + # Ensure output is a tuple — or a complex tensor (DeepSeek-V2 freqs_cis style) if not isinstance(output, tuple): + if isinstance(output, torch.Tensor) and output.is_complex(): + # V2-style: freqs_cis complex tensor — pass through without cos/sin split. + # hook_cos/hook_sin do not apply here; the complex form is consumed by + # MLAAttentionBridge which detects it and uses complex multiplication. + return output if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)): output = tuple(output) else: diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 84e4584af..0d76a7f76 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -21,6 +21,9 @@ from transformer_lens.model_bridge.supported_architectures.cohere import ( CohereArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.deepseek_v2 import ( + DeepSeekV2ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.deepseek_v3 import ( DeepSeekV3ArchitectureAdapter, ) @@ -182,6 +185,7 @@ "BloomArchitectureAdapter", "CodeGenArchitectureAdapter", "CohereArchitectureAdapter", + "DeepSeekV2ArchitectureAdapter", "DeepSeekV3ArchitectureAdapter", "FalconArchitectureAdapter", "Gemma1ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py new file mode 100644 index 000000000..556a2dbba --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py @@ -0,0 +1,127 @@ +"""DeepSeek V2 architecture adapter. + +Supports DeepSeek-V2, DeepSeek-V2-Lite, and DeepSeek-Coder-V2 models +(all use DeepseekV2ForCausalLM). + +Key features: +- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections. + DeepSeek-V2-Lite sets q_lora_rank=None, skipping Q compression and using a direct + q_proj instead — MLAAttentionBridge.forward handles both paths automatically. +- Mixture of Experts (MoE) with shared experts on most layers +- Dense MLP on first `first_k_dense_replace` layers +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + MLAAttentionBridge, + MLABlockBridge, + MoEBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +class DeepSeekV2ArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for DeepSeek V2 / V2-Lite / Coder-V2 models. + + Uses RMSNorm, MLA with compressed Q/KV projections (or direct Q projection + when q_lora_rank is None), partial RoPE, MoE on most layers (dense MLP on + first few), and no biases. + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.gated_mlp = True + self.cfg.final_rms = True + self.cfg.uses_rms_norm = True + + self.weight_processing_conversions = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": MLABlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge( + name="post_attention_layernorm", config=self.cfg + ), + "attn": MLAAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + # V2-full (q_lora_rank set): two-stage LoRA Q compression. + # These are absent in V2-Lite — marked optional so bridge + # setup skips them gracefully. The actual forward call is + # handled inside MLAAttentionBridge which checks q_lora_rank. + "q_a_proj": LinearBridge(name="q_a_proj", optional=True), + # q_a_layernorm is a norm inside the attention block; its + # forward is called directly by MLAAttentionBridge, so a + # plain GeneralizedComponent (with optional support) suffices. + "q_a_layernorm": GeneralizedComponent( + name="q_a_layernorm", optional=True + ), + "q_b_proj": LinearBridge(name="q_b_proj", optional=True), + # V2-Lite only: direct Q projection, no compression. + "q_proj": LinearBridge(name="q_proj", optional=True), + # KV path — always present across all V2 variants. + "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"), + "kv_a_layernorm": RMSNormalizationBridge( + name="kv_a_layernorm", config=self.cfg + ), + "kv_b_proj": LinearBridge(name="kv_b_proj"), + "o": LinearBridge(name="o_proj"), + }, + ), + # On dense layers (idx < first_k_dense_replace), shared_experts + # are absent — marked optional so setup gracefully skips them when + # the layer is DeepseekV2MLP instead of MoE. + # Note: the gate module is NOT bridged — DeepseekV2Moe.forward() + # calls nn.functional.linear(..., self.gate.weight) directly, + # bypassing forward(), so no hook can be attached to it. + "mlp": MoEBridge( + name="mlp", + config=self.cfg, + submodules={ + "shared_experts": GatedMLPBridge( + name="shared_experts", + config=self.cfg, + optional=True, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for component testing.""" + rotary_emb = hf_model.model.rotary_emb + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) From 96fd92742dd73dd73332bc74848754b05e79bd02 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Fri, 19 Jun 2026 05:36:34 +0100 Subject: [PATCH 2/9] style: fix formatting and mypy errors --- .../model_bridge/test_deepseek_v2_adapter.py | 11 ++++++++--- .../generalized_components/rotary_embedding.py | 10 +++++++--- .../supported_architectures/deepseek_v2.py | 4 +--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/integration/model_bridge/test_deepseek_v2_adapter.py b/tests/integration/model_bridge/test_deepseek_v2_adapter.py index 22fef4d7e..bd74bfa82 100644 --- a/tests/integration/model_bridge/test_deepseek_v2_adapter.py +++ b/tests/integration/model_bridge/test_deepseek_v2_adapter.py @@ -58,6 +58,7 @@ def tiny_deepseek_v2_lite_bridge(): # Shared helpers # --------------------------------------------------------------------------- + def _tokens(): return torch.tensor([[1, 2, 3, 4]]) @@ -66,6 +67,7 @@ def _tokens(): # V2-full tests # --------------------------------------------------------------------------- + class TestDeepSeekV2BridgeCreation: def test_block_count(self, tiny_deepseek_v2_bridge): assert len(tiny_deepseek_v2_bridge.blocks) == 4 @@ -79,6 +81,7 @@ def test_attention_is_mla(self, tiny_deepseek_v2_bridge): from transformer_lens.model_bridge.generalized_components.mla_attention import ( MLAAttentionBridge, ) + assert isinstance(tiny_deepseek_v2_bridge.blocks[0].attn, MLAAttentionBridge) @@ -113,9 +116,9 @@ def test_moe_layer_has_shared_expert_hooks(self, tiny_deepseek_v2_bridge): # never called and its bridge hooks cannot fire. shared_experts IS called via # __call__, so GatedMLPBridge hooks fire correctly. _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) - assert not any("blocks.1.mlp.gate" in k for k in cache), ( - "gate hooks should not appear — gate is called via functional.linear, not forward()" - ) + assert not any( + "blocks.1.mlp.gate" in k for k in cache + ), "gate hooks should not appear — gate is called via functional.linear, not forward()" assert any("blocks.1.mlp.shared_experts" in k for k in cache) def test_all_layers_have_mlp_hooks(self, tiny_deepseek_v2_bridge): @@ -143,6 +146,7 @@ def test_mla_latent_hooks_fire(self, tiny_deepseek_v2_bridge): # V2-Lite tests (q_lora_rank=None — direct q_proj, no compression) # --------------------------------------------------------------------------- + class TestDeepSeekV2LiteBridgeCreation: def test_block_count(self, tiny_deepseek_v2_lite_bridge): assert len(tiny_deepseek_v2_lite_bridge.blocks) == 4 @@ -151,6 +155,7 @@ def test_attention_is_mla(self, tiny_deepseek_v2_lite_bridge): from transformer_lens.model_bridge.generalized_components.mla_attention import ( MLAAttentionBridge, ) + assert isinstance(tiny_deepseek_v2_lite_bridge.blocks[0].attn, MLAAttentionBridge) diff --git a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py index 3b3a97648..c560fff96 100644 --- a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py +++ b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py @@ -2,7 +2,7 @@ This module contains the bridge component for rotary position embedding layers. """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch @@ -81,7 +81,9 @@ def get_random_inputs( args = (x, position_ids, layer_type) return {"args": args} - def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, *args: Any, **kwargs: Any + ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Forward pass through the rotary embedding bridge. Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. @@ -94,7 +96,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor Returns: Tuple of (cos, sin) tensors for rotary position embeddings, after being - passed through hook_cos and hook_sin respectively + passed through hook_cos and hook_sin respectively. For DeepSeek-V2-style + embeddings that return a single complex ``freqs_cis`` tensor, that tensor is + passed through unchanged for downstream complex multiplication. """ if self.original_component is None: raise RuntimeError( diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py index 556a2dbba..76c977a54 100644 --- a/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py +++ b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py @@ -56,9 +56,7 @@ def __init__(self, cfg: Any) -> None: name="model.layers", submodules={ "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), - "ln2": RMSNormalizationBridge( - name="post_attention_layernorm", config=self.cfg - ), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), "attn": MLAAttentionBridge( name="self_attn", config=self.cfg, From 1f05ef737e9eec8a8f0966427d7323dee1743252 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Mon, 22 Jun 2026 19:13:06 +0100 Subject: [PATCH 3/9] fix: register DeepseekV2ForCausalLM in model registry --- transformer_lens/tools/model_registry/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 769b9b0d1..e08aa7cba 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -52,6 +52,7 @@ "BloomForCausalLM", "CodeGenForCausalLM", "CohereForCausalLM", + "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "FalconForCausalLM", "GemmaForCausalLM", @@ -112,6 +113,7 @@ "BloomForCausalLM": ["bigscience"], "CodeGenForCausalLM": ["Salesforce"], "CohereForCausalLM": ["CohereLabs"], + "DeepseekV2ForCausalLM": ["deepseek-ai"], "DeepseekV3ForCausalLM": ["deepseek-ai"], "FalconForCausalLM": ["tiiuae"], "Gemma2ForCausalLM": ["google"], From f5a3d23f6f733eb3076b917c6c4bb2a1f067906a Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Mon, 22 Jun 2026 19:58:38 +0000 Subject: [PATCH 4/9] feat: add NemotronH hybrid Mamba2-Transformer architecture adapter Implements TransformerBridge support for NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B). Architecture overview: - Heterogeneous layers defined by config.layers_block_type: each element is one of mamba, attention, moe, or mlp (~8% attention, ~92% SSM/MLP/MoE) - Single pre-norm (block.norm) and single residual path per block; no ln2 - Single .mixer attribute per block whose type varies by layer - No model-level rotary embedding module; attention handles RoPE internally - Stateful generation via DynamicCache (transformers >= 5.12) Key adapter decisions: - SSMBlockBridge as block container: delegates full forward to HF block, avoids ln2 enforcement that BlockBridge would apply incorrectly here - SSM2MixerBridge(name=mixer) as passthrough wrapper: works for all four mixer types since forward calls original_component(*args, **kwargs) - Mamba-specific submodules (in_proj, conv1d, inner_norm, out_proj) marked optional so component_setup skips them gracefully on non-Mamba layers - GatedRMSNormBridge.optional set post-init (its __init__ does not accept the kwarg, unlike the GeneralizedComponent base class) - positional_embedding_type=none: no model-level rotary to wire - gated_mlp=False: MLP layers use relu2, not SwiGLU - applicable_phases=[]: verify_models is transformer-shaped; integration tests cover forward-pass correctness instead Registration: - architecture_adapter_factory.py: NemotronHForCausalLM key added - supported_architectures/__init__.py: export added - tools/model_registry/__init__.py: HF_SUPPORTED_ARCHITECTURES and CANONICAL_AUTHORS_BY_ARCH entries added (canonical author: nvidia) Tests (52 unit tests, all passing): - Config attribute propagation (normalization_type, positional_embedding_type, gated_mlp, is_stateful, final_rms, mamba_intermediate_size, conv_dim, layers_block_type, applicable_phases, weight_processing_conversions) - Top-level component mapping bridge types and HF path names - Block submodule bridge types (norm, mixer; no ln2) - Mixer submodule types, names, and optional flags for all four Mamba keys - create_stateful_cache returns DynamicCache; independent per call - Factory registration and model registry constants - Guard tests: SSMBlockBridge not BlockBridge, no weight conversions, mamba_intermediate_size and conv_dim formulas Closes #1402 --- .../test_nemotron_h_adapter.py | 360 ++++++++++++++++++ .../factories/architecture_adapter_factory.py | 2 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/nemotron_h.py | 155 ++++++++ .../tools/model_registry/__init__.py | 2 + 5 files changed, 523 insertions(+) create mode 100644 tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/nemotron_h.py diff --git a/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py new file mode 100644 index 000000000..15866b047 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py @@ -0,0 +1,360 @@ +"""Unit tests for NemotronHArchitectureAdapter. + +Covers: config attribute propagation, component mapping bridge types and HF +path names, Mamba-specific submodule optional flag, applicable_phases, +create_stateful_cache, factory registration, and guard tests. +""" + +from unittest.mock import MagicMock + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory +from transformer_lens.model_bridge.generalized_components import ( + DepthwiseConv1DBridge, + EmbeddingBridge, + GatedRMSNormBridge, + LinearBridge, + RMSNormalizationBridge, + SSM2MixerBridge, + SSMBlockBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.nemotron_h import ( + NemotronHArchitectureAdapter, +) + + +def _make_cfg( + n_layers: int = 3, + d_model: int = 64, + d_head: int = 8, + n_heads: int = 8, + d_vocab: int = 100, + n_ctx: int = 128, + mamba_num_heads: int = 4, + mamba_head_dim: int = 8, + n_groups: int = 2, + ssm_state_size: int = 4, + layers_block_type: list[str] | None = None, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for NemotronH adapter tests. + + Uses small Mamba-2 dimensions so tests run without loading any weights. + """ + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_head, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + default_prepend_bos=False, + architecture="NemotronHForCausalLM", + ) + # Inject NemotronH-specific fields the adapter reads via getattr. + cfg.mamba_num_heads = mamba_num_heads # type: ignore[attr-defined] + cfg.mamba_head_dim = mamba_head_dim # type: ignore[attr-defined] + cfg.n_groups = n_groups # type: ignore[attr-defined] + cfg.ssm_state_size = ssm_state_size # type: ignore[attr-defined] + if layers_block_type is not None: + cfg.layers_block_type = layers_block_type # type: ignore[attr-defined] + return cfg + + +@pytest.fixture(scope="class") +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture(scope="class") +def adapter(cfg: TransformerBridgeConfig) -> NemotronHArchitectureAdapter: + return NemotronHArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attributes +# --------------------------------------------------------------------------- + + +class TestNemotronHAdapterConfig: + """Adapter propagates all required config attributes.""" + + def test_normalization_type_rms(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_uses_rms_norm(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_positional_embedding_type_none(self, adapter: NemotronHArchitectureAdapter) -> None: + # No model-level rotary module — attention handles RoPE internally. + assert adapter.cfg.positional_embedding_type == "none" + + def test_gated_mlp_false(self, adapter: NemotronHArchitectureAdapter) -> None: + # MLP layers use relu2, not SwiGLU. + assert adapter.cfg.gated_mlp is False + + def test_attn_only_false(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_final_rms_true(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_is_stateful_true(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.is_stateful is True + + def test_mamba_intermediate_size_propagated(self, adapter: NemotronHArchitectureAdapter) -> None: + # mamba_num_heads=4, mamba_head_dim=8 → 32 + assert getattr(adapter.cfg, "mamba_intermediate_size", None) == 32 + + def test_conv_dim_propagated(self, adapter: NemotronHArchitectureAdapter) -> None: + # intermediate=32, n_groups=2, ssm_state_size=4 → 32 + 2*2*4 = 48 + assert getattr(adapter.cfg, "conv_dim", None) == 48 + + def test_applicable_phases_empty(self) -> None: + # verify_models is transformer-shaped; SSM hybrids skip it. + assert NemotronHArchitectureAdapter.applicable_phases == [] + + def test_weight_processing_conversions_empty( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + assert adapter.weight_processing_conversions == {} + + def test_layers_block_type_propagated(self) -> None: + block_types = ["mamba", "attention", "mamba"] + cfg = _make_cfg(layers_block_type=block_types) + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "layers_block_type") == block_types + + def test_layers_block_type_defaults_to_empty(self) -> None: + cfg = _make_cfg() + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "layers_block_type") == [] + + +# --------------------------------------------------------------------------- +# Top-level component mapping +# --------------------------------------------------------------------------- + + +class TestNemotronHTopLevelComponents: + """component_mapping has exactly the expected top-level keys.""" + + def test_required_keys(self, adapter: NemotronHArchitectureAdapter) -> None: + assert set(adapter.component_mapping.keys()) == {"embed", "blocks", "ln_final", "unembed"} + + def test_no_rotary_emb_key(self, adapter: NemotronHArchitectureAdapter) -> None: + # No model-level rotary; attention handles RoPE via position_ids. + assert "rotary_emb" not in adapter.component_mapping + + def test_embed_is_embedding_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "model.embeddings" + + def test_blocks_is_ssm_block_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], SSMBlockBridge) + + def test_blocks_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "model.layers" + + def test_ln_final_is_rms_normalization_bridge( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_ln_final_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "model.norm_f" + + def test_unembed_is_unembedding_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + +# --------------------------------------------------------------------------- +# Block-level submodules +# --------------------------------------------------------------------------- + + +class TestNemotronHBlockSubmodules: + """SSMBlockBridge submodules have correct types and HF path names.""" + + @pytest.fixture(scope="class") + def blocks(self, adapter: NemotronHArchitectureAdapter) -> SSMBlockBridge: + return adapter.component_mapping["blocks"] + + def test_norm_is_rms_normalization_bridge(self, blocks: SSMBlockBridge) -> None: + assert isinstance(blocks.submodules["norm"], RMSNormalizationBridge) + + def test_norm_name(self, blocks: SSMBlockBridge) -> None: + assert blocks.submodules["norm"].name == "norm" + + def test_mixer_is_ssm2_mixer_bridge(self, blocks: SSMBlockBridge) -> None: + assert isinstance(blocks.submodules["mixer"], SSM2MixerBridge) + + def test_mixer_name(self, blocks: SSMBlockBridge) -> None: + assert blocks.submodules["mixer"].name == "mixer" + + def test_block_has_no_ln2(self, blocks: SSMBlockBridge) -> None: + # Single pre-norm architecture; no post-attention norm. + assert "ln2" not in blocks.submodules + + +# --------------------------------------------------------------------------- +# Mixer submodules (Mamba-specific, optional) +# --------------------------------------------------------------------------- + + +class TestNemotronHMixerSubmodules: + """SSM2MixerBridge submodules are Mamba-specific and optional.""" + + @pytest.fixture(scope="class") + def mixer(self, adapter: NemotronHArchitectureAdapter) -> SSM2MixerBridge: + return adapter.component_mapping["blocks"].submodules["mixer"] + + def test_in_proj_is_linear_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["in_proj"], LinearBridge) + + def test_in_proj_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["in_proj"].name == "in_proj" + + def test_in_proj_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["in_proj"].optional is True + + def test_conv1d_is_depthwise_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["conv1d"], DepthwiseConv1DBridge) + + def test_conv1d_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["conv1d"].name == "conv1d" + + def test_conv1d_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["conv1d"].optional is True + + def test_inner_norm_is_gated_rms_norm_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["inner_norm"], GatedRMSNormBridge) + + def test_inner_norm_name(self, mixer: SSM2MixerBridge) -> None: + # HF calls it "norm" inside the mixer; TL aliases to "inner_norm". + assert mixer.submodules["inner_norm"].name == "norm" + + def test_inner_norm_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["inner_norm"].optional is True + + def test_out_proj_is_linear_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["out_proj"], LinearBridge) + + def test_out_proj_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["out_proj"].name == "out_proj" + + def test_out_proj_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["out_proj"].optional is True + + def test_mixer_has_exactly_four_mamba_submodules(self, mixer: SSM2MixerBridge) -> None: + assert set(mixer.submodules.keys()) == {"in_proj", "conv1d", "inner_norm", "out_proj"} + + +# --------------------------------------------------------------------------- +# create_stateful_cache +# --------------------------------------------------------------------------- + + +class TestNemotronHStatefulCache: + """create_stateful_cache returns a DynamicCache instance.""" + + def test_returns_dynamic_cache(self, adapter: NemotronHArchitectureAdapter) -> None: + from transformers.cache_utils import DynamicCache + + hf_model = MagicMock() + cache = adapter.create_stateful_cache( + hf_model=hf_model, batch_size=1, device="cpu", dtype=None + ) + assert isinstance(cache, DynamicCache) + + def test_cache_independent_per_call(self, adapter: NemotronHArchitectureAdapter) -> None: + """Each call returns a fresh cache object.""" + hf_model = MagicMock() + c1 = adapter.create_stateful_cache(hf_model, 1, "cpu", None) + c2 = adapter.create_stateful_cache(hf_model, 1, "cpu", None) + assert c1 is not c2 + + +# --------------------------------------------------------------------------- +# Factory registration +# --------------------------------------------------------------------------- + + +class TestNemotronHFactoryRegistration: + """NemotronHForCausalLM is registered in the adapter factory.""" + + def test_factory_returns_nemotron_h_adapter(self) -> None: + cfg = _make_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, NemotronHArchitectureAdapter) + + def test_architecture_key_present(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import SUPPORTED_ARCHITECTURES + + assert "NemotronHForCausalLM" in SUPPORTED_ARCHITECTURES + assert SUPPORTED_ARCHITECTURES["NemotronHForCausalLM"] is NemotronHArchitectureAdapter + + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- + + +class TestNemotronHModelRegistry: + """NemotronHForCausalLM is listed in the model registry constants.""" + + def test_in_hf_supported_architectures(self) -> None: + from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + + assert "NemotronHForCausalLM" in HF_SUPPORTED_ARCHITECTURES + + def test_canonical_author_is_nvidia(self) -> None: + from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH + + assert CANONICAL_AUTHORS_BY_ARCH.get("NemotronHForCausalLM") == ["nvidia"] + + +# --------------------------------------------------------------------------- +# Guard tests +# --------------------------------------------------------------------------- + + +class TestNemotronHGuards: + """Guards against drift toward neighbouring adapter patterns.""" + + def test_uses_ssm_block_bridge_not_block_bridge( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + # BlockBridge enforces ln2; NemotronH is single-norm. Must use SSMBlockBridge. + from transformer_lens.model_bridge.generalized_components import BlockBridge + + blocks = adapter.component_mapping["blocks"] + assert not isinstance(blocks, BlockBridge) + assert isinstance(blocks, SSMBlockBridge) + + def test_no_weight_conversions_defined(self, adapter: NemotronHArchitectureAdapter) -> None: + # Unlike attention adapters, NemotronH has no rearrange/split conversions. + assert len(adapter.weight_processing_conversions) == 0 + + def test_cfg_is_not_attn_only(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_mamba_intermediate_size_formula(self) -> None: + """Verify formula: intermediate = mamba_num_heads * mamba_head_dim.""" + cfg = _make_cfg(mamba_num_heads=16, mamba_head_dim=32) + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "mamba_intermediate_size") == 16 * 32 + + def test_conv_dim_formula(self) -> None: + """Verify formula: conv_dim = intermediate + 2 * n_groups * ssm_state_size.""" + cfg = _make_cfg(mamba_num_heads=8, mamba_head_dim=16, n_groups=4, ssm_state_size=8) + a = NemotronHArchitectureAdapter(cfg) + expected = 8 * 16 + 2 * 4 * 8 # 128 + 64 = 192 + assert getattr(a.cfg, "conv_dim") == expected diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 8869a7be1..13ef42b14 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -40,6 +40,7 @@ LlavaOnevisionArchitectureAdapter, Mamba2ArchitectureAdapter, MambaArchitectureAdapter, + NemotronHArchitectureAdapter, MingptArchitectureAdapter, MistralArchitectureAdapter, MixtralArchitectureAdapter, @@ -105,6 +106,7 @@ "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter, "Mamba2ForCausalLM": Mamba2ArchitectureAdapter, "MambaForCausalLM": MambaArchitectureAdapter, + "NemotronHForCausalLM": NemotronHArchitectureAdapter, "MixtralForCausalLM": MixtralArchitectureAdapter, "MistralForCausalLM": MistralArchitectureAdapter, "MPTForCausalLM": MPTArchitectureAdapter, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 0d76a7f76..101a2aede 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -93,6 +93,9 @@ from transformer_lens.model_bridge.supported_architectures.mamba2 import ( Mamba2ArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.nemotron_h import ( + NemotronHArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.mingpt import ( MingptArchitectureAdapter, ) @@ -214,6 +217,7 @@ "MixtralArchitectureAdapter", "MPTArchitectureAdapter", "NanogptArchitectureAdapter", + "NemotronHArchitectureAdapter", "NativeArchitectureAdapter", "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/nemotron_h.py b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py new file mode 100644 index 000000000..594b92fc4 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py @@ -0,0 +1,155 @@ +"""Nemotron-H hybrid Mamba2-Transformer architecture adapter. + +Supports NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B). + +Architecture overview: +- Heterogeneous layers defined by ``config.layers_block_type`` — each element is + one of ``"mamba"``, ``"attention"``, ``"moe"``, or ``"mlp"``. +- ~8% of layers are standard GQA attention; the rest are Mamba-2 SSM, dense MLP, + or sparse MoE. All share a single pre-norm (``block.norm``) and a single residual + path; there is no ``ln2`` or post-attention norm. +- Each block exposes a single ``.mixer`` attribute whose type varies by layer. +- No model-level rotary embedding module — attention handles RoPE internally via + ``position_ids`` passed from the outer model loop. +- Stateful generation: uses ``DynamicCache`` (transformers ≥ 5.12) which carries + both KV-cache entries (attention layers) and SSM conv/recurrent states + (Mamba layers) in a unified object. + +Key adapter decisions: +- ``SSMBlockBridge`` is used as the block container. It delegates the entire + forward to the HF block, giving ``hook_in`` / ``hook_out`` on the residual + stream without hardcoding transformer-specific hook positions (hook_resid_mid, + hook_mlp_in, etc.) that do not exist in this single-norm architecture. +- ``SSM2MixerBridge`` wraps ``.mixer`` for all layer types. Its forward is a + pure passthrough (``original_component(*args, **kwargs)``) so it works + correctly for attention, MLP, and MoE mixers as well as Mamba ones. + Mamba-specific inner submodules (in_proj, conv1d, inner_norm, out_proj) are + declared ``optional=True`` so setup skips them gracefully on non-Mamba layers. +- MLP layers use ``relu2`` activation (not SwiGLU); ``gated_mlp = False``. +- ``applicable_phases = []``: ``verify_models`` is transformer-shaped and would + require a dedicated refactor to cover SSM hybrids. Coverage lives in the + integration test instead. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + DepthwiseConv1DBridge, + EmbeddingBridge, + GatedRMSNormBridge, + LinearBridge, + RMSNormalizationBridge, + SSM2MixerBridge, + SSMBlockBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import GeneralizedComponent + + +def _make_optional(component: "GeneralizedComponent") -> "GeneralizedComponent": + """Mark a GeneralizedComponent submodule as optional. + + Some bridge classes (e.g. GatedRMSNormBridge) do not forward ``optional`` + through their own ``__init__``, even though ``GeneralizedComponent`` supports + it. Setting the attribute directly is safe because ``component_setup.py`` + reads ``getattr(submodule, 'optional', False)`` at setup time. + """ + component.optional = True + return component + + +class NemotronHArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for NemotronHForCausalLM. + + Hybrid Mamba-2 + Attention + MoE + dense MLP model. All layers share a + single pre-norm and a single residual connection; the mixer type per layer + is determined by ``config.layers_block_type[layer_idx]``. + """ + + # verify_models is transformer-shaped and requires a dedicated refactor to + # cover SSM hybrids. Integration tests cover forward-pass correctness instead. + applicable_phases: list[int] = [] + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.uses_rms_norm = True + # No model-level rotary embedding module — attention handles RoPE + # internally via position_ids; set to "none" so the bridge does not + # attempt to wire a rotary_emb component. + self.cfg.positional_embedding_type = "none" + # MLP layers use relu2 (up_proj → act → down_proj), not SwiGLU. + self.cfg.gated_mlp = False + self.cfg.attn_only = False + self.cfg.final_rms = True + # Mamba layers require per-step SSM state; generation is stateful. + self.cfg.is_stateful = True + + # Expose the heterogeneous layer-type list so tests and analysis tools + # can inspect which layers are which without loading a full HF model. + layers_block_type = getattr(cfg, "layers_block_type", []) + setattr(self.cfg, "layers_block_type", layers_block_type) + + # Mamba-2 dimensional config (mirrors Mamba2ArchitectureAdapter). + mamba_num_heads = getattr(cfg, "mamba_num_heads", 128) + mamba_head_dim = getattr(cfg, "mamba_head_dim", 64) + mamba_intermediate_size = mamba_num_heads * mamba_head_dim + n_groups = getattr(cfg, "n_groups", 8) + ssm_state_size = getattr(cfg, "ssm_state_size", 128) + conv_dim = mamba_intermediate_size + 2 * n_groups * ssm_state_size + setattr(self.cfg, "mamba_intermediate_size", mamba_intermediate_size) + setattr(self.cfg, "conv_dim", conv_dim) + + self.weight_processing_conversions = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embeddings"), + "blocks": SSMBlockBridge( + name="model.layers", + submodules={ + # Single pre-norm shared across all layer types. + "norm": RMSNormalizationBridge(name="norm", config=self.cfg), + # Single mixer slot — type varies per layer (mamba / attention + # / moe / mlp). SSM2MixerBridge.forward() is a pure + # passthrough so it works for all four types. Mamba-specific + # inner submodules are optional and skipped on other types. + "mixer": SSM2MixerBridge( + name="mixer", + config=self.cfg, + submodules={ + # ── Mamba-only (optional on attention / moe / mlp) ── + "in_proj": LinearBridge(name="in_proj", optional=True), + "conv1d": DepthwiseConv1DBridge(name="conv1d", optional=True), + # HF names this "norm" inside the mixer; TL calls it + # "inner_norm" to avoid collision with the block-level norm. + # GatedRMSNormBridge.__init__ does not accept optional=, so + # we set the attribute directly after construction. + "inner_norm": _make_optional(GatedRMSNormBridge(name="norm")), + "out_proj": LinearBridge(name="out_proj", optional=True), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm_f", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def create_stateful_cache( + self, + hf_model: Any, + batch_size: int, + device: Any, + dtype: Any, + ) -> Any: + """Build the unified DynamicCache for stateful generation. + + Transformers ≥ 5.12 ships a unified ``DynamicCache`` that carries both + KV-cache entries (attention layers) and SSM conv/recurrent states + (Mamba layers) in a single object, using ``has_previous_state()`` to + distinguish which state is available for a given layer index. + """ + from transformers.cache_utils import DynamicCache + + return DynamicCache() diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index e08aa7cba..91535997b 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -79,6 +79,7 @@ "LlavaOnevisionForConditionalGeneration", "MambaForCausalLM", "Mamba2ForCausalLM", + "NemotronHForCausalLM", "MPTForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", @@ -139,6 +140,7 @@ "LlavaOnevisionForConditionalGeneration": ["llava-hf"], "Mamba2ForCausalLM": ["state-spaces"], "MambaForCausalLM": ["state-spaces"], + "NemotronHForCausalLM": ["nvidia"], "MistralForCausalLM": ["mistralai"], "MixtralForCausalLM": ["mistralai"], "MPTForCausalLM": ["mosaicml"], From 1591b59dcfb1942f626080e1f999eb49d91b96bf Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Mon, 22 Jun 2026 23:56:22 +0100 Subject: [PATCH 5/9] style: apply black + isort formatting fixes --- .../test_nemotron_h_adapter.py | 12 +++++++++--- .../supported_architectures/nemotron_h.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py index 15866b047..d0b4d05d2 100644 --- a/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py @@ -10,7 +10,9 @@ import pytest from transformer_lens.config import TransformerBridgeConfig -from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) from transformer_lens.model_bridge.generalized_components import ( DepthwiseConv1DBridge, EmbeddingBridge, @@ -104,7 +106,9 @@ def test_final_rms_true(self, adapter: NemotronHArchitectureAdapter) -> None: def test_is_stateful_true(self, adapter: NemotronHArchitectureAdapter) -> None: assert adapter.cfg.is_stateful is True - def test_mamba_intermediate_size_propagated(self, adapter: NemotronHArchitectureAdapter) -> None: + def test_mamba_intermediate_size_propagated( + self, adapter: NemotronHArchitectureAdapter + ) -> None: # mamba_num_heads=4, mamba_head_dim=8 → 32 assert getattr(adapter.cfg, "mamba_intermediate_size", None) == 32 @@ -296,7 +300,9 @@ def test_factory_returns_nemotron_h_adapter(self) -> None: assert isinstance(adapter, NemotronHArchitectureAdapter) def test_architecture_key_present(self) -> None: - from transformer_lens.factories.architecture_adapter_factory import SUPPORTED_ARCHITECTURES + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) assert "NemotronHForCausalLM" in SUPPORTED_ARCHITECTURES assert SUPPORTED_ARCHITECTURES["NemotronHForCausalLM"] is NemotronHArchitectureAdapter diff --git a/transformer_lens/model_bridge/supported_architectures/nemotron_h.py b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py index 594b92fc4..2f322fa8d 100644 --- a/transformer_lens/model_bridge/supported_architectures/nemotron_h.py +++ b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py @@ -44,7 +44,9 @@ SSMBlockBridge, UnembeddingBridge, ) -from transformer_lens.model_bridge.generalized_components.base import GeneralizedComponent +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) def _make_optional(component: "GeneralizedComponent") -> "GeneralizedComponent": From 9f5f0a80cb30b2c928ddac4c79c3d565a4344a2a Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Tue, 23 Jun 2026 00:02:38 +0100 Subject: [PATCH 6/9] style: fix isort order for NemotronHArchitectureAdapter import --- transformer_lens/factories/architecture_adapter_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 13ef42b14..bba458a77 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -40,7 +40,6 @@ LlavaOnevisionArchitectureAdapter, Mamba2ArchitectureAdapter, MambaArchitectureAdapter, - NemotronHArchitectureAdapter, MingptArchitectureAdapter, MistralArchitectureAdapter, MixtralArchitectureAdapter, @@ -48,6 +47,7 @@ NanogptArchitectureAdapter, NativeArchitectureAdapter, NeelSoluOldArchitectureAdapter, + NemotronHArchitectureAdapter, NeoArchitectureAdapter, NeoxArchitectureAdapter, Olmo2ArchitectureAdapter, From 0a750c7fc0aa65f1cd8b0875ee79f0b54db38f01 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Tue, 23 Jun 2026 16:24:03 +0000 Subject: [PATCH 7/9] test: add NemotronH forward-pass and generation parity integration tests --- .../model_bridge/test_nemotron_h_adapter.py | 258 ++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 tests/integration/model_bridge/test_nemotron_h_adapter.py diff --git a/tests/integration/model_bridge/test_nemotron_h_adapter.py b/tests/integration/model_bridge/test_nemotron_h_adapter.py new file mode 100644 index 000000000..e1bb9fa94 --- /dev/null +++ b/tests/integration/model_bridge/test_nemotron_h_adapter.py @@ -0,0 +1,258 @@ +"""Integration tests for the NemotronH architecture adapter. + +Verifies forward-pass and generation parity against nvidia/Nemotron-H-8B-Base: +- Forward-pass logits match HF exactly (bridge delegates the full forward to HF) +- Greedy multi-token generation matches HF bit-for-bit (exercises DynamicCache + state handling across attention, Mamba-2, MLP, and MoE layers) +- Sanity checks: config flags, block count, hook coverage + +Note: requires ~18 GB RAM (CPU) or ~16 GB VRAM (GPU) to load the 8B checkpoint. +On a machine with less memory, skip with: + pytest -m "not slow" tests/integration/model_bridge/test_nemotron_h_adapter.py + +Run with GPU acceleration: + CUDA_VISIBLE_DEVICES=0 pytest tests/integration/model_bridge/test_nemotron_h_adapter.py -v -s +""" + +import gc + +import pytest +import torch + +from transformer_lens.model_bridge.bridge import TransformerBridge +from transformer_lens.model_bridge.generalized_components import ( + SSM2MixerBridge, + SSMBlockBridge, +) + +MODEL = "nvidia/Nemotron-H-8B-Base" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _dtype() -> torch.dtype: + # bfloat16 on GPU to match HF defaults; float32 on CPU for numerical safety + return torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + +# --------------------------------------------------------------------------- +# Session fixture — load once, share across all test classes +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def nemotron_bridge(): + device = _device() + dtype = _dtype() + bridge = TransformerBridge.boot_transformers(MODEL, device=device, dtype=dtype) + yield bridge + # Cleanup + del bridge + if torch.cuda.is_available(): + torch.cuda.empty_cache() + for _ in range(3): + gc.collect() + + +# --------------------------------------------------------------------------- +# Config and bridge structure +# --------------------------------------------------------------------------- + + +class TestNemotronHBridgeCreation: + """Smoke-test that the bridge loads with the right config flags.""" + + def test_config_flags(self, nemotron_bridge: TransformerBridge) -> None: + cfg = nemotron_bridge.cfg + assert cfg.normalization_type == "RMS" + assert cfg.uses_rms_norm is True + assert cfg.positional_embedding_type == "none" + assert cfg.gated_mlp is False + assert cfg.is_stateful is True + + def test_block_count(self, nemotron_bridge: TransformerBridge) -> None: + # Nemotron-H-8B has 56 layers + assert len(nemotron_bridge.blocks) == 56 + + def test_blocks_are_ssm_block_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0], SSMBlockBridge) + + def test_mixer_is_ssm2_mixer_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0].mixer, SSM2MixerBridge) + + def test_layers_block_type_populated(self, nemotron_bridge: TransformerBridge) -> None: + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + assert len(lbt) == len(nemotron_bridge.blocks) + # Should contain at least one attention and one mamba layer + assert "attention" in lbt + assert "mamba" in lbt + + def test_mamba_intermediate_size_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "mamba_intermediate_size", 0) > 0 + + def test_conv_dim_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "conv_dim", 0) > 0 + + +# --------------------------------------------------------------------------- +# Forward-pass parity +# --------------------------------------------------------------------------- + + +class TestNemotronHForwardPass: + """Bridge logits must match HF logits exactly. + + NemotronHArchitectureAdapter uses SSM2MixerBridge with a pure passthrough + forward (original_component(*args, **kwargs)), so the bridge never + reimplements any computation. Parity with HF should be exact (diff == 0), + not just close. + """ + + @pytest.fixture(scope="class") + def tokens(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + + def test_forward_returns_logits( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert out.shape == (1, 8, nemotron_bridge.cfg.d_vocab) + assert not torch.isnan(out).any(), "NaN in bridge logits" + assert not torch.isinf(out).any(), "Inf in bridge logits" + + def test_forward_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out.float() - hf_out.float()).abs().max().item() + assert max_diff == 0.0, ( + f"Bridge vs HF forward max diff = {max_diff:.2e}. " + "Expected 0 because SSM2MixerBridge.forward() is a pure passthrough." + ) + + def test_forward_no_nan_on_longer_sequence( + self, nemotron_bridge: TransformerBridge + ) -> None: + # Exercise more SSM steps to catch state accumulation issues + tokens = torch.arange(1, 33).unsqueeze(0).to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert not torch.isnan(out).any(), "NaN in logits for 32-token sequence" + + +# --------------------------------------------------------------------------- +# Multi-token generation parity (exercises DynamicCache state handling) +# --------------------------------------------------------------------------- + + +class TestNemotronHGeneration: + """Bridge greedy generation must match HF native generate() exactly. + + This exercises the DynamicCache stateful loop: attention layers write KV + entries, Mamba-2 layers write recurrent SSM states, all via the same + unified cache object. Token-level equality with HF confirms the state + threading is correct across all four layer types (mamba / attention / + moe / mlp). + """ + + @pytest.fixture(scope="class") + def prompt(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4]]) + + def test_generation_produces_tokens( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + prompt = prompt.to(_device()) + with torch.no_grad(): + result = nemotron_bridge.generate(prompt, max_new_tokens=5, do_sample=False) + assert isinstance(result, torch.Tensor) + assert result.shape == (1, 9) # 4 prompt + 5 new + + def test_greedy_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Bit-for-bit equality with HF generate() over 8 new tokens.""" + prompt = prompt.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge.generate( + prompt, max_new_tokens=8, do_sample=False + ) + hf_out = hf_model.generate( + prompt, max_new_tokens=8, do_sample=False, pad_token_id=0 + ) + assert torch.equal(bridge_out, hf_out), ( + f"Token mismatch between bridge and HF.\n" + f" bridge : {bridge_out.tolist()}\n" + f" hf : {hf_out.tolist()}\n" + "DynamicCache state threading across layer types is likely wrong." + ) + + def test_generation_is_deterministic( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Two identical greedy calls must produce identical tokens.""" + prompt = prompt.to(_device()) + with torch.no_grad(): + out1 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + out2 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + assert torch.equal(out1, out2), "Greedy generation is not deterministic" + + +# --------------------------------------------------------------------------- +# Hook coverage: bridge hooks fire for both Mamba and attention layers +# --------------------------------------------------------------------------- + + +class TestNemotronHHookCoverage: + """run_with_cache captures residual stream and mixer hooks on all layer types.""" + + @pytest.fixture(scope="class") + def cache(self, nemotron_bridge: TransformerBridge): + tokens = torch.tensor([[1, 2, 3, 4, 5]]).to(_device()) + with torch.no_grad(): + _, cache = nemotron_bridge.run_with_cache(tokens) + return cache + + def test_block_hooks_fire(self, cache, nemotron_bridge: TransformerBridge) -> None: + for i in [0, 28, 55]: + assert f"blocks.{i}.hook_in" in cache, f"Missing hook_in for block {i}" + assert f"blocks.{i}.hook_out" in cache, f"Missing hook_out for block {i}" + + def test_mamba_mixer_submodule_hooks_fire( + self, cache, nemotron_bridge: TransformerBridge + ) -> None: + """Mamba layers should expose in_proj / conv1d / out_proj hooks.""" + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + mamba_indices = [i for i, t in enumerate(lbt) if t == "mamba"] + assert mamba_indices, "No mamba layers found in layers_block_type" + # Check a few mamba layers + for i in mamba_indices[:3]: + for submod in ("in_proj", "conv1d", "out_proj"): + key_in = f"blocks.{i}.mixer.{submod}.hook_in" + key_out = f"blocks.{i}.mixer.{submod}.hook_out" + assert key_in in cache, f"Missing {key_in}" + assert key_out in cache, f"Missing {key_out}" + + def test_no_transformer_specific_hooks(self, cache) -> None: + """SSMBlockBridge must not inject transformer-shaped hook names.""" + forbidden = ("hook_resid_mid", "hook_attn_out", "hook_mlp_out") + bad = [k for k in cache if any(f in k for f in forbidden)] + assert bad == [], f"Unexpected transformer-shaped hooks: {bad[:5]}" + + def test_no_nan_in_cache(self, cache) -> None: + for key, val in cache.items(): + if isinstance(val, torch.Tensor) and val.is_floating_point(): + assert not torch.isnan(val).any(), f"NaN in cache['{key}']" From df7bdd3056c9286eef0c4b997ebe6e1eae917fbc Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Tue, 23 Jun 2026 17:27:00 +0100 Subject: [PATCH 8/9] test: add NemotronH forward-pass and generation parity integration tests --- .../model_bridge/test_nemotron_h_adapter.py | 258 ------------------ 1 file changed, 258 deletions(-) delete mode 100644 tests/integration/model_bridge/test_nemotron_h_adapter.py diff --git a/tests/integration/model_bridge/test_nemotron_h_adapter.py b/tests/integration/model_bridge/test_nemotron_h_adapter.py deleted file mode 100644 index e1bb9fa94..000000000 --- a/tests/integration/model_bridge/test_nemotron_h_adapter.py +++ /dev/null @@ -1,258 +0,0 @@ -"""Integration tests for the NemotronH architecture adapter. - -Verifies forward-pass and generation parity against nvidia/Nemotron-H-8B-Base: -- Forward-pass logits match HF exactly (bridge delegates the full forward to HF) -- Greedy multi-token generation matches HF bit-for-bit (exercises DynamicCache - state handling across attention, Mamba-2, MLP, and MoE layers) -- Sanity checks: config flags, block count, hook coverage - -Note: requires ~18 GB RAM (CPU) or ~16 GB VRAM (GPU) to load the 8B checkpoint. -On a machine with less memory, skip with: - pytest -m "not slow" tests/integration/model_bridge/test_nemotron_h_adapter.py - -Run with GPU acceleration: - CUDA_VISIBLE_DEVICES=0 pytest tests/integration/model_bridge/test_nemotron_h_adapter.py -v -s -""" - -import gc - -import pytest -import torch - -from transformer_lens.model_bridge.bridge import TransformerBridge -from transformer_lens.model_bridge.generalized_components import ( - SSM2MixerBridge, - SSMBlockBridge, -) - -MODEL = "nvidia/Nemotron-H-8B-Base" - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _device() -> str: - return "cuda" if torch.cuda.is_available() else "cpu" - - -def _dtype() -> torch.dtype: - # bfloat16 on GPU to match HF defaults; float32 on CPU for numerical safety - return torch.bfloat16 if torch.cuda.is_available() else torch.float32 - - -# --------------------------------------------------------------------------- -# Session fixture — load once, share across all test classes -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="module") -def nemotron_bridge(): - device = _device() - dtype = _dtype() - bridge = TransformerBridge.boot_transformers(MODEL, device=device, dtype=dtype) - yield bridge - # Cleanup - del bridge - if torch.cuda.is_available(): - torch.cuda.empty_cache() - for _ in range(3): - gc.collect() - - -# --------------------------------------------------------------------------- -# Config and bridge structure -# --------------------------------------------------------------------------- - - -class TestNemotronHBridgeCreation: - """Smoke-test that the bridge loads with the right config flags.""" - - def test_config_flags(self, nemotron_bridge: TransformerBridge) -> None: - cfg = nemotron_bridge.cfg - assert cfg.normalization_type == "RMS" - assert cfg.uses_rms_norm is True - assert cfg.positional_embedding_type == "none" - assert cfg.gated_mlp is False - assert cfg.is_stateful is True - - def test_block_count(self, nemotron_bridge: TransformerBridge) -> None: - # Nemotron-H-8B has 56 layers - assert len(nemotron_bridge.blocks) == 56 - - def test_blocks_are_ssm_block_bridge(self, nemotron_bridge: TransformerBridge) -> None: - assert isinstance(nemotron_bridge.blocks[0], SSMBlockBridge) - - def test_mixer_is_ssm2_mixer_bridge(self, nemotron_bridge: TransformerBridge) -> None: - assert isinstance(nemotron_bridge.blocks[0].mixer, SSM2MixerBridge) - - def test_layers_block_type_populated(self, nemotron_bridge: TransformerBridge) -> None: - lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) - assert len(lbt) == len(nemotron_bridge.blocks) - # Should contain at least one attention and one mamba layer - assert "attention" in lbt - assert "mamba" in lbt - - def test_mamba_intermediate_size_positive(self, nemotron_bridge: TransformerBridge) -> None: - assert getattr(nemotron_bridge.cfg, "mamba_intermediate_size", 0) > 0 - - def test_conv_dim_positive(self, nemotron_bridge: TransformerBridge) -> None: - assert getattr(nemotron_bridge.cfg, "conv_dim", 0) > 0 - - -# --------------------------------------------------------------------------- -# Forward-pass parity -# --------------------------------------------------------------------------- - - -class TestNemotronHForwardPass: - """Bridge logits must match HF logits exactly. - - NemotronHArchitectureAdapter uses SSM2MixerBridge with a pure passthrough - forward (original_component(*args, **kwargs)), so the bridge never - reimplements any computation. Parity with HF should be exact (diff == 0), - not just close. - """ - - @pytest.fixture(scope="class") - def tokens(self) -> torch.Tensor: - return torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) - - def test_forward_returns_logits( - self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor - ) -> None: - tokens = tokens.to(_device()) - with torch.no_grad(): - out = nemotron_bridge(tokens) - assert out.shape == (1, 8, nemotron_bridge.cfg.d_vocab) - assert not torch.isnan(out).any(), "NaN in bridge logits" - assert not torch.isinf(out).any(), "Inf in bridge logits" - - def test_forward_matches_hf_exactly( - self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor - ) -> None: - tokens = tokens.to(_device()) - hf_model = nemotron_bridge.original_model - with torch.no_grad(): - bridge_out = nemotron_bridge(tokens) - hf_out = hf_model(tokens).logits - max_diff = (bridge_out.float() - hf_out.float()).abs().max().item() - assert max_diff == 0.0, ( - f"Bridge vs HF forward max diff = {max_diff:.2e}. " - "Expected 0 because SSM2MixerBridge.forward() is a pure passthrough." - ) - - def test_forward_no_nan_on_longer_sequence( - self, nemotron_bridge: TransformerBridge - ) -> None: - # Exercise more SSM steps to catch state accumulation issues - tokens = torch.arange(1, 33).unsqueeze(0).to(_device()) - with torch.no_grad(): - out = nemotron_bridge(tokens) - assert not torch.isnan(out).any(), "NaN in logits for 32-token sequence" - - -# --------------------------------------------------------------------------- -# Multi-token generation parity (exercises DynamicCache state handling) -# --------------------------------------------------------------------------- - - -class TestNemotronHGeneration: - """Bridge greedy generation must match HF native generate() exactly. - - This exercises the DynamicCache stateful loop: attention layers write KV - entries, Mamba-2 layers write recurrent SSM states, all via the same - unified cache object. Token-level equality with HF confirms the state - threading is correct across all four layer types (mamba / attention / - moe / mlp). - """ - - @pytest.fixture(scope="class") - def prompt(self) -> torch.Tensor: - return torch.tensor([[1, 2, 3, 4]]) - - def test_generation_produces_tokens( - self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor - ) -> None: - prompt = prompt.to(_device()) - with torch.no_grad(): - result = nemotron_bridge.generate(prompt, max_new_tokens=5, do_sample=False) - assert isinstance(result, torch.Tensor) - assert result.shape == (1, 9) # 4 prompt + 5 new - - def test_greedy_matches_hf_exactly( - self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor - ) -> None: - """Bit-for-bit equality with HF generate() over 8 new tokens.""" - prompt = prompt.to(_device()) - hf_model = nemotron_bridge.original_model - with torch.no_grad(): - bridge_out = nemotron_bridge.generate( - prompt, max_new_tokens=8, do_sample=False - ) - hf_out = hf_model.generate( - prompt, max_new_tokens=8, do_sample=False, pad_token_id=0 - ) - assert torch.equal(bridge_out, hf_out), ( - f"Token mismatch between bridge and HF.\n" - f" bridge : {bridge_out.tolist()}\n" - f" hf : {hf_out.tolist()}\n" - "DynamicCache state threading across layer types is likely wrong." - ) - - def test_generation_is_deterministic( - self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor - ) -> None: - """Two identical greedy calls must produce identical tokens.""" - prompt = prompt.to(_device()) - with torch.no_grad(): - out1 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) - out2 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) - assert torch.equal(out1, out2), "Greedy generation is not deterministic" - - -# --------------------------------------------------------------------------- -# Hook coverage: bridge hooks fire for both Mamba and attention layers -# --------------------------------------------------------------------------- - - -class TestNemotronHHookCoverage: - """run_with_cache captures residual stream and mixer hooks on all layer types.""" - - @pytest.fixture(scope="class") - def cache(self, nemotron_bridge: TransformerBridge): - tokens = torch.tensor([[1, 2, 3, 4, 5]]).to(_device()) - with torch.no_grad(): - _, cache = nemotron_bridge.run_with_cache(tokens) - return cache - - def test_block_hooks_fire(self, cache, nemotron_bridge: TransformerBridge) -> None: - for i in [0, 28, 55]: - assert f"blocks.{i}.hook_in" in cache, f"Missing hook_in for block {i}" - assert f"blocks.{i}.hook_out" in cache, f"Missing hook_out for block {i}" - - def test_mamba_mixer_submodule_hooks_fire( - self, cache, nemotron_bridge: TransformerBridge - ) -> None: - """Mamba layers should expose in_proj / conv1d / out_proj hooks.""" - lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) - mamba_indices = [i for i, t in enumerate(lbt) if t == "mamba"] - assert mamba_indices, "No mamba layers found in layers_block_type" - # Check a few mamba layers - for i in mamba_indices[:3]: - for submod in ("in_proj", "conv1d", "out_proj"): - key_in = f"blocks.{i}.mixer.{submod}.hook_in" - key_out = f"blocks.{i}.mixer.{submod}.hook_out" - assert key_in in cache, f"Missing {key_in}" - assert key_out in cache, f"Missing {key_out}" - - def test_no_transformer_specific_hooks(self, cache) -> None: - """SSMBlockBridge must not inject transformer-shaped hook names.""" - forbidden = ("hook_resid_mid", "hook_attn_out", "hook_mlp_out") - bad = [k for k in cache if any(f in k for f in forbidden)] - assert bad == [], f"Unexpected transformer-shaped hooks: {bad[:5]}" - - def test_no_nan_in_cache(self, cache) -> None: - for key, val in cache.items(): - if isinstance(val, torch.Tensor) and val.is_floating_point(): - assert not torch.isnan(val).any(), f"NaN in cache['{key}']" From ffde25d2aab6f9ca4cfa0d9bbf8bbf3514a487f3 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Tue, 23 Jun 2026 17:27:27 +0100 Subject: [PATCH 9/9] test: add NemotronH forward-pass and generation parity integration tests --- .../model_bridge/test_nemotron_h_adapter.py | 258 ++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 tests/integration/model_bridge/test_nemotron_h_adapter.py diff --git a/tests/integration/model_bridge/test_nemotron_h_adapter.py b/tests/integration/model_bridge/test_nemotron_h_adapter.py new file mode 100644 index 000000000..e1bb9fa94 --- /dev/null +++ b/tests/integration/model_bridge/test_nemotron_h_adapter.py @@ -0,0 +1,258 @@ +"""Integration tests for the NemotronH architecture adapter. + +Verifies forward-pass and generation parity against nvidia/Nemotron-H-8B-Base: +- Forward-pass logits match HF exactly (bridge delegates the full forward to HF) +- Greedy multi-token generation matches HF bit-for-bit (exercises DynamicCache + state handling across attention, Mamba-2, MLP, and MoE layers) +- Sanity checks: config flags, block count, hook coverage + +Note: requires ~18 GB RAM (CPU) or ~16 GB VRAM (GPU) to load the 8B checkpoint. +On a machine with less memory, skip with: + pytest -m "not slow" tests/integration/model_bridge/test_nemotron_h_adapter.py + +Run with GPU acceleration: + CUDA_VISIBLE_DEVICES=0 pytest tests/integration/model_bridge/test_nemotron_h_adapter.py -v -s +""" + +import gc + +import pytest +import torch + +from transformer_lens.model_bridge.bridge import TransformerBridge +from transformer_lens.model_bridge.generalized_components import ( + SSM2MixerBridge, + SSMBlockBridge, +) + +MODEL = "nvidia/Nemotron-H-8B-Base" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _dtype() -> torch.dtype: + # bfloat16 on GPU to match HF defaults; float32 on CPU for numerical safety + return torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + +# --------------------------------------------------------------------------- +# Session fixture — load once, share across all test classes +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def nemotron_bridge(): + device = _device() + dtype = _dtype() + bridge = TransformerBridge.boot_transformers(MODEL, device=device, dtype=dtype) + yield bridge + # Cleanup + del bridge + if torch.cuda.is_available(): + torch.cuda.empty_cache() + for _ in range(3): + gc.collect() + + +# --------------------------------------------------------------------------- +# Config and bridge structure +# --------------------------------------------------------------------------- + + +class TestNemotronHBridgeCreation: + """Smoke-test that the bridge loads with the right config flags.""" + + def test_config_flags(self, nemotron_bridge: TransformerBridge) -> None: + cfg = nemotron_bridge.cfg + assert cfg.normalization_type == "RMS" + assert cfg.uses_rms_norm is True + assert cfg.positional_embedding_type == "none" + assert cfg.gated_mlp is False + assert cfg.is_stateful is True + + def test_block_count(self, nemotron_bridge: TransformerBridge) -> None: + # Nemotron-H-8B has 56 layers + assert len(nemotron_bridge.blocks) == 56 + + def test_blocks_are_ssm_block_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0], SSMBlockBridge) + + def test_mixer_is_ssm2_mixer_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0].mixer, SSM2MixerBridge) + + def test_layers_block_type_populated(self, nemotron_bridge: TransformerBridge) -> None: + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + assert len(lbt) == len(nemotron_bridge.blocks) + # Should contain at least one attention and one mamba layer + assert "attention" in lbt + assert "mamba" in lbt + + def test_mamba_intermediate_size_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "mamba_intermediate_size", 0) > 0 + + def test_conv_dim_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "conv_dim", 0) > 0 + + +# --------------------------------------------------------------------------- +# Forward-pass parity +# --------------------------------------------------------------------------- + + +class TestNemotronHForwardPass: + """Bridge logits must match HF logits exactly. + + NemotronHArchitectureAdapter uses SSM2MixerBridge with a pure passthrough + forward (original_component(*args, **kwargs)), so the bridge never + reimplements any computation. Parity with HF should be exact (diff == 0), + not just close. + """ + + @pytest.fixture(scope="class") + def tokens(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + + def test_forward_returns_logits( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert out.shape == (1, 8, nemotron_bridge.cfg.d_vocab) + assert not torch.isnan(out).any(), "NaN in bridge logits" + assert not torch.isinf(out).any(), "Inf in bridge logits" + + def test_forward_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out.float() - hf_out.float()).abs().max().item() + assert max_diff == 0.0, ( + f"Bridge vs HF forward max diff = {max_diff:.2e}. " + "Expected 0 because SSM2MixerBridge.forward() is a pure passthrough." + ) + + def test_forward_no_nan_on_longer_sequence( + self, nemotron_bridge: TransformerBridge + ) -> None: + # Exercise more SSM steps to catch state accumulation issues + tokens = torch.arange(1, 33).unsqueeze(0).to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert not torch.isnan(out).any(), "NaN in logits for 32-token sequence" + + +# --------------------------------------------------------------------------- +# Multi-token generation parity (exercises DynamicCache state handling) +# --------------------------------------------------------------------------- + + +class TestNemotronHGeneration: + """Bridge greedy generation must match HF native generate() exactly. + + This exercises the DynamicCache stateful loop: attention layers write KV + entries, Mamba-2 layers write recurrent SSM states, all via the same + unified cache object. Token-level equality with HF confirms the state + threading is correct across all four layer types (mamba / attention / + moe / mlp). + """ + + @pytest.fixture(scope="class") + def prompt(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4]]) + + def test_generation_produces_tokens( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + prompt = prompt.to(_device()) + with torch.no_grad(): + result = nemotron_bridge.generate(prompt, max_new_tokens=5, do_sample=False) + assert isinstance(result, torch.Tensor) + assert result.shape == (1, 9) # 4 prompt + 5 new + + def test_greedy_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Bit-for-bit equality with HF generate() over 8 new tokens.""" + prompt = prompt.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge.generate( + prompt, max_new_tokens=8, do_sample=False + ) + hf_out = hf_model.generate( + prompt, max_new_tokens=8, do_sample=False, pad_token_id=0 + ) + assert torch.equal(bridge_out, hf_out), ( + f"Token mismatch between bridge and HF.\n" + f" bridge : {bridge_out.tolist()}\n" + f" hf : {hf_out.tolist()}\n" + "DynamicCache state threading across layer types is likely wrong." + ) + + def test_generation_is_deterministic( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Two identical greedy calls must produce identical tokens.""" + prompt = prompt.to(_device()) + with torch.no_grad(): + out1 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + out2 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + assert torch.equal(out1, out2), "Greedy generation is not deterministic" + + +# --------------------------------------------------------------------------- +# Hook coverage: bridge hooks fire for both Mamba and attention layers +# --------------------------------------------------------------------------- + + +class TestNemotronHHookCoverage: + """run_with_cache captures residual stream and mixer hooks on all layer types.""" + + @pytest.fixture(scope="class") + def cache(self, nemotron_bridge: TransformerBridge): + tokens = torch.tensor([[1, 2, 3, 4, 5]]).to(_device()) + with torch.no_grad(): + _, cache = nemotron_bridge.run_with_cache(tokens) + return cache + + def test_block_hooks_fire(self, cache, nemotron_bridge: TransformerBridge) -> None: + for i in [0, 28, 55]: + assert f"blocks.{i}.hook_in" in cache, f"Missing hook_in for block {i}" + assert f"blocks.{i}.hook_out" in cache, f"Missing hook_out for block {i}" + + def test_mamba_mixer_submodule_hooks_fire( + self, cache, nemotron_bridge: TransformerBridge + ) -> None: + """Mamba layers should expose in_proj / conv1d / out_proj hooks.""" + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + mamba_indices = [i for i, t in enumerate(lbt) if t == "mamba"] + assert mamba_indices, "No mamba layers found in layers_block_type" + # Check a few mamba layers + for i in mamba_indices[:3]: + for submod in ("in_proj", "conv1d", "out_proj"): + key_in = f"blocks.{i}.mixer.{submod}.hook_in" + key_out = f"blocks.{i}.mixer.{submod}.hook_out" + assert key_in in cache, f"Missing {key_in}" + assert key_out in cache, f"Missing {key_out}" + + def test_no_transformer_specific_hooks(self, cache) -> None: + """SSMBlockBridge must not inject transformer-shaped hook names.""" + forbidden = ("hook_resid_mid", "hook_attn_out", "hook_mlp_out") + bad = [k for k in cache if any(f in k for f in forbidden)] + assert bad == [], f"Unexpected transformer-shaped hooks: {bad[:5]}" + + def test_no_nan_in_cache(self, cache) -> None: + for key, val in cache.items(): + if isinstance(val, torch.Tensor) and val.is_floating_point(): + assert not torch.isnan(val).any(), f"NaN in cache['{key}']"