Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions transformer_lens/benchmarks/component_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,17 @@ def print_detailed_analysis(self) -> None:
class ComponentBenchmarker:
"""Benchmarking utility for testing TransformerBridge components against HuggingFace."""

def _is_delegated_block(self) -> bool:
"""Return True if the blocks component uses DelegatedAttentionBlockBridge."""
blocks = (
getattr(self.adapter, "component_mapping", {}).get("blocks")
if self.adapter is not None
else None
)
return blocks is not None and (
"hook_q_input" not in getattr(blocks, "hook_aliases", {"hook_q_input": True})
)

def __init__(
self,
bridge_model: nn.Module,
Expand Down Expand Up @@ -419,6 +430,23 @@ def _test_component_recursive(
):
return

# Skip attention and PLE submodules when using DelegatedAttentionBlockBridge.
# These architectures delegate all math to HF; the benchmark can't call the HF
# attention in isolation (missing position_embeddings, attention_mask, etc.) and
# PLE submodules receive per-layer inputs at a different dimension than hidden_states.
_is_delegated = self._is_delegated_block()
if _is_delegated and "attn" in component_path:
return
if _is_delegated and any(
name in component_path
for name in (
"per_layer_input_gate",
"per_layer_projection",
"post_per_layer_input_norm",
)
):
return

# Skip models whose MLP/attn forward signatures require extra context from the block:
# - BLOOM: MLP requires residual and alibi bias
# - T5: requires cache_position for relative position embeddings
Expand Down Expand Up @@ -526,6 +554,12 @@ def _test_component(
ComponentTestResult or None if the component cannot be tested
"""
try:
# Skip rotary_emb for DelegatedAttentionBlockBridge architectures.
# Gemma4's RotaryEmbeddingBridge wraps a rotary that returns a set-like
# structure which the benchmark comparison can't subscript.
if self._is_delegated_block() and component_path == "rotary_emb":
return None

# Get bridge component
# The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
bridge_component = cast(
Expand Down
10 changes: 5 additions & 5 deletions transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def cleanup_model(model, model_name_str: str):
# Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5)
auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code)
if verbose and auto_model_class != AutoModelForCausalLM:
print(f"Using {auto_model_class.__name__} for encoder-decoder model")
print(f"Using {auto_model_class.__name__}")
# Ensure pad_token_id exists (some models crash without it during init).
hf_config = AutoConfig.from_pretrained(
model_name, trust_remote_code=trust_remote_code, token=_hf_token()
Expand Down Expand Up @@ -1209,14 +1209,14 @@ def cleanup_model(model, model_name_str: str):
# PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed)
# ========================================================================
current_phase[0] = 2
if verbose:
print(f"\n{'='*80}")
print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
print(f"{'='*80}\n")

# OPTIMIZATION: Run generation benchmarks first (only bridge in memory)
# Then cleanup bridge before loading HT to reduce peak memory
if should_run_phase(2) and bridge_unprocessed:
if verbose:
print(f"\n{'='*80}")
print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
print(f"{'='*80}\n")
if verbose:
print("Running Phase 2 benchmarks...\n")

Expand Down
Loading