Skip to content

Commit f85ca88

Browse files
chapman20jcopybara-github
authored andcommitted
quantized matmul kernel
PiperOrigin-RevId: 891781786
1 parent 0e03eae commit f85ca88

4 files changed

Lines changed: 251 additions & 0 deletions

File tree

qwix/_src/core/dot_general.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
# pylint: disable=line-too-long
1616

1717
from collections.abc import Collection, Sequence
18+
import dataclasses
1819
import itertools
1920
from typing import Any
21+
2022
import jax
2123
from jax import numpy as jnp
2224
from qwix._src.core import numerics
2325
from qwix._src.core import qarray
26+
from qwix.contrib.kernels import quantized_matmul
2427

2528

2629
def get_how_to_quantize(
@@ -97,6 +100,7 @@ def _apply_tiling(
97100
Returns:
98101
A tuple of (new_ca, new_ba, sum_axes).
99102
"""
103+
a = 0
100104
new_ca = [a + sum(t <= a for t in tiled_axes) for a in contracting_axes]
101105
new_ba = [a + sum(t < a for t in tiled_axes) for a in batch_axes]
102106
# We choose to insert the tile_count axes to the end of the batch axes.
@@ -399,6 +403,9 @@ def dot_general(
399403
dimension_numbers: jax.lax.DotDimensionNumbers,
400404
precision: jax.lax.PrecisionLike = None,
401405
preferred_element_type: jax.typing.DTypeLike | None = None,
406+
*,
407+
use_kernel: bool = False,
408+
kernel_config: quantized_matmul.QuantizedMatmulConfig | None = None,
402409
**kwargs,
403410
) -> jax.Array:
404411
"""Computes a general dot product with support for ``QArray`` inputs.
@@ -413,6 +420,8 @@ def dot_general(
413420
dimension_numbers: The dimension numbers passed to dot_general.
414421
precision: The precision for jax.lax.dot_general.
415422
preferred_element_type: The preferred element type for jax.lax.dot_general.
423+
use_kernel: Whether to use the Pallas kernel implementation.
424+
kernel_config: Keyword arguments to pass to the Pallas kernel.
416425
**kwargs: Additional keyword arguments to dot_general.
417426
418427
Returns:
@@ -453,6 +462,19 @@ def dot_general(
453462
use_fast_dot_general = False
454463
break
455464

465+
if (
466+
use_kernel
467+
and isinstance(lhs, qarray.QArray)
468+
and isinstance(rhs, qarray.QArray)
469+
and quantized_matmul.can_use_qmm_in_dot_general(
470+
lhs, rhs, dimension_numbers
471+
)
472+
):
473+
kernel_kwargs = dataclasses.asdict(kernel_config)
474+
return quantized_matmul.q_matmul(
475+
lhs.qvalue, lhs.scale, rhs.qvalue, rhs.scale, **kernel_kwargs
476+
)
477+
456478
if use_fast_dot_general:
457479
return _fast_dot_general(
458480
lhs,
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implements a quantized matmul kernel."""
15+
16+
import dataclasses
17+
from typing import Any
18+
19+
import jax
20+
import jax.experimental.pallas as pl
21+
import jax.numpy as jnp
22+
from qwix._src.core import qarray
23+
24+
INTERPRET: bool = True
25+
26+
27+
@dataclasses.dataclass
28+
class QuantizedMatmulConfig:
29+
bm: int = 128
30+
bk: int = 128
31+
bn: int = 128
32+
dtype: jnp.dtype = jnp.float32
33+
34+
35+
def can_use_qmm(x, sx, y, sy, *, bm, bk, bn):
36+
"""Returns whether the quantized matmul can be used."""
37+
mdim, kdim = x.shape
38+
_, ndim = y.shape
39+
k_tiles = sx.shape[1]
40+
41+
if mdim % bm != 0 or ndim % bn != 0 or kdim % bk != 0:
42+
# Block size must divide matrix size.
43+
return False
44+
grid = (mdim // bm, ndim // bn, kdim // bk)
45+
46+
# k information
47+
k_tile_size = kdim // k_tiles
48+
if k_tile_size != bk:
49+
# Block size must match the tile size for the reduction axis.
50+
return False
51+
if sx.shape[1] != sy.shape[0]:
52+
# Number of tiles must match for the scales.
53+
return False
54+
55+
if sx.shape[0] != grid[2] or sx.shape[0] != 1:
56+
# Scale size must match grid size or be 1.
57+
return False
58+
59+
if sy.shape[1] != grid[1] or sy.shape[1] != 1:
60+
# Scale size must match grid size or be 1.
61+
return False
62+
63+
return True
64+
65+
66+
def can_use_qmm_in_dot_general(
67+
lhs: qarray.QArray, rhs: qarray.QArray, dimension_numbers: Any
68+
):
69+
"""Returns whether the quantized matmul can be used in dot_general."""
70+
# Check the qarrays.
71+
if lhs.zero_point is not None or rhs.zero_point is not None:
72+
return False
73+
74+
# Check the dimension numbers.
75+
if not (
76+
len(dimension_numbers) != 2
77+
or len(dimension_numbers[0]) != 2
78+
or len(dimension_numbers[1]) != 2
79+
or tuple(dimension_numbers[0][0]) != (1,)
80+
or tuple(dimension_numbers[0][1]) != (0,)
81+
or len(dimension_numbers[1][0]) != 0
82+
or len(dimension_numbers[1][1]) != 0
83+
):
84+
return False
85+
86+
return True
87+
88+
89+
def quantized_matmul_kernel(x_ref, sx_ref, y_ref, sy_ref, o_ref):
90+
@pl.when(pl.program_id(2) == 0)
91+
def _():
92+
o_ref[...] = jnp.zeros_like(o_ref)
93+
94+
o_ref[...] += (
95+
jnp.matmul(x_ref[...], y_ref[...]).astype(sx_ref.dtype)
96+
* sx_ref[...]
97+
* sy_ref[...]
98+
)
99+
100+
101+
def q_matmul(x, sx, y, sy, *, bm=128, bk=128, bn=128, dtype=jnp.float32):
102+
"""Computes a quantized matmul with support for subchannel quantization.
103+
104+
This kernel does not cover all cases. In particular, it requires that
105+
the block sizes match the tile sizes, and that the scale sizes match the grid
106+
size or be 1.
107+
108+
Args:
109+
x: The left-hand side matrix.
110+
sx: The left-hand side scales.
111+
y: The right-hand side matrix.
112+
sy: The right-hand side scales.
113+
bm: The block size for the m dimension.
114+
bk: The block size for the k dimension.
115+
bn: The block size for the n dimension.
116+
dtype: The data type of the output.
117+
118+
Returns:
119+
The quantized matmul.
120+
"""
121+
mdim, kdim = x.shape
122+
_, ndim = y.shape
123+
k_tiles = sx.shape[1]
124+
125+
# Block specs for x and y.
126+
assert mdim % bm == 0, f'Block size must divide matrix size, {mdim=} {bm=}'
127+
assert ndim % bn == 0, f'Block size must divide matrix size, {ndim=} {bn=}'
128+
assert kdim % bk == 0, f'Block size must divide matrix size, {kdim=} {bk=}'
129+
grid = (mdim // bm, ndim // bn, kdim // bk)
130+
x_blockspec = pl.BlockSpec((bm, bk), lambda a, b, c: (a, c))
131+
y_blockspec = pl.BlockSpec((bk, bn), lambda a, b, c: (c, b))
132+
133+
# k information
134+
k_tile_size = kdim // k_tiles
135+
assert k_tile_size == bk, (
136+
'Block size must match the tile size for the reduction axis'
137+
f' {k_tile_size=} {bk=}'
138+
)
139+
assert sx.shape[1] == sy.shape[0], 'Number of tiles must match for the scales'
140+
141+
# m information
142+
if sx.shape[0] == 1:
143+
sx_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (0, c))
144+
else:
145+
assert (
146+
sx.shape[0] == grid[0]
147+
), f'Scale size must match grid size, {sx.shape[0]=} {grid[0]=}'
148+
sx_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (a, c))
149+
150+
# n information
151+
if sy.shape[1] == 1:
152+
sy_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (c, 0))
153+
else:
154+
assert (
155+
sy.shape[1] == grid[1]
156+
), f'Scale size must match grid size, {sy.shape[1]=} {grid[1]=}'
157+
sy_blockspec = pl.BlockSpec((1, 1), lambda a, b, c: (c, b))
158+
159+
return pl.pallas_call(
160+
quantized_matmul_kernel,
161+
out_shape=jax.ShapeDtypeStruct((mdim, ndim), dtype),
162+
grid=grid,
163+
in_specs=(x_blockspec, sx_blockspec, y_blockspec, sy_blockspec),
164+
out_specs=pl.BlockSpec((bm, bn), lambda a, b, c: (a, b)),
165+
interpret=INTERPRET,
166+
)(x, sx, y, sy).astype(dtype)

tests/_src/core/dot_general_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from qwix._src.core import dot_general
2020
from qwix._src.core import einsum
2121
from qwix._src.core import qarray
22+
from qwix.contrib.kernels import quantized_matmul
2223

2324

2425
class DotGeneralTest(parameterized.TestCase):
@@ -146,6 +147,34 @@ def test_fast_dot_general_channelwise_contracting(self):
146147
self.assertEqual(res.shape, (4, 4))
147148
self.assertTrue(jnp.allclose(res, jnp.full((4, 4), 8.0), atol=0.1))
148149

150+
def test_kernel_dot_general(self):
151+
lhs = jnp.ones((4, 8), jnp.float32)
152+
rhs = jnp.ones((8, 16), jnp.float32)
153+
154+
# Channelwise on axis 1 (contracting)
155+
lhs_how = qarray.HowToQuantize(
156+
qtype=jnp.int8,
157+
tiled_axes={1: 1},
158+
)
159+
# Channelwise on axis 0 (contracting)
160+
rhs_how = qarray.HowToQuantize(
161+
qtype=jnp.int8,
162+
tiled_axes={0: 1},
163+
)
164+
165+
q_lhs = qarray.quantize(lhs, lhs_how)
166+
q_rhs = qarray.quantize(rhs, rhs_how)
167+
168+
kernel_config = quantized_matmul.QuantizedMatmulConfig(bm=4, bn=16, bk=1)
169+
170+
_ = dot_general.dot_general(
171+
q_lhs,
172+
q_rhs,
173+
(([1], [0]), ([], [])),
174+
use_kernel=True,
175+
kernel_config=kernel_config,
176+
)
177+
149178

150179
if __name__ == '__main__':
151180
absltest.main()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import jax.numpy as jnp
2+
from qwix._src.core import qarray
3+
from qwix.contrib.kernels import quantized_matmul
4+
5+
from google3.testing.pybase import googletest
6+
7+
8+
class QuantizedMatmulTest(googletest.TestCase):
9+
10+
def test_kernel_dot_general(self):
11+
lhs = jnp.ones((4, 8), jnp.float32)
12+
rhs = jnp.ones((8, 16), jnp.float32)
13+
14+
# Channelwise on axis 1 (contracting)
15+
lhs_how = qarray.HowToQuantize(
16+
qtype=jnp.int8,
17+
tiled_axes={1: 1},
18+
)
19+
# Channelwise on axis 0 (contracting)
20+
rhs_how = qarray.HowToQuantize(
21+
qtype=jnp.int8,
22+
tiled_axes={0: 1},
23+
)
24+
25+
q_lhs = qarray.quantize(lhs, lhs_how)
26+
q_rhs = qarray.quantize(rhs, rhs_how)
27+
28+
_ = quantized_matmul.q_matmul(
29+
q_lhs.qvalue, q_lhs.scale, q_rhs.qvalue, q_rhs.scale, bm=4, bn=16, bk=1
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
googletest.main()

0 commit comments

Comments
 (0)