-
Notifications
You must be signed in to change notification settings - Fork 133
Add MPS (Apple Silicon) support for RFD3, RF3, and ProteinMPNN #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
|
|
||
| from foundry.model.layers.blocks import Dropout | ||
| from foundry.training.checkpoint import activation_checkpointing | ||
| from foundry.utils.torch import scatter_mean | ||
|
|
||
|
|
||
| class AtomAttentionEncoderPairformer(nn.Module): | ||
|
|
@@ -198,17 +199,14 @@ def embed_features(C_L, D_LL, V_LL): | |
| # Ensure dtype consistency for index_reduce | ||
| processed_Q_L = processed_Q_L.to(Q_L.dtype) | ||
|
|
||
| A_I = torch.zeros( | ||
| A_I_shape, device=Q_L.device, dtype=Q_L.dtype | ||
| ).index_reduce( | ||
| -2, # Operate on the second-to-last dimension (the atom dimension) | ||
| A_I = scatter_mean( | ||
| torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype), | ||
| -2, | ||
| f[ | ||
| "atom_to_token_map" | ||
| ].long(), # [L], mapping from atom index to token index. Must be a torch.int64 or torch.int32 tensor. | ||
| processed_Q_L, # [L, C_atom] -> [L, C_token] | ||
| "mean", | ||
| include_self=False, # Do not use the original values in A_I (all zeros) when aggregating | ||
| ) # [L, C_atom] -> [I, C_token] | ||
| ].long(), # [L], mapping from atom index to token index | ||
| processed_Q_L, # (..., L, C_token) | ||
| ) # (..., I, C_token) | ||
|
|
||
| return A_I, Q_L, C_L, P_LL | ||
|
|
||
|
|
@@ -253,7 +251,7 @@ def forward( | |
| assert S_I is None | ||
| A_I = self.ln_1(A_I) | ||
|
|
||
| if self.use_deepspeed_evo or self.force_bfloat16: | ||
| if (self.use_deepspeed_evo or self.force_bfloat16) and A_I.device.type != "mps": | ||
| A_I = A_I.to(torch.bfloat16) | ||
|
|
||
| Q_IH = self.to_q(A_I) # / np.sqrt(self.c) | ||
|
|
@@ -265,9 +263,7 @@ def forward( | |
| B, L = B_IIH.shape[:2] | ||
|
|
||
| if not self.use_deepspeed_evo or L <= 24: | ||
| Q_IH = Q_IH / torch.sqrt( | ||
| torch.tensor(self.c).to(Q_IH.device, torch.bfloat16) | ||
| ) | ||
| Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)) | ||
| # Attention | ||
|
Comment on lines
265
to
267
|
||
| A_IIH = torch.softmax( | ||
| torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The README currently states that macOS/MPS support is only available “via a community fork” and instructs installing
rc-foundryfromfnachon/foundry.git. Since this PR adds MPS support directly, these instructions will become misleading after merge; consider updating this section to point to the official package/repo (or explicitly mark the fork install as a temporary pre-release option).