diff --git a/transformer_lens/benchmarks/component_outputs.py b/transformer_lens/benchmarks/component_outputs.py index 1efe59f82..adbb76924 100644 --- a/transformer_lens/benchmarks/component_outputs.py +++ b/transformer_lens/benchmarks/component_outputs.py @@ -199,6 +199,17 @@ def print_detailed_analysis(self) -> None: class ComponentBenchmarker: """Benchmarking utility for testing TransformerBridge components against HuggingFace.""" + def _is_delegated_block(self) -> bool: + """Return True if the blocks component uses DelegatedAttentionBlockBridge.""" + blocks = ( + getattr(self.adapter, "component_mapping", {}).get("blocks") + if self.adapter is not None + else None + ) + return blocks is not None and ( + "hook_q_input" not in getattr(blocks, "hook_aliases", {"hook_q_input": True}) + ) + def __init__( self, bridge_model: nn.Module, @@ -419,6 +430,23 @@ def _test_component_recursive( ): return + # Skip attention and PLE submodules when using DelegatedAttentionBlockBridge. + # These architectures delegate all math to HF; the benchmark can't call the HF + # attention in isolation (missing position_embeddings, attention_mask, etc.) and + # PLE submodules receive per-layer inputs at a different dimension than hidden_states. + _is_delegated = self._is_delegated_block() + if _is_delegated and "attn" in component_path: + return + if _is_delegated and any( + name in component_path + for name in ( + "per_layer_input_gate", + "per_layer_projection", + "post_per_layer_input_norm", + ) + ): + return + # Skip models whose MLP/attn forward signatures require extra context from the block: # - BLOOM: MLP requires residual and alibi bias # - T5: requires cache_position for relative position embeddings @@ -526,6 +554,12 @@ def _test_component( ComponentTestResult or None if the component cannot be tested """ try: + # Skip rotary_emb for DelegatedAttentionBlockBridge architectures. + # Gemma4's RotaryEmbeddingBridge wraps a rotary that returns a set-like + # structure which the benchmark comparison can't subscript. + if self._is_delegated_block() and component_path == "rotary_emb": + return None + # Get bridge component # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent bridge_component = cast( diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 05735095f..d2ccbfa70 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -935,7 +935,7 @@ def cleanup_model(model, model_name_str: str): # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5) auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code) if verbose and auto_model_class != AutoModelForCausalLM: - print(f"Using {auto_model_class.__name__} for encoder-decoder model") + print(f"Using {auto_model_class.__name__}") # Ensure pad_token_id exists (some models crash without it during init). hf_config = AutoConfig.from_pretrained( model_name, trust_remote_code=trust_remote_code, token=_hf_token() @@ -1209,14 +1209,14 @@ def cleanup_model(model, model_name_str: str): # PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed) # ======================================================================== current_phase[0] = 2 - if verbose: - print(f"\n{'='*80}") - print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)") - print(f"{'='*80}\n") # OPTIMIZATION: Run generation benchmarks first (only bridge in memory) # Then cleanup bridge before loading HT to reduce peak memory if should_run_phase(2) and bridge_unprocessed: + if verbose: + print(f"\n{'='*80}") + print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)") + print(f"{'='*80}\n") if verbose: print("Running Phase 2 benchmarks...\n")