Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions models/rfd3/src/rfd3/model/layers/block_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def create_attention_indices(
X_L = torch.randn(
(1, L, 3), device=device, dtype=torch.float
) # [L, 3] - random
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
D_LL = torch.cdist(X_L.float(), X_L.float(), p=2) # [B, L, L] - always float32: bfloat16 quantises distances severely at large t, causing topk ties and duplicate attention indices
D_LL = D_LL.nan_to_num(nan=1e8) # bfloat16 overflow at high noise t produces NaN coords → NaN distances → topk undefined behaviour → duplicate indices

# Create attention indices using neighbour distances
base_mask = ~f["unindexing_pair_mask"][
Expand Down Expand Up @@ -404,22 +405,26 @@ def extend_index_mask_with_neighbours(
k = min(k, L)
assert mask.shape == (L, L) and D_LL.shape == (B, L, L)
device = D_LL.device
inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device)
# float32 sentinel for index arithmetic: bfloat16/float16 cannot exactly represent
# large integer indices (e.g. bfloat16 loses precision above ~256), so we keep
# the index tensor in float32 regardless of D_LL.dtype to avoid spurious duplicates.
inf_idx = torch.tensor(float("inf"), dtype=torch.float32, device=device)
inf_dist = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device)

# 1. Selection of sequence neighbours
all_idx_row = torch.arange(L, device=device).expand(L, L)
indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced
indices = indices.sort(dim=1)[0][:, :k] # (L, k)
all_idx_row = torch.arange(L, device=device, dtype=torch.float32).expand(L, L)
indices = torch.where(mask, all_idx_row, inf_idx) # sentinel inf if not-forced
indices = indices.sort(dim=1)[0][:, :k] # (L, k) float32 — exact for any practical L

# 2. Find k-nn excluding forced indices
D_LL = torch.where(mask, inf, D_LL)
D_LL = torch.where(mask, inf_dist, D_LL)
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices

# ... Reverse last axis s.t. best matched indices are last
filler_idx = filler_idx.flip(dims=[-1])

# 3. Fill indices
to_fill = indices == inf
to_fill = indices == inf_idx
to_fill = to_fill.expand_as(filler_idx)
indices = indices.expand_as(filler_idx)
indices = torch.where(to_fill, filler_idx, indices)
Expand Down