From 6ad3ecba845cc631e1a08018f6ba6867700bfb91 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:46:22 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886964 --- aqt/jax/v2/aqt_dot_general.py | 6 +++--- aqt/jax/v2/aqt_dot_general_test.py | 2 +- aqt/jax_legacy/jax/README.md | 6 +++--- aqt/jax_legacy/jax/quantization.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 404991db..2a4ab291 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -899,11 +899,11 @@ def _maybe_dequant( if jax.local_devices()[0].platform == 'cpu': # needed bet lax.dot_general(int4, int4) is illegal on cpu. # TODO(aqt): Remove this platform check once - # https://github.com/google/jax/issues/19682 is fixed. + # https://github.com/jax-ml/jax/issues/19682 is fixed. # TODO(yichizh): It's better to assert False here with the following msg # msg = ( # 'lax.dot_general(int4, int4) is illegal on cpu:' - # ' https://github.com/google/jax/issues/19682. The simple workaround' + # ' https://github.com/jax-ml/jax/issues/19682. The simple workaround' # ' is to upcast to int8, but in that case please directly set the' # ' numerics bits to int8. Please contact the AQT team if you believe' # ' the workaround is needed.' @@ -920,7 +920,7 @@ def _maybe_dequant( precision=lax.Precision.DEFAULT, ) # TODO(lew): Do we have a correct precision above? - # Relevant: https://github.com/google/jax/issues/14022 + # Relevant: https://github.com/jax-ml/jax/issues/14022 out = aqt_tensor.QTensor( qvalue=out, scale=[], diff --git a/aqt/jax/v2/aqt_dot_general_test.py b/aqt/jax/v2/aqt_dot_general_test.py index e79b260c..3461256f 100644 --- a/aqt/jax/v2/aqt_dot_general_test.py +++ b/aqt/jax/v2/aqt_dot_general_test.py @@ -487,7 +487,7 @@ def test_fake_quant( @parameterized.parameters([ dict( # TODO(aqt): Change dlhs_bits to 4bit once - # https://github.com/google/jax/issues/19682 is fixed. + # https://github.com/jax-ml/jax/issues/19682 is fixed. dg=config.config_v3( fwd_bits=3, dlhs_bits=6, diff --git a/aqt/jax_legacy/jax/README.md b/aqt/jax_legacy/jax/README.md index e33f2b11..7eaf0b71 100644 --- a/aqt/jax_legacy/jax/README.md +++ b/aqt/jax_legacy/jax/README.md @@ -7,10 +7,10 @@ quantization for convolution and matmul. ### Jax Libraries -- **quantization.quantized_dot**: [LAX.dot](https://github.com/google/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L632) with optionally quantized weights and activations. -- **quantization.quantized_dynamic_dot_general**: [LAX.dot general](https://github.com/google/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L667) with optionally quantized dynamic inputs. +- **quantization.quantized_dot**: [LAX.dot](https://github.com/jax-ml/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L632) with optionally quantized weights and activations. +- **quantization.quantized_dynamic_dot_general**: [LAX.dot general](https://github.com/jax-ml/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L667) with optionally quantized dynamic inputs. - **quantization.quantized_sum**: Sums a tensor while quantizing intermediate accumulations. -- **quantization.dot_general_aqt**: Adds quantization to [LAX.dot_general](https://github.com/google/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L667) with option to use integer dot. +- **quantization.dot_general_aqt**: Adds quantization to [LAX.dot_general](https://github.com/jax-ml/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L667) with option to use integer dot. ### Flax Libraries diff --git a/aqt/jax_legacy/jax/quantization.py b/aqt/jax_legacy/jax/quantization.py index 366404f0..1cc9852c 100644 --- a/aqt/jax_legacy/jax/quantization.py +++ b/aqt/jax_legacy/jax/quantization.py @@ -703,7 +703,7 @@ def quantized_dot( """LAX dot with optionally quantized weights and activations. Wraps LAX's `Dot -