From 94816cc2d760ca06f11601adb22b104e1b4b6552 Mon Sep 17 00:00:00 2001 From: xlycae Date: Sat, 23 May 2026 11:55:31 +0800 Subject: [PATCH] feat: add MXFP8 fused operators for Wan transformer inference on SM120 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement three fused CUDA kernels for MXFP8 quantized inference on Blackwell (SM120): 1. scaled_mxfp8_gelu_quant: fuse GELU activation + E8M0 quantization 2. scaled_mxfp8_modulate_quant: fuse scale/shift modulation + quantization 3. cutlass_scaled_mxfp8_mm_residual_gate: fuse GEMM + residual + gate in CUTLASS 3.x epilogue Performance on RTX 5090 (Wan 5B FFN, m=4096, hidden=1536, ffn=8960): - GELU+Quant: 1.30× faster (27.8μs → 21.3μs) - Modulate+Quant: 3.26× faster (92.7μs → 28.5μs) - GEMM+Residual+Gate: 1.40× faster (194.7μs → 138.9μs) - End-to-end FFN: 1.20× faster (608μs → 505μs, -103μs per block) - Reduces kernel launches from 7 to 3 per FFN block Features: - Supports all Wan tasks (t2v/i2v/flf2v/animate/s2v/rs2v) - Auto-fallback on non-SM120 GPUs (H100/A100/RTX4090) with warning - Handles FP16/BF16 activations (kernel auto-detects dtype) - One-time device capability probe at init (eliminates ~4000 redundant checks per inference) Tested: 10/10 unit tests pass, 6/6 fallback scenarios verified Address review feedback (PR #1090): - Skip alpha device move when already on target device - Extract check_sm120_or_throw to shared header sm120_utils.h - Replace std::cerr with TORCH_CHECK in dtype switch fallbacks - Avoid intermediate BF16 round in residual_gate kernel - Apply ruff-format Co-Authored-By: Claude Opus 4.7 (1M context) --- .../networks/wan/infer/transformer_infer.py | 224 +++++++- lightx2v_kernel/csrc/common_extension.cc | 15 + .../csrc/gemm/mxfp8_quant_kernels_sm120.cu | 506 +++++++++++++++++- .../gemm/mxfp8_scaled_mm_kernels_sm120.cu | 441 +++++++++++++-- lightx2v_kernel/csrc/gemm/sm120_utils.h | 80 +++ lightx2v_kernel/include/lightx2v_kernel_ops.h | 20 + .../python/lightx2v_kernel/gemm.py | 50 ++ .../test/mxfp8_mxfp8/test_fused_ffn.py | 402 ++++++++++++++ 8 files changed, 1681 insertions(+), 57 deletions(-) create mode 100644 lightx2v_kernel/csrc/gemm/sm120_utils.h create mode 100644 lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index 8daba4808..bff600a5c 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -1,6 +1,7 @@ from functools import partial import torch +from loguru import logger from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.utils.envs import * @@ -10,6 +11,21 @@ from .triton_ops import fuse_scale_shift_kernel from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch, apply_wan_rope_with_torch_naive +try: + from lightx2v_kernel.gemm import ( + cutlass_scaled_mxfp8_mm, + cutlass_scaled_mxfp8_mm_residual_gate, + scaled_mxfp8_gelu_quant, + scaled_mxfp8_modulate_quant, + ) + + _WAN_MXFP8_FFN_IMPORT_ERROR = None +except Exception as exc: + cutlass_scaled_mxfp8_mm, cutlass_scaled_mxfp8_mm_residual_gate = None, None + scaled_mxfp8_gelu_quant = None + scaled_mxfp8_modulate_quant = None + _WAN_MXFP8_FFN_IMPORT_ERROR = exc + torch_device_module = getattr(torch, AI_DEVICE) @@ -75,6 +91,160 @@ def rope_wrapper(xq, xk, cos_sin_cache): self.cos_sin = None + self._mxfp8_fuse_available = self._probe_mxfp8_fuse_availability() + + def _probe_mxfp8_fuse_availability(self): + """Probe once whether MXFP8 fused ops can run on this device. + + Returns False (with a warning) if the kernel is unavailable or the GPU + is not SM120/SM120a, so the inference falls back to the non-fused path. + """ + if self.config.get("dit_quant_scheme", "Default") != "mxfp8": + return False + if not torch.cuda.is_available(): + logger.warning("MXFP8 fused ops require a CUDA device, falling back to non-fused path") + return False + if cutlass_scaled_mxfp8_mm is None or cutlass_scaled_mxfp8_mm_residual_gate is None or scaled_mxfp8_gelu_quant is None or scaled_mxfp8_modulate_quant is None: + detail = f": {type(_WAN_MXFP8_FFN_IMPORT_ERROR).__name__}: {_WAN_MXFP8_FFN_IMPORT_ERROR}" if _WAN_MXFP8_FFN_IMPORT_ERROR is not None else "" + logger.warning(f"MXFP8 fused ops unavailable, falling back to non-fused path{detail}") + return False + major, minor = torch.cuda.get_device_capability() + if major != 12: + logger.warning(f"MXFP8 fused ops require SM120/SM120a, got SM{major}.{minor}, falling back to non-fused path") + return False + return True + + def _use_mxfp8_quant_fuse(self): + return self._mxfp8_fuse_available + + def _ensure_mxfp8_quant_fuse_ready(self, phase, *tensors, module_names=(), required_module_attrs=("weight", "weight_scale", "alpha")): + if not self._use_mxfp8_quant_fuse(): + return + for tensor in tensors: + if tensor is None: + continue + if not tensor.is_cuda: + raise RuntimeError("mxfp8_quant_fuse expects CUDA activations") + device_tensor = next((tensor for tensor in tensors if tensor is not None), None) + if device_tensor is None: + raise RuntimeError("mxfp8_quant_fuse requires at least one CUDA tensor for device validation") + major, _minor = torch.cuda.get_device_capability(device_tensor.device) + if major != 12: + raise RuntimeError("mxfp8_quant_fuse is only enabled on SM120/SM120a GPUs") + if cutlass_scaled_mxfp8_mm is None or cutlass_scaled_mxfp8_mm_residual_gate is None or scaled_mxfp8_gelu_quant is None or scaled_mxfp8_modulate_quant is None: + detail = f": {type(_WAN_MXFP8_FFN_IMPORT_ERROR).__name__}: {_WAN_MXFP8_FFN_IMPORT_ERROR}" if _WAN_MXFP8_FFN_IMPORT_ERROR is not None else "" + raise RuntimeError(f"mxfp8_quant_fuse requires lightx2v_kernel with MXFP8 fused quant ops{detail}") + for name in module_names: + module = getattr(phase, name) + if getattr(module, "has_lora_branch", False) or getattr(module, "has_diff", False): + raise RuntimeError(f"mxfp8_quant_fuse does not support active LoRA/diff on {name}") + if not all(hasattr(module, attr) for attr in required_module_attrs): + raise RuntimeError(f"mxfp8_quant_fuse expects {name} to be an MXFP8 quantized weight module") + + def _ensure_mxfp8_quant_ffn_ready(self, phase, norm2_out, residual, c_gate_msa=None, c_scale_msa=None, c_shift_msa=None): + if not self._use_mxfp8_quant_fuse(): + return + if (c_scale_msa is None) != (c_shift_msa is None): + raise RuntimeError("MXFP8 FFN modulate-quant readiness requires both c_scale_msa and c_shift_msa") + extra_tensors = [] + self._ensure_mxfp8_quant_fuse_ready( + phase, + norm2_out, + residual, + c_scale_msa, + c_shift_msa, + module_names=("ffn_0", "ffn_2"), + required_module_attrs=("act_quant_func", "weight", "weight_scale", "alpha"), + ) + if c_gate_msa is None: + raise RuntimeError("mxfp8_quant_fuse requires c_gate_msa for residual-gate fusion") + extra_tensors.append(c_gate_msa) + if extra_tensors: + self._ensure_mxfp8_quant_fuse_ready(phase, *extra_tensors) + + def _can_use_mxfp8_modulate_quant(self, norm2_out, c_scale_msa, c_shift_msa): + if scaled_mxfp8_modulate_quant is None: + return False + if not self._use_mxfp8_quant_fuse(): + return False + if self.sensitive_layer_dtype != self.infer_dtype: + return False + if norm2_out.dtype != torch.bfloat16 or c_scale_msa.dtype != torch.bfloat16 or c_shift_msa.dtype != torch.bfloat16: + return False + if not (norm2_out.is_cuda and c_scale_msa.is_cuda and c_shift_msa.is_cuda): + return False + if norm2_out.device != c_scale_msa.device or norm2_out.device != c_shift_msa.device: + return False + if norm2_out.dim() != 2 or not norm2_out.is_contiguous(): + return False + hidden = norm2_out.shape[1] + tokens = norm2_out.shape[0] + valid_numel = (hidden, tokens * hidden) + return c_scale_msa.numel() in valid_numel and c_shift_msa.numel() in valid_numel + + def _can_reuse_self_attn_mxfp8_quant(self, phase, norm1_out, scale_msa, shift_msa): + if cutlass_scaled_mxfp8_mm is None: + return False + if not self._can_use_mxfp8_modulate_quant(norm1_out, scale_msa, shift_msa): + return False + for name in ("self_attn_q", "self_attn_k", "self_attn_v"): + module = getattr(phase, name) + if getattr(module, "has_lora_branch", False) or getattr(module, "has_diff", False): + return False + if not all(hasattr(module, attr) for attr in ("weight", "weight_scale", "alpha")): + return False + return True + + def _mxfp8_quant_bias(self, module): + if hasattr(module, "_get_actual_bias"): + return module._get_actual_bias() + return module.bias if hasattr(module, "bias") else None + + def _mxfp8_apply(self, module, input_tensor): + input_tensor_quant, input_tensor_scale = module.act_quant_func(input_tensor) + return self._mxfp8_apply_quantized(module, input_tensor_quant, input_tensor_scale) + + def _mxfp8_apply_quantized(self, module, input_tensor_quant, input_tensor_scale): + if module.alpha.device != module.weight.device: + module.alpha = module.alpha.to(module.weight.device) + return cutlass_scaled_mxfp8_mm( + input_tensor_quant, + module.weight, + input_tensor_scale, + module.weight_scale, + alpha=module.alpha, + bias=self._mxfp8_quant_bias(module), + ) + + def _mxfp8_apply_residual_gate(self, module, input_tensor, residual, gate): + input_tensor_quant, input_tensor_scale = module.act_quant_func(input_tensor) + return self._mxfp8_apply_residual_gate_quantized(module, input_tensor_quant, input_tensor_scale, residual, gate) + + def _mxfp8_apply_residual_gate_quantized(self, module, input_tensor_quant, input_tensor_scale, residual, gate): + if module.alpha.device != module.weight.device: + module.alpha = module.alpha.to(module.weight.device) + return cutlass_scaled_mxfp8_mm_residual_gate( + input_tensor_quant, + module.weight, + input_tensor_scale, + module.weight_scale, + alpha=module.alpha, + residual=residual, + gate=gate, + bias=self._mxfp8_quant_bias(module), + ) + + def _infer_ffn_with_mxfp8_quant_fuse(self, phase, norm2_out, residual, c_gate_msa=None, c_scale_msa=None, c_shift_msa=None): + self._ensure_mxfp8_quant_ffn_ready(phase, norm2_out, residual, c_gate_msa, c_scale_msa, c_shift_msa) + if c_scale_msa is not None and c_shift_msa is not None and self._can_use_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa): + norm2_quant, norm2_scale = scaled_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa) + y = self._mxfp8_apply_quantized(phase.ffn_0, norm2_quant, norm2_scale) + else: + y = self._mxfp8_apply(phase.ffn_0, norm2_out) + y_quant, y_scale = scaled_mxfp8_gelu_quant(y) + self._mxfp8_apply_residual_gate_quantized(phase.ffn_2, y_quant, y_scale, residual, c_gate_msa.squeeze()) + return None + @torch.no_grad() def reset_post_adapter_states(self): pass @@ -149,7 +319,7 @@ def infer_block(self, block, x, pre_infer_out): y_out, gate_msa, ) - y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa) + y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa, c_gate_msa) x = self.post_process(x, y, c_gate_msa, pre_infer_out) if hasattr(block.compute_phases[2], "after_proj"): pre_infer_out.adapter_args["hints"].append(block.compute_phases[2].after_proj.apply(x)) @@ -175,6 +345,8 @@ def pre_process(self, modulation, embed0): def infer_self_attn(self, phase, x, shift_msa, scale_msa): cos_sin = self.cos_sin + norm1_quant = None + norm1_scale = None if hasattr(phase, "smooth_norm1_weight"): norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor @@ -186,15 +358,31 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): norm1_out = phase.norm1.apply(x) if self.sensitive_layer_dtype != self.infer_dtype: norm1_out = norm1_out.to(self.sensitive_layer_dtype) - norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze() + if self._use_mxfp8_quant_fuse(): + self._ensure_mxfp8_quant_fuse_ready( + phase, + norm1_out, + scale_msa, + shift_msa, + module_names=("self_attn_q", "self_attn_k", "self_attn_v"), + ) + if self._can_reuse_self_attn_mxfp8_quant(phase, norm1_out, scale_msa, shift_msa): + norm1_quant, norm1_scale = scaled_mxfp8_modulate_quant(norm1_out, scale_msa, shift_msa) + else: + norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze() if self.sensitive_layer_dtype != self.infer_dtype: norm1_out = norm1_out.to(self.infer_dtype) s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim - q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) - k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) - v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + if norm1_quant is not None: + q = phase.self_attn_norm_q.apply(self._mxfp8_apply_quantized(phase.self_attn_q, norm1_quant, norm1_scale)).view(s, n, d) + k = phase.self_attn_norm_k.apply(self._mxfp8_apply_quantized(phase.self_attn_k, norm1_quant, norm1_scale)).view(s, n, d) + v = self._mxfp8_apply_quantized(phase.self_attn_v, norm1_quant, norm1_scale).view(s, n, d) + else: + q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) + k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) + v = phase.self_attn_v.apply(norm1_out).view(s, n, d) q, k = self.apply_rope_func(q, k, cos_sin) img_qkv_len = q.shape[0] if self.self_attn_cu_seqlens_qkv is None: @@ -202,6 +390,8 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): if self.clean_cuda_cache: del norm1_out, shift_msa, scale_msa + if norm1_quant is not None: + del norm1_quant, norm1_scale torch_device_module.empty_cache() attn_running_args = { @@ -310,13 +500,15 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa): torch_device_module.empty_cache() return x, attn_out - def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): + def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa, c_gate_msa=None): x.add_(attn_out) if self.clean_cuda_cache: del attn_out torch_device_module.empty_cache() + mxfp8_modulate_scale = None + mxfp8_modulate_shift = None if hasattr(phase, "smooth_norm2_weight"): norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor @@ -328,11 +520,27 @@ def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): norm2_out = phase.norm2.apply(x) if self.sensitive_layer_dtype != self.infer_dtype: norm2_out = norm2_out.to(self.sensitive_layer_dtype) - norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze() + if self._use_mxfp8_quant_fuse(): + self._ensure_mxfp8_quant_ffn_ready(phase, norm2_out, x, c_gate_msa, c_scale_msa, c_shift_msa) + if self._can_use_mxfp8_modulate_quant(norm2_out, c_scale_msa, c_shift_msa): + mxfp8_modulate_scale = c_scale_msa + mxfp8_modulate_shift = c_shift_msa + else: + norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze() if self.sensitive_layer_dtype != self.infer_dtype: norm2_out = norm2_out.to(self.infer_dtype) + if self._use_mxfp8_quant_fuse(): + return self._infer_ffn_with_mxfp8_quant_fuse( + phase, + norm2_out, + x, + c_gate_msa, + c_scale_msa=mxfp8_modulate_scale, + c_shift_msa=mxfp8_modulate_shift, + ) + y = phase.ffn_0.apply(norm2_out) if self.clean_cuda_cache: del norm2_out, x @@ -345,6 +553,8 @@ def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): return y def post_process(self, x, y, c_gate_msa, pre_infer_out=None): + if y is None: + return x if self.sensitive_layer_dtype != self.infer_dtype: x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() else: diff --git a/lightx2v_kernel/csrc/common_extension.cc b/lightx2v_kernel/csrc/common_extension.cc index 18b15679a..a198bb284 100644 --- a/lightx2v_kernel/csrc/common_extension.cc +++ b/lightx2v_kernel/csrc/common_extension.cc @@ -26,6 +26,16 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { " Tensor! output_scale) -> ()"); m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120); + m.def( + "scaled_mxfp8_gelu_quant_sm120(Tensor! output, Tensor! input," + " Tensor! output_scale) -> ()"); + m.impl("scaled_mxfp8_gelu_quant_sm120", torch::kCUDA, &scaled_mxfp8_gelu_quant_sm120); + + m.def( + "scaled_mxfp8_modulate_quant_sm120(Tensor! output, Tensor! input, Tensor scale, Tensor shift," + " Tensor! output_scale) -> ()"); + m.impl("scaled_mxfp8_modulate_quant_sm120", torch::kCUDA, &scaled_mxfp8_modulate_quant_sm120); + m.def( "scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input," " Tensor! output_scale) -> ()"); @@ -46,6 +56,11 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { "alpha, Tensor? bias) -> ()"); m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120); + m.def( + "cutlass_scaled_mxfp8_mm_residual_gate_sm120(Tensor! residual, Tensor mat_a, Tensor mat_b, Tensor scales_a, " + "Tensor scales_b, Tensor alpha, Tensor? bias, Tensor gate) -> ()"); + m.impl("cutlass_scaled_mxfp8_mm_residual_gate_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_residual_gate_sm120); + } REGISTER_EXTENSION(common_ops) diff --git a/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu index 2c8cd8106..fa360a4e2 100644 --- a/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu @@ -6,6 +6,10 @@ #include #include +#include +#include + +#include "sm120_utils.h" #include "utils.h" // Get type2 from type or vice versa (applied to half and bfloat16) @@ -74,6 +78,22 @@ inline __device__ float reciprocal_approximate_ftz(float a) { return b; } +__device__ __forceinline__ float gelu_tanh_approx_mxfp8_quant(float x) { + constexpr float kSqrt2OverPi = 0.7978845608028654f; + constexpr float kCoeff = 0.044715f; + float x3 = x * x * x; + return 0.5f * x * (1.0f + tanhf(kSqrt2OverPi * (x + kCoeff * x3))); +} + +template +__device__ __forceinline__ float round_gelu_output_for_dtype(float x) { + if constexpr (std::is_same_v) { + return __half2float(__float2half(x)); + } else { + return __bfloat162float(__float2bfloat16(x)); + } +} + template __device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) { // #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -191,6 +211,149 @@ __device__ uint64_t cvt_warp_fp16_to_fp8(PackedVec& vec, uint8_t* SFout) { return e4m3Vec; } +template // Type can be half or bfloat16 +__device__ uint64_t cvt_warp_gelu_fp16_to_fp8(PackedVec& vec, uint8_t* SFout) { + float2 fp2Vals[CVT_FP8_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x = round_gelu_output_for_dtype(gelu_tanh_approx_mxfp8_quant(fp2Vals[i].x)); + fp2Vals[i].y = round_gelu_output_for_dtype(gelu_tanh_approx_mxfp8_quant(fp2Vals[i].y)); + } + + float2 localMax2; + localMax2.x = fmaxf(fabsf(fp2Vals[0].x), fabsf(fp2Vals[0].y)); + localMax2.y = fmaxf(fabsf(fp2Vals[1].x), fabsf(fp2Vals[1].y)); + +#pragma unroll + for (int i = 2; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + localMax2.x = fmaxf(localMax2.x, fabsf(fp2Vals[i].x)); + localMax2.y = fmaxf(localMax2.y, fabsf(fp2Vals[i].y)); + } + + localMax2.x = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.x, 1), localMax2.x); + localMax2.y = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.y, 1), localMax2.y); + localMax2.x = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.x, 2), localMax2.x); + localMax2.y = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.y, 2), localMax2.y); + float vecMax = fmaxf(localMax2.x, localMax2.y); + + float SFValue = vecMax / 448.0f; + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + + if (SFout) { + *SFout = tmp.__x; + } + + float outputScale = SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f; +#pragma unroll + for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + return fp32_vec_to_e4m3(fp2Vals); +} + +template +__device__ __forceinline__ float convert_modulate_input(float x) { + if constexpr (std::is_same_v) { + return __half2float(__float2half(x)); + } else { + return __bfloat162float(__float2bfloat16(x)); + } +} + +template +__device__ uint64_t cvt_warp_modulate_fp16_to_fp8( + PackedVec& vec, + Type const* scale, + Type const* shift, + int rowIdx, + int colIdx, + int numCols, + bool scale_is_2d, + bool shift_is_2d, + uint8_t* SFout) { + float2 fp2Vals[CVT_FP8_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + + const int col0 = colIdx * CVT_FP8_ELTS_PER_THREAD + i * 2; + const int col1 = col0 + 1; + const int64_t row_offset = static_cast(rowIdx) * numCols; + const int64_t scale_offset0 = (scale_is_2d ? row_offset : 0) + col0; + const int64_t scale_offset1 = (scale_is_2d ? row_offset : 0) + col1; + const int64_t shift_offset0 = (shift_is_2d ? row_offset : 0) + col0; + const int64_t shift_offset1 = (shift_is_2d ? row_offset : 0) + col1; + + float scale0; + float scale1; + float shift0; + float shift1; + if constexpr (std::is_same_v) { + scale0 = __half2float(scale[scale_offset0]); + scale1 = __half2float(scale[scale_offset1]); + shift0 = __half2float(shift[shift_offset0]); + shift1 = __half2float(shift[shift_offset1]); + } else { + scale0 = __bfloat162float(scale[scale_offset0]); + scale1 = __bfloat162float(scale[scale_offset1]); + shift0 = __bfloat162float(shift[shift_offset0]); + shift1 = __bfloat162float(shift[shift_offset1]); + } + + fp2Vals[i].x = convert_modulate_input(fp2Vals[i].x * (1.0f + scale0) + shift0); + fp2Vals[i].y = convert_modulate_input(fp2Vals[i].y * (1.0f + scale1) + shift1); + } + + float2 localMax2; + localMax2.x = fmaxf(fabsf(fp2Vals[0].x), fabsf(fp2Vals[0].y)); + localMax2.y = fmaxf(fabsf(fp2Vals[1].x), fabsf(fp2Vals[1].y)); + +#pragma unroll + for (int i = 2; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + localMax2.x = fmaxf(localMax2.x, fabsf(fp2Vals[i].x)); + localMax2.y = fmaxf(localMax2.y, fabsf(fp2Vals[i].y)); + } + + localMax2.x = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.x, 1), localMax2.x); + localMax2.y = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.y, 1), localMax2.y); + localMax2.x = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.x, 2), localMax2.x); + localMax2.y = fmaxf(__shfl_xor_sync(uint32_t(-1), localMax2.y, 2), localMax2.y); + float vecMax = fmaxf(localMax2.x, localMax2.y); + + float SFValue = vecMax / 448.0f; + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + + if (SFout) { + *SFout = tmp.__x; + } + + float outputScale = SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f; +#pragma unroll + for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + return fp32_vec_to_e4m3(fp2Vals); +} + template // Type can be half or bfloat16 __global__ void @@ -224,6 +387,59 @@ __launch_bounds__(256, 6) cvt_fp16_to_fp8( // #endif } +template // Type can be half or bfloat16 +__global__ void __launch_bounds__(256, 6) cvt_gelu_fp16_to_fp8( + int32_t numRows, int32_t numCols, Type const* in, uint64_t* out, uint32_t* SFout) { + using PackedVec = PackedVec; + static constexpr int CVT_FP8_NUM_THREADS_PER_SF = (CVT_FP8_SF_VEC_SIZE / CVT_FP8_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP8_ELTS_PER_THREAD, "Vec size is not matched."); + + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP8_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + get_sf_out_address(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_gelu_fp16_to_fp8(in_vec, sf_out); + } + } +} + +template // Type can be half or bfloat16 +__global__ void __launch_bounds__(256, 6) cvt_modulate_fp16_to_fp8( + int32_t numRows, + int32_t numCols, + Type const* in, + Type const* scale, + Type const* shift, + bool scale_is_2d, + bool shift_is_2d, + uint64_t* out, + uint32_t* SFout) { + using PackedVec = PackedVec; + static constexpr int CVT_FP8_NUM_THREADS_PER_SF = (CVT_FP8_SF_VEC_SIZE / CVT_FP8_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP8_ELTS_PER_THREAD, "Vec size is not matched."); + + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP8_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + get_sf_out_address(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_modulate_fp16_to_fp8( + in_vec, scale, shift, rowIdx, colIdx, numCols, scale_is_2d, shift_is_2d, sf_out); + } + } +} + template void invokeFP8Quantization( int m, @@ -246,6 +462,54 @@ void invokeFP8Quantization( m, n, input, reinterpret_cast(output), reinterpret_cast(SFOuput)); } +template +void invokeGeluFP8Quantization( + int m, + int n, + T const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream) { + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + cvt_gelu_fp16_to_fp8 + <<>>( + m, n, input, reinterpret_cast(output), reinterpret_cast(SFOuput)); +} + +template +void invokeModulateFP8Quantization( + int m, + int n, + T const* input, + T const* scale, + T const* shift, + bool scale_is_2d, + bool shift_is_2d, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream) { + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + cvt_modulate_fp16_to_fp8 + <<>>( + m, + n, + input, + scale, + shift, + scale_is_2d, + shift_is_2d, + reinterpret_cast(output), + reinterpret_cast(SFOuput)); +} + // Instantiate the function. template void invokeFP8Quantization( int m, @@ -256,7 +520,16 @@ template void invokeFP8Quantization( int multiProcessorCount, cudaStream_t stream); -template void invokeFP8Quantization( +template void invokeGeluFP8Quantization( + int m, + int n, + half const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeGeluFP8Quantization( int m, int n, __nv_bfloat16 const* input, @@ -265,35 +538,127 @@ template void invokeFP8Quantization( int multiProcessorCount, cudaStream_t stream); -inline int getMultiProcessorCount() { - static int multi_processor_count = []() { - int device_id = 0; - int count = 0; +template void invokeModulateFP8Quantization( + int m, + int n, + half const* input, + half const* scale, + half const* shift, + bool scale_is_2d, + bool shift_is_2d, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeModulateFP8Quantization( + int m, + int n, + __nv_bfloat16 const* input, + __nv_bfloat16 const* scale, + __nv_bfloat16 const* shift, + bool scale_is_2d, + bool shift_is_2d, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); - // Get the current CUDA device ID - CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); +template void invokeFP8Quantization( + int m, + int n, + __nv_bfloat16 const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); - // Get the number of multiprocessors for the current device - CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); +namespace { - return count; // Initialize the static variable - }(); +inline int64_t round_up_mxfp8_m_tile(int64_t m) { + return ((m + 127) / 128) * 128; +} - return multi_processor_count; // Return the cached value on subsequent calls +void check_mxfp8_quant_io( + torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + char const* op_name) { + TORCH_CHECK(input.dim() == 2, op_name, " expects a 2D input tensor."); + TORCH_CHECK(input.is_cuda(), op_name, " input must be a CUDA tensor."); + TORCH_CHECK(input.is_contiguous(), op_name, " input must be contiguous."); + TORCH_CHECK( + input.scalar_type() == torch::kHalf || input.scalar_type() == torch::kBFloat16, + op_name, + " input dtype must be float16 or bfloat16, got ", + input.scalar_type()); + + TORCH_CHECK(output.dim() == 2, op_name, " output must be a 2D tensor."); + TORCH_CHECK(output.is_cuda(), op_name, " output must be a CUDA tensor."); + TORCH_CHECK(output.is_contiguous(), op_name, " output must be contiguous."); + TORCH_CHECK(output.scalar_type() == torch::kUInt8, op_name, " output dtype must be uint8, got ", output.scalar_type()); + TORCH_CHECK(output.get_device() == input.get_device(), op_name, " output must be on the same CUDA device as input."); + TORCH_CHECK( + output.size(0) == input.size(0) && output.size(1) == input.size(1), + op_name, + " output shape must match input shape, got output=(", + output.size(0), + ", ", + output.size(1), + ") input=(", + input.size(0), + ", ", + input.size(1), + ")"); + + int64_t m = input.size(0); + int64_t n = input.size(1); + TORCH_CHECK(n % 32 == 0, op_name, " N dimension must be multiple of 32."); + + int64_t expected_sf_m = round_up_mxfp8_m_tile(m); + int64_t expected_sf_n = (n / 32 + 3) / 4; + TORCH_CHECK(output_sf.dim() == 2, op_name, " output_sf must be a 2D tensor."); + TORCH_CHECK(output_sf.is_cuda(), op_name, " output_sf must be a CUDA tensor."); + TORCH_CHECK(output_sf.is_contiguous(), op_name, " output_sf must be contiguous."); + TORCH_CHECK( + output_sf.scalar_type() == torch::kInt32, + op_name, + " output_sf dtype must be int32 storage, got ", + output_sf.scalar_type()); + TORCH_CHECK( + output_sf.get_device() == input.get_device(), + op_name, + " output_sf must be on the same CUDA device as input."); + TORCH_CHECK( + output_sf.size(0) == expected_sf_m && output_sf.size(1) == expected_sf_n, + op_name, + " output_sf shape must be (", + expected_sf_m, + ", ", + expected_sf_n, + "), got (", + output_sf.size(0), + ", ", + output_sf.size(1), + ")"); } +} // namespace + void scaled_mxfp8_quant_sm120( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { + char const* op_name = "scaled_mxfp8_quant_sm120"; + check_mxfp8_quant_io(output, input, output_sf, op_name); + c10::cuda::CUDAGuard device_guard(input.device()); + lightx2v_kernel::check_sm120_or_throw(input, op_name); + int32_t m = input.size(0); int32_t n = input.size(1); - TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 32."); - - int multiProcessorCount = getMultiProcessorCount(); + int multiProcessorCount = lightx2v_kernel::getMultiProcessorCount(input.get_device()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); - at::cuda::CUDAGuard device_guard{(char)input.get_device()}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); switch (input.scalar_type()) { @@ -308,8 +673,113 @@ void scaled_mxfp8_quant_sm120( break; } default: { - std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; - throw std::runtime_error("Unsupported input data type for quantize_to_fp8."); + TORCH_CHECK(false, "Unsupported input data type for quantize_to_fp8: ", input.scalar_type()); + } + } +} + +void scaled_mxfp8_gelu_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { + char const* op_name = "scaled_mxfp8_gelu_quant_sm120"; + check_mxfp8_quant_io(output, input, output_sf, op_name); + c10::cuda::CUDAGuard device_guard(input.device()); + lightx2v_kernel::check_sm120_or_throw(input, op_name); + + int32_t m = input.size(0); + int32_t n = input.size(1); + + int multiProcessorCount = lightx2v_kernel::getMultiProcessorCount(input.get_device()); + + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeGeluFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeGeluFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + default: { + TORCH_CHECK(false, "Unsupported input data type for gelu_quantize_to_fp8: ", input.scalar_type()); + } + } +} + +namespace { + +bool is_mxfp8_modulate_param_2d(torch::Tensor const& tensor, int32_t m, int32_t n, char const* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor."); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous."); + if (tensor.dim() == 1) { + TORCH_CHECK(tensor.size(0) == n, name, " must have shape (N,) for per-column MXFP8 modulate quant."); + return false; + } + if (tensor.dim() == 2) { + TORCH_CHECK( + tensor.size(0) == m && tensor.size(1) == n, + name, + " must have shape (M, N) for per-token MXFP8 modulate quant."); + return true; + } + TORCH_CHECK(false, name, " must be either 1D (N,) or 2D (M, N) for MXFP8 modulate quant."); + return false; +} + +} // namespace + +void scaled_mxfp8_modulate_quant_sm120( + torch::Tensor& output, + torch::Tensor const& input, + torch::Tensor const& scale, + torch::Tensor const& shift, + torch::Tensor& output_sf) { + char const* op_name = "scaled_mxfp8_modulate_quant_sm120"; + check_mxfp8_quant_io(output, input, output_sf, op_name); + c10::cuda::CUDAGuard device_guard(input.device()); + lightx2v_kernel::check_sm120_or_throw(input, op_name); + + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(scale.scalar_type() == input.scalar_type(), "scale dtype must match input dtype."); + TORCH_CHECK(shift.scalar_type() == input.scalar_type(), "shift dtype must match input dtype."); + TORCH_CHECK(scale.get_device() == input.get_device(), "scale must be on the same CUDA device as input."); + TORCH_CHECK(shift.get_device() == input.get_device(), "shift must be on the same CUDA device as input."); + + bool scale_is_2d = is_mxfp8_modulate_param_2d(scale, m, n, "scale"); + bool shift_is_2d = is_mxfp8_modulate_param_2d(shift, m, n, "shift"); + + int multiProcessorCount = lightx2v_kernel::getMultiProcessorCount(input.get_device()); + + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + auto scale_ptr = reinterpret_cast(scale.data_ptr()); + auto shift_ptr = reinterpret_cast(shift.data_ptr()); + invokeModulateFP8Quantization( + m, n, input_ptr, scale_ptr, shift_ptr, scale_is_2d, shift_is_2d, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + auto scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(scale.data_ptr()); + auto shift_ptr = reinterpret_cast<__nv_bfloat16 const*>(shift.data_ptr()); + invokeModulateFP8Quantization( + m, n, input_ptr, scale_ptr, shift_ptr, scale_is_2d, shift_is_2d, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + default: { + TORCH_CHECK(false, "Unsupported input data type for modulate_quantize_to_fp8: ", input.scalar_type()); } } } diff --git a/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu index 8414295d0..fb95fd53c 100644 --- a/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu @@ -1,10 +1,16 @@ #include +#include #include +#include #include +#include +#include + // clang-format off #include "cutlass/cutlass.h" #include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -12,6 +18,8 @@ #include "cutlass/util/packed_stride.hpp" // clang-format on +#include "sm120_utils.h" + #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ @@ -29,6 +37,125 @@ using namespace cute; +namespace cutlass::epilogue::fusion { + +template< + class ElementOutput_, + class ElementCompute_, + class ElementGate_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentGate_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasPerColGateResidual : LinearCombination { + using ElementGate = ElementGate_; + using ElementBias = ElementBias_; + static constexpr int AlignmentGate = AlignmentGate_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementGate = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentGate = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasPerColGateResidual = + // Wan 1D gate fast path: apply gate/residual in the CUTLASS epilogue using + // accumulator values, then round once to ElementOutput. This is intentionally + // not bitwise identical to materializing BF16 GEMM output before gate. + Sm90EVT, // gate * (alpha * acc + bias) + C + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementGate, ElementCompute, Stride<_0,_1,int64_t>, AlignmentGate>, // gate + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + >, + Sm90SrcFetch // C / residual + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementGate, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentGate, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasPerColGateResidual< + ElementOutput, ElementCompute, ElementGate, ElementBias, ElementSource, ElementScalar, AlignmentGate, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBiasPerColGateResidual< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementGate, ElementBias, ElementSource, ElementScalar, AlignmentGate, AlignmentBias, RoundStyle + > { + + using Impl = Sm90LinCombPerColBiasPerColGateResidual< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementGate, ElementBias, ElementSource, ElementScalar, AlignmentGate, AlignmentBias, RoundStyle + >; + using Operation = fusion::LinCombPerColBiasPerColGateResidual< + ElementOutput, ElementCompute, ElementGate, ElementBias, ElementSource, ElementScalar, AlignmentGate, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + + using StrideGate = Stride<_0,_1,int64_t>; + ElementGate const* gate_ptr = nullptr; + StrideGate dGate = {}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : gate * (alpha * acc + bias) + C + {gate_ptr, ElementGate(0), dGate}, // leaf args : gate + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, + {}, // leaf args : C + {} // ternary args : multiply_add + }; + } + }; + + using Impl::Impl; +}; + +} // namespace cutlass::epilogue::fusion + struct Mxfp8GemmSm120 { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -106,6 +233,72 @@ struct Mxfp8GemmSm120 { using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); }; +struct Mxfp8GemmResidualGateSm120 { + using ElementA = cutlass::mx_float8_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 16; + + using ElementB = cutlass::mx_float8_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128; + + using ElementD = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using ThreadBlockShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + + using EVTOp = cutlass::epilogue::fusion::LinCombPerColBiasPerColGateResidual< + ElementD, ElementAccumulator, ElementD, ElementD, ElementC, ElementAccumulator>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + EVTOp + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +}; + // Populates a Gemm::Arguments structure from the given commandline options typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( @@ -154,8 +347,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - // static const float beta_zero = 0.0f; - // fusion_args.beta_ptr = &beta_zero; + fusion_args.beta = 0.0f; + fusion_args.beta_ptr = nullptr; fusion_args.bias_ptr = static_cast(bias->data_ptr()); fusion_args.dBias = StrideBias{}; return arguments; @@ -180,12 +373,64 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - // static const float beta_zero = 0.0f; - // fusion_args.beta_ptr = &beta_zero; + fusion_args.beta = 0.0f; + fusion_args.beta_ptr = nullptr; return arguments; } } +typename Mxfp8GemmResidualGateSm120::Gemm::Arguments args_from_options_mxfp8_residual_gate( + at::Tensor& residual, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + at::Tensor const& gate, + int64_t M, + int64_t N, + int64_t K) { + using Sm1xxBlkScaledConfig = typename Mxfp8GemmResidualGateSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(Mxfp8GemmResidualGateSm120::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(Mxfp8GemmResidualGateSm120::StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(Mxfp8GemmResidualGateSm120::StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + typename Mxfp8GemmResidualGateSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(residual.data_ptr()), + stride_D, + static_cast(residual.data_ptr()), + stride_D}}; + + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + fusion_args.gate_ptr = static_cast(gate.data_ptr()); + if (bias) { + fusion_args.bias_ptr = static_cast(bias->data_ptr()); + } + return arguments; +} + void runGemmMxfp8Sm120( at::Tensor& D, @@ -202,11 +447,6 @@ void runGemmMxfp8Sm120( typename Mxfp8GemmSm120::Gemm gemm; auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); - auto beta_dev = torch::zeros({1}, torch::TensorOptions() - .dtype(torch::kFloat32) - .device(A.device())); - arguments.epilogue.thread.beta_ptr = - static_cast(beta_dev.data_ptr()); size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); @@ -216,33 +456,75 @@ void runGemmMxfp8Sm120( CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); } +void runGemmMxfp8ResidualGateSm120( + at::Tensor& residual, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + at::Tensor const& gate, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Mxfp8GemmResidualGateSm120::Gemm gemm; + + auto arguments = args_from_options_mxfp8_residual_gate(residual, A, B, A_sf, B_sf, alpha, bias, gate, m, n, k); + size_t workspace_size = Mxfp8GemmResidualGateSm120::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + constexpr auto FP6_FP8_TYPE = at::ScalarType::Byte; constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu; -void cutlass_scaled_mxfp8_mm_sm120( - torch::Tensor& D, +namespace lightx2v_mxfp8_fused { + +constexpr int kFusedThreads = 256; + +struct Mxfp8GemmMeta { + int64_t m; + int64_t n; + int64_t k; +}; + +int64_t round_up(int64_t x, int64_t y) { + return (x + y - 1) / y * y; +} + +Mxfp8GemmMeta check_mxfp8_gemm_inputs( + torch::Tensor const& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, torch::Tensor const& B_sf, torch::Tensor const& alpha, - c10::optional const& bias) { - - CHECK_INPUT(A, FP6_FP8_TYPE, "a"); - CHECK_INPUT(B, FP6_FP8_TYPE, "b"); - + c10::optional const& bias, + char const* op_name) { + CHECK_INPUT(D, at::ScalarType::BFloat16, "out"); + CHECK_INPUT(A, FP6_FP8_TYPE, "mat_a"); + CHECK_INPUT(B, FP6_FP8_TYPE, "mat_b"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); - - - TORCH_CHECK(A.dim() == 2, "a must be a matrix"); - TORCH_CHECK(B.dim() == 2, "b must be a matrix"); - + TORCH_CHECK( + D.device() == A.device() && D.device() == B.device() && D.device() == A_sf.device() && + D.device() == B_sf.device() && D.device() == alpha.device(), + op_name, + " expects output, mat_a, mat_b, scale_a, scale_b, and alpha on the same CUDA device"); + TORCH_CHECK(D.dim() == 2, "out must be a matrix"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2, "mat_a and mat_b must be matrices"); + TORCH_CHECK(alpha.numel() == 1, "alpha must contain exactly one scalar"); TORCH_CHECK( A.sizes()[1] == B.sizes()[1], - "a and b shapes cannot be multiplied (", + "mat_a and mat_b shapes cannot be multiplied (", A.sizes()[0], "x", A.sizes()[1], @@ -256,13 +538,15 @@ void cutlass_scaled_mxfp8_mm_sm120( auto const n = B.sizes()[0]; auto const k = A.sizes()[1]; + TORCH_CHECK(D.sizes()[0] == m, "out rows must match mat_a rows"); + TORCH_CHECK(D.sizes()[1] == n, "out cols must match mat_b rows"); constexpr int alignment_a = 16; constexpr int alignment_b = 128; TORCH_CHECK( k % alignment_a == 0, "Expected k to be divisible by ", alignment_a, - ", but got a shape: (", + ", but got mat_a shape: (", A.sizes()[0], "x", A.sizes()[1], @@ -273,18 +557,15 @@ void cutlass_scaled_mxfp8_mm_sm120( n % alignment_b == 0, "Expected n to be divisible by ", alignment_b, - ", but got b shape: (", + ", but got mat_b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); - auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; - int rounded_m = round_up(m, 128); - int rounded_n = round_up(n, 128); - // Since k is divisible by 32 (alignment), k / 32 is guaranteed to be an - // integer. - int rounded_k = round_up(k / 32, 4); + int64_t rounded_m = round_up(m, 128); + int64_t rounded_n = round_up(n, 128); + int64_t rounded_k = round_up(k / 32, 4); TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); @@ -321,10 +602,106 @@ void cutlass_scaled_mxfp8_mm_sm120( "x", B_sf.sizes()[1], ")"); + if (bias) { + auto const& bias_tensor = bias.value(); + CHECK_INPUT(bias_tensor, at::ScalarType::BFloat16, "bias"); + TORCH_CHECK(bias_tensor.device() == A.device(), "bias must be on the same CUDA device"); + TORCH_CHECK(bias_tensor.numel() == n, "bias numel must match output columns"); + } + lightx2v_kernel::check_sm120_or_throw(A, op_name); + return {m, n, k}; +} - auto out_dtype = D.dtype(); +void check_mxfp8_residual_gate( + torch::Tensor const& residual, + torch::Tensor const& gate, + char const* op_name) { + CHECK_INPUT(residual, at::ScalarType::BFloat16, "residual"); + CHECK_INPUT(gate, at::ScalarType::BFloat16, "gate"); + TORCH_CHECK(residual.dim() == 2, "residual must be a matrix"); + TORCH_CHECK(gate.device() == residual.device(), op_name, " expects residual and gate on the same CUDA device"); + TORCH_CHECK(gate.dim() == 1 || gate.dim() == 2, "gate must be 1D or 2D"); + if (gate.dim() == 1) { + TORCH_CHECK(gate.sizes()[0] == residual.sizes()[1], "1D gate size must match residual columns"); + } else { + TORCH_CHECK(gate.sizes() == residual.sizes(), "2D gate shape must match residual shape"); + } +} + +__global__ void mxfp8_residual_gate_bf16_kernel( + __nv_bfloat16* residual, + __nv_bfloat16 const* ffn_out, + __nv_bfloat16 const* gate, + int64_t total, + int64_t cols, + bool gate_per_element) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t stride = static_cast(blockDim.x) * gridDim.x; + for (; idx < total; idx += stride) { + int64_t gate_idx = gate_per_element ? idx : idx % cols; + float product = __bfloat162float(ffn_out[idx]) * __bfloat162float(gate[gate_idx]); + float sum = __bfloat162float(residual[idx]) + product; + residual[idx] = __float2bfloat16(sum); + } +} + +void launch_mxfp8_residual_gate(torch::Tensor& residual, torch::Tensor const& ffn_out, torch::Tensor const& gate, cudaStream_t stream) { + int64_t total = residual.numel(); + int64_t cols = residual.sizes()[1]; + bool gate_per_element = gate.dim() == 2; + int blocks = static_cast((total + kFusedThreads - 1) / kFusedThreads); + mxfp8_residual_gate_bf16_kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(residual.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(ffn_out.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(gate.data_ptr()), + total, + cols, + gate_per_element); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace lightx2v_mxfp8_fused + +void cutlass_scaled_mxfp8_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias) { + + auto const meta = lightx2v_mxfp8_fused::check_mxfp8_gemm_inputs( + D, A, B, A_sf, B_sf, alpha, bias, "cutlass_scaled_mxfp8_mm_sm120"); at::cuda::CUDAGuard device_guard{(char)A.get_device()}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); - runGemmMxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); + runGemmMxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, meta.m, meta.n, meta.k, stream); +} + +void cutlass_scaled_mxfp8_mm_residual_gate_sm120( + torch::Tensor& residual, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias, + torch::Tensor const& gate) { + auto const meta = lightx2v_mxfp8_fused::check_mxfp8_gemm_inputs( + residual, A, B, A_sf, B_sf, alpha, bias, "cutlass_scaled_mxfp8_mm_residual_gate_sm120"); + lightx2v_mxfp8_fused::check_mxfp8_residual_gate( + residual, gate, "cutlass_scaled_mxfp8_mm_residual_gate_sm120"); + if (gate.dim() == 1) { + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + runGemmMxfp8ResidualGateSm120( + residual, A, B, A_sf, B_sf, alpha, bias, gate, meta.m, meta.n, meta.k, stream); + return; + } + auto ffn_out = torch::empty_like(residual); + cutlass_scaled_mxfp8_mm_sm120(ffn_out, A, B, A_sf, B_sf, alpha, bias); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + lightx2v_mxfp8_fused::launch_mxfp8_residual_gate(residual, ffn_out, gate, stream); } diff --git a/lightx2v_kernel/csrc/gemm/sm120_utils.h b/lightx2v_kernel/csrc/gemm/sm120_utils.h new file mode 100644 index 000000000..497988f7d --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/sm120_utils.h @@ -0,0 +1,80 @@ +/* Copyright 2025 LightX2V Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace lightx2v_kernel { + +constexpr int kMaxCudaDevices = 64; + +inline void check_valid_cuda_device_index(int device, char const* op_name) { + TORCH_CHECK( + device >= 0 && device < kMaxCudaDevices, + op_name, + " requires CUDA device index in [0, ", + kMaxCudaDevices, + "), got ", + device); +} + +inline std::pair get_cached_device_capability(int device) { + static std::array device_once; + static std::array cached_major{}; + static std::array cached_minor{}; + std::call_once(device_once[device], [device]() { + cudaDeviceProp prop; + C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + cached_major[device] = prop.major; + cached_minor[device] = prop.minor; + }); + return {cached_major[device], cached_minor[device]}; +} + +inline void check_sm120_or_throw(torch::Tensor const& tensor, char const* op_name) { + int device = tensor.get_device(); + check_valid_cuda_device_index(device, op_name); + auto [major, minor] = get_cached_device_capability(device); + TORCH_CHECK( + major == 12, + op_name, + " is only supported on SM120/SM120a GPUs, got CUDA device ", + device, + " with compute capability ", + major, + ".", + minor); +} + +inline int getMultiProcessorCount(int device) { + check_valid_cuda_device_index(device, "getMultiProcessorCount"); + static std::array device_once; + static std::array cached_mp_count{}; + std::call_once(device_once[device], [device]() { + int mp_count = 0; + C10_CUDA_CHECK(cudaDeviceGetAttribute(&mp_count, cudaDevAttrMultiProcessorCount, device)); + cached_mp_count[device] = mp_count; + }); + return cached_mp_count[device]; +} + +} // namespace lightx2v_kernel diff --git a/lightx2v_kernel/include/lightx2v_kernel_ops.h b/lightx2v_kernel/include/lightx2v_kernel_ops.h index b937971a8..c8295e201 100644 --- a/lightx2v_kernel/include/lightx2v_kernel_ops.h +++ b/lightx2v_kernel/include/lightx2v_kernel_ops.h @@ -54,6 +54,16 @@ void scaled_mxfp6_quant_sm120( void scaled_mxfp8_quant_sm120( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); +void scaled_mxfp8_gelu_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); + +void scaled_mxfp8_modulate_quant_sm120( + torch::Tensor& output, + torch::Tensor const& input, + torch::Tensor const& scale, + torch::Tensor const& shift, + torch::Tensor& output_sf); + void cutlass_scaled_nvfp4_mm_sm120( torch::Tensor& D, torch::Tensor const& A, @@ -90,3 +100,13 @@ void cutlass_scaled_mxfp8_mm_sm120( torch::Tensor const& B_sf, torch::Tensor const& alpha, c10::optional const& bias); + +void cutlass_scaled_mxfp8_mm_residual_gate_sm120( + torch::Tensor& residual, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias, + torch::Tensor const& gate); diff --git a/lightx2v_kernel/python/lightx2v_kernel/gemm.py b/lightx2v_kernel/python/lightx2v_kernel/gemm.py index d79ca213f..e7e765e6d 100644 --- a/lightx2v_kernel/python/lightx2v_kernel/gemm.py +++ b/lightx2v_kernel/python/lightx2v_kernel/gemm.py @@ -94,6 +94,43 @@ def scaled_mxfp8_quant(input: torch.Tensor): return output, output_scale +def scaled_mxfp8_gelu_quant(input: torch.Tensor): + m, n = input.shape + block_size = 32 + device = input.device + + output = torch.empty((m, n), device=device, dtype=torch.uint8) + output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + + torch.ops.lightx2v_kernel.scaled_mxfp8_gelu_quant_sm120.default(output, input, output_scale) + output_scale = output_scale.view(torch.float8_e8m0fnu) + return output, output_scale + + +def _mxfp8_modulate_param(param: torch.Tensor, m: int, n: int, name: str): + param = param.squeeze() + if param.numel() == n: + return param.reshape(n).contiguous() + if param.numel() == m * n: + return param.reshape(m, n).contiguous() + raise ValueError(f"{name} must have numel {n} or {m * n}, got shape={tuple(param.shape)}") + + +def scaled_mxfp8_modulate_quant(input: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): + m, n = input.shape + block_size = 32 + device = input.device + + output = torch.empty((m, n), device=device, dtype=torch.uint8) + output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + scale = _mxfp8_modulate_param(scale, m, n, "scale") + shift = _mxfp8_modulate_param(shift, m, n, "shift") + + torch.ops.lightx2v_kernel.scaled_mxfp8_modulate_quant_sm120.default(output, input, scale, shift, output_scale) + output_scale = output_scale.view(torch.float8_e8m0fnu) + return output, output_scale + + def cutlass_scaled_mxfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): m, n = mat_a.shape[0], mat_b.shape[0] out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) @@ -113,3 +150,16 @@ def cutlass_scaled_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) torch.ops.lightx2v_kernel.cutlass_scaled_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) return out + + +def cutlass_scaled_mxfp8_mm_residual_gate(mat_a, mat_b, scales_a, scales_b, alpha, residual, gate, bias=None): + """Fused residual update for Wan FFN. + + A 1D gate uses the CUTLASS epilogue contract and applies gate/residual to + the GEMM accumulator before the final BF16 store. A 2D gate is a fallback + compatibility path and may differ by BF16 rounding at the GEMM boundary. + """ + torch.ops.lightx2v_kernel.cutlass_scaled_mxfp8_mm_residual_gate_sm120.default( + residual, mat_a, mat_b, scales_a, scales_b, alpha, bias, gate.contiguous() + ) + return residual diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py new file mode 100644 index 000000000..e73cbfb7e --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py @@ -0,0 +1,402 @@ +import unittest + +import torch +import torch.nn.functional as F + +try: + from lightx2v_kernel.gemm import ( + cutlass_scaled_mxfp8_mm, + cutlass_scaled_mxfp8_mm_residual_gate, + scaled_mxfp8_gelu_quant, + scaled_mxfp8_modulate_quant, + scaled_mxfp8_quant, + ) + _IMPORT_ERROR = None +except Exception as exc: # noqa: BLE001 - test reports extension availability. + cutlass_scaled_mxfp8_mm = None + cutlass_scaled_mxfp8_mm_residual_gate = None + scaled_mxfp8_gelu_quant = None + scaled_mxfp8_modulate_quant = None + scaled_mxfp8_quant = None + _IMPORT_ERROR = exc + + +def _first_visible_sm120_device(): + if not torch.cuda.is_available(): + return None + for device_index in range(torch.cuda.device_count()): + major, _minor = torch.cuda.get_device_capability(device_index) + if major == 12: + return torch.device("cuda", device_index) + return None + + +def _skip_cuda_unavailable(): + if not torch.cuda.is_available(): + return "CUDA is not available" + if _first_visible_sm120_device() is None: + caps = [torch.cuda.get_device_capability(i) for i in range(torch.cuda.device_count())] + return f"MXFP8 fused FFN kernels require a visible SM120/SM120a CUDA device, got {caps}" + return None + + +class TestMxfp8FusedFfn(unittest.TestCase): + def setUp(self): + skip_reason = _skip_cuda_unavailable() + if skip_reason is not None: + self.skipTest(skip_reason) + if _IMPORT_ERROR is not None: + raise RuntimeError(f"Failed to import MXFP8 fused FFN symbols: {_IMPORT_ERROR}") from _IMPORT_ERROR + self.device = _first_visible_sm120_device() + torch.cuda.set_device(self.device) + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + + def _quantized_inputs(self, m=257, k=256, n=384): + activation = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") * 0.5 + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") * 0.5 + bias = torch.randn(n, dtype=torch.bfloat16, device="cuda") * 0.1 + activation_quant, activation_scale = scaled_mxfp8_quant(activation) + weight_quant, weight_scale = scaled_mxfp8_quant(weight) + alpha = torch.tensor(1.0, dtype=torch.float32, device="cuda") + return activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias + + def assert_close_enough(self, actual, expected): + actual_f = actual.float().flatten() + expected_f = expected.float().flatten() + cosine = F.cosine_similarity(actual_f, expected_f, dim=0).item() + max_abs = (actual_f - expected_f).abs().max().item() + mean_abs = (actual_f - expected_f).abs().mean().item() + self.assertGreater(cosine, 0.999, f"cosine={cosine}, max_abs={max_abs}, mean_abs={mean_abs}") + self.assertLess(max_abs, 0.08, f"cosine={cosine}, max_abs={max_abs}, mean_abs={mean_abs}") + + def test_mxfp8_gelu_quant_matches_baseline(self): + activation = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + baseline_quant, baseline_scale = scaled_mxfp8_quant(F.gelu(activation, approximate="tanh")) + fused_quant, fused_scale = scaled_mxfp8_gelu_quant(activation) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + + # Scale factors use the CUTLASS tiled layout, so row-major slicing is not + # meaningful for partial M tiles. Use a full tile for raw scale equality. + activation = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + baseline_quant, baseline_scale = scaled_mxfp8_quant(F.gelu(activation, approximate="tanh")) + fused_quant, fused_scale = scaled_mxfp8_gelu_quant(activation) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + self.assertTrue(torch.equal(fused_scale.view(torch.uint8), baseline_scale.view(torch.uint8))) + + def test_mxfp8_quant_ops_validate_explicit_outputs(self): + activation = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + m, n = activation.shape + output = torch.empty((m, n), dtype=torch.uint8, device="cuda") + output_sf = torch.empty(((m + 127) // 128 * 128, (n // 32 + 3) // 4), dtype=torch.int32, device="cuda") + scale = torch.randn(n, dtype=torch.bfloat16, device="cuda") + shift = torch.randn(n, dtype=torch.bfloat16, device="cuda") + + with self.assertRaisesRegex(RuntimeError, "output dtype must be uint8"): + torch.ops.lightx2v_kernel.scaled_mxfp8_gelu_quant_sm120.default( + torch.empty_like(activation), + activation, + output_sf, + ) + with self.assertRaisesRegex(RuntimeError, "output_sf dtype must be int32"): + torch.ops.lightx2v_kernel.scaled_mxfp8_modulate_quant_sm120.default( + output, + activation, + scale, + shift, + torch.empty_like(output_sf, dtype=torch.float32), + ) + with self.assertRaisesRegex(RuntimeError, "output_sf shape must be"): + torch.ops.lightx2v_kernel.scaled_mxfp8_quant_sm120.default( + output, + activation, + torch.empty((128, output_sf.shape[1]), dtype=torch.int32, device="cuda"), + ) + with self.assertRaisesRegex(RuntimeError, "scale must have shape"): + torch.ops.lightx2v_kernel.scaled_mxfp8_modulate_quant_sm120.default( + output, + activation, + scale.reshape(n, 1), + shift, + output_sf, + ) + with self.assertRaisesRegex(RuntimeError, "shift must have shape"): + torch.ops.lightx2v_kernel.scaled_mxfp8_modulate_quant_sm120.default( + output, + activation, + scale, + shift.reshape(n, 1), + output_sf, + ) + + def _modulate_reference(self, activation, scale, shift): + return (activation.float() * (1.0 + scale.float()) + shift.float()).to(activation.dtype) + + def test_mxfp8_modulate_quant_matches_1d_baseline(self): + activation = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + scale = torch.randn(512, dtype=torch.bfloat16, device="cuda") * 0.1 + shift = torch.randn(512, dtype=torch.bfloat16, device="cuda") * 0.1 + baseline_quant, baseline_scale = scaled_mxfp8_quant(self._modulate_reference(activation, scale, shift)) + fused_quant, fused_scale = scaled_mxfp8_modulate_quant(activation, scale, shift) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + + activation = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + baseline_quant, baseline_scale = scaled_mxfp8_quant(self._modulate_reference(activation, scale, shift)) + fused_quant, fused_scale = scaled_mxfp8_modulate_quant(activation, scale.view(1, 1, -1), shift.view(1, 1, -1)) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + self.assertTrue(torch.equal(fused_scale.view(torch.uint8), baseline_scale.view(torch.uint8))) + + def test_mxfp8_modulate_quant_matches_2d_baseline(self): + activation = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + scale = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.1 + shift = torch.randn(257, 512, dtype=torch.bfloat16, device="cuda") * 0.1 + baseline_quant, baseline_scale = scaled_mxfp8_quant(self._modulate_reference(activation, scale, shift)) + fused_quant, fused_scale = scaled_mxfp8_modulate_quant(activation, scale, shift) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + + activation = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") * 0.5 + scale = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") * 0.1 + shift = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") * 0.1 + baseline_quant, baseline_scale = scaled_mxfp8_quant(self._modulate_reference(activation, scale, shift)) + fused_quant, fused_scale = scaled_mxfp8_modulate_quant(activation, scale, shift) + torch.cuda.synchronize() + self.assertTrue(torch.equal(fused_quant, baseline_quant)) + self.assertTrue(torch.equal(fused_scale.view(torch.uint8), baseline_scale.view(torch.uint8))) + + def test_mxfp8_gemm_residual_gate_matches_baseline(self): + activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(n, dtype=torch.bfloat16, device="cuda") * 0.25 + gemm_out = cutlass_scaled_mxfp8_mm( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + bias=bias, + ) + baseline = residual + gemm_out * gate + fused = cutlass_scaled_mxfp8_mm_residual_gate( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual=residual.clone(), + gate=gate, + bias=bias, + ) + torch.cuda.synchronize() + self.assert_close_enough(fused, baseline) + + def test_mxfp8_gemm_residual_gate_without_bias_matches_baseline(self): + activation_quant, weight_quant, activation_scale, weight_scale, alpha, _ = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(n, dtype=torch.bfloat16, device="cuda") * 0.25 + gemm_out = cutlass_scaled_mxfp8_mm( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + bias=None, + ) + baseline = residual + gemm_out * gate + fused = cutlass_scaled_mxfp8_mm_residual_gate( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual=residual.clone(), + gate=gate, + bias=None, + ) + torch.cuda.synchronize() + self.assert_close_enough(fused, baseline) + + def test_mxfp8_gemm_residual_gate_2d_fallback_matches_baseline(self): + activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") * 0.25 + gemm_out = cutlass_scaled_mxfp8_mm( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + bias=bias, + ) + baseline = residual + gemm_out * gate + fused = cutlass_scaled_mxfp8_mm_residual_gate( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual=residual.clone(), + gate=gate, + bias=bias, + ) + torch.cuda.synchronize() + self.assert_close_enough(fused, baseline) + + def test_mxfp8_gemm_residual_gate_1d_fast_path_matches_2d_fallback_contract(self): + activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(n, dtype=torch.bfloat16, device="cuda") * 0.25 + fused_1d = cutlass_scaled_mxfp8_mm_residual_gate( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual=residual.clone(), + gate=gate, + bias=bias, + ) + fused_2d = cutlass_scaled_mxfp8_mm_residual_gate( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual=residual.clone(), + gate=gate.expand(m, n).contiguous(), + bias=bias, + ) + torch.cuda.synchronize() + self.assert_close_enough(fused_1d, fused_2d) + + def _assert_residual_gate_rejects(self, mat_a, mat_b, scales_a, scales_b, alpha, residual, gate, pattern, bias=None): + with self.assertRaisesRegex(RuntimeError, pattern): + cutlass_scaled_mxfp8_mm_residual_gate( + mat_a, + mat_b, + scales_a, + scales_b, + alpha, + residual=residual, + gate=gate, + bias=bias, + ) + + def test_mxfp8_gemm_residual_gate_fast_path_validates_gemm_inputs(self): + activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(n, dtype=torch.bfloat16, device="cuda") * 0.25 + + bad_weight = torch.randn(n, activation_quant.shape[1] + 32, dtype=torch.bfloat16, device="cuda") + bad_weight_quant, bad_weight_scale = scaled_mxfp8_quant(bad_weight) + bad_residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + self._assert_residual_gate_rejects( + activation_quant, + bad_weight_quant, + activation_scale, + bad_weight_scale, + alpha, + bad_residual, + gate, + "shapes cannot be multiplied", + bias=bias, + ) + + unaligned_weight = torch.randn(130, activation_quant.shape[1], dtype=torch.bfloat16, device="cuda") + unaligned_weight_quant, unaligned_weight_scale = scaled_mxfp8_quant(unaligned_weight) + unaligned_residual = torch.randn(m, 130, dtype=torch.bfloat16, device="cuda") + unaligned_gate = torch.randn(130, dtype=torch.bfloat16, device="cuda") + self._assert_residual_gate_rejects( + activation_quant, + unaligned_weight_quant, + activation_scale, + unaligned_weight_scale, + alpha, + unaligned_residual, + unaligned_gate, + "Expected n to be divisible by 128", + ) + + self._assert_residual_gate_rejects( + activation_quant, + weight_quant, + activation_scale[:128], + weight_scale, + alpha, + residual.clone(), + gate, + "scale_a must be padded", + bias=bias, + ) + self._assert_residual_gate_rejects( + activation_quant, + weight_quant, + activation_scale, + weight_scale[:128], + alpha, + residual.clone(), + gate, + "scale_b must be padded", + bias=bias, + ) + self._assert_residual_gate_rejects( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha.double(), + residual.clone(), + gate, + "Inconsistency of Tensor type:alpha", + bias=bias, + ) + self._assert_residual_gate_rejects( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + torch.ones(2, dtype=torch.float32, device="cuda"), + residual.clone(), + gate, + "alpha must contain exactly one scalar", + bias=bias, + ) + + def test_mxfp8_gemm_residual_gate_fast_path_validates_device_mismatch(self): + if torch.cuda.device_count() < 2: + self.skipTest("device mismatch test requires at least two visible CUDA devices") + other_device = None + for device_index in range(torch.cuda.device_count()): + if device_index != self.device.index: + other_device = torch.device("cuda", device_index) + break + if other_device is None: + self.skipTest("device mismatch test requires another visible CUDA device") + activation_quant, weight_quant, activation_scale, weight_scale, alpha, bias = self._quantized_inputs() + m, n = activation_quant.shape[0], weight_quant.shape[0] + residual = torch.randn(m, n, dtype=torch.bfloat16, device="cuda") + gate = torch.randn(n, dtype=torch.bfloat16, device=other_device) + self._assert_residual_gate_rejects( + activation_quant, + weight_quant, + activation_scale, + weight_scale, + alpha, + residual, + gate, + "same CUDA device", + bias=bias, + ) + + +if __name__ == "__main__": + unittest.main()