Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions qwix/_src/core/dot_general_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,10 @@ class DotGeneralQtConfig:
# the gradient clipping is always skipped.
disable_gradient_clipping: bool = False

# By default, Qwix will use the quantized inputs from the fwd pass as the
# residuals during the bwd pass, which is generally more accurate and more
# efficient. Enabling this will allow separate quantization schema for the
# residuals.
use_original_residuals: bool = False

# Quantization for residuals. These only take effect when residuals are not
# already quantized, i.e., either use_original_residuals is True or the
# corresponding qtype in the fwd pass is None.
# already quantized. Qwix will use the quantized inputs from the fwd pass as
# the residuals during the bwd pass only when both the fwd pass and the bwd
# pass are quantized.
dlhs_residual_qtype: jax.typing.DTypeLike | None = None
dlhs_residual_calibration_method: str = 'absmax'
dlhs_residual_disable_channelwise_axes: bool = False
Expand Down Expand Up @@ -201,7 +196,10 @@ def _compute_gradient_for_operand(g: jax.Array, *, for_dlhs: bool):
g_calibration_method = config.dlhs_grad_calibration_method
g_noise_fn = config.dlhs_stochastic_rounding_noise_fn
g_disable_channelwise_axes = config.dlhs_grad_disable_channelwise_axes
y = rhs_in if config.use_original_residuals else rhs
fwd_quantized = config.rhs_qtype is not None
bwd_quantized = config.dlhs_grad_qtype is not None
use_quantized_residual = fwd_quantized and bwd_quantized
y = rhs if use_quantized_residual else rhs_in
y_qtype = config.dlhs_residual_qtype
y_calibration_method = config.dlhs_residual_calibration_method
y_disable_channelwise_axes = config.dlhs_residual_disable_channelwise_axes
Expand All @@ -211,7 +209,10 @@ def _compute_gradient_for_operand(g: jax.Array, *, for_dlhs: bool):
g_calibration_method = config.drhs_grad_calibration_method
g_noise_fn = config.drhs_stochastic_rounding_noise_fn
g_disable_channelwise_axes = config.drhs_grad_disable_channelwise_axes
y = lhs_in if config.use_original_residuals else lhs
fwd_quantized = config.lhs_qtype is not None
bwd_quantized = config.drhs_grad_qtype is not None
use_quantized_residual = fwd_quantized and bwd_quantized
y = lhs if use_quantized_residual else lhs_in
y_qtype = config.drhs_residual_qtype
y_calibration_method = config.drhs_residual_calibration_method
y_disable_channelwise_axes = config.drhs_residual_disable_channelwise_axes
Expand Down
26 changes: 23 additions & 3 deletions tests/_src/core/dot_general_qt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,40 @@ class DotGeneralQtTest(parameterized.TestCase):
testcase_name='int8',
lhs_qtype='int8',
rhs_qtype='int8',
bwd_qtype='int8',
# If we set precision=HIGHEST in fq implementation above, then
# expected_mae_fq_out will become 1e-7 but expected_mae_fq_grads will
# be non-zero.
expected_mae_fq_out=0.02,
expected_mae_fq_grads=0.0,
expected_mae_fq_grads=0.01,
expected_mae_fp_out=0.06,
expected_mae_fp_grads=0.02,
expected_mae_fp_grads=0.03,
),
dict(
testcase_name='int4',
lhs_qtype='int4',
rhs_qtype='int4',
bwd_qtype='int4',
expected_mae_fq_out=0.04,
expected_mae_fq_grads=0.2,
expected_mae_fp_out=0.5,
expected_mae_fp_grads=0.5,
),
dict(
testcase_name='int8_fwd',
lhs_qtype='int8',
rhs_qtype='int8',
expected_mae_fq_out=0.02,
expected_mae_fq_grads=0.03,
expected_mae_fp_out=0.06,
expected_mae_fp_grads=0.02,
),
dict(
testcase_name='int4_fwd',
lhs_qtype='int4',
rhs_qtype='int4',
expected_mae_fq_out=0.04,
expected_mae_fq_grads=0.0,
expected_mae_fq_grads=0.5,
expected_mae_fp_out=0.5,
expected_mae_fp_grads=0.5,
),
Expand Down