Update JAX API usage to latest version#1317
Conversation
GitOrigin-RevId: d143490
Make sure sink's contribution is added once. Also added tests. GitOrigin-RevId: 8de870c
GitOrigin-RevId: 56cf7e8
* pin nccl version * empty commit * add actual pacakge * trigger new build to address flaky test * Update pyproject.toml GitOrigin-RevId: b8653ad
…ункции flatten/unflatten в utils
|
Hi, the workflows need approval to run (GitHub Actions are pending). Can someone with write access approve and run them? @ruomingp pls |
GitOrigin-RevId: 65af801
…рать лишний пробел в _enable_numeric_checks
…мый код в dataclass
There was a problem hiding this comment.
Some of these changes have the potential to break things and don't seem to be necessary, as @changlan mentioned. Could you explain for every change, why it is necessary? Also, please do not mark comments as resolved yourself. To streamline reviewing, we only have PR reviewers mark comments as resolved.
|
Also please resolve any merge conflicts. |
|
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the |
|
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the |
|
This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue. |
We need to update the JAX API usage across the codebase to use the latest stable versions.
Changes Required
jax.tree_utilwithjax.tree:register_pytree_with_keysinstead ofregister_pytree_nodetree_mapwith new API versionFiles to Modify
Key files that need updates:
axlearn/common/struct.pyaxlearn/common/utils.pyaxlearn/common/metrics.pyaxlearn/common/learner.pyImplementation Details
Success Criteria