Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 217 additions & 7 deletions lightx2v/models/networks/wan/infer/transformer_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

import torch
from loguru import logger

from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
Expand All @@ -10,6 +11,21 @@
from .triton_ops import fuse_scale_shift_kernel
from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch, apply_wan_rope_with_torch_naive

try:
from lightx2v_kernel.gemm import (
cutlass_scaled_mxfp8_mm,
cutlass_scaled_mxfp8_mm_residual_gate,
scaled_mxfp8_gelu_quant,
scaled_mxfp8_modulate_quant,
)

_WAN_MXFP8_FFN_IMPORT_ERROR = None
except Exception as exc:
cutlass_scaled_mxfp8_mm, cutlass_scaled_mxfp8_mm_residual_gate = None, None
scaled_mxfp8_gelu_quant = None
scaled_mxfp8_modulate_quant = None
_WAN_MXFP8_FFN_IMPORT_ERROR = exc

torch_device_module = getattr(torch, AI_DEVICE)


Expand Down Expand Up @@ -75,6 +91,160 @@ def rope_wrapper(xq, xk, cos_sin_cache):

self.cos_sin = None

self._mxfp8_fuse_available = self._probe_mxfp8_fuse_availability()

def _probe_mxfp8_fuse_availability(self):
"""Probe once whether MXFP8 fused ops can run on this device.

Returns False (with a warning) if the kernel is unavailable or the GPU
is not SM120/SM120a, so the inference falls back to the non-fused path.
"""
if self.config.get("dit_quant_scheme", "Default") != "mxfp8":
return False
if not torch.cuda.is_available():
logger.warning("MXFP8 fused ops require a CUDA device, falling back to non-fused path")
return False
if cutlass_scaled_mxfp8_mm is None or cutlass_scaled_mxfp8_mm_residual_gate is None or scaled_mxfp8_gelu_quant is None or scaled_mxfp8_modulate_quant is None:
detail = f": {type(_WAN_MXFP8_FFN_IMPORT_ERROR).__name__}: {_WAN_MXFP8_FFN_IMPORT_ERROR}" if _WAN_MXFP8_FFN_IMPORT_ERROR is not None else ""
logger.warning(f"MXFP8 fused ops unavailable, falling back to non-fused path{detail}")
return False
major, minor = torch.cuda.get_device_capability()
if major != 12:
logger.warning(f"MXFP8 fused ops require SM120/SM120a, got SM{major}.{minor}, falling back to non-fused path")
return False
return True

def _use_mxfp8_quant_fuse(self):
return self._mxfp8_fuse_available

def _ensure_mxfp8_quant_fuse_ready(self, phase, *tensors, module_names=(), required_module_attrs=("weight", "weight_scale", "alpha")):
if not self._use_mxfp8_quant_fuse():
return
for tensor in tensors:
if tensor is None:
continue
if not tensor.is_cuda:
raise RuntimeError("mxfp8_quant_fuse expects CUDA activations")
device_tensor = next((tensor for tensor in tensors if tensor is not None), None)
if device_tensor is None:
raise RuntimeError("mxfp8_quant_fuse requires at least one CUDA tensor for device validation")
major, _minor = torch.cuda.get_device_capability(device_tensor.device)
if major != 12:
raise RuntimeError("mxfp8_quant_fuse is only enabled on SM120/SM120a GPUs")
if cutlass_scaled_mxfp8_mm is None or cutlass_scaled_mxfp8_mm_residual_gate is None or scaled_mxfp8_gelu_quant is None or scaled_mxfp8_modulate_quant is None:
detail = f": {type(_WAN_MXFP8_FFN_IMPORT_ERROR).__name__}: {_WAN_MXFP8_FFN_IMPORT_ERROR}" if _WAN_MXFP8_FFN_IMPORT_ERROR is not None else ""
raise RuntimeError(f"mxfp8_quant_fuse requires lightx2v_kernel with MXFP8 fused quant ops{detail}")
for name in module_names:
module = getattr(phase, name)
if getattr(module, "has_lora_branch", False) or getattr(module, "has_diff", False):
raise RuntimeError(f"mxfp8_quant_fuse does not support active LoRA/diff on {name}")
if not all(hasattr(module, attr) for attr in required_module_attrs):
raise RuntimeError(f"mxfp8_quant_fuse expects {name} to be an MXFP8 quantized weight module")

def _ensure_mxfp8_quant_ffn_ready(self, phase, norm2_out, residual, c_gate_msa=None, c_scale_msa=None, c_shift_msa=None):
if not self._use_mxfp8_quant_fuse():
return
if (c_scale_msa is None) != (c_shift_msa is None):
raise RuntimeError("MXFP8 FFN modulate-quant readiness requires both c_scale_msa and c_shift_msa")
extra_tensors = []
self._ensure_mxfp8_quant_fuse_ready(
phase,
norm2_out,
residual,
c_scale_msa,
c_shift_msa,
module_names=("ffn_0", "ffn_2"),
required_module_attrs=("act_quant_func", "weight", "weight_scale", "alpha"),
)
if c_gate_msa is None:
raise RuntimeError("mxfp8_quant_fuse requires c_gate_msa for residual-gate fusion")
extra_tensors.append(c_gate_msa)
if extra_tensors:
self._ensure_mxfp8_quant_fuse_ready(phase, *extra_tensors)

def _can_use_mxfp8_modulate_quant(self, norm2_out, c_scale_msa, c_shift_msa):
if scaled_mxfp8_modulate_quant is None:
return False
if not self._use_mxfp8_quant_fuse():
return False
if self.sensitive_layer_dtype != self.infer_dtype:
return False
if norm2_out.dtype != torch.bfloat16 or c_scale_msa.dtype != torch.bfloat16 or c_shift_msa.dtype != torch.bfloat16:
return False
if not (norm2_out.is_cuda and c_scale_msa.is_cuda and c_shift_msa.is_cuda):
return False
if norm2_out.device != c_scale_msa.device or norm2_out.device != c_shift_msa.device:
return False
if norm2_out.dim() != 2 or not norm2_out.is_contiguous():
return False
hidden = norm2_out.shape[1]
tokens = norm2_out.shape[0]
valid_numel = (hidden, tokens * hidden)
return c_scale_msa.numel() in valid_numel and c_shift_msa.numel() in valid_numel

def _can_reuse_self_attn_mxfp8_quant(self, phase, norm1_out, scale_msa, shift_msa):
if cutlass_scaled_mxfp8_mm is None:
return False
if not self._can_use_mxfp8_modulate_quant(norm1_out, scale_msa, shift_msa):
return False
for name in ("self_attn_q", "self_attn_k", "self_attn_v"):
module = getattr(phase, name)
if getattr(module, "has_lora_branch", False) or getattr(module, "has_diff", False):
return False
if not all(hasattr(module, attr) for attr in ("weight", "weight_scale", "alpha")):
return False
return True

def _mxfp8_quant_bias(self, module):
if hasattr(module, "_get_actual_bias"):
return module._get_actual_bias()
return module.bias if hasattr(module, "bias") else None

def _mxfp8_apply(self, module, input_tensor):
input_tensor_quant, input_tensor_scale = module.act_quant_func(input_tensor)
return self._mxfp8_apply_quantized(module, input_tensor_quant, input_tensor_scale)

def _mxfp8_apply_quantized(self, module, input_tensor_quant, input_tensor_scale):
if module.alpha.device != module.weight.device:
module.alpha = module.alpha.to(module.weight.device)
return cutlass_scaled_mxfp8_mm(
input_tensor_quant,
module.weight,
input_tensor_scale,
module.weight_scale,
alpha=module.alpha,
bias=self._mxfp8_quant_bias(module),
)

def _mxfp8_apply_residual_gate(self, module, input_tensor, residual, gate):
input_tensor_quant, input_tensor_scale = module.act_quant_func(input_tensor)
return self._mxfp8_apply_residual_gate_quantized(module, input_tensor_quant, input_tensor_scale, residual, gate)

def _mxfp8_apply_residual_gate_quantized(self, module, input_tensor_quant, input_tensor_scale, residual, gate):
if module.alpha.device != module.weight.device:
module.alpha = module.alpha.to(module.weight.device)
return cutlass_scaled_mxfp8_mm_residual_gate(
input_tensor_quant,
module.weight,
input_tensor_scale,
module.weight_scale,
alpha=module.alpha,
residual=residual,
gate=gate,
bias=self._mxfp8_quant_bias(module),
)

def _infer_ffn_with_mxfp8_quant_fuse(self, phase, norm2_out, residual, c_gate_msa=None, c_scale_msa=None, c_shift_msa=None):
self._ensure_mxfp8_quant_ffn_ready(phase, norm2_out, residual, c_gate_msa, c_scale_msa, c_shift_msa)
if c_scale_msa is not None and c_shift_msa is not None and self._can_use_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa):
norm2_quant, norm2_scale = scaled_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa)
y = self._mxfp8_apply_quantized(phase.ffn_0, norm2_quant, norm2_scale)
else:
y = self._mxfp8_apply(phase.ffn_0, norm2_out)
y_quant, y_scale = scaled_mxfp8_gelu_quant(y)
self._mxfp8_apply_residual_gate_quantized(phase.ffn_2, y_quant, y_scale, residual, c_gate_msa.squeeze())
return None

@torch.no_grad()
def reset_post_adapter_states(self):
pass
Expand Down Expand Up @@ -149,7 +319,7 @@ def infer_block(self, block, x, pre_infer_out):
y_out,
gate_msa,
)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa, c_gate_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_args["hints"].append(block.compute_phases[2].after_proj.apply(x))
Expand All @@ -175,6 +345,8 @@ def pre_process(self, modulation, embed0):

def infer_self_attn(self, phase, x, shift_msa, scale_msa):
cos_sin = self.cos_sin
norm1_quant = None
norm1_scale = None
if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
Expand All @@ -186,22 +358,40 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa):
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze()
if self._use_mxfp8_quant_fuse():
self._ensure_mxfp8_quant_fuse_ready(
phase,
norm1_out,
scale_msa,
shift_msa,
module_names=("self_attn_q", "self_attn_k", "self_attn_v"),
)
if self._can_reuse_self_attn_mxfp8_quant(phase, norm1_out, scale_msa, shift_msa):
norm1_quant, norm1_scale = scaled_mxfp8_modulate_quant(norm1_out, scale_msa, shift_msa)
else:
norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze()

if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.infer_dtype)

s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d)
k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
if norm1_quant is not None:
q = phase.self_attn_norm_q.apply(self._mxfp8_apply_quantized(phase.self_attn_q, norm1_quant, norm1_scale)).view(s, n, d)
k = phase.self_attn_norm_k.apply(self._mxfp8_apply_quantized(phase.self_attn_k, norm1_quant, norm1_scale)).view(s, n, d)
v = self._mxfp8_apply_quantized(phase.self_attn_v, norm1_quant, norm1_scale).view(s, n, d)
else:
q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d)
k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
q, k = self.apply_rope_func(q, k, cos_sin)
img_qkv_len = q.shape[0]
if self.self_attn_cu_seqlens_qkv is None:
self.self_attn_cu_seqlens_qkv = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32)

if self.clean_cuda_cache:
del norm1_out, shift_msa, scale_msa
if norm1_quant is not None:
del norm1_quant, norm1_scale
torch_device_module.empty_cache()

attn_running_args = {
Expand Down Expand Up @@ -310,13 +500,15 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa):
torch_device_module.empty_cache()
return x, attn_out

def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa, c_gate_msa=None):
x.add_(attn_out)

if self.clean_cuda_cache:
del attn_out
torch_device_module.empty_cache()

mxfp8_modulate_scale = None
mxfp8_modulate_shift = None
if hasattr(phase, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor
Expand All @@ -328,11 +520,27 @@ def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze()
if self._use_mxfp8_quant_fuse():
self._ensure_mxfp8_quant_ffn_ready(phase, norm2_out, x, c_gate_msa, c_scale_msa, c_shift_msa)
if self._can_use_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa):
mxfp8_modulate_scale = c_scale_msa
mxfp8_modulate_shift = c_shift_msa
else:
norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze()

if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype)

if self._use_mxfp8_quant_fuse():
return self._infer_ffn_with_mxfp8_quant_fuse(
phase,
norm2_out,
x,
c_gate_msa,
c_scale_msa=mxfp8_modulate_scale,
c_shift_msa=mxfp8_modulate_shift,
)

y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
del norm2_out, x
Expand All @@ -345,6 +553,8 @@ def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
return y

def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
if y is None:
return x
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else:
Expand Down
15 changes: 15 additions & 0 deletions lightx2v_kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
" Tensor! output_scale) -> ()");
m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120);

m.def(
"scaled_mxfp8_gelu_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_mxfp8_gelu_quant_sm120", torch::kCUDA, &scaled_mxfp8_gelu_quant_sm120);

m.def(
"scaled_mxfp8_modulate_quant_sm120(Tensor! output, Tensor! input, Tensor scale, Tensor shift,"
" Tensor! output_scale) -> ()");
m.impl("scaled_mxfp8_modulate_quant_sm120", torch::kCUDA, &scaled_mxfp8_modulate_quant_sm120);

m.def(
"scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
Expand All @@ -46,6 +56,11 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120);

m.def(
"cutlass_scaled_mxfp8_mm_residual_gate_sm120(Tensor! residual, Tensor mat_a, Tensor mat_b, Tensor scales_a, "
"Tensor scales_b, Tensor alpha, Tensor? bias, Tensor gate) -> ()");
m.impl("cutlass_scaled_mxfp8_mm_residual_gate_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_residual_gate_sm120);

}

REGISTER_EXTENSION(common_ops)
Loading
Loading