Add MPS (Apple Silicon) support for RFD3, RF3, and ProteinMPNN#260
Add MPS (Apple Silicon) support for RFD3, RF3, and ProteinMPNN#260Ubiquinone-dot merged 1 commit intoproductionfrom
Conversation
Adds support for running inference on Apple Silicon MPS devices. Key changes: - Handle bfloat16 -> float32 fallback on MPS (bfloat16 unsupported) - Add scatter_mean utility with MPS fallback (index_reduce unsupported) - Guard masked_scatter_/boolean indexing in block_utils with MPS paths - Use expand (zero-copy) on CUDA/CPU, repeat (contiguous) on MPS for torch.where compatibility - Add .contiguous() calls for scatter/gather ops that require it on MPS - Replace hardcoded "cuda" in autocast dtype queries with device_of() - Store torch.linalg.det result before torch.sign to avoid MPS in-place op issues on autograd graph leaves - Auto-detect MPS accelerator and enforce float32 precision - Add MPS installation instructions to README Based on work by @fnachon in PR #257 with additional fixes: - expand/repeat if-else to avoid O(L^2) allocation on non-MPS backends - ruff 0.8.3 formatting to pass CI lint checks Co-Authored-By: fnachon Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds Apple Silicon (MPS) support for inference across RFD3, RF3, and ProteinMPNN by introducing MPS-safe tensor ops/paths and auto-selecting the MPS accelerator with float32 precision where applicable.
Changes:
- Add
scatter_meanutility to replaceindex_reduce(..., reduce="mean")where MPS lacks support, and wire it into RFD3/RF3 embedding code. - Add MPS-aware precision/autocast and contiguity workarounds (avoid bf16 on MPS, fix non-contiguous tensors for scatter/where, etc.).
- Expand device selection/docs/tests utilities to recognize MPS (accelerator auto-detection, fixtures, README, debug tooling, MPNN inference engine).
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/foundry/utils/torch.py | Adds scatter_mean with MPS fallback path. |
| src/foundry/utils/ddp.py | Auto-detects MPS and sets Fabric accelerator/precision. |
| src/foundry/utils/alignment.py | Refactors determinant/sign computation for rigid alignment. |
| src/foundry/testing/fixtures.py | Updates GPU fixture to treat MPS as an accelerator. |
| README.md | Adds macOS (Apple Silicon) installation + MPS notes. |
| models/rfd3/src/rfd3/testing/debug.py | Selects CUDA/MPS/CPU device for debug forward pass. |
| models/rfd3/src/rfd3/model/RFD3_diffusion_module.py | Removes leftover debug/comment-only lines. |
| models/rfd3/src/rfd3/model/layers/pairformer_layers.py | Avoids bf16/autocast on MPS; adjusts scaling tensor dtype. |
| models/rfd3/src/rfd3/model/layers/blocks.py | Replaces index_reduce token pooling with scatter_mean. |
| models/rfd3/src/rfd3/model/layers/block_utils.py | Adds MPS-safe scatter/gather replacements and contiguity fixes. |
| models/rfd3/src/rfd3/inference/symmetry/frames.py | Refactors determinant/sign computation for frame construction. |
| models/rf3/src/rf3/model/RF3.py | Uses module device type for autocast dtype selection. |
| models/rf3/src/rf3/model/layers/pairformer_layers.py | Replaces index_reduce token pooling with scatter_mean; avoids bf16 on MPS. |
| models/rf3/src/rf3/model/layers/attention.py | Uses tensor device type for autocast dtype selection. |
| models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py | Uses scatter_mean and imports it alongside device_of. |
| models/mpnn/src/mpnn/model/mpnn.py | Extends AMP dtype handling to include MPS device type. |
| models/mpnn/src/mpnn/inference_engines/mpnn.py | Adds MPS device selection and MPS seeding path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Expand 1D index (N,) to match source shape (..., N, C) | ||
| shape = [1] * ndim | ||
| shape[dim] = index.shape[0] | ||
| idx = index.view(shape).expand_as(source) | ||
|
|
||
| # Sum source values into output positions | ||
| result = zeros.scatter_add(dim, idx, source) | ||
|
|
||
| # Count how many source values land in each output position. | ||
| # Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor. | ||
| idx_count = idx[..., :1] # (..., N, 1) |
There was a problem hiding this comment.
On the MPS fallback path, idx (and idx_count) are created via .expand_as(...) / slicing, which produces non-contiguous tensors. Elsewhere in this PR you explicitly call .contiguous() before scatter_add to avoid MPS issues; doing the same here would make the MPS fallback more reliable (or use .repeat/clone for indices on MPS).
| # Expand 1D index (N,) to match source shape (..., N, C) | |
| shape = [1] * ndim | |
| shape[dim] = index.shape[0] | |
| idx = index.view(shape).expand_as(source) | |
| # Sum source values into output positions | |
| result = zeros.scatter_add(dim, idx, source) | |
| # Count how many source values land in each output position. | |
| # Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor. | |
| idx_count = idx[..., :1] # (..., N, 1) | |
| # Expand 1D index (N,) to match source shape (..., N, C). | |
| # Materialize as contiguous for MPS scatter_add reliability. | |
| shape = [1] * ndim | |
| shape[dim] = index.shape[0] | |
| idx = index.view(shape).expand_as(source).contiguous() | |
| # Sum source values into output positions | |
| result = zeros.scatter_add(dim, idx, source) | |
| # Count how many source values land in each output position. | |
| # Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor. | |
| idx_count = idx[..., :1].contiguous() # (..., N, 1) |
| def scatter_mean(zeros: Tensor, dim: int, index: Tensor, source: Tensor) -> Tensor: | ||
| """Scatter-mean aggregation, with an MPS-compatible fallback. | ||
|
|
||
| On non-MPS devices uses index_reduce (faster, in-place kernel). | ||
| On MPS, index_reduce is not implemented so falls back to scatter_add + count. | ||
|
|
||
| Equivalent to: zeros.index_reduce(dim, index, source, 'mean', include_self=False) | ||
|
|
||
| Args: | ||
| zeros: Pre-allocated zero tensor, shape (..., I, C). Will not be modified in-place. | ||
| dim: Dimension to scatter along. Must not be the last dimension. | ||
| index: 1D index tensor of shape (N,) mapping source positions to output positions. | ||
| source: Source tensor where size at `dim` equals N. | ||
|
|
||
| Returns: | ||
| Tensor of same shape as zeros. | ||
| """ | ||
| if zeros.device.type != "mps": | ||
| return zeros.index_reduce(dim, index, source, "mean", include_self=False) | ||
|
|
There was a problem hiding this comment.
scatter_mean is a new public utility but there are existing unit tests for foundry.utils.torch (e.g. tests/test_torch_utils.py). Adding coverage that validates scatter_mean matches index_reduce(..., reduce='mean') on CPU (and that it doesn’t mutate the zeros input) would help prevent regressions.
| Q_IH = Q_IH / torch.sqrt( | ||
| torch.tensor(self.c).to(Q_IH.device, torch.bfloat16) | ||
| ) | ||
| Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)) |
There was a problem hiding this comment.
This scaling computes torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) every forward pass, which allocates and transfers a tensor to the device each time. Since self.c is a Python scalar, consider using a Python/math scalar (or a cached buffer created once in __init__) to avoid per-step allocation and device sync overhead.
| Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)) | |
| Q_IH = Q_IH / (self.c ** 0.5) |
| if not self.use_deepspeed_evo or L <= 24: | ||
| Q_IH = Q_IH / torch.sqrt( | ||
| torch.tensor(self.c).to(Q_IH.device, torch.bfloat16) | ||
| ) | ||
| Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)) | ||
| # Attention |
There was a problem hiding this comment.
This scaling computes torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) every forward pass, which allocates and transfers a tensor to the device each time. Since self.c is a Python scalar, consider using a Python/math scalar (or a cached buffer created once in __init__) to avoid per-step allocation and device sync overhead.
| **macOS (Apple Silicon) Installation** | ||
|
|
||
| MPS support is available via a community fork. Install PyTorch first, then install directly from the fork: | ||
| ```bash | ||
| pip install torch | ||
| pip install "rc-foundry[all] @ git+https://github.com/fnachon/foundry.git" | ||
| ``` |
There was a problem hiding this comment.
The README currently states that macOS/MPS support is only available “via a community fork” and instructs installing rc-foundry from fnachon/foundry.git. Since this PR adds MPS support directly, these instructions will become misleading after merge; consider updating this section to point to the official package/repo (or explicitly mark the fork install as a temporary pre-release option).
| if not torch.cuda.is_available(): | ||
| """Fixture to check GPU availability for tests that require CUDA or MPS.""" | ||
| if not torch.cuda.is_available() and not torch.backends.mps.is_available(): | ||
| pytest.skip("GPU not available") |
There was a problem hiding this comment.
This fixture now skips when neither CUDA nor MPS is available, but the skip reason still says “GPU not available”. Consider updating the message to mention CUDA/MPS (or “accelerator”) to make test skips clearer on macOS runners.
| pytest.skip("GPU not available") | |
| pytest.skip("CUDA or MPS not available") |
Summary
index_reduce,masked_scatter_), and non-contiguous tensor quirksscatter_meanutility with MPS-compatible fallback pathBased on @fnachon's work in #257 with the following additional fixes:
block_utils.pyline 431: Added if-else guard so CUDA/CPU keep.expand()(zero-copy) while MPS uses.repeat()(contiguous), avoiding an unnecessary O(L²) allocation on non-MPS backendsSupersedes #257.
Test plan
.contiguous()is a no-op on contiguous tensors, expand path preserved)🤖 Generated with Claude Code