diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index d63931578..c74e42887 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -8,6 +8,7 @@ from transformer_lens.model_bridge.supported_architectures import ( BertArchitectureAdapter, BloomArchitectureAdapter, + DeepseekArchitectureAdapter, Gemma1ArchitectureAdapter, Gemma2ArchitectureAdapter, Gemma3ArchitectureAdapter, @@ -35,6 +36,7 @@ SUPPORTED_ARCHITECTURES = { "BertForMaskedLM": BertArchitectureAdapter, "BloomForCausalLM": BloomArchitectureAdapter, + "DeepseekV3ForCausalLM": DeepseekArchitectureAdapter, "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version "Gemma1ForCausalLM": Gemma1ArchitectureAdapter, "Gemma2ForCausalLM": Gemma2ArchitectureAdapter, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 5f5e279af..762de393c 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -9,6 +9,9 @@ from transformer_lens.model_bridge.supported_architectures.bloom import ( BloomArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.deepseek import ( + DeepseekArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gemma1 import ( Gemma1ArchitectureAdapter, ) @@ -77,6 +80,7 @@ __all__ = [ "BertArchitectureAdapter", "BloomArchitectureAdapter", + "DeepseekArchitectureAdapter", "Gemma1ArchitectureAdapter", "Gemma2ArchitectureAdapter", "Gemma3ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek.py b/transformer_lens/model_bridge/supported_architectures/deepseek.py new file mode 100644 index 000000000..fbb343aa7 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/deepseek.py @@ -0,0 +1,70 @@ +"""DeepSeek architecture adapter.""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.conversion_utils.conversion_steps import ( + WeightConversionSet, +) +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + LayerNormBridge, + MLPBridge, + MoEBridge, + UnembeddingBridge, +) + + +class DeepseekArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for DeepSeek models.""" + + def __init__(self, cfg: Any) -> None: + """Initialize the DeepSeek architecture adapter. + + Args: + cfg: The configuration object. + """ + super().__init__(cfg) + + self.conversion_rules = WeightConversionSet( + { + "embed.W_E": "model.embed_tokens.weight", + "blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight", + # Attention weights + "blocks.{i}.attn.W_Q": "model.layers.{i}.self_attn.q_proj.weight", + "blocks.{i}.attn.W_K": "model.layers.{i}.self_attn.k_proj.weight", + "blocks.{i}.attn.W_V": "model.layers.{i}.self_attn.v_proj.weight", + "blocks.{i}.attn.W_O": "model.layers.{i}.self_attn.o_proj.weight", + "blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight", + # MLP weights for dense layers + "blocks.{i}.mlp.W_gate": "model.layers.{i}.mlp.gate_proj.weight", + "blocks.{i}.mlp.W_in": "model.layers.{i}.mlp.up_proj.weight", + "blocks.{i}.mlp.W_out": "model.layers.{i}.mlp.down_proj.weight", + # MoE weights + "blocks.{i}.moe.gate.w": "model.layers.{i}.mlp.gate.weight", + "blocks.{i}.moe.experts.W_gate.{j}": "model.layers.{i}.mlp.experts.{j}.gate_proj.weight", + "blocks.{i}.moe.experts.W_in.{j}": "model.layers.{i}.mlp.experts.{j}.up_proj.weight", + "blocks.{i}.moe.experts.W_out.{j}": "model.layers.{i}.mlp.experts.{j}.down_proj.weight", + "ln_final.w": "model.norm.weight", + "unembed.W_U": "lm_head.weight", + } + ) + + self.component_mapping = { + "embed": ("model.embed_tokens", EmbeddingBridge), + "blocks": ( + "model.layers", + BlockBridge, + { + "ln1": ("input_layernorm", LayerNormBridge), + "ln2": ("post_attention_layernorm", LayerNormBridge), + "attn": ("self_attn", AttentionBridge), + "mlp": ("mlp", MLPBridge), + "moe": ("mlp", MoEBridge), + }, + ), + "ln_final": ("model.norm", LayerNormBridge), + "unembed": ("lm_head", UnembeddingBridge), + }