File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -204,6 +204,9 @@ def take(x: jax.Array) -> jax.Array:
204204 return jax .tree .map (take , array )
205205
206206
207+ _VALID_SCALE_DTYPES = (jnp .float16 , jnp .bfloat16 , jnp .float32 , jnp .float64 )
208+
209+
207210def validate_qarray (array : QArray ):
208211 """Validates the internal consistency of a QArray."""
209212 if not isinstance (array .qvalue , jax .Array ):
@@ -220,7 +223,7 @@ def validate_qarray(array: QArray):
220223 )
221224 if array .qvalue .dtype .itemsize > 1 :
222225 raise ValueError (f'{ array .qvalue .dtype } is not a valid type for qvalue.' )
223- if array .scale .dtype not in ( jnp . bfloat16 , jnp . float32 , jnp . float64 ) :
226+ if array .scale .dtype not in _VALID_SCALE_DTYPES :
224227 raise ValueError (f'{ array .scale .dtype } is not a valid type for scale.' )
225228 if array .zero_point is not None :
226229 if array .zero_point .ndim != array .qvalue .ndim :
You can’t perform that action at this time.
0 commit comments