System Info
transformers version: 5.0.0.dev0
- Platform:
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
- Python version:
3.12.3
huggingface_hub version: 1.3.2
safetensors version: 0.7.0
accelerate version: 1.12.0
- Accelerate config:
not installed
- DeepSpeed version:
not installed
- PyTorch version (accelerator?):
2.9.1+cu128 (CUDA)
- GPU type:
NVIDIA L4
- NVIDIA driver version:
550.90.07
- CUDA version:
12.4
Information
Tasks
Reproduction
Switch Transformers:
import torch
from transformers import SwitchTransformersModel
try:
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to("cuda").eval()
input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda")
output = model(input_ids, decoder_input_ids=input_ids)
print(output.last_hidden_state.shape)
except Exception as e:
print(e)
TimmWrapper:
import torch
from transformers import TimmWrapperModel, TimmWrapperConfig
try:
config = TimmWrapperConfig(architecture="resnet18")
model = TimmWrapperModel(config).to("cuda", torch.bfloat16).eval()
pixel_values = torch.randn(1, 3, 224, 224, device="cuda")
output = model(pixel_values)
print(output.last_hidden_state.shape)
except Exception as e:
print(e)
→ Loading "google/switch-base-8" in bfloat16 and running a forward pass crashes with a dtype mismatch in the MoE router's linear layer; got: float != c10::BFloat16.
→ Instantiating a TimmWrapperModel in bfloat16 on CUDA and passing float32 pixel_values crashes; the first conv layer raises Input type (torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same.
Current Repro Output:
Expected behavior
→ Both models should complete bfloat16 inference successfully.
Note to the Reviewers
I see a few unsolicited attempts to fix the issue, even though a PR had already been linked to it previously. Thank you!
System Info
transformersversion:5.0.0.dev0Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.393.12.3huggingface_hubversion:1.3.2safetensorsversion:0.7.0accelerateversion:1.12.0not installednot installed2.9.1+cu128 (CUDA)NVIDIA L4550.90.0712.4Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Switch Transformers:
TimmWrapper:
→ Loading
"google/switch-base-8"in bfloat16 and running a forward pass crashes with a dtype mismatch in the MoE router's linear layer; got:float != c10::BFloat16.→ Instantiating a TimmWrapperModel in bfloat16 on CUDA and passing float32 pixel_values crashes; the first conv layer raises Input type (
torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same.Current Repro Output:
Expected behavior
→ Both models should complete bfloat16 inference successfully.
Note to the Reviewers
I see a few unsolicited attempts to fix the issue, even though a PR had already been linked to it previously. Thank you!