Skip to content

Conversation

@polvalente
Copy link
Contributor

This PR adds a very minimal (and sub-optimal wrt to input slicing) sharding implementation for EXLA based on Shardy.

The goal is for us to discuss whether we want these as Nx callbacks, or if we want to add a way for EXLA to declare its own defn symbols for all_reduce/all_gather/all_to_all and related things.

My biggest concern with exposing this to Nx core is that up until now, Nx core doesn't have the concept of devices and XLA's sharding (PyTorch's as well, to what me and @Chapaman researched) is very coupled to devices.

We could very well get away with EXLA providing deftransforms that introduce :metadata Defn.Expr nodes annotating things for the EXLA.Defn to turn into EXLA.MLIR.Value calls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants