Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions crates/uzu/src/backends/common/kernel/quant_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,13 @@ impl<B: Backend> QuantizedMatmulKernelEncodable<B> {
// Matrix-matrix
let aligned_n_64 = configuration.output_dim % 64 == 0;
let is_bf16 = configuration.data_type == DataType::BF16;
let is_zero_point = configuration.quantization_type == QuantizedMatmulType::ZeroPoint;

let matrix_matrix = if aligned_n_64 && is_bf16 && matches!(configuration.group_size, 64 | 128) {
let matrix_matrix = if aligned_n_64
&& is_bf16
&& is_zero_point
&& matches!(configuration.group_size, 64 | 128)
{
MatrixMatrixKernel::QmmTransposed64x64(
<B::Kernels as Kernels>::QuantizedMatmulQmmTransposed64x64Kernel::new(
context,
Expand All @@ -129,7 +134,7 @@ impl<B: Backend> QuantizedMatmulKernelEncodable<B> {
)
.map_err(QuantizedMatmulError::BackendError)?,
)
} else if aligned_n_64 && is_bf16 {
} else if aligned_n_64 && is_bf16 && is_zero_point {
MatrixMatrixKernel::QmmTransposedWide(
<B::Kernels as Kernels>::QuantizedMatmulQmmTransposedWideKernel::new(
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ PUBLIC KERNEL(QuantizedMatmulQmmTransposed)(
const constant uint& in_vec_size,
const constant uint& out_vec_size,
const constant uint& batch_size,
threadgroup T Xs[32 * (32 + 16 / sizeof(T))],
threadgroup T Xs[160 * (32 + 16 / sizeof(T))],
threadgroup T Ws[32 * (32 + 16 / sizeof(T))],
const bool use_zero_points SPECIALIZE,
const bool use_mlx_quant SPECIALIZE,
const bool aligned_n SPECIALIZE,
const uint out_block_idx GROUPS(out_vec_size.div_ceil(32)),
const uint batch_block_idx GROUPS(batch_size.div_ceil(32)),
const uint batch_block_idx GROUPS(batch_size.div_ceil(160)),
const uint simd_lane THREADS(32),
const uint simd_group THREADS(4)
) {
if (use_mlx_quant) {
if (aligned_n) {
qmm_transposed_impl<T, GROUP_SIZE, BITS, true, 32, 32, 32, true>(
qmm_transposed_impl<T, GROUP_SIZE, BITS, true, 160, 32, 32, true>(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep BM at 32 until BlockLoader handles wide row spans

Setting qmm_transposed_impl to BM=160 makes loader_x_t use BlockLoader with n_reads=(BM*BK)/(WM*WN*32)=40 (for BK=32), which is larger than BK. In BlockLoader::load_unsafe/load_safe (quant_matmul.h via mma.h), writes are linearized from (bi,bj) using BK while the destination uses BK_padded stride, so each thread spills into padding and skips real matrix elements. For matrix-matrix calls that hit QuantizedMatmulQmmTransposed (e.g., MLX path and any non-specialized zero-point path), this feeds incorrect Xs tiles to MMA and produces wrong outputs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to either just keep 32 here or fix BlockLoader here as the perf gains without it break correctness tests

weights,
scales,
zero_points,
Expand All @@ -46,7 +46,7 @@ PUBLIC KERNEL(QuantizedMatmulQmmTransposed)(
simd_lane
);
} else {
qmm_transposed_impl<T, GROUP_SIZE, BITS, false, 32, 32, 32, true>(
qmm_transposed_impl<T, GROUP_SIZE, BITS, false, 160, 32, 32, true>(
weights,
scales,
zero_points,
Expand All @@ -66,7 +66,7 @@ PUBLIC KERNEL(QuantizedMatmulQmmTransposed)(
}
} else {
if (aligned_n) {
qmm_transposed_impl<T, GROUP_SIZE, BITS, true, 32, 32, 32, false>(
qmm_transposed_impl<T, GROUP_SIZE, BITS, true, 160, 32, 32, false>(
weights,
scales,
zero_points,
Expand All @@ -84,7 +84,7 @@ PUBLIC KERNEL(QuantizedMatmulQmmTransposed)(
simd_lane
);
} else {
qmm_transposed_impl<T, GROUP_SIZE, BITS, false, 32, 32, 32, false>(
qmm_transposed_impl<T, GROUP_SIZE, BITS, false, 160, 32, 32, false>(
weights,
scales,
zero_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ void qmm_transposed_impl(
static_assert(BK >= 32, "BK should be larger than METAL_SIMD_SIZE");
static_assert(BK % 32 == 0, "BK should be divisible by METAL_SIMD_SIZE");

constexpr int WM = 2;
constexpr int WN = 2;
constexpr int WM = use_mlx_quant && group_size == 64 && bits == 4 ? 4 : 2;
constexpr int WN = use_mlx_quant && group_size == 64 && bits == 4 ? 1 : 2;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
Expand Down
Loading