Change composite names and refactor int8 - temporary WA for xla#2431
Change composite names and refactor int8 - temporary WA for xla#2431bkowalskiINTEL wants to merge 5 commits intomasterfrom
Conversation
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
5380709 to
a9bfaf0
Compare
3bd98e2 to
3f5aaf6
Compare
| if self._is_int8: | ||
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) | ||
| else: | ||
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) |
There was a problem hiding this comment.
Remove this with one call
There was a problem hiding this comment.
Pull request overview
This PR standardizes JAX composite names for quantize/dequantize ops and applies an int8-focused refactor/workaround intended to improve XLA compatibility.
Changes:
- Renamed
jax.lax.compositenames for quantize/dequantize to unifiedinc.quantize/inc.dequantize. - Refactored asymmetric dequantization arithmetic and extended
get_q_params()withemulate_asymmetric; zero-point is nowint32. - Updated static-quant layers to store activation zero-points as
int32and adjusted int8/asymmetric wiring in several prepare/convert paths.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
neural_compressor/jax/utils/utility.py |
Composite renames; asymmetric dequant arithmetic tweak; get_q_params() API/typing changes (incl. int32 zero-point + emulation flag). |
neural_compressor/jax/quantization/layers_static.py |
Switch activation zero-point weights to int32 and refactor int8/asymmetric paths in static quant layers. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int_min = jnp.array(jnp.iinfo(dtype).min).astype(compute_dtype) | ||
| int_max = jnp.array(jnp.iinfo(dtype).max).astype(compute_dtype) | ||
| scale = (orig_max - orig_min) / (int_max - int_min) | ||
| zero_point = jnp.round(int_min - orig_min / scale) | ||
| return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(compute_dtype) | ||
| return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(jnp.int32) |
There was a problem hiding this comment.
In integer_get_q_params, the non-empty path now returns zero_point as jnp.int32, but the empty-tensor early return above still returns a compute_dtype (float) zero_point. With a_zero_point variables now defined as jnp.int32, the empty-tensor case can trigger a dtype mismatch or assignment error. Consider returning an int32 zero_point (e.g., jnp.array(0, dtype=jnp.int32)) consistently in the empty-tensor branch as well.
| if self._is_int8: | ||
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) | ||
| else: | ||
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) |
There was a problem hiding this comment.
This if self._is_int8 branch is redundant because both branches call self.wquantfun(self.kernel, self.w_scale.value) with identical arguments. Consider removing the conditional (or, if intended as an XLA workaround, make the differing behavior explicit) to avoid dead code paths.
| if self._is_int8: | |
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) | |
| else: | |
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) | |
| _kernel_quant = self.wquantfun(self.kernel, self.w_scale.value) |
| def get_q_params( | ||
| orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False, emulate_asymmetric=False | ||
| ): | ||
| """Compute quantization scale and zero-point for a weight tensor. | ||
|
|
||
| Args: | ||
| orig_weight (jnp.ndarray): Weight tensor to analyze. | ||
| dtype (jnp.dtype): Target quantized dtype. | ||
| compute_dtype (jnp.dtype): dtype for scale computation. | ||
| asymmetric (bool): Whether to compute asymmetric quantization parameters. | ||
| emulate_asymmetric (bool): Whether to emulate asymmetric quantization using symmetric quantization with zero_poin=0 |
There was a problem hiding this comment.
Remove emulate mode
| if emulate_asymmetric: | ||
| return get_scale(orig_weight, dtype, compute_dtype), jnp.array((0,), dtype=jnp.int32) |
There was a problem hiding this comment.
Remove emulate mode
Force scales to be fp32 for int8 - temporary [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Revert "Force scales to be fp32 for int8 - temporary" This reverts commit 3f5aaf6. Code cleanup Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
33c1f69 to
052ffa5
Compare
Type of Change
others
Description
Removed suffixes from quantize/dequantize composites - inc.quantize_int8... to inc.quantize, etc.
Made a minor refactor of int8 as a temporary WA for xla.