-
Notifications
You must be signed in to change notification settings - Fork 237
Finally FLUX NVFP4 quantization working #782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughAdds support for loading quantized models via .safetensors files with optional local component paths (VAE, text encoders). Introduces a new SafeTensors export workflow with quantization metadata. Extends ModelConfig, ExportConfig, and ExportManager with relevant parameters. Updates CLI to accept component path arguments. Changes
Sequence DiagramssequenceDiagram
participant CLI as User/CLI
participant Config as ModelConfig
participant Loader as Pipeline Loader
participant HF as HuggingFace
participant Local as Local Components
participant Pipeline as Diffusion Pipeline
CLI->>Config: Parse .safetensors path + component paths
Config->>Loader: Initialize with paths
alt .safetensors detected
Loader->>HF: Load base model
Loader->>Local: Load VAE from vae_path
Loader->>Local: Load text_encoder from path
Loader->>Local: Load text_encoder_2 from path (Flux/SD3)
Local-->>Loader: Component objects
else Standard load
Loader->>HF: Load full pipeline
end
Loader->>Pipeline: Assemble components
Pipeline-->>CLI: Ready pipeline
sequenceDiagram
participant PM as ExportManager
participant Quantizer as save_quantized_safetensors
participant StateDict as State Dict Handler
participant Quant as Quantization Logic
participant ST as SafeTensors Writer
participant Output as Output File
PM->>Quantizer: Call with backbone + quant_format
Quantizer->>StateDict: Extract quantizable layers
alt Flux-style detected
StateDict->>StateDict: Apply Flux key mapping
end
StateDict->>Quant: Build quantization targets
alt nvfp4 format
Quant->>Quant: Per-tensor/block scaling
else float8_e4m3fn format
Quant->>Quant: E4M3FN quantization
end
Quant->>StateDict: Collect metadata per layer
StateDict->>ST: Prepare weights + metadata
ST->>Output: Write SafeTensors file
Output-->>PM: Export complete
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~70 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/diffusers/quantization/quantize.py (1)
1562-1568: Logic error: comparing path with itself.Line 1565-1566 compares
export_config.restore_fromwith itself usingsamefile(), which will always returnTrue. This means the conditionnot export_config.restore_from.samefile(export_config.restore_from)is alwaysFalse, sosave_checkpointis never called when restoring.The intent appears to be checking if the restore path differs from the save path.
🐛 Suggested fix
if export_config.restore_from and export_config.restore_from.exists(): export_manager.restore_checkpoint(backbone) if export_config.quantized_torch_ckpt_path and not export_config.restore_from.samefile( - export_config.restore_from + export_config.quantized_torch_ckpt_path ): export_manager.save_checkpoint(backbone)
🤖 Fix all issues with AI agents
In `@examples/diffusers/quantization/quantize.py`:
- Around line 856-869: The code ignores provided SD3 text encoder paths
(self.config.text_encoder_path and self.config.text_encoder_2_path) and silently
falls back to HF via StableDiffusion3Pipeline.from_pretrained; update the logic
to either validate and reject unsupported local paths or load them: the quick
fix is to raise a clear exception when model_type is SD3 (ModelType.SD3_MEDIUM
or SD3_3_5) and any of text_encoder_path/text_encoder_2_path is set, e.g. raise
ValueError with a message explaining SD3 requires three text encoders and local
paths are not supported; alternatively implement loading of local text encoders
and pass them into StableDiffusion3Pipeline.from_pretrained (replace passing
transformer only) so the provided paths are honored.
- Around line 746-773: The state_dict loaded with load_file may not be deleted
if text_encoder_2.load_state_dict(...) raises, leaking memory; wrap the
load_state_dict and del state_dict in a try/finally (or move the deletion to a
single finally after both the cached-load try and the fallback except) so that
del state_dict (or state_dict = None) always runs regardless of exceptions —
apply this to the block that creates/loads text_encoder_2 (references:
state_dict, text_encoder_2, text_encoder_2.load_state_dict,
T5Config.from_pretrained).
🧹 Nitpick comments (7)
examples/diffusers/quantization/save_quantized_safetensors.py (3)
212-218: Use theF8_E4M3_MAXconstant for consistency.Line 215 uses the hardcoded value
448.0instead of theF8_E4M3_MAXconstant defined at line 18. This inconsistency could lead to maintenance issues if the constant value needs to change.♻️ Suggested fix
- scale = amax.float() / (maxbound * 448.0) + scale = amax.float() / (maxbound * F8_E4M3_MAX)
243-244: Use constants instead of hardcoded values.Line 244 uses hardcoded
448.0 * 6.0instead of the defined constantsF8_E4M3_MAX * F4_E2M1_MAX.♻️ Suggested fix
- weight_scale_2 = weight_quantizer.amax.float() / (448.0 * 6.0) + weight_scale_2 = weight_quantizer.amax.float() / (F8_E4M3_MAX * F4_E2M1_MAX)
424-438: Consider clarifying the tensor expansion logic.When
old_weightis None at line 426-427,torch.empty_like(weight)creates a tensor matching the incoming weight's shape. However, if this is the first weight being concatenated and the final tensor needs to be larger (as handled by lines 428-435), the initial tensor gets replaced anyway. This works but could be more efficient by directly allocating the correct size when possible.examples/diffusers/quantization/quantize.py (4)
627-877: Long method with duplicated imports and potential resource management issues.This method is ~250 lines and handles multiple model types. Consider:
- Extracting model-specific loading into separate methods (e.g.,
_load_flux_components,_load_sd3_components).- The
load_fileimport is repeated multiple times (lines 660, 692, 749-750). Move to module-level.- Lines 761-773: The
state_dictdeletion in both try/except branches could be cleaner with afinallyblock.
1214-1220: Use pathlib consistently instead of os.path.Lines 1214-1216 use
os.path.exists()andos.path.getsize()while the rest of the codebase usespathlib.Path. TheExportConfigalready stores paths asPathobjects.♻️ Suggested fix
- import os - if os.path.exists(str(self.config.quantized_torch_ckpt_path)): - file_size = os.path.getsize(str(self.config.quantized_torch_ckpt_path)) / (1024**3) + if self.config.quantized_torch_ckpt_path.exists(): + file_size = self.config.quantized_torch_ckpt_path.stat().st_size / (1024**3)
1521-1542: Consider using DEBUG log level for verbose configuration logging.The configuration logging at lines 1521-1527 and 1537-1542 uses
logger.info()but appears to be debugging output. Consider usinglogger.debug()so it only appears when--verboseis enabled.♻️ Suggested fix
- logger.info("=" * 80) - logger.info("🔍 EXPORT CONFIGURATION DEBUG") - logger.info(f" args.quantized_torch_ckpt_save_path = {args.quantized_torch_ckpt_save_path}") - logger.info(f" Type: {type(args.quantized_torch_ckpt_save_path)}") - logger.info(f" Truthy: {bool(args.quantized_torch_ckpt_save_path)}") - logger.info("=" * 80) + logger.debug("=" * 80) + logger.debug("EXPORT CONFIGURATION DEBUG") + logger.debug(f" args.quantized_torch_ckpt_save_path = {args.quantized_torch_ckpt_save_path}") + logger.debug(f" Type: {type(args.quantized_torch_ckpt_save_path)}") + logger.debug(f" Truthy: {bool(args.quantized_torch_ckpt_save_path)}") + logger.debug("=" * 80)
462-487: Extract hardcoded FLUX transformer configuration to a shared helper to reduce duplication.The FLUX transformer configuration appears in 3 places (lines 467, 563, 647) with identical parameters. While the hardcoded values (attention_head_dim=128, num_layers=19, num_attention_heads=24, etc.) are correct for FLUX.1-dev, extracting this to a shared helper will improve maintainability and ensure consistency across all usages.
Additionally, using
strict=Falseinload_state_dict()can silently ignore mismatched keys. If intentional for partial model loading, add a comment explaining why strict validation is disabled.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/diffusers/quantization/quantize.pyexamples/diffusers/quantization/save_quantized_safetensors.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/diffusers/quantization/quantize.py (1)
examples/diffusers/quantization/save_quantized_safetensors.py (1)
save_quantized_safetensors(645-732)
🔇 Additional comments (12)
examples/diffusers/quantization/save_quantized_safetensors.py (8)
1-18: LGTM!The imports and constants are appropriate for the quantization functionality. The FP4/FP8 max value constants are correct.
21-94: LGTM!The helper functions for bit manipulation and float conversion are well-implemented for the quantization workflow.
104-123: LGTM!The blocked matrix layout conversion is correctly implemented for NVFP4 Tensor Core compatibility.
126-173: LGTM!The NVFP4 quantization function correctly handles per-tensor and block-level scaling with proper zero-value safeguards.
269-403: LGTM!The FLUX key mapping logic is comprehensive and correctly handles the conversion from diffusers naming to ComfyUI naming conventions.
460-503: LGTM!The quantization target collection logic correctly filters and maps quantizable layers with proper handling for overlapping targets.
506-578: LGTM!The quantization functions have good error handling with appropriate fallbacks and logging for debugging.
604-616: Verify asymmetric quantization logic between formats.For
nvfp4, the code calls_nvfp4_quantize_comfywhich re-quantizes from the detached weight. For other formats (fp8), it calls_quantize_weightwhich may use existing quantizer metadata. This asymmetry might be intentional for ComfyUI compatibility, but should be verified.examples/diffusers/quantization/quantize.py (4)
48-75: LGTM!The new imports are appropriate for supporting local component loading and SafeTensors export.
365-365: Verify default value forsave_safetensors.The
save_safetensorsfield defaults toTrue, meaning SafeTensors export will be enabled by default for all users. This is a behavior change that adds additional file writes. Consider whether this should default toFalseto maintain backward compatibility.
1371-1379: LGTM!The new CLI arguments for local component paths are well-documented and correctly integrated.
1555-1560: LGTM!The ExportManager is correctly instantiated with quantization format and algorithm metadata.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| if self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: | ||
| # Create transformer with hardcoded config to avoid downloading from HuggingFace | ||
| # Use meta device to reduce RAM usage during initialization | ||
| with torch.device('meta'): | ||
| transformer = FluxTransformer2DModel( | ||
| attention_head_dim=128, | ||
| guidance_embeds=True, | ||
| in_channels=64, | ||
| joint_attention_dim=4096, | ||
| num_attention_heads=24, | ||
| num_layers=19, | ||
| num_single_layers=38, | ||
| patch_size=1, | ||
| pooled_projection_dim=768, | ||
| ) | ||
|
|
||
| # Load weights from safetensors file | ||
| from safetensors.torch import load_file | ||
| state_dict = load_file(model_path) | ||
| transformer = transformer.to_empty(device='cpu') | ||
| transformer.load_state_dict(state_dict, strict=False, assign=True) | ||
| del state_dict | ||
| transformer = transformer.to(self.config.model_dtype.get("transformer", self.config.model_dtype.get("default"))) | ||
|
|
||
| self.pipe.transformer = transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential KeyError when accessing model_dtype dict.
Line 581 accesses self.config.model_dtype.get("transformer", self.config.model_dtype.get("default")). If model_dtype is not a dict (which can happen per the type hint dict[str, torch.dtype]), this should work. However, there's no fallback if neither key exists - the result would be None, which would then fail when passed to .to().
🐛 Suggested fix
- transformer = transformer.to(self.config.model_dtype.get("transformer", self.config.model_dtype.get("default")))
+ dtype = self.config.model_dtype.get("transformer") or self.config.model_dtype.get("default") or torch.bfloat16
+ transformer = transformer.to(dtype)| if self.config.text_encoder_2_path: | ||
| self.logger.info(f"Loading T5 text encoder from: {self.config.text_encoder_2_path}") | ||
| # Load T5 from single file | ||
| from safetensors.torch import load_file | ||
| state_dict = load_file(str(self.config.text_encoder_2_path)) | ||
|
|
||
| # Try to load config and create model from cached HF model | ||
| try: | ||
| from transformers import T5Config | ||
| config_t5 = T5Config.from_pretrained( | ||
| "black-forest-labs/FLUX.1-dev", | ||
| subfolder="text_encoder_2", | ||
| local_files_only=True, | ||
| ) | ||
| text_encoder_2 = T5EncoderModel(config_t5) | ||
| text_encoder_2.load_state_dict(state_dict, strict=False) | ||
| del state_dict # Free memory | ||
| text_encoder_2 = text_encoder_2.to(self.config.model_dtype.get("text_encoder_2", dtype_default)) | ||
| except Exception as e: | ||
| self.logger.warning(f"Could not load T5 config from cache: {e}. Downloading...") | ||
| config_t5 = T5Config.from_pretrained( | ||
| "black-forest-labs/FLUX.1-dev", | ||
| subfolder="text_encoder_2", | ||
| ) | ||
| text_encoder_2 = T5EncoderModel(config_t5) | ||
| text_encoder_2.load_state_dict(state_dict, strict=False) | ||
| del state_dict # Free memory | ||
| text_encoder_2 = text_encoder_2.to(self.config.model_dtype.get("text_encoder_2", dtype_default)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resource cleanup could fail on exception.
If text_encoder_2.load_state_dict(state_dict, strict=False) at line 761 or 771 raises an exception, the del state_dict statement won't execute, potentially leaking memory. Consider using a try/finally pattern.
🐛 Suggested fix
if self.config.text_encoder_2_path:
self.logger.info(f"Loading T5 text encoder from: {self.config.text_encoder_2_path}")
- # Load T5 from single file
- from safetensors.torch import load_file
state_dict = load_file(str(self.config.text_encoder_2_path))
- # Try to load config and create model from cached HF model
try:
- from transformers import T5Config
- config_t5 = T5Config.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- subfolder="text_encoder_2",
- local_files_only=True,
- )
- text_encoder_2 = T5EncoderModel(config_t5)
- text_encoder_2.load_state_dict(state_dict, strict=False)
- del state_dict # Free memory
- text_encoder_2 = text_encoder_2.to(self.config.model_dtype.get("text_encoder_2", dtype_default))
- except Exception as e:
- self.logger.warning(f"Could not load T5 config from cache: {e}. Downloading...")
- config_t5 = T5Config.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- subfolder="text_encoder_2",
- )
- text_encoder_2 = T5EncoderModel(config_t5)
- text_encoder_2.load_state_dict(state_dict, strict=False)
- del state_dict # Free memory
- text_encoder_2 = text_encoder_2.to(self.config.model_dtype.get("text_encoder_2", dtype_default))
+ from transformers import T5Config
+ try:
+ config_t5 = T5Config.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ local_files_only=True,
+ )
+ except Exception as e:
+ self.logger.warning(f"Could not load T5 config from cache: {e}. Downloading...")
+ config_t5 = T5Config.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ )
+ text_encoder_2 = T5EncoderModel(config_t5)
+ text_encoder_2.load_state_dict(state_dict, strict=False)
+ text_encoder_2 = text_encoder_2.to(self.config.model_dtype.get("text_encoder_2", dtype_default))
+ finally:
+ del state_dict🤖 Prompt for AI Agents
In `@examples/diffusers/quantization/quantize.py` around lines 746 - 773, The
state_dict loaded with load_file may not be deleted if
text_encoder_2.load_state_dict(...) raises, leaking memory; wrap the
load_state_dict and del state_dict in a try/finally (or move the deletion to a
single finally after both the cached-load try and the fallback except) so that
del state_dict (or state_dict = None) always runs regardless of exceptions —
apply this to the block that creates/loads text_encoder_2 (references:
state_dict, text_encoder_2, text_encoder_2.load_state_dict,
T5Config.from_pretrained).
| # For SD3, we need 3 text encoders - for now, fall back to HF if not all provided | ||
| # This is more complex and users typically keep the full model together | ||
| if self.config.text_encoder_path and self.config.text_encoder_2_path: | ||
| self.logger.warning("SD3 requires 3 text encoders. Falling back to HuggingFace for text encoders...") | ||
|
|
||
| base_model = "stabilityai/stable-diffusion-3-medium" if self.config.model_type == ModelType.SD3_MEDIUM else "stabilityai/stable-diffusion-3.5-medium" | ||
|
|
||
| # Load full pipeline and replace transformer | ||
| pipe = StableDiffusion3Pipeline.from_pretrained( | ||
| base_model, | ||
| transformer=transformer, | ||
| vae=vae, | ||
| torch_dtype=self.config.model_dtype, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SD3 text encoder paths are accepted but ignored.
Lines 858-859 log a warning when text encoder paths are provided for SD3, but the code proceeds to load text encoders from HuggingFace anyway. This could confuse users who expect their local paths to be used. Consider either:
- Raising an error if text encoder paths are provided for unsupported models
- Implementing SD3 text encoder loading from local paths
🤖 Prompt for AI Agents
In `@examples/diffusers/quantization/quantize.py` around lines 856 - 869, The code
ignores provided SD3 text encoder paths (self.config.text_encoder_path and
self.config.text_encoder_2_path) and silently falls back to HF via
StableDiffusion3Pipeline.from_pretrained; update the logic to either validate
and reject unsupported local paths or load them: the quick fix is to raise a
clear exception when model_type is SD3 (ModelType.SD3_MEDIUM or SD3_3_5) and any
of text_encoder_path/text_encoder_2_path is set, e.g. raise ValueError with a
message explaining SD3 requires three text encoders and local paths are not
supported; alternatively implement loading of local text encoders and pass them
into StableDiffusion3Pipeline.from_pretrained (replace passing transformer only)
so the provided paths are honored.
| self.logger.info(f"") | ||
| self.logger.info("=" * 80) | ||
| self.logger.info("🚀 CLAUDE'S FIX: Starting SafeTensors export with metadata!") | ||
| self.logger.info(f" Fix applied: 2026-01-14 @ 14:20 UTC+3") | ||
| self.logger.info(f" Target file: {safetensors_path.name}") | ||
| self.logger.info("=" * 80) | ||
| self.logger.info(f"📦 Saving SafeTensors format for ComfyUI compatibility...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove hardcoded timestamp from log messages.
Lines 1228-1229 contain a hardcoded timestamp "2026-01-14 @ 14:20 UTC+3" which will immediately become stale. This appears to be debug/development logging that should be removed before merging.
🐛 Suggested fix
self.logger.info(f"")
self.logger.info("=" * 80)
- self.logger.info("🚀 CLAUDE'S FIX: Starting SafeTensors export with metadata!")
- self.logger.info(f" Fix applied: 2026-01-14 @ 14:20 UTC+3")
+ self.logger.info("Starting SafeTensors export with ComfyUI metadata")
self.logger.info(f" Target file: {safetensors_path.name}")
self.logger.info("=" * 80)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.logger.info(f"") | |
| self.logger.info("=" * 80) | |
| self.logger.info("🚀 CLAUDE'S FIX: Starting SafeTensors export with metadata!") | |
| self.logger.info(f" Fix applied: 2026-01-14 @ 14:20 UTC+3") | |
| self.logger.info(f" Target file: {safetensors_path.name}") | |
| self.logger.info("=" * 80) | |
| self.logger.info(f"📦 Saving SafeTensors format for ComfyUI compatibility...") | |
| self.logger.info(f"") | |
| self.logger.info("=" * 80) | |
| self.logger.info("Starting SafeTensors export with ComfyUI metadata") | |
| self.logger.info(f" Target file: {safetensors_path.name}") | |
| self.logger.info("=" * 80) | |
| self.logger.info(f"📦 Saving SafeTensors format for ComfyUI compatibility...") |
| logger.info(f"Saving quantized model to SafeTensors: {output_path}") | ||
| save_file(state_dict, str(output_path), metadata=metadata) | ||
|
|
||
| if output_path.exists(): | ||
| file_size_gb = output_path.stat().st_size / (1024**3) | ||
| logger.info("SafeTensors file saved successfully.") | ||
| logger.info(f"Path: {output_path}") | ||
| logger.info(f"Size: {file_size_gb:.2f} GB") | ||
| logger.info(f"Quantized layers: {len(quant_layers)}") | ||
| return True | ||
|
|
||
| logger.error("Failed to save SafeTensors file.") | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider raising an exception on save failure instead of returning False.
The function logs an error and returns False when the file fails to save (line 731-732), but this could be silently ignored by callers. The calling code in quantize.py doesn't appear to check the return value, meaning a failed save could go unnoticed.
🐛 Suggested fix
if output_path.exists():
file_size_gb = output_path.stat().st_size / (1024**3)
logger.info("SafeTensors file saved successfully.")
logger.info(f"Path: {output_path}")
logger.info(f"Size: {file_size_gb:.2f} GB")
logger.info(f"Quantized layers: {len(quant_layers)}")
return True
- logger.error("Failed to save SafeTensors file.")
- return False
+ raise IOError(f"Failed to save SafeTensors file: {output_path}")
It took me more than 1 day and massive amount of money to make it work
I have quantized FLUX SRPO NVFP4
Please fix others as well like this so that we can quantize Wan 2.2, Qwen 2512 and use in ComfyUI
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.