|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Implements a quantized matmul kernel.""" |
| 15 | + |
| 16 | +import dataclasses |
| 17 | +from typing import Any |
| 18 | + |
| 19 | +import jax |
| 20 | +import jax.experimental.pallas as pl |
| 21 | +import jax.numpy as jnp |
| 22 | +from qwix._src.core import qarray |
| 23 | + |
| 24 | +INTERPRET: bool = True |
| 25 | + |
| 26 | + |
| 27 | +@dataclasses.dataclass |
| 28 | +class QuantizedMatmulConfig: |
| 29 | + bm: int = 128 |
| 30 | + bk: int = 128 |
| 31 | + bn: int = 128 |
| 32 | + dtype: jnp.dtype = jnp.float32 |
| 33 | + |
| 34 | + |
| 35 | +def can_use_qmm(x, sx, y, sy, *, bm, bk, bn): |
| 36 | + """Returns whether the quantized matmul can be used.""" |
| 37 | + mdim, kdim = x.shape |
| 38 | + _, ndim = y.shape |
| 39 | + k_tiles = sx.shape[1] |
| 40 | + |
| 41 | + if mdim % bm != 0 or ndim % bn != 0 or kdim % bk != 0: |
| 42 | + # Block size must divide matrix size. |
| 43 | + return False |
| 44 | + grid = (mdim // bm, ndim // bn, kdim // bk) |
| 45 | + |
| 46 | + # k information |
| 47 | + k_tile_size = kdim // k_tiles |
| 48 | + if k_tile_size != bk: |
| 49 | + # Block size must match the tile size for the reduction axis. |
| 50 | + return False |
| 51 | + if sx.shape[1] != sy.shape[0]: |
| 52 | + # Number of tiles must match for the scales. |
| 53 | + return False |
| 54 | + |
| 55 | + if sx.shape[0] != grid[2] or sx.shape[0] != 1: |
| 56 | + # Scale size must match grid size or be 1. |
| 57 | + return False |
| 58 | + |
| 59 | + if sy.shape[1] != grid[1] or sy.shape[1] != 1: |
| 60 | + # Scale size must match grid size or be 1. |
| 61 | + return False |
| 62 | + |
| 63 | + return True |
| 64 | + |
| 65 | + |
| 66 | +def can_use_qmm_in_dot_general( |
| 67 | + lhs: qarray.QArray, rhs: qarray.QArray, dimension_numbers: Any |
| 68 | +): |
| 69 | + """Returns whether the quantized matmul can be used in dot_general.""" |
| 70 | + # Check the qarrays. |
| 71 | + if lhs.zero_point is not None or rhs.zero_point is not None: |
| 72 | + return False |
| 73 | + |
| 74 | + # Check the dimension numbers. |
| 75 | + if not ( |
| 76 | + len(dimension_numbers) != 2 |
| 77 | + or len(dimension_numbers[0]) != 2 |
| 78 | + or len(dimension_numbers[1]) != 2 |
| 79 | + or tuple(dimension_numbers[0][0]) != (1,) |
| 80 | + or tuple(dimension_numbers[0][1]) != (0,) |
| 81 | + or len(dimension_numbers[1][0]) != 0 |
| 82 | + or len(dimension_numbers[1][1]) != 0 |
| 83 | + ): |
| 84 | + return False |
| 85 | + |
| 86 | + return True |
| 87 | + |
| 88 | + |
| 89 | +def quantized_matmul_kernel(x_ref, sx_ref, y_ref, sy_ref, o_ref): |
| 90 | + @pl.when(pl.program_id(2) == 0) |
| 91 | + def _(): |
| 92 | + o_ref[...] = jnp.zeros_like(o_ref) |
| 93 | + |
| 94 | + o_ref[...] += ( |
| 95 | + jnp.matmul(x_ref[...], y_ref[...]).astype(sx_ref.dtype) |
| 96 | + * sx_ref[...] |
| 97 | + * sy_ref[...] |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +def q_matmul(x, sx, y, sy, *, bm=128, bk=128, bn=128, dtype=jnp.float32): |
| 102 | + """Computes a quantized matmul with support for subchannel quantization. |
| 103 | +
|
| 104 | + This kernel does not cover all cases. In particular, it requires that |
| 105 | + the block sizes match the tile sizes, and that the scale sizes match the grid |
| 106 | + size or be 1. |
| 107 | +
|
| 108 | + Args: |
| 109 | + x: The left-hand side matrix. |
| 110 | + sx: The left-hand side scales. |
| 111 | + y: The right-hand side matrix. |
| 112 | + sy: The right-hand side scales. |
| 113 | + bm: The block size for the m dimension. |
| 114 | + bk: The block size for the k dimension. |
| 115 | + bn: The block size for the n dimension. |
| 116 | + dtype: The data type of the output. |
| 117 | +
|
| 118 | + Returns: |
| 119 | + The quantized matmul. |
| 120 | + """ |
| 121 | + mdim, kdim = x.shape |
| 122 | + _, ndim = y.shape |
| 123 | + k_tiles = sx.shape[1] |
| 124 | + |
| 125 | + # Block specs for x and y. |
| 126 | + assert mdim % bm == 0, f'Block size must divide matrix size, {mdim=} {bm=}' |
| 127 | + assert ndim % bn == 0, f'Block size must divide matrix size, {ndim=} {bn=}' |
| 128 | + assert kdim % bk == 0, f'Block size must divide matrix size, {kdim=} {bk=}' |
| 129 | + grid = (mdim // bm, ndim // bn, kdim // bk) |
| 130 | + x_blockspec = pl.BlockSpec((bm, bk), lambda a, b, c: (a, c)) |
| 131 | + y_blockspec = pl.BlockSpec((bk, bn), lambda a, b, c: (c, b)) |
| 132 | + |
| 133 | + # k information |
| 134 | + k_tile_size = kdim // k_tiles |
| 135 | + assert k_tile_size == bk, ( |
| 136 | + 'Block size must match the tile size for the reduction axis' |
| 137 | + f' {k_tile_size=} {bk=}' |
| 138 | + ) |
| 139 | + assert sx.shape[1] == sy.shape[0], 'Number of tiles must match for the scales' |
| 140 | + |
| 141 | + # m information |
| 142 | + if sx.shape[0] == 1: |
| 143 | + sx_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (0, c)) |
| 144 | + else: |
| 145 | + assert ( |
| 146 | + sx.shape[0] == grid[0] |
| 147 | + ), f'Scale size must match grid size, {sx.shape[0]=} {grid[0]=}' |
| 148 | + sx_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (a, c)) |
| 149 | + |
| 150 | + # n information |
| 151 | + if sy.shape[1] == 1: |
| 152 | + sy_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (c, 0)) |
| 153 | + else: |
| 154 | + assert ( |
| 155 | + sy.shape[1] == grid[1] |
| 156 | + ), f'Scale size must match grid size, {sy.shape[1]=} {grid[1]=}' |
| 157 | + sy_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (c, b)) |
| 158 | + |
| 159 | + return pl.pallas_call( |
| 160 | + quantized_matmul_kernel, |
| 161 | + out_shape=jax.ShapeDtypeStruct((mdim, ndim), dtype), |
| 162 | + grid=grid, |
| 163 | + in_specs=(x_blockspec, sx_blockspec, y_blockspec, sy_blockspec), |
| 164 | + out_specs=pl.BlockSpec((bm, bn), lambda a, b, c: (a, b)), |
| 165 | + interpret=INTERPRET, |
| 166 | + )(x, sx, y, sy).astype(dtype) |
0 commit comments