-
Notifications
You must be signed in to change notification settings - Fork 490
Description
TransformerLens exhibits device/dtype mismatches that cause failures in several testing contexts, including half-precision (float16/bfloat16) inference and multi-GPU scenarios. These issues manifest as type mismatches in attention operations and device synchronization failures when working with activation caches across devices.
Issue Description
1. Half-Precision Type Mismatches in Attention
When using half-precision dtypes (float16/bfloat16), the attention module encounters type mismatches between tensors that should have matching dtypes. The issue occurs because:
- Attention patterns are cast to
self.cfg.dtype(which remains float16/bfloat16) - Value tensors (
v) may be upcast to float32 during operations for numerical stability - This creates a dtype mismatch when computing weighted sums:
pattern @ v
The current code in abstract_attention.py:
pattern = pattern.to(self.cfg.dtype) # Cast to config dtype (float16/bfloat16)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern) # v may be float32 here!2. Multi-Device Cache Synchronization
In multi-GPU contexts, the TransformerBridge.to() method only moves the underlying HuggingFace model, but doesn't synchronize the configuration or ensure all bridge components are properly relocated. This causes:
cfg.devicebecoming stale and not reflecting actual tensor locations- Activation cache tensors remaining on wrong devices
- Device mismatch errors in multi-GPU workflows
3. Output Projection Device Mismatches
The output projection in attention applies weights that may be on different devices than the input activations, particularly in load balancing scenarios or when components are moved independently.
Test Failures
The following test failures were observed on the dev-3.x-folding branch:
FAILED tests/acceptance/test_multi_gpu.py::test_cache_device
FAILED tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
FAILED tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
FAILED tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
FAILED tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
FAILED tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
FAILED tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
FAILED tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]Impact
These synchronization issues affect:
- Mixed-precision workflows: Users working with float16/bfloat16 for memory efficiency
- Multi-GPU scenarios: Distribution of models across multiple devices for large model inference
- Dynamic device movement: Runtime model relocation between CPU/GPU or across GPUs
System Info
* CUDA:
- GPU:
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 2070 SUPER
- available: True
- version: 12.8
* Packages:
- circuit_tracer: 0.1.0
- datasets: 4.4.1
- finetuning_scheduler: 2.9.1
- interpretune: 0.1.0.dev249+g84f4d5a9a.d20251129
- lightning: 2.5.6
- neuronpedia: 1.2.0
- numpy: 2.3.5
- sae_lens: 6.22.3
- torch: 2.9.1+cu128
- torch_debug_mode: False
- torch_git_version: 5811a8d7da873dd699ff6687092c225caffcf1bb
- tqdm: 4.67.1
- transformer_lens: 0.0.0 # using transformer_lens from source, latest commit on dev-3.x-folding
- transformers: 4.57.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.12.8
- version: #13~22.04.1-Ubuntu SMP Wed Jan 24 23:39:40 UTC 2024
Checklist
- I have checked that there is no similar issue in the repo (required)
I'll be submitting a PR shortly to address these issues. 🚀