diff --git a/tests/integration/model_bridge/test_nemotron_h_adapter.py b/tests/integration/model_bridge/test_nemotron_h_adapter.py new file mode 100644 index 000000000..e1bb9fa94 --- /dev/null +++ b/tests/integration/model_bridge/test_nemotron_h_adapter.py @@ -0,0 +1,258 @@ +"""Integration tests for the NemotronH architecture adapter. + +Verifies forward-pass and generation parity against nvidia/Nemotron-H-8B-Base: +- Forward-pass logits match HF exactly (bridge delegates the full forward to HF) +- Greedy multi-token generation matches HF bit-for-bit (exercises DynamicCache + state handling across attention, Mamba-2, MLP, and MoE layers) +- Sanity checks: config flags, block count, hook coverage + +Note: requires ~18 GB RAM (CPU) or ~16 GB VRAM (GPU) to load the 8B checkpoint. +On a machine with less memory, skip with: + pytest -m "not slow" tests/integration/model_bridge/test_nemotron_h_adapter.py + +Run with GPU acceleration: + CUDA_VISIBLE_DEVICES=0 pytest tests/integration/model_bridge/test_nemotron_h_adapter.py -v -s +""" + +import gc + +import pytest +import torch + +from transformer_lens.model_bridge.bridge import TransformerBridge +from transformer_lens.model_bridge.generalized_components import ( + SSM2MixerBridge, + SSMBlockBridge, +) + +MODEL = "nvidia/Nemotron-H-8B-Base" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _dtype() -> torch.dtype: + # bfloat16 on GPU to match HF defaults; float32 on CPU for numerical safety + return torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + +# --------------------------------------------------------------------------- +# Session fixture — load once, share across all test classes +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def nemotron_bridge(): + device = _device() + dtype = _dtype() + bridge = TransformerBridge.boot_transformers(MODEL, device=device, dtype=dtype) + yield bridge + # Cleanup + del bridge + if torch.cuda.is_available(): + torch.cuda.empty_cache() + for _ in range(3): + gc.collect() + + +# --------------------------------------------------------------------------- +# Config and bridge structure +# --------------------------------------------------------------------------- + + +class TestNemotronHBridgeCreation: + """Smoke-test that the bridge loads with the right config flags.""" + + def test_config_flags(self, nemotron_bridge: TransformerBridge) -> None: + cfg = nemotron_bridge.cfg + assert cfg.normalization_type == "RMS" + assert cfg.uses_rms_norm is True + assert cfg.positional_embedding_type == "none" + assert cfg.gated_mlp is False + assert cfg.is_stateful is True + + def test_block_count(self, nemotron_bridge: TransformerBridge) -> None: + # Nemotron-H-8B has 56 layers + assert len(nemotron_bridge.blocks) == 56 + + def test_blocks_are_ssm_block_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0], SSMBlockBridge) + + def test_mixer_is_ssm2_mixer_bridge(self, nemotron_bridge: TransformerBridge) -> None: + assert isinstance(nemotron_bridge.blocks[0].mixer, SSM2MixerBridge) + + def test_layers_block_type_populated(self, nemotron_bridge: TransformerBridge) -> None: + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + assert len(lbt) == len(nemotron_bridge.blocks) + # Should contain at least one attention and one mamba layer + assert "attention" in lbt + assert "mamba" in lbt + + def test_mamba_intermediate_size_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "mamba_intermediate_size", 0) > 0 + + def test_conv_dim_positive(self, nemotron_bridge: TransformerBridge) -> None: + assert getattr(nemotron_bridge.cfg, "conv_dim", 0) > 0 + + +# --------------------------------------------------------------------------- +# Forward-pass parity +# --------------------------------------------------------------------------- + + +class TestNemotronHForwardPass: + """Bridge logits must match HF logits exactly. + + NemotronHArchitectureAdapter uses SSM2MixerBridge with a pure passthrough + forward (original_component(*args, **kwargs)), so the bridge never + reimplements any computation. Parity with HF should be exact (diff == 0), + not just close. + """ + + @pytest.fixture(scope="class") + def tokens(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + + def test_forward_returns_logits( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert out.shape == (1, 8, nemotron_bridge.cfg.d_vocab) + assert not torch.isnan(out).any(), "NaN in bridge logits" + assert not torch.isinf(out).any(), "Inf in bridge logits" + + def test_forward_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, tokens: torch.Tensor + ) -> None: + tokens = tokens.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out.float() - hf_out.float()).abs().max().item() + assert max_diff == 0.0, ( + f"Bridge vs HF forward max diff = {max_diff:.2e}. " + "Expected 0 because SSM2MixerBridge.forward() is a pure passthrough." + ) + + def test_forward_no_nan_on_longer_sequence( + self, nemotron_bridge: TransformerBridge + ) -> None: + # Exercise more SSM steps to catch state accumulation issues + tokens = torch.arange(1, 33).unsqueeze(0).to(_device()) + with torch.no_grad(): + out = nemotron_bridge(tokens) + assert not torch.isnan(out).any(), "NaN in logits for 32-token sequence" + + +# --------------------------------------------------------------------------- +# Multi-token generation parity (exercises DynamicCache state handling) +# --------------------------------------------------------------------------- + + +class TestNemotronHGeneration: + """Bridge greedy generation must match HF native generate() exactly. + + This exercises the DynamicCache stateful loop: attention layers write KV + entries, Mamba-2 layers write recurrent SSM states, all via the same + unified cache object. Token-level equality with HF confirms the state + threading is correct across all four layer types (mamba / attention / + moe / mlp). + """ + + @pytest.fixture(scope="class") + def prompt(self) -> torch.Tensor: + return torch.tensor([[1, 2, 3, 4]]) + + def test_generation_produces_tokens( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + prompt = prompt.to(_device()) + with torch.no_grad(): + result = nemotron_bridge.generate(prompt, max_new_tokens=5, do_sample=False) + assert isinstance(result, torch.Tensor) + assert result.shape == (1, 9) # 4 prompt + 5 new + + def test_greedy_matches_hf_exactly( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Bit-for-bit equality with HF generate() over 8 new tokens.""" + prompt = prompt.to(_device()) + hf_model = nemotron_bridge.original_model + with torch.no_grad(): + bridge_out = nemotron_bridge.generate( + prompt, max_new_tokens=8, do_sample=False + ) + hf_out = hf_model.generate( + prompt, max_new_tokens=8, do_sample=False, pad_token_id=0 + ) + assert torch.equal(bridge_out, hf_out), ( + f"Token mismatch between bridge and HF.\n" + f" bridge : {bridge_out.tolist()}\n" + f" hf : {hf_out.tolist()}\n" + "DynamicCache state threading across layer types is likely wrong." + ) + + def test_generation_is_deterministic( + self, nemotron_bridge: TransformerBridge, prompt: torch.Tensor + ) -> None: + """Two identical greedy calls must produce identical tokens.""" + prompt = prompt.to(_device()) + with torch.no_grad(): + out1 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + out2 = nemotron_bridge.generate(prompt, max_new_tokens=4, do_sample=False) + assert torch.equal(out1, out2), "Greedy generation is not deterministic" + + +# --------------------------------------------------------------------------- +# Hook coverage: bridge hooks fire for both Mamba and attention layers +# --------------------------------------------------------------------------- + + +class TestNemotronHHookCoverage: + """run_with_cache captures residual stream and mixer hooks on all layer types.""" + + @pytest.fixture(scope="class") + def cache(self, nemotron_bridge: TransformerBridge): + tokens = torch.tensor([[1, 2, 3, 4, 5]]).to(_device()) + with torch.no_grad(): + _, cache = nemotron_bridge.run_with_cache(tokens) + return cache + + def test_block_hooks_fire(self, cache, nemotron_bridge: TransformerBridge) -> None: + for i in [0, 28, 55]: + assert f"blocks.{i}.hook_in" in cache, f"Missing hook_in for block {i}" + assert f"blocks.{i}.hook_out" in cache, f"Missing hook_out for block {i}" + + def test_mamba_mixer_submodule_hooks_fire( + self, cache, nemotron_bridge: TransformerBridge + ) -> None: + """Mamba layers should expose in_proj / conv1d / out_proj hooks.""" + lbt = getattr(nemotron_bridge.cfg, "layers_block_type", []) + mamba_indices = [i for i, t in enumerate(lbt) if t == "mamba"] + assert mamba_indices, "No mamba layers found in layers_block_type" + # Check a few mamba layers + for i in mamba_indices[:3]: + for submod in ("in_proj", "conv1d", "out_proj"): + key_in = f"blocks.{i}.mixer.{submod}.hook_in" + key_out = f"blocks.{i}.mixer.{submod}.hook_out" + assert key_in in cache, f"Missing {key_in}" + assert key_out in cache, f"Missing {key_out}" + + def test_no_transformer_specific_hooks(self, cache) -> None: + """SSMBlockBridge must not inject transformer-shaped hook names.""" + forbidden = ("hook_resid_mid", "hook_attn_out", "hook_mlp_out") + bad = [k for k in cache if any(f in k for f in forbidden)] + assert bad == [], f"Unexpected transformer-shaped hooks: {bad[:5]}" + + def test_no_nan_in_cache(self, cache) -> None: + for key, val in cache.items(): + if isinstance(val, torch.Tensor) and val.is_floating_point(): + assert not torch.isnan(val).any(), f"NaN in cache['{key}']" diff --git a/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py new file mode 100644 index 000000000..d0b4d05d2 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_nemotron_h_adapter.py @@ -0,0 +1,366 @@ +"""Unit tests for NemotronHArchitectureAdapter. + +Covers: config attribute propagation, component mapping bridge types and HF +path names, Mamba-specific submodule optional flag, applicable_phases, +create_stateful_cache, factory registration, and guard tests. +""" + +from unittest.mock import MagicMock + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + DepthwiseConv1DBridge, + EmbeddingBridge, + GatedRMSNormBridge, + LinearBridge, + RMSNormalizationBridge, + SSM2MixerBridge, + SSMBlockBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.nemotron_h import ( + NemotronHArchitectureAdapter, +) + + +def _make_cfg( + n_layers: int = 3, + d_model: int = 64, + d_head: int = 8, + n_heads: int = 8, + d_vocab: int = 100, + n_ctx: int = 128, + mamba_num_heads: int = 4, + mamba_head_dim: int = 8, + n_groups: int = 2, + ssm_state_size: int = 4, + layers_block_type: list[str] | None = None, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for NemotronH adapter tests. + + Uses small Mamba-2 dimensions so tests run without loading any weights. + """ + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_head, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + default_prepend_bos=False, + architecture="NemotronHForCausalLM", + ) + # Inject NemotronH-specific fields the adapter reads via getattr. + cfg.mamba_num_heads = mamba_num_heads # type: ignore[attr-defined] + cfg.mamba_head_dim = mamba_head_dim # type: ignore[attr-defined] + cfg.n_groups = n_groups # type: ignore[attr-defined] + cfg.ssm_state_size = ssm_state_size # type: ignore[attr-defined] + if layers_block_type is not None: + cfg.layers_block_type = layers_block_type # type: ignore[attr-defined] + return cfg + + +@pytest.fixture(scope="class") +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture(scope="class") +def adapter(cfg: TransformerBridgeConfig) -> NemotronHArchitectureAdapter: + return NemotronHArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attributes +# --------------------------------------------------------------------------- + + +class TestNemotronHAdapterConfig: + """Adapter propagates all required config attributes.""" + + def test_normalization_type_rms(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_uses_rms_norm(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_positional_embedding_type_none(self, adapter: NemotronHArchitectureAdapter) -> None: + # No model-level rotary module — attention handles RoPE internally. + assert adapter.cfg.positional_embedding_type == "none" + + def test_gated_mlp_false(self, adapter: NemotronHArchitectureAdapter) -> None: + # MLP layers use relu2, not SwiGLU. + assert adapter.cfg.gated_mlp is False + + def test_attn_only_false(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_final_rms_true(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_is_stateful_true(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.is_stateful is True + + def test_mamba_intermediate_size_propagated( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + # mamba_num_heads=4, mamba_head_dim=8 → 32 + assert getattr(adapter.cfg, "mamba_intermediate_size", None) == 32 + + def test_conv_dim_propagated(self, adapter: NemotronHArchitectureAdapter) -> None: + # intermediate=32, n_groups=2, ssm_state_size=4 → 32 + 2*2*4 = 48 + assert getattr(adapter.cfg, "conv_dim", None) == 48 + + def test_applicable_phases_empty(self) -> None: + # verify_models is transformer-shaped; SSM hybrids skip it. + assert NemotronHArchitectureAdapter.applicable_phases == [] + + def test_weight_processing_conversions_empty( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + assert adapter.weight_processing_conversions == {} + + def test_layers_block_type_propagated(self) -> None: + block_types = ["mamba", "attention", "mamba"] + cfg = _make_cfg(layers_block_type=block_types) + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "layers_block_type") == block_types + + def test_layers_block_type_defaults_to_empty(self) -> None: + cfg = _make_cfg() + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "layers_block_type") == [] + + +# --------------------------------------------------------------------------- +# Top-level component mapping +# --------------------------------------------------------------------------- + + +class TestNemotronHTopLevelComponents: + """component_mapping has exactly the expected top-level keys.""" + + def test_required_keys(self, adapter: NemotronHArchitectureAdapter) -> None: + assert set(adapter.component_mapping.keys()) == {"embed", "blocks", "ln_final", "unembed"} + + def test_no_rotary_emb_key(self, adapter: NemotronHArchitectureAdapter) -> None: + # No model-level rotary; attention handles RoPE via position_ids. + assert "rotary_emb" not in adapter.component_mapping + + def test_embed_is_embedding_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "model.embeddings" + + def test_blocks_is_ssm_block_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], SSMBlockBridge) + + def test_blocks_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "model.layers" + + def test_ln_final_is_rms_normalization_bridge( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_ln_final_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "model.norm_f" + + def test_unembed_is_unembedding_bridge(self, adapter: NemotronHArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + +# --------------------------------------------------------------------------- +# Block-level submodules +# --------------------------------------------------------------------------- + + +class TestNemotronHBlockSubmodules: + """SSMBlockBridge submodules have correct types and HF path names.""" + + @pytest.fixture(scope="class") + def blocks(self, adapter: NemotronHArchitectureAdapter) -> SSMBlockBridge: + return adapter.component_mapping["blocks"] + + def test_norm_is_rms_normalization_bridge(self, blocks: SSMBlockBridge) -> None: + assert isinstance(blocks.submodules["norm"], RMSNormalizationBridge) + + def test_norm_name(self, blocks: SSMBlockBridge) -> None: + assert blocks.submodules["norm"].name == "norm" + + def test_mixer_is_ssm2_mixer_bridge(self, blocks: SSMBlockBridge) -> None: + assert isinstance(blocks.submodules["mixer"], SSM2MixerBridge) + + def test_mixer_name(self, blocks: SSMBlockBridge) -> None: + assert blocks.submodules["mixer"].name == "mixer" + + def test_block_has_no_ln2(self, blocks: SSMBlockBridge) -> None: + # Single pre-norm architecture; no post-attention norm. + assert "ln2" not in blocks.submodules + + +# --------------------------------------------------------------------------- +# Mixer submodules (Mamba-specific, optional) +# --------------------------------------------------------------------------- + + +class TestNemotronHMixerSubmodules: + """SSM2MixerBridge submodules are Mamba-specific and optional.""" + + @pytest.fixture(scope="class") + def mixer(self, adapter: NemotronHArchitectureAdapter) -> SSM2MixerBridge: + return adapter.component_mapping["blocks"].submodules["mixer"] + + def test_in_proj_is_linear_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["in_proj"], LinearBridge) + + def test_in_proj_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["in_proj"].name == "in_proj" + + def test_in_proj_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["in_proj"].optional is True + + def test_conv1d_is_depthwise_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["conv1d"], DepthwiseConv1DBridge) + + def test_conv1d_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["conv1d"].name == "conv1d" + + def test_conv1d_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["conv1d"].optional is True + + def test_inner_norm_is_gated_rms_norm_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["inner_norm"], GatedRMSNormBridge) + + def test_inner_norm_name(self, mixer: SSM2MixerBridge) -> None: + # HF calls it "norm" inside the mixer; TL aliases to "inner_norm". + assert mixer.submodules["inner_norm"].name == "norm" + + def test_inner_norm_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["inner_norm"].optional is True + + def test_out_proj_is_linear_bridge(self, mixer: SSM2MixerBridge) -> None: + assert isinstance(mixer.submodules["out_proj"], LinearBridge) + + def test_out_proj_name(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["out_proj"].name == "out_proj" + + def test_out_proj_optional(self, mixer: SSM2MixerBridge) -> None: + assert mixer.submodules["out_proj"].optional is True + + def test_mixer_has_exactly_four_mamba_submodules(self, mixer: SSM2MixerBridge) -> None: + assert set(mixer.submodules.keys()) == {"in_proj", "conv1d", "inner_norm", "out_proj"} + + +# --------------------------------------------------------------------------- +# create_stateful_cache +# --------------------------------------------------------------------------- + + +class TestNemotronHStatefulCache: + """create_stateful_cache returns a DynamicCache instance.""" + + def test_returns_dynamic_cache(self, adapter: NemotronHArchitectureAdapter) -> None: + from transformers.cache_utils import DynamicCache + + hf_model = MagicMock() + cache = adapter.create_stateful_cache( + hf_model=hf_model, batch_size=1, device="cpu", dtype=None + ) + assert isinstance(cache, DynamicCache) + + def test_cache_independent_per_call(self, adapter: NemotronHArchitectureAdapter) -> None: + """Each call returns a fresh cache object.""" + hf_model = MagicMock() + c1 = adapter.create_stateful_cache(hf_model, 1, "cpu", None) + c2 = adapter.create_stateful_cache(hf_model, 1, "cpu", None) + assert c1 is not c2 + + +# --------------------------------------------------------------------------- +# Factory registration +# --------------------------------------------------------------------------- + + +class TestNemotronHFactoryRegistration: + """NemotronHForCausalLM is registered in the adapter factory.""" + + def test_factory_returns_nemotron_h_adapter(self) -> None: + cfg = _make_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, NemotronHArchitectureAdapter) + + def test_architecture_key_present(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "NemotronHForCausalLM" in SUPPORTED_ARCHITECTURES + assert SUPPORTED_ARCHITECTURES["NemotronHForCausalLM"] is NemotronHArchitectureAdapter + + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- + + +class TestNemotronHModelRegistry: + """NemotronHForCausalLM is listed in the model registry constants.""" + + def test_in_hf_supported_architectures(self) -> None: + from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + + assert "NemotronHForCausalLM" in HF_SUPPORTED_ARCHITECTURES + + def test_canonical_author_is_nvidia(self) -> None: + from transformer_lens.tools.model_registry import CANONICAL_AUTHORS_BY_ARCH + + assert CANONICAL_AUTHORS_BY_ARCH.get("NemotronHForCausalLM") == ["nvidia"] + + +# --------------------------------------------------------------------------- +# Guard tests +# --------------------------------------------------------------------------- + + +class TestNemotronHGuards: + """Guards against drift toward neighbouring adapter patterns.""" + + def test_uses_ssm_block_bridge_not_block_bridge( + self, adapter: NemotronHArchitectureAdapter + ) -> None: + # BlockBridge enforces ln2; NemotronH is single-norm. Must use SSMBlockBridge. + from transformer_lens.model_bridge.generalized_components import BlockBridge + + blocks = adapter.component_mapping["blocks"] + assert not isinstance(blocks, BlockBridge) + assert isinstance(blocks, SSMBlockBridge) + + def test_no_weight_conversions_defined(self, adapter: NemotronHArchitectureAdapter) -> None: + # Unlike attention adapters, NemotronH has no rearrange/split conversions. + assert len(adapter.weight_processing_conversions) == 0 + + def test_cfg_is_not_attn_only(self, adapter: NemotronHArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_mamba_intermediate_size_formula(self) -> None: + """Verify formula: intermediate = mamba_num_heads * mamba_head_dim.""" + cfg = _make_cfg(mamba_num_heads=16, mamba_head_dim=32) + a = NemotronHArchitectureAdapter(cfg) + assert getattr(a.cfg, "mamba_intermediate_size") == 16 * 32 + + def test_conv_dim_formula(self) -> None: + """Verify formula: conv_dim = intermediate + 2 * n_groups * ssm_state_size.""" + cfg = _make_cfg(mamba_num_heads=8, mamba_head_dim=16, n_groups=4, ssm_state_size=8) + a = NemotronHArchitectureAdapter(cfg) + expected = 8 * 16 + 2 * 4 * 8 # 128 + 64 = 192 + assert getattr(a.cfg, "conv_dim") == expected diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 9222ef97f..0345506ef 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -49,6 +49,7 @@ NanogptArchitectureAdapter, NativeArchitectureAdapter, NeelSoluOldArchitectureAdapter, + NemotronHArchitectureAdapter, NeoArchitectureAdapter, NeoxArchitectureAdapter, Olmo2ArchitectureAdapter, @@ -111,6 +112,7 @@ "Lfm2MoeForCausalLM": Lfm2MoeArchitectureAdapter, "Mamba2ForCausalLM": Mamba2ArchitectureAdapter, "MambaForCausalLM": MambaArchitectureAdapter, + "NemotronHForCausalLM": NemotronHArchitectureAdapter, "MixtralForCausalLM": MixtralArchitectureAdapter, "MistralForCausalLM": MistralArchitectureAdapter, "MPTForCausalLM": MPTArchitectureAdapter, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index c975a2229..f170c6ab3 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -99,6 +99,9 @@ from transformer_lens.model_bridge.supported_architectures.mamba2 import ( Mamba2ArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.nemotron_h import ( + NemotronHArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.mingpt import ( MingptArchitectureAdapter, ) @@ -228,6 +231,7 @@ "MixtralArchitectureAdapter", "MPTArchitectureAdapter", "NanogptArchitectureAdapter", + "NemotronHArchitectureAdapter", "NativeArchitectureAdapter", "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/nemotron_h.py b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py new file mode 100644 index 000000000..2f322fa8d --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/nemotron_h.py @@ -0,0 +1,157 @@ +"""Nemotron-H hybrid Mamba2-Transformer architecture adapter. + +Supports NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B). + +Architecture overview: +- Heterogeneous layers defined by ``config.layers_block_type`` — each element is + one of ``"mamba"``, ``"attention"``, ``"moe"``, or ``"mlp"``. +- ~8% of layers are standard GQA attention; the rest are Mamba-2 SSM, dense MLP, + or sparse MoE. All share a single pre-norm (``block.norm``) and a single residual + path; there is no ``ln2`` or post-attention norm. +- Each block exposes a single ``.mixer`` attribute whose type varies by layer. +- No model-level rotary embedding module — attention handles RoPE internally via + ``position_ids`` passed from the outer model loop. +- Stateful generation: uses ``DynamicCache`` (transformers ≥ 5.12) which carries + both KV-cache entries (attention layers) and SSM conv/recurrent states + (Mamba layers) in a unified object. + +Key adapter decisions: +- ``SSMBlockBridge`` is used as the block container. It delegates the entire + forward to the HF block, giving ``hook_in`` / ``hook_out`` on the residual + stream without hardcoding transformer-specific hook positions (hook_resid_mid, + hook_mlp_in, etc.) that do not exist in this single-norm architecture. +- ``SSM2MixerBridge`` wraps ``.mixer`` for all layer types. Its forward is a + pure passthrough (``original_component(*args, **kwargs)``) so it works + correctly for attention, MLP, and MoE mixers as well as Mamba ones. + Mamba-specific inner submodules (in_proj, conv1d, inner_norm, out_proj) are + declared ``optional=True`` so setup skips them gracefully on non-Mamba layers. +- MLP layers use ``relu2`` activation (not SwiGLU); ``gated_mlp = False``. +- ``applicable_phases = []``: ``verify_models`` is transformer-shaped and would + require a dedicated refactor to cover SSM hybrids. Coverage lives in the + integration test instead. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + DepthwiseConv1DBridge, + EmbeddingBridge, + GatedRMSNormBridge, + LinearBridge, + RMSNormalizationBridge, + SSM2MixerBridge, + SSMBlockBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +def _make_optional(component: "GeneralizedComponent") -> "GeneralizedComponent": + """Mark a GeneralizedComponent submodule as optional. + + Some bridge classes (e.g. GatedRMSNormBridge) do not forward ``optional`` + through their own ``__init__``, even though ``GeneralizedComponent`` supports + it. Setting the attribute directly is safe because ``component_setup.py`` + reads ``getattr(submodule, 'optional', False)`` at setup time. + """ + component.optional = True + return component + + +class NemotronHArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for NemotronHForCausalLM. + + Hybrid Mamba-2 + Attention + MoE + dense MLP model. All layers share a + single pre-norm and a single residual connection; the mixer type per layer + is determined by ``config.layers_block_type[layer_idx]``. + """ + + # verify_models is transformer-shaped and requires a dedicated refactor to + # cover SSM hybrids. Integration tests cover forward-pass correctness instead. + applicable_phases: list[int] = [] + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.uses_rms_norm = True + # No model-level rotary embedding module — attention handles RoPE + # internally via position_ids; set to "none" so the bridge does not + # attempt to wire a rotary_emb component. + self.cfg.positional_embedding_type = "none" + # MLP layers use relu2 (up_proj → act → down_proj), not SwiGLU. + self.cfg.gated_mlp = False + self.cfg.attn_only = False + self.cfg.final_rms = True + # Mamba layers require per-step SSM state; generation is stateful. + self.cfg.is_stateful = True + + # Expose the heterogeneous layer-type list so tests and analysis tools + # can inspect which layers are which without loading a full HF model. + layers_block_type = getattr(cfg, "layers_block_type", []) + setattr(self.cfg, "layers_block_type", layers_block_type) + + # Mamba-2 dimensional config (mirrors Mamba2ArchitectureAdapter). + mamba_num_heads = getattr(cfg, "mamba_num_heads", 128) + mamba_head_dim = getattr(cfg, "mamba_head_dim", 64) + mamba_intermediate_size = mamba_num_heads * mamba_head_dim + n_groups = getattr(cfg, "n_groups", 8) + ssm_state_size = getattr(cfg, "ssm_state_size", 128) + conv_dim = mamba_intermediate_size + 2 * n_groups * ssm_state_size + setattr(self.cfg, "mamba_intermediate_size", mamba_intermediate_size) + setattr(self.cfg, "conv_dim", conv_dim) + + self.weight_processing_conversions = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embeddings"), + "blocks": SSMBlockBridge( + name="model.layers", + submodules={ + # Single pre-norm shared across all layer types. + "norm": RMSNormalizationBridge(name="norm", config=self.cfg), + # Single mixer slot — type varies per layer (mamba / attention + # / moe / mlp). SSM2MixerBridge.forward() is a pure + # passthrough so it works for all four types. Mamba-specific + # inner submodules are optional and skipped on other types. + "mixer": SSM2MixerBridge( + name="mixer", + config=self.cfg, + submodules={ + # ── Mamba-only (optional on attention / moe / mlp) ── + "in_proj": LinearBridge(name="in_proj", optional=True), + "conv1d": DepthwiseConv1DBridge(name="conv1d", optional=True), + # HF names this "norm" inside the mixer; TL calls it + # "inner_norm" to avoid collision with the block-level norm. + # GatedRMSNormBridge.__init__ does not accept optional=, so + # we set the attribute directly after construction. + "inner_norm": _make_optional(GatedRMSNormBridge(name="norm")), + "out_proj": LinearBridge(name="out_proj", optional=True), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm_f", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def create_stateful_cache( + self, + hf_model: Any, + batch_size: int, + device: Any, + dtype: Any, + ) -> Any: + """Build the unified DynamicCache for stateful generation. + + Transformers ≥ 5.12 ships a unified ``DynamicCache`` that carries both + KV-cache entries (attention layers) and SSM conv/recurrent states + (Mamba layers) in a single object, using ``has_previous_state()`` to + distinguish which state is available for a given layer index. + """ + from transformers.cache_utils import DynamicCache + + return DynamicCache() diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index eb14a2fb0..6f9354c24 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -81,6 +81,7 @@ "Lfm2MoeForCausalLM", "MambaForCausalLM", "Mamba2ForCausalLM", + "NemotronHForCausalLM", "MPTForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", @@ -144,6 +145,7 @@ "Lfm2MoeForCausalLM": ["LiquidAI"], "Mamba2ForCausalLM": ["state-spaces"], "MambaForCausalLM": ["state-spaces"], + "NemotronHForCausalLM": ["nvidia"], "MistralForCausalLM": ["mistralai"], "MixtralForCausalLM": ["mistralai"], "MPTForCausalLM": ["mosaicml"],