From 1c9a414ae3517212da8267c5348e1d599b3ee0d4 Mon Sep 17 00:00:00 2001 From: arrdel Date: Fri, 5 Dec 2025 21:29:15 -0500 Subject: [PATCH] Fix meta tensor error when using bitsandbytes quantization with device_map Fixes #12719 When loading transformers models with both bitsandbytes quantization (via quantization_config) and device_map (especially 'balanced' for multi-GPU), the combination of low_cpu_mem_usage=True and device_map causes meta tensors to be used for memory-efficient loading. However, bitsandbytes quantization state objects (containing code and absmax tensors) cannot be materialized from meta device, resulting in: 'NotImplementedError: Cannot copy out of meta tensor; no data!' This occurs because: 1. With low_cpu_mem_usage=True and device_map, transformers uses meta tensors as placeholders for lazy weight loading 2. During quantization, bitsandbytes creates quantization state with meta tensors 3. When accelerate's AlignDevicesHook tries to move parameters to target devices, it calls quant_state.to(device) 4. The quantization state's code/absmax tensors are still meta and cannot be copied/moved The fix: Disable low_cpu_mem_usage when loading transformers models with bitsandbytes quantization (llm_int8, fp4, nf4) and device_map. This ensures tensors are materialized during loading, not kept as meta placeholders, allowing quantization state to be properly moved to target devices. This allows users to successfully use quantization with device_map strategies like 'balanced' or 'auto' for multi-GPU inference without encountering meta tensor errors. --- .../pipelines/pipeline_loading_utils.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..13d69a34eef6 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -867,6 +867,26 @@ def load_sub_model( ) if model_quant_config is not None: loading_kwargs["quantization_config"] = model_quant_config + + # When using bitsandbytes quantization with device_map on transformers models, + # we must disable low_cpu_mem_usage to avoid meta tensors. Meta tensors cannot + # be materialized properly when bitsandbytes tries to move quantization state + # (which includes tensors like code and absmax) to the target device. + # This issue occurs because quantization state is created during model loading + # and needs actual tensors, not meta placeholders. + # See: https://github.com/huggingface/diffusers/issues/12719 + if ( + is_transformers_model + and device_map is not None + and hasattr(model_quant_config, "quant_method") + ): + quant_method = getattr(model_quant_config.quant_method, "value", model_quant_config.quant_method) + if quant_method in ["llm_int8", "fp4", "nf4"]: # bitsandbytes quantization methods + loading_kwargs["low_cpu_mem_usage"] = False + logger.info( + f"Disabling low_cpu_mem_usage for {name} because bitsandbytes quantization " + f"with device_map requires materialized tensors, not meta tensors." + ) # check if the module is in a subdirectory if dduf_entries: