Skip to content

Conversation

@arrdel
Copy link

@arrdel arrdel commented Dec 6, 2025

What does this PR do?

Fixes #12719

This PR fixes a critical issue where using bitsandbytes quantization with device_map='balanced' (or other device_map strategies) on transformers models within diffusers pipelines results in a meta tensor error: NotImplementedError: Cannot copy out of meta tensor; no data!

Root Cause

When loading transformers models with both:

  • quantization_config (bitsandbytes 4-bit/8-bit)
  • device_map (especially 'balanced' for multi-GPU)

The combination of low_cpu_mem_usage=True (default) and device_map causes transformers to use meta tensors for memory-efficient loading. However, bitsandbytes quantization state objects cannot be materialized from meta device.

The error occurs because:

  1. With low_cpu_mem_usage=True and device_map, transformers uses meta tensors as placeholders
  2. During quantization, bitsandbytes creates quantization state (code, absmax tensors) on meta device
  3. When accelerate's AlignDevicesHook tries to move parameters to target devices via quant_state.to(device)
  4. The quantization state's tensors are still meta and cannot be copied/moved

Solution

Disable low_cpu_mem_usage when loading transformers models with bitsandbytes quantization (llm_int8, fp4, nf4) and device_map. This ensures tensors are materialized during loading rather than kept as meta placeholders, allowing quantization state to be properly moved to target devices.

Changes

  • Modified _load_sub_model in pipeline_loading_utils.py to detect bitsandbytes quantization + device_map combinations
  • Added logic to set low_cpu_mem_usage=False for these cases
  • Added informative logging when this workaround is applied
  • Added comprehensive documentation explaining the issue

Testing

This fix allows the exact code from issue #12719 to work correctly:

import diffusers, torch
qwen = diffusers.QwenImagePipeline.from_pretrained(
    'Qwen/Qwen-Image',
    quantization_config=diffusers.PipelineQuantizationConfig(
        quant_backend='bitsandbytes_4bit',
        quant_kwargs={'load_in_4bit':True, 'bnb_4bit_quant_type':'nf4', 'bnb_4bit_compute_dtype':torch.float16},
        components_to_quantize=['transformer', 'text_encoder']
    ),
    torch_dtype=torch.float16,
    device_map='balanced'
)

Impact

  • ✅ Enables multi-GPU quantized inference with device_map strategies
  • ✅ Maintains backward compatibility (only affects bitsandbytes + device_map case)
  • ✅ No performance regression for other quantization methods
  • ⚠️ Slightly higher memory usage during loading for affected cases (necessary tradeoff)

cc @yiyixuxu @DN6

…e_map

Fixes huggingface#12719

When loading transformers models with both bitsandbytes quantization
(via quantization_config) and device_map (especially 'balanced' for
multi-GPU), the combination of low_cpu_mem_usage=True and device_map
causes meta tensors to be used for memory-efficient loading. However,
bitsandbytes quantization state objects (containing code and absmax
tensors) cannot be materialized from meta device, resulting in:
'NotImplementedError: Cannot copy out of meta tensor; no data!'

This occurs because:
1. With low_cpu_mem_usage=True and device_map, transformers uses meta
   tensors as placeholders for lazy weight loading
2. During quantization, bitsandbytes creates quantization state with
   meta tensors
3. When accelerate's AlignDevicesHook tries to move parameters to
   target devices, it calls quant_state.to(device)
4. The quantization state's code/absmax tensors are still meta and
   cannot be copied/moved

The fix: Disable low_cpu_mem_usage when loading transformers models
with bitsandbytes quantization (llm_int8, fp4, nf4) and device_map.
This ensures tensors are materialized during loading, not kept as
meta placeholders, allowing quantization state to be properly moved
to target devices.

This allows users to successfully use quantization with device_map
strategies like 'balanced' or 'auto' for multi-GPU inference without
encountering meta tensor errors.
Copilot AI review requested due to automatic review settings December 6, 2025 02:39
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR attempts to fix a critical issue where using bitsandbytes quantization with device_map on transformers models results in a meta tensor error (NotImplementedError: Cannot copy out of meta tensor; no data!). The solution disables low_cpu_mem_usage when loading transformers models with bitsandbytes quantization and device_map to ensure tensors are materialized during loading.

Key Changes:

  • Added logic to detect bitsandbytes quantization with device_map combinations
  • Sets low_cpu_mem_usage=False for these cases to avoid meta tensor issues
  • Added informative logging and documentation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

and hasattr(model_quant_config, "quant_method")
):
quant_method = getattr(model_quant_config.quant_method, "value", model_quant_config.quant_method)
if quant_method in ["llm_int8", "fp4", "nf4"]: # bitsandbytes quantization methods
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for detecting bitsandbytes quantization is incorrect. The quant_method attribute is set to QuantizationMethod.BITS_AND_BYTES (which has the value "bitsandbytes"), not to the specific quantization method strings ["llm_int8", "fp4", "nf4"].

This condition will always be False, meaning the workaround will never be applied.

The fix should check if quant_method equals "bitsandbytes" (or QuantizationMethod.BITS_AND_BYTES):

quant_method = getattr(model_quant_config.quant_method, "value", model_quant_config.quant_method)
if quant_method == "bitsandbytes":  # or quant_method == QuantizationMethod.BITS_AND_BYTES

Alternatively, if you want to check the specific quantization type, you should call the quantization_method() method instead:

if hasattr(model_quant_config, "quantization_method"):
    quant_method = model_quant_config.quantization_method()
    if quant_method in ["llm_int8", "fp4", "nf4"]:

Reference: The quant_method attribute is defined in BitsAndBytesConfig.__init__ at line 248 of quantization_config.py as self.quant_method = QuantizationMethod.BITS_AND_BYTES. The specific method names are returned by the quantization_method() method (lines 365-377).

Suggested change
if quant_method in ["llm_int8", "fp4", "nf4"]: # bitsandbytes quantization methods
if quant_method == "bitsandbytes": # bitsandbytes quantization

Copilot uses AI. Check for mistakes.
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config

Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace detected. Please remove the trailing spaces on this line.

Suggested change

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

how to use quantization and device_map=balance to run qwen-image on kaggle T4 * 2

1 participant