Skip to content

Change composite names and refactor int8 - temporary WA for xla#2431

Draft
bkowalskiINTEL wants to merge 5 commits intomasterfrom
dev/bkowalsk/jax_int8_pr
Draft

Change composite names and refactor int8 - temporary WA for xla#2431
bkowalskiINTEL wants to merge 5 commits intomasterfrom
dev/bkowalsk/jax_int8_pr

Conversation

@bkowalskiINTEL
Copy link
Copy Markdown
Contributor

Type of Change

others

Description

Removed suffixes from quantize/dequantize composites - inc.quantize_int8... to inc.quantize, etc.

Made a minor refactor of int8 as a temporary WA for xla.

Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
@bkowalskiINTEL bkowalskiINTEL force-pushed the dev/bkowalsk/jax_int8_pr branch from 5380709 to a9bfaf0 Compare April 1, 2026 15:09
Comment on lines +438 to +441
if self._is_int8:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
else:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this with one call

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR standardizes JAX composite names for quantize/dequantize ops and applies an int8-focused refactor/workaround intended to improve XLA compatibility.

Changes:

  • Renamed jax.lax.composite names for quantize/dequantize to unified inc.quantize / inc.dequantize.
  • Refactored asymmetric dequantization arithmetic and extended get_q_params() with emulate_asymmetric; zero-point is now int32.
  • Updated static-quant layers to store activation zero-points as int32 and adjusted int8/asymmetric wiring in several prepare/convert paths.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
neural_compressor/jax/utils/utility.py Composite renames; asymmetric dequant arithmetic tweak; get_q_params() API/typing changes (incl. int32 zero-point + emulation flag).
neural_compressor/jax/quantization/layers_static.py Switch activation zero-point weights to int32 and refactor int8/asymmetric paths in static quant layers.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 309 to +313
int_min = jnp.array(jnp.iinfo(dtype).min).astype(compute_dtype)
int_max = jnp.array(jnp.iinfo(dtype).max).astype(compute_dtype)
scale = (orig_max - orig_min) / (int_max - int_min)
zero_point = jnp.round(int_min - orig_min / scale)
return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(compute_dtype)
return scale.reshape((1,)).astype(compute_dtype), zero_point.reshape((1,)).astype(jnp.int32)
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In integer_get_q_params, the non-empty path now returns zero_point as jnp.int32, but the empty-tensor early return above still returns a compute_dtype (float) zero_point. With a_zero_point variables now defined as jnp.int32, the empty-tensor case can trigger a dtype mismatch or assignment error. Consider returning an int32 zero_point (e.g., jnp.array(0, dtype=jnp.int32)) consistently in the empty-tensor branch as well.

Copilot uses AI. Check for mistakes.
Comment on lines +438 to +441
if self._is_int8:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
else:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if self._is_int8 branch is redundant because both branches call self.wquantfun(self.kernel, self.w_scale.value) with identical arguments. Consider removing the conditional (or, if intended as an XLA workaround, make the differing behavior explicit) to avoid dead code paths.

Suggested change
if self._is_int8:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
else:
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)
_kernel_quant = self.wquantfun(self.kernel, self.w_scale.value)

Copilot uses AI. Check for mistakes.
Comment on lines +277 to +287
def get_q_params(
orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False, emulate_asymmetric=False
):
"""Compute quantization scale and zero-point for a weight tensor.

Args:
orig_weight (jnp.ndarray): Weight tensor to analyze.
dtype (jnp.dtype): Target quantized dtype.
compute_dtype (jnp.dtype): dtype for scale computation.
asymmetric (bool): Whether to compute asymmetric quantization parameters.
emulate_asymmetric (bool): Whether to emulate asymmetric quantization using symmetric quantization with zero_poin=0
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove emulate mode

Comment on lines +319 to +320
if emulate_asymmetric:
return get_scale(orig_weight, dtype, compute_dtype), jnp.array((0,), dtype=jnp.int32)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove emulate mode

Force scales to be fp32 for int8 - temporary

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Revert "Force scales to be fp32 for int8 - temporary"

This reverts commit 3f5aaf6.

Code cleanup

Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
@bkowalskiINTEL bkowalskiINTEL force-pushed the dev/bkowalsk/jax_int8_pr branch from 33c1f69 to 052ffa5 Compare April 9, 2026 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants