Skip to content

[BUG][CI] SwitchTransformers and TimmWrapperModel dtype mismatches in bfloat16 inference #45072

@harshaljanjani

Description

@harshaljanjani

System Info

  • transformers version: 5.0.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • huggingface_hub version: 1.3.2
  • safetensors version: 0.7.0
  • accelerate version: 1.12.0
  • Accelerate config: not installed
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
  • GPU type: NVIDIA L4
  • NVIDIA driver version: 550.90.07
  • CUDA version: 12.4

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Switch Transformers:

import torch
from transformers import SwitchTransformersModel

try:
    model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to("cuda").eval()
    input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda")
    output = model(input_ids, decoder_input_ids=input_ids)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)

TimmWrapper:

import torch
from transformers import TimmWrapperModel, TimmWrapperConfig

try:
    config = TimmWrapperConfig(architecture="resnet18")
    model = TimmWrapperModel(config).to("cuda", torch.bfloat16).eval()
    pixel_values = torch.randn(1, 3, 224, 224, device="cuda")
    output = model(pixel_values)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)

→ Loading "google/switch-base-8" in bfloat16 and running a forward pass crashes with a dtype mismatch in the MoE router's linear layer; got: float != c10::BFloat16.
→ Instantiating a TimmWrapperModel in bfloat16 on CUDA and passing float32 pixel_values crashes; the first conv layer raises Input type (torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same.

Current Repro Output:

Image Image

Expected behavior

→ Both models should complete bfloat16 inference successfully.

Note to the Reviewers

I see a few unsolicited attempts to fix the issue, even though a PR had already been linked to it previously. Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions