|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import logging |
15 | 16 | from unittest import mock |
16 | 17 | from absl.testing import absltest |
17 | 18 | from absl.testing import parameterized |
|
21 | 22 | from qwix._src.core import ragged_dot |
22 | 23 |
|
23 | 24 |
|
24 | | -def mae(a, b): |
| 25 | +def rel_mae(a, b): |
25 | 26 | assert a.dtype == b.dtype and a.shape == b.shape |
26 | 27 | return jnp.abs(a - b).mean() / jnp.abs(a).mean() |
27 | 28 |
|
28 | 29 |
|
29 | 30 | class RaggedDotTest(parameterized.TestCase): |
30 | 31 |
|
| 32 | + def setUp(self): |
| 33 | + super().setUp() |
| 34 | + self._random_key = jax.random.key(42) |
| 35 | + |
| 36 | + def _make_array(self, shape, asymmetric=False): |
| 37 | + self._random_key, key = jax.random.split(self._random_key) |
| 38 | + if asymmetric: |
| 39 | + return jax.random.uniform(key, shape, jnp.float32) |
| 40 | + return jax.random.normal(key, shape, jnp.float32) |
| 41 | + |
31 | 42 | @parameterized.named_parameters( |
32 | 43 | dict( |
33 | | - testcase_name='no_channelwise', |
34 | | - lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]), |
35 | | - rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]), |
| 44 | + testcase_name='int8', |
| 45 | + lhs_shape=(128, 256), |
| 46 | + lhs_how=qarray.HowToQuantize(qtype=jnp.int8), |
| 47 | + rhs_shape=(4, 256, 64), |
| 48 | + rhs_how=qarray.HowToQuantize(qtype=jnp.int8), |
| 49 | + group_sizes=(64, 32, 16, 16), |
| 50 | + expected_mae=0.03, |
| 51 | + ), |
| 52 | + dict( |
| 53 | + testcase_name='lhs_asymmetric', |
| 54 | + lhs_shape=(128, 256), |
| 55 | + lhs_how=qarray.HowToQuantize( |
| 56 | + qtype=jnp.int8, |
| 57 | + calibration_method='minmax', |
| 58 | + ), |
| 59 | + rhs_shape=(4, 256, 64), |
| 60 | + rhs_how=qarray.HowToQuantize( |
| 61 | + qtype=jnp.int8, |
| 62 | + calibration_method='absmax', |
| 63 | + ), |
| 64 | + group_sizes=(50, 50, 28, 0), |
| 65 | + expected_mae=0.07, |
| 66 | + disable_fast_ragged_dot=True, |
| 67 | + ), |
| 68 | + dict( |
| 69 | + testcase_name='rhs_group_channelwise', |
| 70 | + lhs_shape=(128, 256), |
| 71 | + lhs_how=qarray.HowToQuantize( |
| 72 | + qtype=jnp.int8, |
| 73 | + calibration_method='absmax', |
| 74 | + ), |
| 75 | + rhs_shape=(4, 256, 64), |
| 76 | + rhs_how=qarray.HowToQuantize( |
| 77 | + qtype=jnp.int8, |
| 78 | + channelwise_axes=(0,), |
| 79 | + calibration_method='absmax', |
| 80 | + ), |
| 81 | + group_sizes=(128, 0, 0, 0), |
| 82 | + expected_mae=0.03, |
| 83 | + disable_fast_ragged_dot=True, |
| 84 | + ), |
| 85 | + dict( |
| 86 | + testcase_name='rhs_contracting_tiled', |
| 87 | + lhs_shape=(128, 256), |
| 88 | + lhs_how=qarray.HowToQuantize( |
| 89 | + qtype=jnp.int8, |
| 90 | + calibration_method='absmax', |
| 91 | + ), |
| 92 | + rhs_shape=(4, 256, 64), |
| 93 | + rhs_how=qarray.HowToQuantize( |
| 94 | + qtype=jnp.int8, |
| 95 | + tiled_axes={1: 128}, |
| 96 | + calibration_method='absmax', |
| 97 | + ), |
| 98 | + group_sizes=(10, 20, 30, 68), |
| 99 | + expected_mae=0.03, |
| 100 | + disable_fast_ragged_dot=True, |
36 | 101 | ), |
37 | 102 | dict( |
38 | 103 | testcase_name='channelwise', |
39 | | - lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]), |
40 | | - rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[2]), |
| 104 | + lhs_shape=(128, 256), |
| 105 | + lhs_how=qarray.HowToQuantize( |
| 106 | + qtype=jnp.float8_e5m2, |
| 107 | + channelwise_axes=(0,), |
| 108 | + ), |
| 109 | + rhs_shape=(4, 256, 64), |
| 110 | + rhs_how=qarray.HowToQuantize( |
| 111 | + qtype=jnp.float8_e5m2, |
| 112 | + channelwise_axes=(2,), |
| 113 | + ), |
| 114 | + group_sizes=(128, 100, 0, 28), |
| 115 | + expected_mae=0.08, |
41 | 116 | ), |
42 | 117 | dict( |
43 | | - testcase_name='more_channelwise', |
44 | | - lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]), |
45 | | - rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0, 2]), |
| 118 | + testcase_name='rhs_group_and_out_channelwise', |
| 119 | + lhs_shape=(128, 256), |
| 120 | + lhs_how=qarray.HowToQuantize( |
| 121 | + qtype=jnp.float8_e5m2, |
| 122 | + channelwise_axes=(0,), |
| 123 | + ), |
| 124 | + rhs_shape=(4, 256, 64), |
| 125 | + rhs_how=qarray.HowToQuantize( |
| 126 | + qtype=jnp.float8_e5m2, |
| 127 | + channelwise_axes=(0, 2), |
| 128 | + ), |
| 129 | + group_sizes=(128, 100, 0, 28), |
| 130 | + expected_mae=0.08, |
46 | 131 | ), |
47 | 132 | ) |
48 | 133 | def test_ragged_dot( |
49 | 134 | self, |
50 | | - lhs_how, |
51 | | - rhs_how, |
52 | | - disable_fast_path=False, |
| 135 | + *, |
| 136 | + lhs_shape: tuple[int, ...], |
| 137 | + lhs_how: qarray.HowToQuantize | None, |
| 138 | + rhs_shape: tuple[int, ...], |
| 139 | + rhs_how: qarray.HowToQuantize | None, |
| 140 | + group_sizes: tuple[int, ...], |
| 141 | + expected_mae: float, |
| 142 | + disable_fast_ragged_dot: bool = False, |
53 | 143 | ): |
54 | | - lhs = jax.random.normal(jax.random.key(0), (256, 16), jnp.bfloat16) |
55 | | - rhs = jax.random.normal(jax.random.key(1), (10, 16, 64), jnp.bfloat16) |
56 | | - group_sizes = jnp.array([10, 20, 30, 40, 0, 115, 6, 7, 1, 27], jnp.int32) |
| 144 | + lhs_asymmetric = ( |
| 145 | + lhs_how.calibration_method == 'minmax' if lhs_how else False |
| 146 | + ) |
| 147 | + rhs_asymmetric = ( |
| 148 | + rhs_how.calibration_method == 'minmax' if rhs_how else False |
| 149 | + ) |
| 150 | + lhs = self._make_array(lhs_shape, lhs_asymmetric) |
| 151 | + rhs = self._make_array(rhs_shape, rhs_asymmetric) |
| 152 | + group_sizes = jnp.array(group_sizes) |
57 | 153 |
|
58 | | - fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes) |
| 154 | + q_lhs = qarray.quantize(lhs, lhs_how) if lhs_how else lhs |
| 155 | + q_rhs = qarray.quantize(rhs, rhs_how) if rhs_how else rhs |
59 | 156 |
|
60 | | - qlhs = qarray.quantize(lhs, lhs_how) |
61 | | - qrhs = qarray.quantize(rhs, rhs_how) |
| 157 | + @jax.jit |
| 158 | + def _multi_ragged_dot(lhs, rhs, fp_res): |
| 159 | + slow_res = ragged_dot._slow_ragged_dot(lhs, rhs, group_sizes) |
| 160 | + if disable_fast_ragged_dot: |
| 161 | + fast_res = slow_res |
| 162 | + else: |
| 163 | + fast_res = ragged_dot._fast_ragged_dot(lhs, rhs, group_sizes) |
| 164 | + return ( |
| 165 | + rel_mae(slow_res, fp_res), |
| 166 | + rel_mae(slow_res, fast_res), |
| 167 | + ) |
62 | 168 |
|
63 | | - slow_res = ragged_dot._slow_ragged_dot(qlhs, qrhs, group_sizes) |
64 | | - self.assertLess(mae(slow_res, fp_res), 0.02) |
| 169 | + fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes) |
| 170 | + fp_mae, fast_mae = _multi_ragged_dot(q_lhs, q_rhs, fp_res) |
65 | 171 |
|
66 | | - if not disable_fast_path: |
67 | | - fast_res = ragged_dot._fast_ragged_dot(qlhs, qrhs, group_sizes) |
68 | | - self.assertLess(mae(fast_res, slow_res), 0.005) |
| 172 | + logging.info('fp_mae=%s fast_mae=%s', fp_mae, fast_mae) |
| 173 | + self.assertLessEqual(fp_mae, expected_mae) |
| 174 | + self.assertLessEqual(fast_mae, 0.003) |
69 | 175 |
|
70 | 176 | @parameterized.named_parameters( |
71 | 177 | dict( |
|
0 commit comments