From 8ee51675988515020d9ceb1b82f7b71ac299766c Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 11 Jun 2025 22:33:14 +0200 Subject: [PATCH 1/6] removed wrap function --- .../generalized_components/attention.py | 32 ++++--------------- .../generalized_components/block.py | 18 ----------- .../generalized_components/embedding.py | 18 ----------- .../generalized_components/layer_norm.py | 18 ----------- .../generalized_components/mlp.py | 18 ----------- .../generalized_components/moe.py | 21 ++---------- .../generalized_components/unembedding.py | 18 ----------- 7 files changed, 9 insertions(+), 134 deletions(-) diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index ab9703f25..3d2ffd4b4 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -38,8 +38,8 @@ def __init__( # Add all the hooks from the old attention components self.hook_k = HookPoint() # [batch, pos, head_index, d_head] self.hook_q = HookPoint() # [batch, pos, head_index, d_head] - self.hook_v = HookPoint() # [batch, pos, head_index, d_head] - self.hook_z = HookPoint() # [batch, pos, head_index, d_head] + self.hook_v = HookPoint() # Value vectors + self.hook_z = HookPoint() # Attention output self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_result = HookPoint() # [batch, pos, head_index, d_model] @@ -50,31 +50,13 @@ def __init__( self.hook_rot_q = HookPoint() # [batch, pos, head_index, d_head] (for rotary) def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass through the attention layer. - - This method forwards all arguments to the original component and applies hooks - to the output. The arguments should match the original component's forward method. + """Forward pass through the attention bridge. Args: - *args: Input arguments to pass to the original component - **kwargs: Input keyword arguments to pass to the original component + *args: Positional arguments for the original component + **kwargs: Keyword arguments for the original component Returns: - The output from the original component, with hooks applied + Output from the original component """ - # Handle hook_attn_input for shortformer positional embeddings - if "query_input" in kwargs: - # Combine normalized residual stream with positional embeddings - attn_input = kwargs["query_input"] - # Pass through hook_attn_input - attn_input = self.hook_attn_input(attn_input) - # Update query_input with the hooked value - kwargs["query_input"] = attn_input - - # Forward through the original component - output = self.original_component(*args, **kwargs) - - # Execute hooks on the output (for add_hook compatibility) - output = self.execute_hooks("output", output) - - return output + return self.original_component(*args, **kwargs) diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 82737f67c..8da57d30b 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -34,24 +34,6 @@ def __init__( """ super().__init__(original_component, name, architecture_adapter) - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's a transformer block. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's a transformer block, otherwise the original component - """ - if name.endswith(".block") or name.endswith(".layer"): - return cls(component, name, architecture_adapter) - return component - def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass through the block bridge. diff --git a/transformer_lens/model_bridge/generalized_components/embedding.py b/transformer_lens/model_bridge/generalized_components/embedding.py index 394da13d9..95eb3e721 100644 --- a/transformer_lens/model_bridge/generalized_components/embedding.py +++ b/transformer_lens/model_bridge/generalized_components/embedding.py @@ -79,21 +79,3 @@ def forward( self.hook_outputs.update({"output": output}) return output - - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's an embedding layer. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's an embedding layer, otherwise the original component - """ - if name.endswith(".embed") or name.endswith(".embed_tokens"): - return cls(component, name, architecture_adapter) - return component diff --git a/transformer_lens/model_bridge/generalized_components/layer_norm.py b/transformer_lens/model_bridge/generalized_components/layer_norm.py index 34f32d912..8baaa5e84 100644 --- a/transformer_lens/model_bridge/generalized_components/layer_norm.py +++ b/transformer_lens/model_bridge/generalized_components/layer_norm.py @@ -64,21 +64,3 @@ def forward( self.hook_outputs.update({"output": output}) return output - - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's a LayerNorm layer. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's a LayerNorm layer, otherwise the original component - """ - if name.endswith(".ln") or name.endswith(".ln1") or name.endswith(".ln2"): - return cls(component, name, architecture_adapter) - return component diff --git a/transformer_lens/model_bridge/generalized_components/mlp.py b/transformer_lens/model_bridge/generalized_components/mlp.py index 862296d8f..665a37f92 100644 --- a/transformer_lens/model_bridge/generalized_components/mlp.py +++ b/transformer_lens/model_bridge/generalized_components/mlp.py @@ -66,21 +66,3 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.hook_outputs.update({"output": output}) return output - - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's an MLP layer. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's an MLP layer, otherwise the original component - """ - if name.endswith(".mlp"): - return cls(component, name, architecture_adapter) - return component diff --git a/transformer_lens/model_bridge/generalized_components/moe.py b/transformer_lens/model_bridge/generalized_components/moe.py index 7126de07f..d11cc295b 100644 --- a/transformer_lens/model_bridge/generalized_components/moe.py +++ b/transformer_lens/model_bridge/generalized_components/moe.py @@ -7,6 +7,7 @@ import torch.nn as nn +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components.base import ( GeneralizedComponent, ) @@ -23,7 +24,7 @@ def __init__( self, original_component: nn.Module, name: str, - architecture_adapter: Any | None = None, + architecture_adapter: ArchitectureAdapter, ): """Initialize the MoE bridge. @@ -34,24 +35,6 @@ def __init__( """ super().__init__(original_component, name, architecture_adapter) - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's a MoE layer. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's a MoE layer, otherwise the original component - """ - if name.endswith(".moe"): - return cls(component, name, architecture_adapter) - return component - def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass through the MoE bridge. diff --git a/transformer_lens/model_bridge/generalized_components/unembedding.py b/transformer_lens/model_bridge/generalized_components/unembedding.py index 09529eddf..9d0e08189 100644 --- a/transformer_lens/model_bridge/generalized_components/unembedding.py +++ b/transformer_lens/model_bridge/generalized_components/unembedding.py @@ -71,21 +71,3 @@ def forward( self.hook_outputs.update({"output": output}) return output - - @classmethod - def wrap_component( - cls, component: nn.Module, name: str, architecture_adapter: Any | None = None - ) -> nn.Module: - """Wrap a component with this bridge if it's an unembedding layer. - - Args: - component: The component to wrap - name: The name of the component - architecture_adapter: The architecture adapter instance - - Returns: - The wrapped component if it's an unembedding layer, otherwise the original component - """ - if name.endswith(".unembed") or name.endswith(".lm_head"): - return cls(component, name, architecture_adapter) - return component From 312a8474d1c6fab9b994f09c4176c4d8bda2a780 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 11 Jun 2025 22:33:28 +0200 Subject: [PATCH 2/6] added deep seek architecture --- .../supported_architectures/__init__.py | 4 ++ .../supported_architectures/deepseek.py | 68 +++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 transformer_lens/model_bridge/supported_architectures/deepseek.py diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index a33baa765..67090cc15 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -9,6 +9,9 @@ from transformer_lens.model_bridge.supported_architectures.bloom import ( BloomArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.deepseek import ( + DeepseekArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gemma1 import ( Gemma1ArchitectureAdapter, ) @@ -76,6 +79,7 @@ __all__ = [ "BertArchitectureAdapter", "BloomArchitectureAdapter", + "DeepseekArchitectureAdapter", "Gemma1ArchitectureAdapter", "Gemma2ArchitectureAdapter", "Gemma3ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek.py b/transformer_lens/model_bridge/supported_architectures/deepseek.py new file mode 100644 index 000000000..cdc3c2e10 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/deepseek.py @@ -0,0 +1,68 @@ +"""DeepSeek architecture adapter.""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.conversion_utils.conversion_steps import ( + WeightConversionSet, +) +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + EmbeddingBridge, + LayerNormBridge, + MLPBridge, + MoEBridge, + UnembeddingBridge, +) + + +class DeepseekArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for DeepSeek models.""" + + def __init__(self, cfg: Any) -> None: + """Initialize the DeepSeek architecture adapter. + + Args: + cfg: The configuration object. + """ + super().__init__(cfg) + + self.conversion_rules = WeightConversionSet( + { + "embed.W_E": "model.embed_tokens.weight", + "blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight", + # Attention weights + "blocks.{i}.attn.W_Q": "model.layers.{i}.self_attn.q_proj.weight", + "blocks.{i}.attn.W_K": "model.layers.{i}.self_attn.k_proj.weight", + "blocks.{i}.attn.W_V": "model.layers.{i}.self_attn.v_proj.weight", + "blocks.{i}.attn.W_O": "model.layers.{i}.self_attn.o_proj.weight", + "blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight", + # MLP weights for dense layers + "blocks.{i}.mlp.W_gate": "model.layers.{i}.mlp.gate_proj.weight", + "blocks.{i}.mlp.W_in": "model.layers.{i}.mlp.up_proj.weight", + "blocks.{i}.mlp.W_out": "model.layers.{i}.mlp.down_proj.weight", + # MoE weights + "blocks.{i}.moe.gate.w": "model.layers.{i}.mlp.gate.weight", + "blocks.{i}.moe.experts.W_gate.{j}": "model.layers.{i}.mlp.experts.{j}.gate_proj.weight", + "blocks.{i}.moe.experts.W_in.{j}": "model.layers.{i}.mlp.experts.{j}.up_proj.weight", + "blocks.{i}.moe.experts.W_out.{j}": "model.layers.{i}.mlp.experts.{j}.down_proj.weight", + "ln_final.w": "model.norm.weight", + "unembed.W_U": "lm_head.weight", + } + ) + + self.component_mapping = { + "embed": ("model.embed_tokens", EmbeddingBridge), + "blocks": ( + "model.layers", + { + "ln1": ("input_layernorm", LayerNormBridge), + "ln2": ("post_attention_layernorm", LayerNormBridge), + "attn": ("self_attn", AttentionBridge), + "mlp": ("mlp", MLPBridge), + "moe": ("mlp", MoEBridge), + }, + ), + "ln_final": ("model.norm", LayerNormBridge), + "unembed": ("lm_head", UnembeddingBridge), + } \ No newline at end of file From 8f50a22efc729623e42673320ff8944ec77c0e52 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 11 Jun 2025 22:41:11 +0200 Subject: [PATCH 3/6] registered deep seek --- tests/mocks/architecture_adapter.py | 81 +++++++++++++++++++ .../factories/architecture_adapter_factory.py | 2 + 2 files changed, 83 insertions(+) create mode 100644 tests/mocks/architecture_adapter.py diff --git a/tests/mocks/architecture_adapter.py b/tests/mocks/architecture_adapter.py new file mode 100644 index 000000000..a5a73684a --- /dev/null +++ b/tests/mocks/architecture_adapter.py @@ -0,0 +1,81 @@ +"""Mock architecture adapter for testing.""" +import pytest +import torch.nn as nn + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + LayerNormBridge, + MLPBridge, +) + + +class MockArchitectureAdapter(ArchitectureAdapter): + """Mock architecture adapter for testing.""" + + def __init__(self, cfg=None): + super().__init__(cfg) + self.component_mapping = { + "embed": ("embed", EmbeddingBridge), + "unembed": ("unembed", EmbeddingBridge), + "ln_final": ("ln_final", LayerNormBridge), + "blocks": ( + "blocks", + BlockBridge, + { + "ln1": ("ln1", LayerNormBridge), + "ln2": ("ln2", LayerNormBridge), + "attn": ("attn", AttentionBridge), + "mlp": ("mlp", MLPBridge), + }, + ), + "outer_blocks": ( + "outer_blocks", + BlockBridge, + { + "inner_blocks": ( + "inner_blocks", + BlockBridge, + {"ln": ("ln", LayerNormBridge)}, + ) + }, + ), + } + + +@pytest.fixture +def mock_adapter() -> MockArchitectureAdapter: + """Create a mock adapter.""" + return MockArchitectureAdapter() + + +@pytest.fixture +def mock_model_adapter() -> nn.Module: + """Create a mock model for testing.""" + model = nn.Module() + + # For embed/unembed + model.embed = nn.Embedding(100, 10) + model.unembed = nn.Linear(10, 100) + + model.ln_final = nn.LayerNorm(10) + model.blocks = nn.ModuleList() + block = nn.Module() + block.ln1 = nn.LayerNorm(10) + block.ln2 = nn.LayerNorm(10) + block.attn = nn.Module() + block.mlp = nn.Module() + model.blocks.append(block) + + # For nested blocks + model.outer_blocks = nn.ModuleList() + outer_block = nn.Module() + outer_block.inner_blocks = nn.ModuleList() + inner_block = nn.Module() + inner_block.ln = nn.LayerNorm(10) + outer_block.inner_blocks.append(inner_block) + model.outer_blocks.append(outer_block) + + return model diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 6eff34236..13e276dc4 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -9,6 +9,7 @@ from transformer_lens.model_bridge.supported_architectures import ( BertArchitectureAdapter, BloomArchitectureAdapter, + DeepseekArchitectureAdapter, Gemma1ArchitectureAdapter, Gemma2ArchitectureAdapter, Gemma3ArchitectureAdapter, @@ -35,6 +36,7 @@ SUPPORTED_ARCHITECTURES = { "BertForMaskedLM": BertArchitectureAdapter, "BloomForCausalLM": BloomArchitectureAdapter, + "DeepseekV3ForCausalLM": DeepseekArchitectureAdapter, "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version "Gemma1ForCausalLM": Gemma1ArchitectureAdapter, "Gemma2ForCausalLM": Gemma2ArchitectureAdapter, From 0c6e9be1e68deaa97aa908caf61bb6716088c376 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 16 Jun 2025 22:12:44 +0200 Subject: [PATCH 4/6] ran format --- .../model_bridge/supported_architectures/deepseek.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek.py b/transformer_lens/model_bridge/supported_architectures/deepseek.py index cdc3c2e10..10ced9552 100644 --- a/transformer_lens/model_bridge/supported_architectures/deepseek.py +++ b/transformer_lens/model_bridge/supported_architectures/deepseek.py @@ -65,4 +65,4 @@ def __init__(self, cfg: Any) -> None: ), "ln_final": ("model.norm", LayerNormBridge), "unembed": ("lm_head", UnembeddingBridge), - } \ No newline at end of file + } From 549860c5c6fb3f95b7f98bbf44d33afd2b21071c Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 16 Jun 2025 22:51:17 +0200 Subject: [PATCH 5/6] fixed typing --- .../model_bridge/supported_architectures/deepseek.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek.py b/transformer_lens/model_bridge/supported_architectures/deepseek.py index 10ced9552..fbb343aa7 100644 --- a/transformer_lens/model_bridge/supported_architectures/deepseek.py +++ b/transformer_lens/model_bridge/supported_architectures/deepseek.py @@ -8,6 +8,7 @@ ) from transformer_lens.model_bridge.generalized_components import ( AttentionBridge, + BlockBridge, EmbeddingBridge, LayerNormBridge, MLPBridge, @@ -55,6 +56,7 @@ def __init__(self, cfg: Any) -> None: "embed": ("model.embed_tokens", EmbeddingBridge), "blocks": ( "model.layers", + BlockBridge, { "ln1": ("input_layernorm", LayerNormBridge), "ln2": ("post_attention_layernorm", LayerNormBridge), From 5095d3027a59f438a7df52e5d8a62c7a780a8ede Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 17 Jun 2025 19:03:08 +0200 Subject: [PATCH 6/6] updated loading --- transformer_lens/boot.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_lens/boot.py b/transformer_lens/boot.py index b569adc72..cb6ddc95b 100644 --- a/transformer_lens/boot.py +++ b/transformer_lens/boot.py @@ -13,27 +13,27 @@ def boot( model_name: str, - config: dict | None = None, + model_config: dict | None = None, + tokenizer_config: dict | None = None, device: str | torch.device | None = None, dtype: torch.dtype = torch.float32, - **kwargs, ) -> TransformerBridge: """Boot a model from HuggingFace. Args: model_name: The name of the model to load. - config: The config dict to use. If None, will be loaded from HuggingFace. + model_config: Additional configuration parameters to override the default config. + tokenizer_config: The config dict to use for tokenizer loading. If None, will use default settings. device: The device to use. If None, will be determined automatically. dtype: The dtype to use for the model. - **kwargs: Additional keyword arguments for from_pretrained. Returns: The bridge to the loaded model. """ - hf_config = AutoConfig.from_pretrained(model_name, **kwargs) + hf_config = AutoConfig.from_pretrained(model_name, **(model_config or {})) adapter = ArchitectureAdapterFactory.select_architecture_adapter(hf_config) default_config = adapter.default_cfg - merged_config = {**default_config, **(config or {})} + merged_config = {**default_config, **(model_config or {})} # Load the model from HuggingFace using the original config hf_model = AutoModelForCausalLM.from_pretrained( @@ -44,7 +44,7 @@ def boot( ) # Load the tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_name, **(tokenizer_config or {})) return TransformerBridge( hf_model,