diff --git a/qwix/_src/core/dot_general_qt.py b/qwix/_src/core/dot_general_qt.py index 4933886..d3f55cf 100644 --- a/qwix/_src/core/dot_general_qt.py +++ b/qwix/_src/core/dot_general_qt.py @@ -62,15 +62,10 @@ class DotGeneralQtConfig: # the gradient clipping is always skipped. disable_gradient_clipping: bool = False - # By default, Qwix will use the quantized inputs from the fwd pass as the - # residuals during the bwd pass, which is generally more accurate and more - # efficient. Enabling this will allow separate quantization schema for the - # residuals. - use_original_residuals: bool = False - # Quantization for residuals. These only take effect when residuals are not - # already quantized, i.e., either use_original_residuals is True or the - # corresponding qtype in the fwd pass is None. + # already quantized. Qwix will use the quantized inputs from the fwd pass as + # the residuals during the bwd pass only when both the fwd pass and the bwd + # pass are quantized. dlhs_residual_qtype: jax.typing.DTypeLike | None = None dlhs_residual_calibration_method: str = 'absmax' dlhs_residual_disable_channelwise_axes: bool = False @@ -201,7 +196,10 @@ def _compute_gradient_for_operand(g: jax.Array, *, for_dlhs: bool): g_calibration_method = config.dlhs_grad_calibration_method g_noise_fn = config.dlhs_stochastic_rounding_noise_fn g_disable_channelwise_axes = config.dlhs_grad_disable_channelwise_axes - y = rhs_in if config.use_original_residuals else rhs + fwd_quantized = config.rhs_qtype is not None + bwd_quantized = config.dlhs_grad_qtype is not None + use_quantized_residual = fwd_quantized and bwd_quantized + y = rhs if use_quantized_residual else rhs_in y_qtype = config.dlhs_residual_qtype y_calibration_method = config.dlhs_residual_calibration_method y_disable_channelwise_axes = config.dlhs_residual_disable_channelwise_axes @@ -211,7 +209,10 @@ def _compute_gradient_for_operand(g: jax.Array, *, for_dlhs: bool): g_calibration_method = config.drhs_grad_calibration_method g_noise_fn = config.drhs_stochastic_rounding_noise_fn g_disable_channelwise_axes = config.drhs_grad_disable_channelwise_axes - y = lhs_in if config.use_original_residuals else lhs + fwd_quantized = config.lhs_qtype is not None + bwd_quantized = config.drhs_grad_qtype is not None + use_quantized_residual = fwd_quantized and bwd_quantized + y = lhs if use_quantized_residual else lhs_in y_qtype = config.drhs_residual_qtype y_calibration_method = config.drhs_residual_calibration_method y_disable_channelwise_axes = config.drhs_residual_disable_channelwise_axes diff --git a/tests/_src/core/dot_general_qt_test.py b/tests/_src/core/dot_general_qt_test.py index 6228cea..e42f1f7 100644 --- a/tests/_src/core/dot_general_qt_test.py +++ b/tests/_src/core/dot_general_qt_test.py @@ -76,20 +76,40 @@ class DotGeneralQtTest(parameterized.TestCase): testcase_name='int8', lhs_qtype='int8', rhs_qtype='int8', + bwd_qtype='int8', # If we set precision=HIGHEST in fq implementation above, then # expected_mae_fq_out will become 1e-7 but expected_mae_fq_grads will # be non-zero. expected_mae_fq_out=0.02, - expected_mae_fq_grads=0.0, + expected_mae_fq_grads=0.01, expected_mae_fp_out=0.06, - expected_mae_fp_grads=0.02, + expected_mae_fp_grads=0.03, ), dict( testcase_name='int4', lhs_qtype='int4', rhs_qtype='int4', + bwd_qtype='int4', + expected_mae_fq_out=0.04, + expected_mae_fq_grads=0.2, + expected_mae_fp_out=0.5, + expected_mae_fp_grads=0.5, + ), + dict( + testcase_name='int8_fwd', + lhs_qtype='int8', + rhs_qtype='int8', + expected_mae_fq_out=0.02, + expected_mae_fq_grads=0.03, + expected_mae_fp_out=0.06, + expected_mae_fp_grads=0.02, + ), + dict( + testcase_name='int4_fwd', + lhs_qtype='int4', + rhs_qtype='int4', expected_mae_fq_out=0.04, - expected_mae_fq_grads=0.0, + expected_mae_fq_grads=0.5, expected_mae_fp_out=0.5, expected_mae_fp_grads=0.5, ),