diff --git a/CMakeLists.txt b/CMakeLists.txt index 61006e83cbf..a0b9a67e4a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -142,6 +142,7 @@ if(BUILD_CUTLASS) set(NVFUSER_CUTLASS_SRCS) list(APPEND NVFUSER_CUTLASS_SRCS ${NVFUSER_CUTLASS}/group_mm.cu + ${NVFUSER_CUTLASS}/mxfp8_scaled_mm.cu ${NVFUSER_CUTLASS}/nvfp4_scaled_mm.cu ${NVFUSER_CUTLASS}/nvfp4_scaled_mm_blockscale.cu ${NVFUSER_CUTLASS}/nvfp4_scaled_group_mm.cu diff --git a/cutlass/mxfp8_scaled_mm.cu b/cutlass/mxfp8_scaled_mm.cu new file mode 100644 index 00000000000..603c4600120 --- /dev/null +++ b/cutlass/mxfp8_scaled_mm.cu @@ -0,0 +1,316 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" + +namespace nvfuser::cutlass_kernels { + +namespace { + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +// Kernel configuration traits for different output data types +// Defines tile shapes and cluster configurations. +template +struct KernelTraits; + +// Kernel traits for FP16 output +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_2, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +// Kernel traits for BF16 output +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_2, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +// Main GEMM configuration for MXFP8 scaled matrix multiplication on SM100+ +// Defines all the types, layouts, and configurations needed for the CUTLASS +// kernel +template +struct MxFp8GemmSm100 { + // A matrix configuration + using ElementA = cutlass::mx_float8_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int kAlignmentA = 16; + + // B matrix configuration + using ElementB = cutlass::mx_float8_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int kAlignmentB = 16; + + // C/D matrix configuration + using ElementD = T; + using ElementC = T; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int kAlignmentD = + 128 / cutlass::sizeof_bits::value; + static constexpr int kAlignmentC = + 128 / cutlass::sizeof_bits::value; + // Kernel functional config + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + // Kernel Perf config + using MmaTileShape = typename KernelTraits::MmaTileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutCTag, + kAlignmentC, + ElementD, + LayoutDTag, + kAlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + kAlignmentA, + ElementB, + LayoutBTag, + kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + // Scale Factor tensors have an interleaved layout. Bring Layout instead of + // stride. + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + // Scale Factor tensors have an interleaved layout. Bring Layout instead of + // stride. + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +// Constructs CUTLASS GEMM arguments from PyTorch tensors and dimensions +// +// This function converts PyTorch tensor data and metadata into the format +// expected by CUTLASS GEMM kernels, including proper stride calculations +// and layout configurations for the scaled matrix multiplication. +// +// Parameters: +// output: Output tensor for storing results +// a: Input matrix A in MXFP8 format +// b: Input matrix B in MXFP8 format +// scales_a: Per-block scaling factors for matrix A +// scales_b: Per-block scaling factors for matrix B +// alpha: Global scaling factor +// M, N, K: Matrix dimensions +// +// Returns: CUTLASS GEMM arguments structure ready for kernel execution +template +typename T::Gemm::Arguments args_from_options( + at::Tensor& output, + const at::Tensor& a, + const at::Tensor& b, + const at::Tensor& scales_a, + const at::Tensor& scales_b, + const at::Tensor& alpha, + int64_t M, + int64_t N, + int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue8m0_t; + using ElementSFB = cutlass::float_ue8m0_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm1xxBlkScaledConfig = + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(a.data_ptr()), + stride_A, + static_cast(b.data_ptr()), + stride_B, + static_cast(scales_a.data_ptr()), + layout_SFA, + static_cast(scales_b.data_ptr()), + layout_SFB}, + {// Epilogue arguments + {}, // epilogue.thread + static_cast(output.data_ptr()), + stride_D, + static_cast(output.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + return arguments; +} + +// Executes the MXFP8 scaled matrix multiplication using CUTLASS kernels +// +// This function orchestrates the GEMM operation by setting up the kernel, +// allocating workspace memory, and running the computation on the GPU. +// It handles the complete lifecycle from kernel initialization to execution. +// +// Parameters: +// output: Output tensor to store the result +// a, b: Input matrices in MXFP8 format +// scales_a, scales_b: Per-block scaling factors +// alpha: Global scaling factor +// m, n, k: Matrix dimensions +// stream: CUDA stream for asynchronous execution +template +void runGemm( + at::Tensor& output, + const at::Tensor& a, + const at::Tensor& b, + const at::Tensor& scales_a, + const at::Tensor& scales_b, + const at::Tensor& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename MxFp8GemmSm100::Gemm gemm; + + auto arguments = args_from_options>( + output, a, b, scales_a, scales_b, alpha, m, n, k); + + size_t workspace_size = + MxFp8GemmSm100::Gemm::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto can_implement_status = gemm.can_implement(arguments); + NVF_CHECK( + can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + auto status = gemm.initialize(arguments, workspace.data_ptr(), stream); + NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm.run(arguments, workspace.data_ptr(), stream); + NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} +#else +// Fallback implementation for unsupported CUTLASS versions +// Throws an error when SM100+ CUTLASS support is not available +template +void runGemm( + at::Tensor& output, + at::Tensor const& a, + at::Tensor const& b, + at::Tensor const& scales_a, + at::Tensor const& scales_b, + at::Tensor const& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + NVF_THROW("Unsupported CUTLASS version."); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +} // namespace + +torch::Tensor mxfp8_scaled_mm( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& alpha, + const at::ScalarType out_dtype, + bool skip_checks) { + // Validate all inputs and get matrix dimensions + auto [m, n, k] = + validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks); + + at::cuda::CUDAGuard device_guard{(int8_t)a.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto options = + at::TensorOptions().dtype(out_dtype).device(at::kCUDA, a.get_device()); + torch::Tensor output = at::empty({a.sizes()[0], b.sizes()[0]}, options); + + if (out_dtype == at::ScalarType::Half) { + runGemm( + output, a, b, scales_a, scales_b, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::BFloat16) { + runGemm( + output, a, b, scales_a, scales_b, alpha, m, n, k, stream); + } else { + NVF_THROW("Unsupported output data type of mxfp8 scaled_mm."); + } + return output; +} + +} // namespace nvfuser::cutlass_kernels diff --git a/cutlass/nvf_cutlass.cpp b/cutlass/nvf_cutlass.cpp index 9a85f2b6032..b1fe4be874c 100644 --- a/cutlass/nvf_cutlass.cpp +++ b/cutlass/nvf_cutlass.cpp @@ -121,4 +121,116 @@ std::tuple validateInputsNvfp4ScaledMm( return ret; } +std::tuple validateInputsMxFp8ScaledMm( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& alpha, + bool skip_checks) { + // Validate matrix dimensions + NVF_CHECK(a.dim() == 2, "Operand A must be a matrix."); + NVF_CHECK(b.dim() == 2, "Operand B must be a matrix."); + NVF_CHECK( + a.sizes()[1] == b.sizes()[1], + "A and B shapes cannot be multiplied (", + a.sizes()[0], + ",", + a.sizes()[1], + " and ", + b.sizes()[0], + ",", + b.sizes()[1], + ")"); + + const int64_t m = a.sizes()[0]; + const int64_t n = b.sizes()[0]; + const int64_t k = a.sizes()[1]; + + std::tuple ret = {m, n, k}; + + if (skip_checks) { + return ret; + } + + // Check CUDA device and contiguity for all input tensors + for (const torch::Tensor& t : {a, b, scales_a, scales_b, alpha}) { + NVF_CHECK( + t.is_cuda() && t.is_contiguous(), + "Input argument must be a CUDA tensor and contiguous.") + } + + // Validate data types + NVF_CHECK( + a.scalar_type() == at::ScalarType::Float8_e4m3fn, + "Expected Float8_e4m3fn for Operand A.") + NVF_CHECK( + b.scalar_type() == at::ScalarType::Float8_e4m3fn, + "Expected Float8_e4m3fn for Operand B.") + NVF_CHECK( + scales_a.scalar_type() == at::ScalarType::Float8_e8m0fnu, + "Expected FP8_e8m0fnu for Blockscale scale_a.") + NVF_CHECK( + scales_b.scalar_type() == at::ScalarType::Float8_e8m0fnu, + "Expected FP8_e8m0fnu for Blockscale scale_b.") + NVF_CHECK( + alpha.scalar_type() == at::ScalarType::Float, + "Expected FP32 for alpha scalar.") + + // Check alignment requirements + constexpr int64_t kAlignment = 16; + NVF_CHECK_EQ( + k % kAlignment, + 0, + "The K dimension", + k, + "is not divisible by ", + kAlignment) + NVF_CHECK_EQ( + n % kAlignment, + 0, + "The N dimension", + n, + "is not divisible by ", + kAlignment) + + // Calculate rounded dimensions for scale matrix validation + int64_t rounded_m = roundUp(m, 128); + int64_t rounded_n = roundUp(n, 128); + constexpr int64_t BLOCK_SCALE = 32; + int64_t rounded_k = roundUp(k / BLOCK_SCALE, 4); + + // Validate scale matrix properties + NVF_CHECK(scales_a.dim() == 2, "Blockscale scale_a must be a matrix."); + NVF_CHECK(scales_b.dim() == 2, "Blockscale scale_b must be a matrix."); + NVF_CHECK( + scales_a.sizes()[1] == scales_b.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied because the inner-most " + "dimensions are not equal.") + NVF_CHECK( + scales_a.sizes()[0] == rounded_m && scales_a.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + ",", + rounded_k, + "), but got a shape (", + scales_a.sizes()[0], + ",", + scales_a.sizes()[1], + ")"); + NVF_CHECK( + scales_b.sizes()[0] == rounded_n && scales_b.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + ",", + rounded_k, + "), but got a shape (", + scales_b.sizes()[0], + ",", + scales_b.sizes()[1], + ")"); + + return ret; +} + } // namespace nvfuser::cutlass_kernels diff --git a/cutlass/nvf_cutlass.h b/cutlass/nvf_cutlass.h index 50331a26fbe..8e96634b049 100644 --- a/cutlass/nvf_cutlass.h +++ b/cutlass/nvf_cutlass.h @@ -62,7 +62,7 @@ NVF_API std::tuple validateInputsNvfp4ScaledMm( // scales_a: Per-block scaling factors for matrix A in FP8_E4M3 format // scales_b: Per-block scaling factors for matrix B in FP8_E4M3 format // alpha: Combined global scaling factor for operands A and B in FP32 format -// out_dtype: Output data type (Half, BFloat16, or Float) +// out_dtype: Output data type (Half or BFloat16) // // Returns: Matrix C = alpha * (A @ B) in the specified output dtype NVF_API torch::Tensor nvfp4_scaled_mm( @@ -74,6 +74,57 @@ NVF_API torch::Tensor nvfp4_scaled_mm( at::ScalarType out_dtype, bool skip_checks = false); +// Validates all input parameters and tensor properties for MXFP8 scaled matrix +// multiplication +// +// This function performs comprehensive validation of input tensors including: +// - CUDA device and contiguity checks +// - Data type validation for all inputs +// - Matrix dimension and shape compatibility +// - Alignment requirements for optimal performance +// - Scale matrix shape validation +// +// Parameters: +// a, b: Input matrices to validate +// scales_a, scales_b: Scale matrices to validate +// alpha: Alpha scaling factor to validate +// +// Returns: Tuple of (m, n, k) dimensions for the GEMM operation +// +// Throws: NVF_CHECK exceptions for any validation failures +NVF_API std::tuple validateInputsMxFp8ScaledMm( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& alpha, + bool skip_checks = false); + +// Performs scaled matrix multiplication using MXFP8 format. +// +// This function implements a scaled matrix multiplication C = alpha * (A @ B) +// where A and B are matrices in MXFP8 format with per-block scaling factors. +// The function uses CUTLASS kernels optimized for NVIDIA GPUs with SM100+ +// architecture. +// +// Parameters: +// a: Input matrix A in Float8_e4m3fn format +// b: Input matrix B in Float8_e4m3fn format +// scales_a: Per-block scaling factors for matrix A in FP8_E8M0fnu format +// scales_b: Per-block scaling factors for matrix B in FP8_E8M0fnu format +// alpha: Combined global scaling factor for operands A and B in FP32 format +// out_dtype: Output data type (Half or BFloat16) +// +// Returns: Matrix C = alpha * (A @ B) in the specified output dtype +NVF_API torch::Tensor mxfp8_scaled_mm( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& alpha, + at::ScalarType out_dtype, + bool skip_checks = false); + // Performs scaled matrix multiplication using NVFP4 format with fused epilogue // blockscale quantization. // diff --git a/python/python_direct/cutlass.cpp b/python/python_direct/cutlass.cpp index 6030b5c72ca..0b5f8347f44 100644 --- a/python/python_direct/cutlass.cpp +++ b/python/python_direct/cutlass.cpp @@ -15,6 +15,26 @@ namespace nvfuser::python { namespace { void bindGemm(py::module_& cutlass) { + cutlass.def( + "mxfp8_scaled_mm", + [](const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& alpha, + at::ScalarType out_dtype) -> torch::Tensor { + return cutlass_kernels::mxfp8_scaled_mm( + a, b, scales_a, scales_b, alpha, out_dtype); + }, + R"(Computes mxfp8 matmul and returns bf16 or fp16 output tensor. + mxfp8_scaled_mm(Tensor a, + Tensor b, + Tensor scales_a, + Tensor scales_b, + Tensor alpha, + DataType out_dtype) + -> Tensor output)"); + cutlass.def( "nvfp4_scaled_mm", [](const torch::Tensor& a, @@ -26,7 +46,7 @@ void bindGemm(py::module_& cutlass) { return cutlass_kernels::nvfp4_scaled_mm( a, b, scales_a, scales_b, alpha, out_dtype); }, - R"(Computes nvfp4 matmul and returns bf16, fp16, or fp32 output tensor. + R"(Computes nvfp4 matmul and returns bf16 or fp16 output tensor. nvfp4_scaled_mm(Tensor a, Tensor b, Tensor scales_a, diff --git a/tests/python/direct/test_cutlass_mxfp8_gemm.py b/tests/python/direct/test_cutlass_mxfp8_gemm.py new file mode 100644 index 00000000000..5890febded0 --- /dev/null +++ b/tests/python/direct/test_cutlass_mxfp8_gemm.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import pytest +import torch +from nvfuser_direct import nvf_cutlass + +compute_cap = torch.cuda.get_device_capability() +if compute_cap < (10, 0) or compute_cap >= (12, 0): + pytest.skip( + reason="MxFp8 Requires compute capability 10.", + allow_module_level=True, + ) + +from python.direct_utils import ( + linear_to_swizzled_128_4, + swizzled_to_linear_128_4, +) + + +def dequantize_mxfp8(tensor_fp8, tensor_sf): + """Dequantize the fp8 tensor back to high precision.""" + m, k = tensor_fp8.shape + BLOCK_SIZE = 32 + tensor_sf_linear = swizzled_to_linear_128_4(tensor_sf, m, k) + # Apply scale factor to all elements in the same block + sf = tensor_sf_linear.repeat_interleave(BLOCK_SIZE, dim=1).to(torch.float32) + dqx = tensor_fp8.to(torch.float32) + # Account for padding of scale factor + sf = sf[: dqx.shape[0], : dqx.shape[1]] + dequant = dqx * sf + return dequant.reshape(m, k) + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + # The fn suffix means that fp8 is a finite type without infinite support. + # Clamp values above 464 to avoid casting values to NaN. + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def pytorch_mxfp8_quantize(a): + BLOCK_SIZE = 32 + assert ( + a.size(-1) % BLOCK_SIZE == 0 + ), "The inner-most dim must be divisible by block_size; Padding is not implemented." + assert a.is_contiguous(), "Only contiguous tensors are supported." + + # Find absolute maximum along blockwise dimension + original_shape = a.shape + a_fp32 = a.float().reshape(original_shape[0], -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(a_fp32), dim=-1) + + # Get fp32 block scale factor for fp8 + FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + block_scale_fp32 = (max_abs / FLOAT8_E4M3_MAX).float() + + # Clamp scale factor within UE8M0 + FLOAT8_UE8M0_EPS = torch.finfo(torch.float8_e8m0fnu).tiny + FLOAT8_UE8M0_MAX = torch.finfo(torch.float8_e8m0fnu).max + block_scale_fp32 = torch.clamp( + block_scale_fp32, min=FLOAT8_UE8M0_EPS, max=FLOAT8_UE8M0_MAX + ) + + # Apply block conversion factor + a_scaled = a_fp32 / block_scale_fp32.unsqueeze(-1) + a_scaled = a_scaled.view(original_shape) + + return to_fp8(a_scaled), block_scale_fp32.to(torch.float8_e8m0fnu) + + +def get_ref_results( + a_fp8, + b_fp8, + a_sf, + b_sf, + m, + n, +): + _, m_k = a_fp8.shape + _, n_k = b_fp8.shape + assert m_k == n_k + a_in_dtype = dequantize_mxfp8(a_fp8, a_sf) + b_in_dtype = dequantize_mxfp8(b_fp8, b_sf) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "shape", [(128, 128, 128), (128, 128, 256), (256, 128, 128), (128, 256, 256)] +) +@torch.inference_mode() +def test_mxfp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], +) -> None: + m, n, k = shape + block_size = 32 + a_dtype = torch.randn((m, k), dtype=dtype, device="cuda") + b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") + + alpha = torch.tensor(1.0, device="cuda") + a_fp8, a_scale_linear = pytorch_mxfp8_quantize(a_dtype) + b_fp8, b_scale_linear = pytorch_mxfp8_quantize(b_dtype) + a_scale_interleaved = linear_to_swizzled_128_4(a_scale_linear) + b_scale_interleaved = linear_to_swizzled_128_4(b_scale_linear) + + expected_out = get_ref_results( + a_fp8, + b_fp8, + a_scale_interleaved, + b_scale_interleaved, + m, + n, + ) + out = nvf_cutlass.mxfp8_scaled_mm( + a_fp8, b_fp8, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) + + torch.testing.assert_close(out, expected_out.to(dtype=dtype))