Skip to content

[Bug Report] Device and dtype synchronization issues causing test failures in mixed precision and multi-device contexts #1140

@speediedan

Description

@speediedan

TransformerLens exhibits device/dtype mismatches that cause failures in several testing contexts, including half-precision (float16/bfloat16) inference and multi-GPU scenarios. These issues manifest as type mismatches in attention operations and device synchronization failures when working with activation caches across devices.

Issue Description

1. Half-Precision Type Mismatches in Attention

When using half-precision dtypes (float16/bfloat16), the attention module encounters type mismatches between tensors that should have matching dtypes. The issue occurs because:

  • Attention patterns are cast to self.cfg.dtype (which remains float16/bfloat16)
  • Value tensors (v) may be upcast to float32 during operations for numerical stability
  • This creates a dtype mismatch when computing weighted sums: pattern @ v

The current code in abstract_attention.py:

pattern = pattern.to(self.cfg.dtype)  # Cast to config dtype (float16/bfloat16)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern)  # v may be float32 here!

2. Multi-Device Cache Synchronization

In multi-GPU contexts, the TransformerBridge.to() method only moves the underlying HuggingFace model, but doesn't synchronize the configuration or ensure all bridge components are properly relocated. This causes:

  • cfg.device becoming stale and not reflecting actual tensor locations
  • Activation cache tensors remaining on wrong devices
  • Device mismatch errors in multi-GPU workflows

3. Output Projection Device Mismatches

The output projection in attention applies weights that may be on different devices than the input activations, particularly in load balancing scenarios or when components are moved independently.

Test Failures

The following test failures were observed on the dev-3.x-folding branch:

FAILED tests/acceptance/test_multi_gpu.py::test_cache_device
FAILED tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
FAILED tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
FAILED tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
FAILED tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
FAILED tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
FAILED tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
FAILED tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]

Impact

These synchronization issues affect:

  • Mixed-precision workflows: Users working with float16/bfloat16 for memory efficiency
  • Multi-GPU scenarios: Distribution of models across multiple devices for large model inference
  • Dynamic device movement: Runtime model relocation between CPU/GPU or across GPUs

System Info

* CUDA:
	- GPU:
		- NVIDIA GeForce RTX 4090
		- NVIDIA GeForce RTX 2070 SUPER
	- available:         True
	- version:           12.8
* Packages:
	- circuit_tracer:    0.1.0
	- datasets:          4.4.1
	- finetuning_scheduler: 2.9.1
	- interpretune:      0.1.0.dev249+g84f4d5a9a.d20251129
	- lightning:         2.5.6
	- neuronpedia:       1.2.0
	- numpy:             2.3.5
	- sae_lens:          6.22.3
	- torch:             2.9.1+cu128
	- torch_debug_mode:  False
	- torch_git_version: 5811a8d7da873dd699ff6687092c225caffcf1bb
	- tqdm:              4.67.1
	- transformer_lens:  0.0.0  # using transformer_lens from source, latest commit on dev-3.x-folding
	- transformers:      4.57.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.12.8
	- version:           #13~22.04.1-Ubuntu SMP Wed Jan 24 23:39:40 UTC 2024

Checklist

  • I have checked that there is no similar issue in the repo (required)

I'll be submitting a PR shortly to address these issues. 🚀

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions