-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Description
System Info
- `transformers` version: 4.57.1
- Platform: Linux-5.15.0-161-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cpu (NA)
- Tensorflow version (GPU?): 2.18.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", dtype="auto", low_cpu_mem_usage=True)
# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt")
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
Expected behavior
📝 Transformers GitHub Issue: Translation
Here is the translated text for your GitHub issue, including the title and body.
Title
Cannot inference llava-next with transformers==4.57.1 on dtype="auto" bug
Body
I am encountering an issue when attempting to run inference on LLaVA-Next models (e.g., llava-hf/llava-v1.6-mistral-7b-hf) using transformers==4.57.1 and setting dtype="auto" when loading the model.
The issue stems from the model's config.json having different torch_dtype values for the overall model and the text configuration:
"text_config": {
"_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
// ... other config values
"torch_dtype": "bfloat16",
"vocab_size": 32064
},
"torch_dtype": "float16",
When the model is loaded with dtype="auto", each submodule (the visual model and the text model) seems to load with its respective torch_dtype ("float16" and "bfloat16").
This difference in data types then causes an error during inference, specifically within the forward pass of the LlavaNextForConditionalGeneration model:
File "MY_ENV/.venv/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py", line 687, in forward
logits = self.lm_head(hidden_states[:, slice_indices, :])
File "MY_ENV/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "MY_ENV/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "MY_ENV/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected m1 and m2 to have the same dtype, but got: c10::BFloat16 != c10::Half
This RuntimeError indicates a dtype mismatch, likely between the linear layer's weight (from self.lm_head) and the input tensor (hidden_states), which results from the different dtypes loaded by dtype="auto" for self.lm_head and self.model.
Is there a plan to support loading LLaVA-Next models with dtype="auto" given their current configuration structure?