Skip to content

Add MPS (Apple Silicon) support for RFD3, RF3, and ProteinMPNN#260

Merged
Ubiquinone-dot merged 1 commit intoproductionfrom
jbutch/mps-support
Apr 4, 2026
Merged

Add MPS (Apple Silicon) support for RFD3, RF3, and ProteinMPNN#260
Ubiquinone-dot merged 1 commit intoproductionfrom
jbutch/mps-support

Conversation

@Ubiquinone-dot
Copy link
Copy Markdown
Collaborator

Summary

  • Adds MPS (Apple Silicon) support for inference across all three models (RFD3, RF3, ProteinMPNN)
  • Handles bfloat16 → float32 fallback, MPS-incompatible ops (index_reduce, masked_scatter_), and non-contiguous tensor quirks
  • Auto-detects MPS accelerator and enforces float32 precision
  • Adds scatter_mean utility with MPS-compatible fallback path

Based on @fnachon's work in #257 with the following additional fixes:

  • block_utils.py line 431: Added if-else guard so CUDA/CPU keep .expand() (zero-copy) while MPS uses .repeat() (contiguous), avoiding an unnecessary O(L²) allocation on non-MPS backends
  • ruff 0.8.3 formatting: Reformatted files to match the repo's pinned ruff version so CI lint checks pass

Supersedes #257.

Test plan

  • Verify CUDA/CPU inference is unchanged (no behavioral diff — .contiguous() is a no-op on contiguous tensors, expand path preserved)
  • Run RFD3 inference on MPS device
  • Run RF3 inference on MPS device
  • Run ProteinMPNN inference on MPS device
  • Verify CI lint passes (ruff 0.8.3 format)

🤖 Generated with Claude Code

Adds support for running inference on Apple Silicon MPS devices.

Key changes:
- Handle bfloat16 -> float32 fallback on MPS (bfloat16 unsupported)
- Add scatter_mean utility with MPS fallback (index_reduce unsupported)
- Guard masked_scatter_/boolean indexing in block_utils with MPS paths
- Use expand (zero-copy) on CUDA/CPU, repeat (contiguous) on MPS for
  torch.where compatibility
- Add .contiguous() calls for scatter/gather ops that require it on MPS
- Replace hardcoded "cuda" in autocast dtype queries with device_of()
- Store torch.linalg.det result before torch.sign to avoid MPS in-place
  op issues on autograd graph leaves
- Auto-detect MPS accelerator and enforce float32 precision
- Add MPS installation instructions to README

Based on work by @fnachon in PR #257 with additional fixes:
- expand/repeat if-else to avoid O(L^2) allocation on non-MPS backends
- ruff 0.8.3 formatting to pass CI lint checks

Co-Authored-By: fnachon
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 4, 2026 23:32
@Ubiquinone-dot Ubiquinone-dot enabled auto-merge (squash) April 4, 2026 23:36
@Ubiquinone-dot Ubiquinone-dot disabled auto-merge April 4, 2026 23:36
@Ubiquinone-dot Ubiquinone-dot merged commit 9fa9ba2 into production Apr 4, 2026
6 checks passed
@Ubiquinone-dot Ubiquinone-dot deleted the jbutch/mps-support branch April 4, 2026 23:36
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Apple Silicon (MPS) support for inference across RFD3, RF3, and ProteinMPNN by introducing MPS-safe tensor ops/paths and auto-selecting the MPS accelerator with float32 precision where applicable.

Changes:

  • Add scatter_mean utility to replace index_reduce(..., reduce="mean") where MPS lacks support, and wire it into RFD3/RF3 embedding code.
  • Add MPS-aware precision/autocast and contiguity workarounds (avoid bf16 on MPS, fix non-contiguous tensors for scatter/where, etc.).
  • Expand device selection/docs/tests utilities to recognize MPS (accelerator auto-detection, fixtures, README, debug tooling, MPNN inference engine).

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
src/foundry/utils/torch.py Adds scatter_mean with MPS fallback path.
src/foundry/utils/ddp.py Auto-detects MPS and sets Fabric accelerator/precision.
src/foundry/utils/alignment.py Refactors determinant/sign computation for rigid alignment.
src/foundry/testing/fixtures.py Updates GPU fixture to treat MPS as an accelerator.
README.md Adds macOS (Apple Silicon) installation + MPS notes.
models/rfd3/src/rfd3/testing/debug.py Selects CUDA/MPS/CPU device for debug forward pass.
models/rfd3/src/rfd3/model/RFD3_diffusion_module.py Removes leftover debug/comment-only lines.
models/rfd3/src/rfd3/model/layers/pairformer_layers.py Avoids bf16/autocast on MPS; adjusts scaling tensor dtype.
models/rfd3/src/rfd3/model/layers/blocks.py Replaces index_reduce token pooling with scatter_mean.
models/rfd3/src/rfd3/model/layers/block_utils.py Adds MPS-safe scatter/gather replacements and contiguity fixes.
models/rfd3/src/rfd3/inference/symmetry/frames.py Refactors determinant/sign computation for frame construction.
models/rf3/src/rf3/model/RF3.py Uses module device type for autocast dtype selection.
models/rf3/src/rf3/model/layers/pairformer_layers.py Replaces index_reduce token pooling with scatter_mean; avoids bf16 on MPS.
models/rf3/src/rf3/model/layers/attention.py Uses tensor device type for autocast dtype selection.
models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py Uses scatter_mean and imports it alongside device_of.
models/mpnn/src/mpnn/model/mpnn.py Extends AMP dtype handling to include MPS device type.
models/mpnn/src/mpnn/inference_engines/mpnn.py Adds MPS device selection and MPS seeding path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +231 to +241
# Expand 1D index (N,) to match source shape (..., N, C)
shape = [1] * ndim
shape[dim] = index.shape[0]
idx = index.view(shape).expand_as(source)

# Sum source values into output positions
result = zeros.scatter_add(dim, idx, source)

# Count how many source values land in each output position.
# Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor.
idx_count = idx[..., :1] # (..., N, 1)
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the MPS fallback path, idx (and idx_count) are created via .expand_as(...) / slicing, which produces non-contiguous tensors. Elsewhere in this PR you explicitly call .contiguous() before scatter_add to avoid MPS issues; doing the same here would make the MPS fallback more reliable (or use .repeat/clone for indices on MPS).

Suggested change
# Expand 1D index (N,) to match source shape (..., N, C)
shape = [1] * ndim
shape[dim] = index.shape[0]
idx = index.view(shape).expand_as(source)
# Sum source values into output positions
result = zeros.scatter_add(dim, idx, source)
# Count how many source values land in each output position.
# Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor.
idx_count = idx[..., :1] # (..., N, 1)
# Expand 1D index (N,) to match source shape (..., N, C).
# Materialize as contiguous for MPS scatter_add reliability.
shape = [1] * ndim
shape[dim] = index.shape[0]
idx = index.view(shape).expand_as(source).contiguous()
# Sum source values into output positions
result = zeros.scatter_add(dim, idx, source)
# Count how many source values land in each output position.
# Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor.
idx_count = idx[..., :1].contiguous() # (..., N, 1)

Copilot uses AI. Check for mistakes.
Comment on lines +207 to +226
def scatter_mean(zeros: Tensor, dim: int, index: Tensor, source: Tensor) -> Tensor:
"""Scatter-mean aggregation, with an MPS-compatible fallback.

On non-MPS devices uses index_reduce (faster, in-place kernel).
On MPS, index_reduce is not implemented so falls back to scatter_add + count.

Equivalent to: zeros.index_reduce(dim, index, source, 'mean', include_self=False)

Args:
zeros: Pre-allocated zero tensor, shape (..., I, C). Will not be modified in-place.
dim: Dimension to scatter along. Must not be the last dimension.
index: 1D index tensor of shape (N,) mapping source positions to output positions.
source: Source tensor where size at `dim` equals N.

Returns:
Tensor of same shape as zeros.
"""
if zeros.device.type != "mps":
return zeros.index_reduce(dim, index, source, "mean", include_self=False)

Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scatter_mean is a new public utility but there are existing unit tests for foundry.utils.torch (e.g. tests/test_torch_utils.py). Adding coverage that validates scatter_mean matches index_reduce(..., reduce='mean') on CPU (and that it doesn’t mutate the zeros input) would help prevent regressions.

Copilot uses AI. Check for mistakes.
Q_IH = Q_IH / torch.sqrt(
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
)
Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype))
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scaling computes torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) every forward pass, which allocates and transfers a tensor to the device each time. Since self.c is a Python scalar, consider using a Python/math scalar (or a cached buffer created once in __init__) to avoid per-step allocation and device sync overhead.

Suggested change
Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype))
Q_IH = Q_IH / (self.c ** 0.5)

Copilot uses AI. Check for mistakes.
Comment on lines 265 to 267
if not self.use_deepspeed_evo or L <= 24:
Q_IH = Q_IH / torch.sqrt(
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
)
Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype))
# Attention
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scaling computes torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) every forward pass, which allocates and transfers a tensor to the device each time. Since self.c is a Python scalar, consider using a Python/math scalar (or a cached buffer created once in __init__) to avoid per-step allocation and device sync overhead.

Copilot uses AI. Check for mistakes.
Comment on lines +28 to +34
**macOS (Apple Silicon) Installation**

MPS support is available via a community fork. Install PyTorch first, then install directly from the fork:
```bash
pip install torch
pip install "rc-foundry[all] @ git+https://github.com/fnachon/foundry.git"
```
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README currently states that macOS/MPS support is only available “via a community fork” and instructs installing rc-foundry from fnachon/foundry.git. Since this PR adds MPS support directly, these instructions will become misleading after merge; consider updating this section to point to the official package/repo (or explicitly mark the fork install as a temporary pre-release option).

Copilot uses AI. Check for mistakes.
if not torch.cuda.is_available():
"""Fixture to check GPU availability for tests that require CUDA or MPS."""
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
pytest.skip("GPU not available")
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixture now skips when neither CUDA nor MPS is available, but the skip reason still says “GPU not available”. Consider updating the message to mention CUDA/MPS (or “accelerator”) to make test skips clearer on macOS runners.

Suggested change
pytest.skip("GPU not available")
pytest.skip("CUDA or MPS not available")

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants