Added support for MPS (Apple Silicon)#257
Open
fnachon wants to merge 5 commits intoRosettaCommons:productionfrom
Open
Added support for MPS (Apple Silicon)#257fnachon wants to merge 5 commits intoRosettaCommons:productionfrom
fnachon wants to merge 5 commits intoRosettaCommons:productionfrom
Conversation
bfloat16 is not supported yet on MPS. It is a better option to avoid it.
Fixed some code issues related to MPS
Collaborator
Ubiquinone-dot
left a comment
There was a problem hiding this comment.
generally close to being ready, just a few comments for clarity
| # Handle case when indices and P_LA don't have batch dimensions | ||
| B, L, k = P_LK_indices.shape | ||
| if P_LA_indices.ndim == 2: | ||
| P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1) |
Collaborator
There was a problem hiding this comment.
are these contiguous calls necessary?
| all_idx_row = torch.arange(L, device=device).expand(L, L) | ||
| indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced | ||
| # Use .repeat() instead of .expand() to produce a contiguous tensor — MPS does | ||
| # not handle non-contiguous inputs to torch.where correctly. |
Collaborator
There was a problem hiding this comment.
make this an if mps else do the old version; old version is faster afaik
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
A few code fixes to make RFD3, RF3, and MNPNN run on MacOS with M1-M5 chips.
Mostly handling of Bfloat16 -> Float32 for mps.