diff --git a/pyproject.toml b/pyproject.toml index ae2f2c8fc..e8921a727 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ "flax==0.10.7", - "optax==0.2.2", + "optax==0.2.3", "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5",