From 24abfe44a285b90971109aa3d29eca96fe169d0a Mon Sep 17 00:00:00 2001 From: julio carlos Date: Sun, 22 Mar 2026 19:25:29 +0100 Subject: [PATCH] fix: apply per-tensor weight scales during FP8 dequantization The pre-quantized FP8 checkpoint (ltx-2.3-22b-dev-fp8.safetensors) stores weights in float8_e4m3fn format with per-tensor weight_scale and input_scale factors. When using fp8-cast quantization mode, the upcast path performs a naive .to(bfloat16) without applying these scale factors, producing weight values that are ~770x too large and resulting in noise output instead of coherent video. This commit fixes the issue by: 1. Extracting weight_scale tensors from the state dict before load_state_dict(strict=False) discards them (in SingleGPUModelBuilder) 2. Passing the scales through to _replace_fwd_with_upcast via the model 3. Multiplying dequantized weights by their scale factor during inference The fix is backward-compatible: when no weight_scale tensors are present (e.g. when using fp8-cast to quantize a BF16 checkpoint on the fly), the behavior is unchanged. Fixes noise output when running: python -m ltx_pipelines.ti2vid_two_stages \ --checkpoint-path ltx-2.3-22b-dev-fp8.safetensors \ --quantization fp8-cast ... Related: #165 --- .../loader/single_gpu_model_builder.py | 13 +++++ .../src/ltx_core/quantization/fp8_cast.py | 53 ++++++++++++++++--- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py index 30f432fc..a480aca4 100644 --- a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +++ b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py @@ -90,6 +90,19 @@ def build(self, device: torch.device | None = None, dtype: torch.dtype | None = model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path] model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device) + # Extract per-tensor FP8 weight scales before load_state_dict drops them. + # Pre-quantized FP8 checkpoints include weight_scale tensors that are not + # part of the model's parameter schema, so strict=False silently discards + # them. We stash them on the model so that fp8_cast's upcast forward can + # apply them during inference dequantization. + _fp8_weight_scales: dict[str, float] = {} + for key, value in model_state_dict.sd.items(): + if key.endswith(".weight_scale") and value.numel() == 1: + base_name = key[: -len(".weight_scale")] + _fp8_weight_scales[base_name] = value.item() + if _fp8_weight_scales: + meta_model._fp8_weight_scales = _fp8_weight_scales # type: ignore[attr-defined] + lora_strengths = [lora.strength for lora in self.loras] if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0): sd = model_state_dict.sd diff --git a/packages/ltx-core/src/ltx_core/quantization/fp8_cast.py b/packages/ltx-core/src/ltx_core/quantization/fp8_cast.py index 7452d31c..865e8a0d 100644 --- a/packages/ltx-core/src/ltx_core/quantization/fp8_cast.py +++ b/packages/ltx-core/src/ltx_core/quantization/fp8_cast.py @@ -65,11 +65,27 @@ def _upcast_and_round( return _fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed) -def _replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None: +def _replace_fwd_with_upcast( + layer: torch.nn.Linear, + with_stochastic_rounding: bool = False, + seed: int = 0, + weight_scale: float | None = None, +) -> None: """ - Replace linear.forward and rms_norm.forward with a version that: + Replace linear.forward with a version that: - upcasts weight and bias to input's dtype - - returns F.linear or F.rms_norm calculated in that dtype + - applies weight_scale if the checkpoint was quantized with per-tensor scaling + - returns F.linear calculated in that dtype + + Args: + layer: The Linear layer to patch. + with_stochastic_rounding: Whether to use stochastic rounding during upcast. + seed: Seed for stochastic rounding. + weight_scale: Per-tensor scale factor from the FP8 checkpoint. When provided, + the dequantized weight is multiplied by this value. This is required for + FP8 checkpoints that were quantized with per-tensor scaling (e.g. + ``ltx-2.3-22b-dev-fp8.safetensors``) where each weight tensor has an + associated ``weight_scale`` stored alongside it. """ layer.original_forward = layer.forward @@ -78,6 +94,13 @@ def new_linear_forward(*args, **_kwargs) -> torch.Tensor: # assume first arg is the input tensor x = args[0] w_up = _upcast_and_round(layer.weight, x.dtype, with_stochastic_rounding, seed) + + # Apply per-tensor weight scale from FP8 checkpoint if available. + # Without this, pre-quantized FP8 checkpoints produce incorrect outputs + # because the raw FP8 values are not in the correct magnitude range. + if weight_scale is not None: + w_up = w_up * weight_scale + b_up = None if layer.bias is not None: @@ -92,12 +115,28 @@ def _amend_forward_with_upcast( model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0 ) -> torch.nn.Module: """ - Replace the forward method of the model's Linear and RMSNorm layers to forward + Replace the forward method of the model's Linear layers to forward with upcast and optional stochastic rounding. + + If the model was loaded from a pre-quantized FP8 checkpoint that includes + per-tensor ``weight_scale`` values (stashed on the model by the builder as + ``_fp8_weight_scales``), those scales are automatically applied during the + upcast to produce correctly-scaled outputs. + + This is necessary because pre-quantized FP8 checkpoints (e.g. + ``ltx-2.3-22b-dev-fp8.safetensors``) store weights in a scaled FP8 format + where the raw FP8 values must be multiplied by their associated + ``weight_scale`` to recover the correct magnitude. Without this, a naive + ``.to(bfloat16)`` produces values that are orders of magnitude too large, + resulting in noise output. """ - for m in model.modules(): - if isinstance(m, (torch.nn.Linear)): - _replace_fwd_with_upcast(m, with_stochastic_rounding, seed) + # Retrieve per-tensor weight scales stashed by the model builder + weight_scales: dict[str, float] = getattr(model, "_fp8_weight_scales", {}) + + for name, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + scale = weight_scales.get(name, None) + _replace_fwd_with_upcast(m, with_stochastic_rounding, seed, weight_scale=scale) return model