Skip to content

Added support for MPS (Apple Silicon)#257

Open
fnachon wants to merge 5 commits intoRosettaCommons:productionfrom
fnachon:production
Open

Added support for MPS (Apple Silicon)#257
fnachon wants to merge 5 commits intoRosettaCommons:productionfrom
fnachon:production

Conversation

@fnachon
Copy link
Copy Markdown

@fnachon fnachon commented Apr 1, 2026

A few code fixes to make RFD3, RF3, and MNPNN run on MacOS with M1-M5 chips.
Mostly handling of Bfloat16 -> Float32 for mps.

fnachon added 5 commits March 31, 2026 16:52
bfloat16 is not supported yet on MPS. It is a better option to avoid it.
Fixed some code issues related to MPS
Copy link
Copy Markdown
Collaborator

@Ubiquinone-dot Ubiquinone-dot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally close to being ready, just a few comments for clarity

# Handle case when indices and P_LA don't have batch dimensions
B, L, k = P_LK_indices.shape
if P_LA_indices.ndim == 2:
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these contiguous calls necessary?

all_idx_row = torch.arange(L, device=device).expand(L, L)
indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced
# Use .repeat() instead of .expand() to produce a contiguous tensor — MPS does
# not handle non-contiguous inputs to torch.where correctly.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this an if mps else do the old version; old version is faster afaik

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants