diff --git a/README.md b/README.md index 7f4decc7..f5d7a44e 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,22 @@ pip install "rc-foundry[all]" > [!NOTE] > Use `pip` (not `uv`) for XPU installs since UV re-resolves dependencies and may replace your XPU torch with the standard PyPI version. +**macOS (Apple Silicon) Installation** + +MPS support is available via a community fork. Install PyTorch first, then install directly from the fork: +```bash +pip install torch +pip install "rc-foundry[all] @ git+https://github.com/fnachon/foundry.git" +``` + +All three models — **RFD3**, **RF3**, and **ProteinMPNN/LigandMPNN** — run on Apple Silicon MPS. + +> [!NOTE] +> - The `rf3` extra (cuEquivariance) is Linux-only and is automatically skipped on macOS. +> - Use `float32` precision — `bfloat16` is not supported on MPS. The MPS accelerator is selected and float32 precision is enforced automatically. +> - Inference only; multi-GPU training is not supported on MPS. +> - For `rf3 fold`, pass an absolute path to the input CIF file. + **Downloading weights** Models can be downloaded to a target folder with: ``` foundry install base-models --checkpoint-dir diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index 8c06b35e..9f53c1d5 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -74,6 +74,8 @@ def __init__( self.device = torch.device("cuda") elif hasattr(torch, "xpu") and torch.xpu.is_available(): self.device = torch.device("xpu") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") else: self.device = torch.device("cpu") @@ -258,6 +260,8 @@ def run( np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + elif torch.backends.mps.is_available(): + torch.mps.manual_seed(seed) # Run the batches for this input. for batch_idx in range(inference_input.input_dict["number_of_batches"]): diff --git a/models/mpnn/src/mpnn/model/mpnn.py b/models/mpnn/src/mpnn/model/mpnn.py index 75e84847..19bd0632 100644 --- a/models/mpnn/src/mpnn/model/mpnn.py +++ b/models/mpnn/src/mpnn/model/mpnn.py @@ -1364,7 +1364,7 @@ def decode_auto_regressive( # precision settings. This works because the W_out layer is a linear # layer, which has predictable dtype behavior with AMP. device = input_features["residue_mask"].device - if device.type in ("cuda", "cpu") and torch.is_autocast_enabled( + if device.type in ("cuda", "cpu", "mps") and torch.is_autocast_enabled( device_type=device.type ): output_dtype = torch.get_autocast_dtype(device_type=device.type) diff --git a/models/rf3/src/rf3/model/RF3.py b/models/rf3/src/rf3/model/RF3.py index 28c6d9fa..2fa4281b 100644 --- a/models/rf3/src/rf3/model/RF3.py +++ b/models/rf3/src/rf3/model/RF3.py @@ -16,6 +16,7 @@ from torch import nn from foundry.training.checkpoint import create_custom_forward +from foundry.utils.torch import device_of """ Shape Annotation Glossary: @@ -148,7 +149,7 @@ def forward( """ # Cast features to lower precision if autocast is enabled if torch.is_autocast_enabled(): - autocast_dtype = torch.get_autocast_dtype("cuda") + autocast_dtype = torch.get_autocast_dtype(device_of(self).type) for x in [ "msa_stack", "profile", @@ -382,7 +383,7 @@ def forward( """ # Cast features to lower precision if autocast is enabled if torch.is_autocast_enabled(): - autocast_dtype = torch.get_autocast_dtype("cuda") + autocast_dtype = torch.get_autocast_dtype(device_of(self).type) for x in [ "msa_stack", "profile", diff --git a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py index 11998428..abe0c37a 100644 --- a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py +++ b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py @@ -12,7 +12,7 @@ from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage from foundry.training.checkpoint import activation_checkpointing -from foundry.utils.torch import device_of +from foundry.utils.torch import device_of, scatter_mean class AtomAttentionEncoderDiffusion(nn.Module): @@ -241,16 +241,11 @@ def embed_atom_feats(R_L, C_L, D_LL, V_LL, P_LL, tok_idx): # Ensure dtype consistency for index_reduce processed_Q_L = processed_Q_L.to(Q_L.dtype) - A_I = ( - torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype) - .index_reduce( - -2, - f["atom_to_token_map"].long(), - processed_Q_L, - "mean", - include_self=False, - ) - .clone() + A_I = scatter_mean( + torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype), + -2, + f["atom_to_token_map"].long(), + processed_Q_L, ) return A_I, Q_L, C_L, P_LL @@ -427,7 +422,7 @@ def forward( # zero out layer norms for the key and query return self.atom_attention(A_I, S_I, Z_II) - if self.use_deepspeed_evo or self.force_bfloat16: + if (self.use_deepspeed_evo or self.force_bfloat16) and A_I.device.type != "mps": A_I = A_I.to(torch.bfloat16) assert len(A_I.shape) == 3, f"(Diffusion batch, I, C_a) but got {A_I.shape}" diff --git a/models/rf3/src/rf3/model/layers/attention.py b/models/rf3/src/rf3/model/layers/attention.py index 2957012c..1fccf190 100644 --- a/models/rf3/src/rf3/model/layers/attention.py +++ b/models/rf3/src/rf3/model/layers/attention.py @@ -101,7 +101,7 @@ def _forward_cuequivariance(self, pair, bias): """cuEquivariance triangle attention implementation.""" # Handle autocast conversion if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(pair.device.type) pair = pair.to(dtype=dtype) bias = bias.to(dtype=dtype) @@ -288,7 +288,7 @@ def _forward_cuequivariance( # Handle autocast conversion # (Use bfloat16 for optimal performance) if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(pair.device.type) pair = pair.to(dtype=dtype) assert ( diff --git a/models/rf3/src/rf3/model/layers/pairformer_layers.py b/models/rf3/src/rf3/model/layers/pairformer_layers.py index 8c2c81c4..4d155932 100644 --- a/models/rf3/src/rf3/model/layers/pairformer_layers.py +++ b/models/rf3/src/rf3/model/layers/pairformer_layers.py @@ -21,6 +21,7 @@ from foundry.model.layers.blocks import Dropout from foundry.training.checkpoint import activation_checkpointing +from foundry.utils.torch import scatter_mean class AtomAttentionEncoderPairformer(nn.Module): @@ -198,17 +199,12 @@ def embed_features(C_L, D_LL, V_LL): # Ensure dtype consistency for index_reduce processed_Q_L = processed_Q_L.to(Q_L.dtype) - A_I = torch.zeros( - A_I_shape, device=Q_L.device, dtype=Q_L.dtype - ).index_reduce( - -2, # Operate on the second-to-last dimension (the atom dimension) - f[ - "atom_to_token_map" - ].long(), # [L], mapping from atom index to token index. Must be a torch.int64 or torch.int32 tensor. - processed_Q_L, # [L, C_atom] -> [L, C_token] - "mean", - include_self=False, # Do not use the original values in A_I (all zeros) when aggregating - ) # [L, C_atom] -> [I, C_token] + A_I = scatter_mean( + torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype), + -2, + f["atom_to_token_map"].long(), # [L], mapping from atom index to token index + processed_Q_L, # (..., L, C_token) + ) # (..., I, C_token) return A_I, Q_L, C_L, P_LL @@ -253,7 +249,7 @@ def forward( assert S_I is None A_I = self.ln_1(A_I) - if self.use_deepspeed_evo or self.force_bfloat16: + if (self.use_deepspeed_evo or self.force_bfloat16) and A_I.device.type != "mps": A_I = A_I.to(torch.bfloat16) Q_IH = self.to_q(A_I) # / np.sqrt(self.c) @@ -266,7 +262,7 @@ def forward( if not self.use_deepspeed_evo or L <= 24: Q_IH = Q_IH / torch.sqrt( - torch.tensor(self.c).to(Q_IH.device, torch.bfloat16) + torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) ) # Attention A_IIH = torch.softmax( diff --git a/models/rfd3/src/rfd3/inference/symmetry/frames.py b/models/rfd3/src/rfd3/inference/symmetry/frames.py index 7907b87f..672d4401 100644 --- a/models/rfd3/src/rfd3/inference/symmetry/frames.py +++ b/models/rfd3/src/rfd3/inference/symmetry/frames.py @@ -178,7 +178,8 @@ def _mean_along_dim(X, dim): R = U @ V if is_torch: F = torch.eye(3, 3, device=R.device).expand(B, 3, 3).clone() - F[..., -1, -1] = torch.sign(torch.linalg.det(R)) + det = torch.linalg.det(R) + F[..., -1, -1] = torch.sign(det) else: F = np.broadcast_to(np.eye(3, 3), (B, 3, 3)).copy() F[..., -1, -1] = np.sign(np.linalg.det(R)) diff --git a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py index b6512d02..9d94c5da 100644 --- a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py +++ b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py @@ -239,8 +239,6 @@ def forward( Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"]) A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx) - # Debug chunked parameters - # ... Run forward with recycling recycled_features = self.forward_with_recycle( n_recycle, @@ -340,7 +338,6 @@ def process_( ), full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"), ) - # ... Decoder readout # Check if using chunked P_LL mode diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index aeac08c8..5af459d1 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -54,6 +54,11 @@ def build_valid_mask( return valid_mask +def _atom_flat_idx(valid_mask: torch.Tensor) -> torch.Tensor: + """Return the 1-D indices of valid atoms in the flattened (n_tokens * A) grid.""" + return valid_mask.flatten().nonzero(as_tuple=False).squeeze(1) + + def ungroup_atoms(Q_L, valid_mask): """ Args @@ -67,11 +72,20 @@ def ungroup_atoms(Q_L, valid_mask): """ B, n_atoms, c = Q_L.shape n_tokens, A = valid_mask.shape - Q_IA = torch.zeros(B, n_tokens, A, c, dtype=Q_L.dtype, device=Q_L.device) - mask4d = valid_mask.unsqueeze(0).unsqueeze(-1) # (1, n_tok, A, 1) - mask4d = mask4d.expand(B, -1, -1, c) # (B, n_tok, A, c) - Q_IA.masked_scatter_(mask4d, Q_L) - return Q_IA + if Q_L.device.type == "mps": + # masked_scatter_ with non-contiguous masks is unreliable on MPS; + # use scatter with integer indices instead. + flat_idx = _atom_flat_idx(valid_mask) # (n_atoms,) + idx = flat_idx.view(1, -1, 1).expand(B, -1, c) # (B, n_atoms, c) + Q_IA = torch.zeros(B, n_tokens * A, c, dtype=Q_L.dtype, device=Q_L.device) + Q_IA = Q_IA.scatter(1, idx, Q_L) + return Q_IA.reshape(B, n_tokens, A, c) + else: + Q_IA = torch.zeros(B, n_tokens, A, c, dtype=Q_L.dtype, device=Q_L.device) + mask4d = valid_mask.unsqueeze(0).unsqueeze(-1) # (1, n_tok, A, 1) + mask4d = mask4d.expand(B, -1, -1, c) # (B, n_tok, A, c) + Q_IA.masked_scatter_(mask4d, Q_L) + return Q_IA def group_atoms(Q_IA: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: @@ -85,10 +99,17 @@ def group_atoms(Q_IA: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: ------- Q_L : (B, n_atoms, c) flattened real atoms, order preserved """ - B, _, _, c = Q_IA.shape - mask4d = valid_mask.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, c) # (B,n_tok,A,c) - Q_L = Q_IA[mask4d].view(B, -1, c) # restore 2‑D shape - return Q_L + B, n_tok, A, c = Q_IA.shape + if Q_IA.device.type == "mps": + # Boolean indexing with non-contiguous expanded masks is unreliable on MPS; + # use integer index gather instead. + flat_idx = _atom_flat_idx(valid_mask) # (n_atoms,) + Q_L = Q_IA.reshape(B, n_tok * A, c)[:, flat_idx, :] + return Q_L.contiguous() + else: + mask4d = valid_mask.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, c) + Q_L = Q_IA[mask4d].view(B, -1, c) + return Q_L def group_pair(P_IAA, valid_mask): @@ -137,9 +158,9 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices): # Handle case when indices and P_LA don't have batch dimensions B, L, k = P_LK_indices.shape if P_LA_indices.ndim == 2: - P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1) + P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1).contiguous() if P_LA_src.ndim == 3: - P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1) + P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1).contiguous() assert ( P_LA_src.shape[-1] == P_LK_tgt.shape[-1] ), "Channel dims do not match, got: {} vs {}".format( @@ -154,8 +175,8 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices): k_indices = matches.long().argmax(dim=-1) # (B, L, a) scatter_indices = k_indices.unsqueeze(-1).expand( -1, -1, -1, P_LK_tgt.shape[-1] - ) # (B, L, a, c) - P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src) + ).contiguous() # (B, L, a, c) + P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src.contiguous()) return P_LK_tgt @@ -169,8 +190,8 @@ def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: k = idx.shape[-1] # (B, L, 1, C) → stride-0 along k → (B, L, k, C) - src = values.unsqueeze(2).expand(-1, -1, k, -1) - idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C) + src = values.unsqueeze(2).expand(-1, -1, k, -1).contiguous() + idx = idx.unsqueeze(-1).expand(-1, -1, -1, C).contiguous() # (B, L, k, C) return torch.gather(src, 1, idx) # dim=1 is the L-axis @@ -350,7 +371,7 @@ def build_index_mask( # Exclude tokens which are partially filled (L, I) n_query_per_token = torch.zeros((L, I), device=device).float() n_query_per_token.scatter_add_( - 1, tok_idx.long()[None, :].expand(L, -1), mask.float() + 1, tok_idx.long()[None, :].expand(L, -1).contiguous(), mask.float() ) # Find mask for the atoms for which the number of keys @@ -407,12 +428,14 @@ def extend_index_mask_with_neighbours( inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) # 1. Selection of sequence neighbours - all_idx_row = torch.arange(L, device=device).expand(L, L) - indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced + # Use .repeat() instead of .expand() to produce a contiguous tensor — MPS does + # not handle non-contiguous inputs to torch.where correctly. + all_idx_row = torch.arange(L, device=device).unsqueeze(0).repeat(L, 1) + indices = torch.where(mask.contiguous(), all_idx_row, inf) # sentinel inf if not-forced indices = indices.sort(dim=1)[0][:, :k] # (L, k) # 2. Find k-nn excluding forced indices - D_LL = torch.where(mask, inf, D_LL) + D_LL = torch.where(mask.contiguous(), inf, D_LL) filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices # ... Reverse last axis s.t. best matched indices are last @@ -420,8 +443,8 @@ def extend_index_mask_with_neighbours( # 3. Fill indices to_fill = indices == inf - to_fill = to_fill.expand_as(filler_idx) - indices = indices.expand_as(filler_idx) + to_fill = to_fill.expand_as(filler_idx).contiguous() + indices = indices.expand_as(filler_idx).contiguous() indices = torch.where(to_fill, filler_idx, indices) return indices.long() # (B, L, k) @@ -437,7 +460,7 @@ def get_sparse_attention_indices( # Sort and assert no duplicates (optional but good practise) indices, _ = torch.sort(indices, dim=-1) - if (indices[..., 1:] == indices[..., :-1]).any(): + if indices.device.type != "mps" and (indices[..., 1:] == indices[..., :-1]).any(): raise AssertionError("Tensor has duplicate elements along the last dimension.") assert ( diff --git a/models/rfd3/src/rfd3/model/layers/blocks.py b/models/rfd3/src/rfd3/model/layers/blocks.py index eaf08093..9963d6ec 100644 --- a/models/rfd3/src/rfd3/model/layers/blocks.py +++ b/models/rfd3/src/rfd3/model/layers/blocks.py @@ -30,6 +30,7 @@ from foundry import DISABLE_CHECKPOINTING from foundry.common import exists +from foundry.utils.torch import scatter_mean logger = logging.getLogger(__name__) @@ -213,16 +214,11 @@ def forward(self, R_L, tok_idx): self.c_token, ) Q_L = self.linear(R_L) - A_I = ( - torch.zeros(A_I_shape, device=R_L.device, dtype=Q_L.dtype) - .index_reduce( - -2, - tok_idx.long(), - Q_L, - "mean", - include_self=False, - ) - .clone() + A_I = scatter_mean( + torch.zeros(A_I_shape, device=R_L.device, dtype=Q_L.dtype), + -2, + tok_idx.long(), + Q_L, ) return A_I diff --git a/models/rfd3/src/rfd3/model/layers/pairformer_layers.py b/models/rfd3/src/rfd3/model/layers/pairformer_layers.py index e7f327e8..eec38618 100644 --- a/models/rfd3/src/rfd3/model/layers/pairformer_layers.py +++ b/models/rfd3/src/rfd3/model/layers/pairformer_layers.py @@ -49,7 +49,7 @@ def forward( assert S_I is None A_I = self.ln_1(A_I) - if self.use_deepspeed_evo or self.force_bfloat16: + if (self.use_deepspeed_evo or self.force_bfloat16) and A_I.device.type != "mps": A_I = A_I.to(torch.bfloat16) Q_IH = self.to_q(A_I) # / np.sqrt(self.c) @@ -62,7 +62,7 @@ def forward( if not self.use_deepspeed_evo or L <= 24: Q_IH = Q_IH / torch.sqrt( - torch.tensor(self.c).to(Q_IH.device, torch.bfloat16) + torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype) ) # Attention A_IIH = torch.softmax( @@ -116,8 +116,10 @@ def __init__( @activation_checkpointing def forward(self, S_I, Z_II): + _device = device_of(self) + _use_autocast = _device.type != "mps" with torch.amp.autocast( - device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16 + device_type=_device.type, enabled=_use_autocast, dtype=torch.bfloat16 ): Z_II = Z_II + self.z_transition(Z_II) if S_I is not None: diff --git a/models/rfd3/src/rfd3/testing/debug.py b/models/rfd3/src/rfd3/testing/debug.py index 856ad890..50d6ee28 100755 --- a/models/rfd3/src/rfd3/testing/debug.py +++ b/models/rfd3/src/rfd3/testing/debug.py @@ -51,7 +51,12 @@ def forward(example, trainer, model, is_inference=is_inference): network_input = trainer._assemble_network_inputs(example) # Forward pass - device = "cuda:0" + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" def _inmap(path, x): if hasattr(x, "cpu") and path != ("f", "msa_stack"): diff --git a/src/foundry/testing/fixtures.py b/src/foundry/testing/fixtures.py index fee9a281..9e67f282 100644 --- a/src/foundry/testing/fixtures.py +++ b/src/foundry/testing/fixtures.py @@ -8,8 +8,8 @@ @pytest.fixture(scope="session") def gpu(): - """Fixture to check GPU availability for tests that require CUDA.""" - if not torch.cuda.is_available(): + """Fixture to check GPU availability for tests that require CUDA or MPS.""" + if not torch.cuda.is_available() and not torch.backends.mps.is_available(): pytest.skip("GPU not available") return True diff --git a/src/foundry/utils/alignment.py b/src/foundry/utils/alignment.py index a2e5d740..c5d066dd 100644 --- a/src/foundry/utils/alignment.py +++ b/src/foundry/utils/alignment.py @@ -69,7 +69,8 @@ def weighted_rigid_align( ) ) - F[..., -1, -1] = torch.sign(torch.linalg.det(R)) + det = torch.linalg.det(R) + F[..., -1, -1] = torch.sign(det) R = U @ F @ V X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2) diff --git a/src/foundry/utils/ddp.py b/src/foundry/utils/ddp.py index 53a37bfa..fc9983fa 100644 --- a/src/foundry/utils/ddp.py +++ b/src/foundry/utils/ddp.py @@ -39,9 +39,14 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig): elif hasattr(torch, "xpu") and torch.xpu.is_available(): logger.info("Intel XPU detected - using XPU accelerator") cfg.trainer.accelerator = "xpu" + elif torch.backends.mps.is_available(): + logger.info("Apple MPS detected - using MPS accelerator with float32 precision") + cfg.trainer.accelerator = "mps" + if hasattr(cfg.trainer, "precision"): + cfg.trainer.precision = "32-true" else: logger.error( - "No GPUs/XPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?" + "No GPUs/XPUs/MPS available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?" ) cfg.trainer.accelerator = "cpu" cfg.trainer.devices_per_node = 1 diff --git a/src/foundry/utils/torch.py b/src/foundry/utils/torch.py index ab899eb7..16fbe4b7 100755 --- a/src/foundry/utils/torch.py +++ b/src/foundry/utils/torch.py @@ -1,6 +1,6 @@ """General convenience utilities for PyTorch.""" -__all__ = ["map_to", "assert_no_nans", "assert_shape", "assert_same_shape"] +__all__ = ["map_to", "assert_no_nans", "assert_shape", "assert_same_shape", "scatter_mean"] import time import warnings @@ -198,6 +198,48 @@ def assert_same_shape(tensor: Tensor, ref_tensor: Tensor) -> None: assert_shape(tensor, ref_tensor.shape) +def scatter_mean(zeros: Tensor, dim: int, index: Tensor, source: Tensor) -> Tensor: + """Scatter-mean aggregation, with an MPS-compatible fallback. + + On non-MPS devices uses index_reduce (faster, in-place kernel). + On MPS, index_reduce is not implemented so falls back to scatter_add + count. + + Equivalent to: zeros.index_reduce(dim, index, source, 'mean', include_self=False) + + Args: + zeros: Pre-allocated zero tensor, shape (..., I, C). Will not be modified in-place. + dim: Dimension to scatter along. Must not be the last dimension. + index: 1D index tensor of shape (N,) mapping source positions to output positions. + source: Source tensor where size at `dim` equals N. + + Returns: + Tensor of same shape as zeros. + """ + if zeros.device.type != "mps": + return zeros.index_reduce(dim, index, source, "mean", include_self=False) + + ndim = source.dim() + if dim < 0: + dim = ndim + dim + + # Expand 1D index (N,) to match source shape (..., N, C) + shape = [1] * ndim + shape[dim] = index.shape[0] + idx = index.view(shape).expand_as(source) + + # Sum source values into output positions + result = zeros.scatter_add(dim, idx, source) + + # Count how many source values land in each output position. + # Take a single slice along the last dim to avoid allocating a full (N, C) ones tensor. + idx_count = idx[..., :1] # (..., N, 1) + ones = torch.ones_like(source[..., :1]) # (..., N, 1) + count = torch.zeros(*zeros.shape[:-1], 1, device=zeros.device, dtype=zeros.dtype) + count = count.scatter_add(dim, idx_count, ones) # (..., I, 1) + + return result / count.clamp(min=1) + + def device_of(obj: Any) -> torch.device: """Get the device of a PyTorch object, e.g. a `nn.Module` or a `Tensor`.""" if hasattr(obj, "device"):