Skip to content
22 changes: 10 additions & 12 deletions neural_compressor/jax/quantization/layers_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def add_variables(self):
initializer="zeros",
trainable=False,
autocast=False,
dtype=self.compute_dtype,
dtype=jnp.int32,
)
self.a_scale = self.add_weight(
name="a_scale",
Expand Down Expand Up @@ -374,7 +374,7 @@ def add_variables(self):
initializer="zeros",
trainable=False,
autocast=False,
dtype=self.compute_dtype,
dtype=jnp.int32,
)
self.a_scale = self.add_weight(
name="a_scale",
Expand Down Expand Up @@ -435,7 +435,6 @@ def convert(self):

w_scale, _ = get_q_params(self.kernel, self.weight_dtype, self.compute_dtype, asymmetric=False)
self.w_scale.assign(w_scale)

_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
self._kernel_quant.assign(_kernel_quant)
self._tracker.lock()
Expand Down Expand Up @@ -585,14 +584,10 @@ def prepare(cls, orig, weight_dtype, activation_dtype, const_scale=False, const_
orig._tracker.unlock()
orig.__class__ = cls
orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer)
orig.q_qdq = StaticQDQLayer(
"q_qdq", activation_dtype, orig.dtype_policy, False, const_scale
) # the second argument of einsum has to be quantized symmetrically for onednn to work
orig.q_qdq = StaticQDQLayer("q_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.k_qdq = StaticQDQLayer("k_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.a_qdq = StaticQDQLayer("a_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.v_qdq = StaticQDQLayer(
"v_qdq", activation_dtype, orig.dtype_policy, False, const_scale
) # the second argument of einsum has to be quantized symmetrically for onednn to work
orig.v_qdq = StaticQDQLayer("v_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig._is_quantized = False
orig._tracker.lock()
return orig
Expand Down Expand Up @@ -1066,9 +1061,12 @@ def prepare(cls, orig, weight_dtype, activation_dtype, const_scale=False, const_
"""
orig._tracker.unlock()
orig.__class__ = cls
orig.positions_qdq = StaticQDQLayer("positions_qdq", activation_dtype, orig.dtype_policy, False, const_scale)
orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer)
orig.positions_qdq = StaticQDQLayer(
"positions_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale
)
orig.inverse_freq_qdq = StaticQDQLayer(
"inverse_freq_qdq", activation_dtype, orig.dtype_policy, False, const_scale
"inverse_freq_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale
)
orig._is_quantized = False
orig._tracker.lock()
Expand Down Expand Up @@ -1177,7 +1175,7 @@ def prepare(cls, orig, weight_dtype, activation_dtype, const_scale=False, const_
orig.__class__ = cls
orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer)
orig.inputs_qdq = StaticQDQLayer("inputs_qdq", activation_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.kernel_qdq = StaticQDQLayer("kernel_qdq", weight_dtype, orig.dtype_policy, False, const_scale)
orig.kernel_qdq = StaticQDQLayer("kernel_qdq", weight_dtype, orig.dtype_policy, orig._is_int8, const_scale)
orig.const_scale = const_scale
Comment thread
bkowalskiINTEL marked this conversation as resolved.
orig.const_weight = const_weight
orig._is_quantized = False
Expand Down
15 changes: 8 additions & 7 deletions neural_compressor/jax/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_quantize_fun(dtype=ml_dtypes.float8_e4m3, asymmetric=False):
Callable: Quantization function that maps tensors to the target dtype.
"""

@partial(jax.lax.composite, name="inc.quantize_fp8")
@partial(jax.lax.composite, name="inc.quantize")
def quantize_tensor_float(x, scale):
"""Quantize floating-point tensors using clamping.

Expand All @@ -134,7 +134,7 @@ def quantize_tensor_float(x, scale):
jnp.finfo(dtype).min.astype(x.dtype), x / scale, jnp.finfo(dtype).max.astype(x.dtype)
).astype(dtype)

@partial(jax.lax.composite, name="inc.quantize_int8")
@partial(jax.lax.composite, name="inc.quantize")
def quantize_tensor_int(x, scale):
"""Quantize integer tensors using symmetric scaling.

Expand All @@ -149,7 +149,7 @@ def quantize_tensor_int(x, scale):
val = jnp.clip(val, jnp.iinfo(dtype).min, jnp.iinfo(dtype).max)
return val.astype(dtype)

@partial(jax.lax.composite, name="inc.quantize_int8_asymmetric")
@partial(jax.lax.composite, name="inc.quantize")
def quantize_tensor_int_asymmetric(x, scale, zero_point):
"""Quantize integer tensors using asymmetric scaling.

Expand Down Expand Up @@ -197,7 +197,7 @@ def dequantize(x, scale):
"""
return x.astype(dtype) * scale

@partial(jax.lax.composite, name="inc.dequantize_asymmetric")
@partial(jax.lax.composite, name="inc.dequantize")
def dequantize_asymmetric(x, scale, zero_point=jnp.array(0, dtype=dtype)):
"""Dequantize a tensor with asymmetric scaling.

Expand All @@ -209,7 +209,8 @@ def dequantize_asymmetric(x, scale, zero_point=jnp.array(0, dtype=dtype)):
Returns:
jnp.ndarray: Dequantized tensor.
"""
return (x.astype(dtype) - zero_point) * scale
negated_zero_point = -zero_point
return (x.astype(dtype) + negated_zero_point) * scale

return dequantize_asymmetric if asymmetric else dequantize

Expand Down Expand Up @@ -299,14 +300,14 @@ def integer_get_q_params(orig_weight):
"""
if 0 in orig_weight.shape:
# For empty tensor, return scale as 1.0
return jnp.array(1.0, dtype=compute_dtype), jnp.array(0.0, dtype=compute_dtype)
return jnp.array(1.0, dtype=compute_dtype), jnp.array(0.0, dtype=jnp.int32)
orig_min = jnp.min(orig_weight).astype(compute_dtype)
orig_max = jnp.max(orig_weight).astype(compute_dtype)
int_min = jnp.array(jnp.iinfo(dtype).min).astype(compute_dtype)
int_max = jnp.array(jnp.iinfo(dtype).max).astype(compute_dtype)
scale = (orig_max - orig_min) / (int_max - int_min)
zero_point = jnp.round(int_min - orig_min / scale)
return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(compute_dtype)
return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(jnp.int32)

if jnp.issubdtype(dtype, jnp.floating):
return get_scale(orig_weight, dtype, compute_dtype), None
Expand Down
Loading