Skip to content

Commit f58e033

Browse files
chr1sj0nescopybara-github
authored andcommitted
Remove QuantizedArray class.
The is the final step in the transition to using `qwix.QArray`. PiperOrigin-RevId: 834596328
1 parent e7aeea1 commit f58e033

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

qwix/_src/core/qarray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def take(x: jax.Array) -> jax.Array:
204204
return jax.tree.map(take, array)
205205

206206

207+
_VALID_SCALE_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64)
208+
209+
207210
def validate_qarray(array: QArray):
208211
"""Validates the internal consistency of a QArray."""
209212
if not isinstance(array.qvalue, jax.Array):
@@ -220,7 +223,7 @@ def validate_qarray(array: QArray):
220223
)
221224
if array.qvalue.dtype.itemsize > 1:
222225
raise ValueError(f'{array.qvalue.dtype} is not a valid type for qvalue.')
223-
if array.scale.dtype not in (jnp.bfloat16, jnp.float32, jnp.float64):
226+
if array.scale.dtype not in _VALID_SCALE_DTYPES:
224227
raise ValueError(f'{array.scale.dtype} is not a valid type for scale.')
225228
if array.zero_point is not None:
226229
if array.zero_point.ndim != array.qvalue.ndim:

0 commit comments

Comments
 (0)