Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions tests/operators/test_winx_unzip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.gpu import winx_unzip

paddle.seed(42)
np.random.seed(42)


def wint25_unzip_ref(zipped_weight_np, super_scale_np):
"""NumPy reference for weight_only_int2.5 decompression."""
batch, k_zipped, n_cols = zipped_weight_np.shape
num_groups = k_zipped // 10
num_rows = num_groups * 64
weight = np.zeros((batch, num_rows, n_cols), dtype=np.float32)
shift_bits = np.array([13, 11, 9, 6, 4, 2, 0], dtype=np.int32)

for b in range(batch):
for g in range(num_groups):
zipped_last = zipped_weight_np[b, g * 10 + 9, :].astype(np.int32)
local_scale = (zipped_last & 0x1FFF).astype(np.float32)
scale = local_scale * super_scale_np[b]
for zr in range(9):
zv = zipped_weight_np[b, g * 10 + zr, :].astype(np.int32)
for si in range(7):
shifted = (zv >> shift_bits[si]) & 0x7
weight[b, g * 64 + zr * 7 + si, :] = (shifted.astype(np.float32) - 4) * scale
val_last = (zipped_last >> 13) & 0x7
weight[b, g * 64 + 63, :] = (val_last.astype(np.float32) - 4) * scale
return weight


def wint2_unzip_ref(zipped_weight_np, local_scale_np, code_scale_np, code_zp_np, super_scale_np):
"""NumPy reference for weight_only_int2 decompression."""
batch, k_packed, n_cols = zipped_weight_np.shape
num_rows = k_packed * 4
num_groups = num_rows // 64
weight = np.zeros((batch, num_rows, n_cols), dtype=np.float32)
shift_bits = np.array([9, 6, 3, 0], dtype=np.int32)

for b in range(batch):
for g in range(num_groups):
block_start_row = b * num_rows + g * 64
ls_row = g // 2
local_scale_shift = ((block_start_row // 64 + 0 + 1) & 1) * 4
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 local_scale_shift 使用了跨 batch 的全局行号

block_start_row = b * num_rows + g * 64 包含了 b * num_rows 偏移,导致 local_scale_shift 的 nibble 选择在不同 batch 间可能反转。若底层 CUDA kernel 对每个 batch 独立地从 g=0 起算,则正确写法应为:

local_scale_shift = ((g + 1) & 1) * 4

请确认此处是否有意使用全局偏移来复现 kernel 内部的内存布局,若是请添加注释说明。

ls_byte = local_scale_np[b, ls_row, :].astype(np.int32)
ls_float = ((ls_byte >> local_scale_shift) & 0xF).astype(np.float32)
scale = ls_float * super_scale_np[b] if super_scale_np is not None else ls_float
for zr in range(16):
zv = zipped_weight_np[b, g * 16 + zr, :].astype(np.float32)
decode_val = np.floor(zv * code_scale_np[b] + code_zp_np[b] + 0.5).astype(np.int32)
for si in range(4):
shifted = (decode_val >> shift_bits[si]) & 0x3F
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 shift_bits = [9, 6, 3, 0](3-bit 间隔),对应掩码建议确认

当前使用 & 0x3F(6 位),而 shift_bits 相邻间隔仅 3 位,相邻 slot 的 6-bit 窗口会相互重叠(如 >> 9 & 0x3F 取 bits[14:9],>> 6 & 0x3F 取 bits[11:6],两者共享 bits[11:9])。若格式是 3-bit packed,掩码应为 & 0x7;若格式确为 6-bit(对应 -32 的 bias),请添加注释说明打包格式,以便 reviewer 核对。

weight[b, g * 64 + zr * 4 + si, :] = (shifted.astype(np.float32) - 32) * scale
return weight


@unittest.skipUnless(paddle.is_compiled_with_cuda(), "GPU required for winx_unzip")
class TestWinxUnzip(unittest.TestCase):
"""Correctness tests for the winx_unzip custom op."""

def setUp(self):
paddle.set_device("gpu")

def _check_wint25(self, batch, k_zipped, n, seed=42):
"""Run wint2.5 op and compare against NumPy reference."""
np.random.seed(seed)
zipped_np = np.random.randint(0, 65536, (batch, k_zipped, n)).astype(np.uint16)
super_scale_np = np.random.rand(batch, n).astype(np.float32) * 0.1 + 0.01

expected = wint25_unzip_ref(zipped_np, super_scale_np)
zipped_pd = paddle.to_tensor(zipped_np.view(np.int16), dtype=paddle.int16)
super_scale_pd = paddle.to_tensor(super_scale_np.astype(np.float16), dtype=paddle.float16)
out = winx_unzip(zipped_pd, None, None, None, super_scale_pd, "weight_only_int2.5")
self.assertEqual(list(out.shape), [batch, k_zipped // 10 * 64, n])
self.assertEqual(out.dtype, paddle.float16)
np.testing.assert_allclose(out.astype(paddle.float32).numpy(), expected, rtol=5e-3, atol=5e-3)

def _check_wint2(self, batch, k_packed, n, seed=42):
"""Run wint2 op and compare against NumPy reference."""
np.random.seed(seed)
num_ls_rows = (k_packed * 4 // 64 + 1) // 2
zipped_np = np.random.randint(0, 256, (batch, k_packed, n)).astype(np.uint8)
local_scale_np = np.random.randint(0, 256, (batch, num_ls_rows, n)).astype(np.uint8)
code_scale_np = np.full((batch, n), 128.0, dtype=np.float32)
code_zp_np = np.zeros((batch, n), dtype=np.float32)
super_scale_np = (np.random.rand(batch, n) * 0.1 + 0.01).astype(np.float32)

expected = wint2_unzip_ref(zipped_np, local_scale_np, code_scale_np, code_zp_np, super_scale_np)
out = winx_unzip(
paddle.to_tensor(zipped_np, dtype=paddle.uint8),
paddle.to_tensor(local_scale_np, dtype=paddle.uint8),
paddle.to_tensor(code_scale_np, dtype=paddle.float32),
paddle.to_tensor(code_zp_np, dtype=paddle.float32),
paddle.to_tensor(super_scale_np.astype(np.float16), dtype=paddle.float16),
"weight_only_int2",
)
self.assertEqual(list(out.shape), [batch, k_packed * 4, n])
self.assertEqual(out.dtype, paddle.float16)
np.testing.assert_allclose(out.astype(paddle.float32).numpy(), expected, rtol=5e-2, atol=5e-2)

def test_wint25_correctness(self):
"""wint2.5 single and multi-group correctness."""
for batch, k_zipped, n in [(1, 10, 64), (2, 20, 128)]:
with self.subTest(batch=batch, k_zipped=k_zipped, n=n):
self._check_wint25(batch, k_zipped, n)

def test_wint2_correctness(self):
"""wint2 single and multi-group correctness."""
for batch, k_packed, n in [(1, 16, 256), (2, 32, 256)]:
with self.subTest(batch=batch, k_packed=k_packed, n=n):
self._check_wint2(batch, k_packed, n)

def test_wint2_with_code_zp(self):
"""wint2 correctness with non-zero code_zp."""
np.random.seed(99)
batch, k_packed, n = 2, 32, 256
num_ls_rows = 1
zipped_np = np.random.randint(0, 256, (batch, k_packed, n)).astype(np.uint8)
local_scale_np = np.random.randint(0, 256, (batch, num_ls_rows, n)).astype(np.uint8)
code_scale_np = np.full((batch, n), 128.0, dtype=np.float32)
code_zp_np = np.full((batch, n), 0.5, dtype=np.float32)
super_scale_np = (np.random.rand(batch, n) * 0.05 + 0.005).astype(np.float32)

expected = wint2_unzip_ref(zipped_np, local_scale_np, code_scale_np, code_zp_np, super_scale_np)
out = winx_unzip(
paddle.to_tensor(zipped_np, dtype=paddle.uint8),
paddle.to_tensor(local_scale_np, dtype=paddle.uint8),
paddle.to_tensor(code_scale_np, dtype=paddle.float32),
paddle.to_tensor(code_zp_np, dtype=paddle.float32),
paddle.to_tensor(super_scale_np.astype(np.float16), dtype=paddle.float16),
"weight_only_int2",
)
np.testing.assert_allclose(out.astype(paddle.float32).numpy(), expected, rtol=5e-2, atol=5e-2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 test_wint2_with_code_zp 缺少 shape 和 dtype 断言

_check_wint2(第 114-115 行)相比,本测试方法只做了数值精度校验,未验证输出形状和数据类型。建议在 np.testing.assert_allclose 前补充:

self.assertEqual(list(out.shape), [batch, k_packed * 4, n])
self.assertEqual(out.dtype, paddle.float16)



if __name__ == "__main__":
unittest.main()
Loading