Skip to content

Commit 1eaff08

Browse files
Qwix Developerscopybara-github
authored andcommitted
multiplatform tests for qwix.
PiperOrigin-RevId: 816464117
1 parent f54f56d commit 1eaff08

3 files changed

Lines changed: 136 additions & 202 deletions

File tree

tests/core/ragged_dot_qt_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class RaggedDotQtTest(parameterized.TestCase):
5757
@parameterized.named_parameters(
5858
dict(
5959
testcase_name="fp8",
60-
lhs_qtype=jnp.float8_e4m3,
61-
rhs_qtype=jnp.float8_e4m3,
60+
lhs_qtype=jnp.float8_e4m3fn,
61+
rhs_qtype=jnp.float8_e4m3fn,
6262
expected_mae_fq_out=1e-6,
6363
expected_mae_fq_dlhs=1e-6,
6464
expected_mae_fq_drhs=1e-6,
@@ -68,9 +68,9 @@ class RaggedDotQtTest(parameterized.TestCase):
6868
),
6969
dict(
7070
testcase_name="fp8_bwd",
71-
lhs_qtype=jnp.float8_e4m3,
72-
rhs_qtype=jnp.float8_e4m3,
73-
bwd_qtype=jnp.float8_e4m3,
71+
lhs_qtype=jnp.float8_e4m3fn,
72+
rhs_qtype=jnp.float8_e4m3fn,
73+
bwd_qtype=jnp.float8_e4m3fn,
7474
expected_mae_fq_out=1e-6,
7575
expected_mae_fq_dlhs=0.03,
7676
expected_mae_fq_drhs=0.03,
@@ -154,8 +154,8 @@ def test_traced_group_sizes(self):
154154
lhs = jax.random.normal(jax.random.key(0), (256, 64), jnp.float32)
155155
rhs = jax.random.normal(jax.random.key(1), (8, 64, 128), jnp.float32)
156156
config = ragged_dot_qt.RaggedDotQtConfig(
157-
lhs_qtype=jnp.float8_e4m3,
158-
rhs_qtype=jnp.float8_e4m3,
157+
lhs_qtype=jnp.float8_e4m3fn,
158+
rhs_qtype=jnp.float8_e4m3fn,
159159
)
160160

161161
@jax.jit

tests/core/ragged_dot_test.py

Lines changed: 129 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from unittest import mock
1617
from absl.testing import absltest
1718
from absl.testing import parameterized
@@ -21,51 +22,156 @@
2122
from qwix._src.core import ragged_dot
2223

2324

24-
def mae(a, b):
25+
def rel_mae(a, b):
2526
assert a.dtype == b.dtype and a.shape == b.shape
2627
return jnp.abs(a - b).mean() / jnp.abs(a).mean()
2728

2829

2930
class RaggedDotTest(parameterized.TestCase):
3031

32+
def setUp(self):
33+
super().setUp()
34+
self._random_key = jax.random.key(42)
35+
36+
def _make_array(self, shape, asymmetric=False):
37+
self._random_key, key = jax.random.split(self._random_key)
38+
if asymmetric:
39+
return jax.random.uniform(key, shape, jnp.float32)
40+
return jax.random.normal(key, shape, jnp.float32)
41+
3142
@parameterized.named_parameters(
3243
dict(
33-
testcase_name='no_channelwise',
34-
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]),
35-
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]),
44+
testcase_name='int8',
45+
lhs_shape=(128, 256),
46+
lhs_how=qarray.HowToQuantize(qtype=jnp.int8),
47+
rhs_shape=(4, 256, 64),
48+
rhs_how=qarray.HowToQuantize(qtype=jnp.int8),
49+
group_sizes=(64, 32, 16, 16),
50+
expected_mae=0.03,
51+
),
52+
dict(
53+
testcase_name='lhs_asymmetric',
54+
lhs_shape=(128, 256),
55+
lhs_how=qarray.HowToQuantize(
56+
qtype=jnp.int8,
57+
calibration_method='minmax',
58+
),
59+
rhs_shape=(4, 256, 64),
60+
rhs_how=qarray.HowToQuantize(
61+
qtype=jnp.int8,
62+
calibration_method='absmax',
63+
),
64+
group_sizes=(50, 50, 28, 0),
65+
expected_mae=0.07,
66+
disable_fast_ragged_dot=True,
67+
),
68+
dict(
69+
testcase_name='rhs_group_channelwise',
70+
lhs_shape=(128, 256),
71+
lhs_how=qarray.HowToQuantize(
72+
qtype=jnp.int8,
73+
calibration_method='absmax',
74+
),
75+
rhs_shape=(4, 256, 64),
76+
rhs_how=qarray.HowToQuantize(
77+
qtype=jnp.int8,
78+
channelwise_axes=(0,),
79+
calibration_method='absmax',
80+
),
81+
group_sizes=(128, 0, 0, 0),
82+
expected_mae=0.03,
83+
disable_fast_ragged_dot=True,
84+
),
85+
dict(
86+
testcase_name='rhs_contracting_tiled',
87+
lhs_shape=(128, 256),
88+
lhs_how=qarray.HowToQuantize(
89+
qtype=jnp.int8,
90+
calibration_method='absmax',
91+
),
92+
rhs_shape=(4, 256, 64),
93+
rhs_how=qarray.HowToQuantize(
94+
qtype=jnp.int8,
95+
tiled_axes={1: 128},
96+
calibration_method='absmax',
97+
),
98+
group_sizes=(10, 20, 30, 68),
99+
expected_mae=0.03,
100+
disable_fast_ragged_dot=True,
36101
),
37102
dict(
38103
testcase_name='channelwise',
39-
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]),
40-
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[2]),
104+
lhs_shape=(128, 256),
105+
lhs_how=qarray.HowToQuantize(
106+
qtype=jnp.float8_e5m2,
107+
channelwise_axes=(0,),
108+
),
109+
rhs_shape=(4, 256, 64),
110+
rhs_how=qarray.HowToQuantize(
111+
qtype=jnp.float8_e5m2,
112+
channelwise_axes=(2,),
113+
),
114+
group_sizes=(128, 100, 0, 28),
115+
expected_mae=0.08,
41116
),
42117
dict(
43-
testcase_name='more_channelwise',
44-
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]),
45-
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0, 2]),
118+
testcase_name='rhs_group_and_out_channelwise',
119+
lhs_shape=(128, 256),
120+
lhs_how=qarray.HowToQuantize(
121+
qtype=jnp.float8_e5m2,
122+
channelwise_axes=(0,),
123+
),
124+
rhs_shape=(4, 256, 64),
125+
rhs_how=qarray.HowToQuantize(
126+
qtype=jnp.float8_e5m2,
127+
channelwise_axes=(0, 2),
128+
),
129+
group_sizes=(128, 100, 0, 28),
130+
expected_mae=0.08,
46131
),
47132
)
48133
def test_ragged_dot(
49134
self,
50-
lhs_how,
51-
rhs_how,
52-
disable_fast_path=False,
135+
*,
136+
lhs_shape: tuple[int, ...],
137+
lhs_how: qarray.HowToQuantize | None,
138+
rhs_shape: tuple[int, ...],
139+
rhs_how: qarray.HowToQuantize | None,
140+
group_sizes: tuple[int, ...],
141+
expected_mae: float,
142+
disable_fast_ragged_dot: bool = False,
53143
):
54-
lhs = jax.random.normal(jax.random.key(0), (256, 16), jnp.bfloat16)
55-
rhs = jax.random.normal(jax.random.key(1), (10, 16, 64), jnp.bfloat16)
56-
group_sizes = jnp.array([10, 20, 30, 40, 0, 115, 6, 7, 1, 27], jnp.int32)
144+
lhs_asymmetric = (
145+
lhs_how.calibration_method == 'minmax' if lhs_how else False
146+
)
147+
rhs_asymmetric = (
148+
rhs_how.calibration_method == 'minmax' if rhs_how else False
149+
)
150+
lhs = self._make_array(lhs_shape, lhs_asymmetric)
151+
rhs = self._make_array(rhs_shape, rhs_asymmetric)
152+
group_sizes = jnp.array(group_sizes)
57153

58-
fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes)
154+
q_lhs = qarray.quantize(lhs, lhs_how) if lhs_how else lhs
155+
q_rhs = qarray.quantize(rhs, rhs_how) if rhs_how else rhs
59156

60-
qlhs = qarray.quantize(lhs, lhs_how)
61-
qrhs = qarray.quantize(rhs, rhs_how)
157+
@jax.jit
158+
def _multi_ragged_dot(lhs, rhs, fp_res):
159+
slow_res = ragged_dot._slow_ragged_dot(lhs, rhs, group_sizes)
160+
if disable_fast_ragged_dot:
161+
fast_res = slow_res
162+
else:
163+
fast_res = ragged_dot._fast_ragged_dot(lhs, rhs, group_sizes)
164+
return (
165+
rel_mae(slow_res, fp_res),
166+
rel_mae(slow_res, fast_res),
167+
)
62168

63-
slow_res = ragged_dot._slow_ragged_dot(qlhs, qrhs, group_sizes)
64-
self.assertLess(mae(slow_res, fp_res), 0.02)
169+
fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes)
170+
fp_mae, fast_mae = _multi_ragged_dot(q_lhs, q_rhs, fp_res)
65171

66-
if not disable_fast_path:
67-
fast_res = ragged_dot._fast_ragged_dot(qlhs, qrhs, group_sizes)
68-
self.assertLess(mae(fast_res, slow_res), 0.005)
172+
logging.info('fp_mae=%s fast_mae=%s', fp_mae, fast_mae)
173+
self.assertLessEqual(fp_mae, expected_mae)
174+
self.assertLessEqual(fast_mae, 0.003)
69175

70176
@parameterized.named_parameters(
71177
dict(

tests/core/ragged_dot_tpu_test.py

Lines changed: 0 additions & 172 deletions
This file was deleted.

0 commit comments

Comments
 (0)