-
Notifications
You must be signed in to change notification settings - Fork 161
Description
Description
I found several critical issues in the int8_gemm_triton implementation in lightx2v/common/ops/mm/triton_kernels.py. These bugs make the Int8 quantization operator produce incorrect results (overflow) and unusable output types.
1. Accumulator Overflow (K >= 128)
When the hidden dimension int8(1) and 128, but the kernel returns -128. This suggests that the internal accumulation or the write-back logic is incorrectly using int8 instead of int32/float32.
2. Incorrect Output Dtype
The int8_gemm_triton function returns a tensor with torch.int8 (Char) dtype instead of the expected floating-point dtype (fp16 or bf16). This causes downstream errors in PyTorch (e.g., mean() or other floating-point operations fail).
3. Broken Scale Indexing
The scale loading logic A_SCALES + pid_m * BLOCK_M assumes scales are stored in blocks, which is incompatible with standard per-token/per-tensor quantization scales. This leads to incorrect de-quantization or zeroed outputs.
Reproduction Code
import torch
# Assuming triton_kernels is imported from lightx2v
device = "cuda"
M, N, K = 128, 128, 128
# All ones input
a = torch.ones((M, K), device=device, dtype=torch.int8)
b = torch.ones((N, K), device=device, dtype=torch.int8) # Linear weight shape (out, in)
a_scales = torch.ones(M, device=device, dtype=torch.float32)
b_scales = torch.ones(N, device=device, dtype=torch.float32)
# Run kernel
output = int8_gemm_triton(a, b, a_scales, b_scales)
print(f"Output Dtype: {output.dtype}") # Should be float, but got int8
print(f"Output Max: {output.max().item()}") # Expected 128.0, but got -128.0Metadata
Metadata
Assignees
Labels
Type
Projects
Status