Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/operators/test_cutlass_fp8_fp8_half_block_gemm_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused

BLOCK_SIZE = 128

paddle.seed(2025)
np.random.seed(2025)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 随机种子在模块级别设置,各测试用例之间随机状态相互依赖。

当测试用例以不同顺序或单独运行时,每次得到的随机张量可能不同,降低可重复性。建议将 seed 移入 setUp 方法:

def setUp(self):
    paddle.seed(2025)
    np.random.seed(2025)
    paddle.set_device("gpu")
    ...



class TestCutlassFp8Fp8HalfBlockGemmFused(unittest.TestCase):
"""Tests for cutlass_fp8_fp8_half_block_gemm_fused (FP8 block-scaled GEMM)."""

def setUp(self):
paddle.set_device("gpu")
self.prop = paddle.device.cuda.get_device_properties()
self.sm_version = self.prop.major * 10 + self.prop.minor
# Auto-tune mode lets the kernel find a valid config for each MNK.
os.environ["FLAGS_use_cutlass_device_best_config_path"] = "tune"

def tearDown(self):
os.environ.pop("FLAGS_use_cutlass_device_best_config_path", None)

def _skip_if_not_sm90(self):
if self.sm_version < 90:
self.skipTest(f"Requires SM90+ (current: SM{self.sm_version})")

def _check_output(self, m, n, k, output_dtype="bfloat16"):
"""Run block GEMM and verify against dequant-matmul reference."""
scale_k = (k + BLOCK_SIZE - 1) // BLOCK_SIZE
scale_n = (n + BLOCK_SIZE - 1) // BLOCK_SIZE

x_fp8 = paddle.rand([m, k], dtype="bfloat16").astype("float8_e4m3fn")
y_fp8 = paddle.rand([n, k], dtype="bfloat16").astype("float8_e4m3fn")
x_scale = paddle.rand([scale_k, m], dtype="float32") * 0.9 + 0.1
y_scale = paddle.rand([scale_n, scale_k], dtype="float32") * 0.9 + 0.1

# Dequantize: expand block scales, then matmul in fp32
x_s = paddle.repeat_interleave(x_scale, BLOCK_SIZE, axis=0)[:k, :].transpose([1, 0])
y_s = paddle.repeat_interleave(y_scale, BLOCK_SIZE, axis=0)[:n, :]
y_s = paddle.repeat_interleave(y_s, BLOCK_SIZE, axis=1)[:, :k]
ref = paddle.matmul(
x_fp8.astype("float32") * x_s.astype("float32"),
y_fp8.astype("float32") * y_s.astype("float32"),
transpose_y=True,
)
out_t = paddle.bfloat16 if output_dtype == "bfloat16" else paddle.float16
ref = ref.astype(out_t)

result = cutlass_fp8_fp8_half_block_gemm_fused(
x_fp8,
y_fp8,
x_scale,
y_scale,
None,
transpose_x=False,
transpose_y=True,
output_dtype=output_dtype,
act="",
)

self.assertEqual(result.shape, [m, n])
self.assertEqual(result.dtype, out_t)
np.testing.assert_allclose(
ref.astype("float32").numpy(),
result.astype("float32").numpy(),
rtol=5e-2,
atol=5e-2,
)

def test_bfloat16_correctness(self):
"""BF16 output correctness with multiple shapes."""
self._skip_if_not_sm90()
for m, n, k in [(32, 2048, 2048), (64, 4096, 4096), (128, 5120, 5120)]:
with self.subTest(m=m, n=n, k=k):
self._check_output(m, n, k)

def test_float16_output(self):
"""FP16 output correctness."""
self._skip_if_not_sm90()
self._check_output(64, 2048, 2048, output_dtype="float16")

def test_non_aligned_dimensions(self):
"""N and K not aligned to block size 128."""
self._skip_if_not_sm90()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 文档字符串称 "N and K not aligned to block size 128",但实际传入的 n=2048 是 128 的整数倍(2048 / 128 = 16),仅 K=5504 不对齐。

建议修正文档字符串以准确描述测试内容:

"""K not aligned to block size 128 (K=5504 is not a multiple of 128)."""

或同时测试 N 不对齐的情况,例如 _check_output(32, 2050, 5504)

self._check_output(32, 2048, 5504)


if __name__ == "__main__":
unittest.main()
Loading