Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a very minimal (and sub-optimal wrt to input slicing) sharding implementation for EXLA based on Shardy.
The goal is for us to discuss whether we want these as Nx callbacks, or if we want to add a way for EXLA to declare its own defn symbols for all_reduce/all_gather/all_to_all and related things.
My biggest concern with exposing this to Nx core is that up until now, Nx core doesn't have the concept of devices and XLA's sharding (PyTorch's as well, to what me and @Chapaman researched) is very coupled to devices.
We could very well get away with EXLA providing deftransforms that introduce :metadata Defn.Expr nodes annotating things for the EXLA.Defn to turn into EXLA.MLIR.Value calls.