Skip to content

Commit e465cfd

Browse files
Use expand/repeat if-else for MPS backward compat in block_utils
The unconditional .repeat(L, 1) on line 431 allocated an O(L^2) tensor on every backend. .expand() is zero-copy but produces a non-contiguous view that MPS torch.where cannot handle. Guard with an if-else so CUDA/CPU keep the original expand path and MPS gets repeat. Also runs ruff 0.8.3 format on files that were formatted with a different ruff version. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2e4f513 commit e465cfd

4 files changed

Lines changed: 27 additions & 16 deletions

File tree

models/rf3/src/rf3/model/layers/pairformer_layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ def embed_features(C_L, D_LL, V_LL):
202202
A_I = scatter_mean(
203203
torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype),
204204
-2,
205-
f["atom_to_token_map"].long(), # [L], mapping from atom index to token index
205+
f[
206+
"atom_to_token_map"
207+
].long(), # [L], mapping from atom index to token index
206208
processed_Q_L, # (..., L, C_token)
207209
) # (..., I, C_token)
208210

@@ -261,9 +263,7 @@ def forward(
261263
B, L = B_IIH.shape[:2]
262264

263265
if not self.use_deepspeed_evo or L <= 24:
264-
Q_IH = Q_IH / torch.sqrt(
265-
torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)
266-
)
266+
Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype))
267267
# Attention
268268
A_IIH = torch.softmax(
269269
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2

models/rfd3/src/rfd3/model/layers/block_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
173173
elif not torch.all(matches.sum(dim=-1) <= 1):
174174
raise ValueError("Did not find a scatter index for every atom")
175175
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
176-
scatter_indices = k_indices.unsqueeze(-1).expand(
177-
-1, -1, -1, P_LK_tgt.shape[-1]
178-
).contiguous() # (B, L, a, c)
179-
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src.contiguous())
176+
scatter_indices = (
177+
k_indices.unsqueeze(-1).expand(-1, -1, -1, P_LK_tgt.shape[-1]).contiguous()
178+
) # (B, L, a, c)
179+
P_LK_tgt = P_LK_tgt.scatter_add(
180+
dim=2, index=scatter_indices, src=P_LA_src.contiguous()
181+
)
180182
return P_LK_tgt
181183

182184

@@ -428,10 +430,15 @@ def extend_index_mask_with_neighbours(
428430
inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device)
429431

430432
# 1. Selection of sequence neighbours
431-
# Use .repeat() instead of .expand() to produce a contiguous tensor — MPS does
432-
# not handle non-contiguous inputs to torch.where correctly.
433-
all_idx_row = torch.arange(L, device=device).unsqueeze(0).repeat(L, 1)
434-
indices = torch.where(mask.contiguous(), all_idx_row, inf) # sentinel inf if not-forced
433+
# MPS does not handle non-contiguous inputs to torch.where correctly,
434+
# so use .repeat() (allocates) there; .expand() (zero-copy view) elsewhere.
435+
if device.type == "mps":
436+
all_idx_row = torch.arange(L, device=device).unsqueeze(0).repeat(L, 1)
437+
else:
438+
all_idx_row = torch.arange(L, device=device).unsqueeze(0).expand(L, L)
439+
indices = torch.where(
440+
mask.contiguous(), all_idx_row, inf
441+
) # sentinel inf if not-forced
435442
indices = indices.sort(dim=1)[0][:, :k] # (L, k)
436443

437444
# 2. Find k-nn excluding forced indices

models/rfd3/src/rfd3/model/layers/pairformer_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def forward(
6161
B, L = B_IIH.shape[:2]
6262

6363
if not self.use_deepspeed_evo or L <= 24:
64-
Q_IH = Q_IH / torch.sqrt(
65-
torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype)
66-
)
64+
Q_IH = Q_IH / torch.sqrt(torch.tensor(self.c).to(Q_IH.device, Q_IH.dtype))
6765
# Attention
6866
A_IIH = torch.softmax(
6967
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2

src/foundry/utils/torch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
"""General convenience utilities for PyTorch."""
22

3-
__all__ = ["map_to", "assert_no_nans", "assert_shape", "assert_same_shape", "scatter_mean"]
3+
__all__ = [
4+
"map_to",
5+
"assert_no_nans",
6+
"assert_shape",
7+
"assert_same_shape",
8+
"scatter_mean",
9+
]
410

511
import time
612
import warnings

0 commit comments

Comments
 (0)