diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index aeac08c8..d8ff1637 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -196,7 +196,8 @@ def create_attention_indices( X_L = torch.randn( (1, L, 3), device=device, dtype=torch.float ) # [L, 3] - random - D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances + D_LL = torch.cdist(X_L.float(), X_L.float(), p=2) # [B, L, L] - always float32: bfloat16 quantises distances severely at large t, causing topk ties and duplicate attention indices + D_LL = D_LL.nan_to_num(nan=1e8) # bfloat16 overflow at high noise t produces NaN coords → NaN distances → topk undefined behaviour → duplicate indices # Create attention indices using neighbour distances base_mask = ~f["unindexing_pair_mask"][ @@ -404,22 +405,26 @@ def extend_index_mask_with_neighbours( k = min(k, L) assert mask.shape == (L, L) and D_LL.shape == (B, L, L) device = D_LL.device - inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) + # float32 sentinel for index arithmetic: bfloat16/float16 cannot exactly represent + # large integer indices (e.g. bfloat16 loses precision above ~256), so we keep + # the index tensor in float32 regardless of D_LL.dtype to avoid spurious duplicates. + inf_idx = torch.tensor(float("inf"), dtype=torch.float32, device=device) + inf_dist = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) # 1. Selection of sequence neighbours - all_idx_row = torch.arange(L, device=device).expand(L, L) - indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced - indices = indices.sort(dim=1)[0][:, :k] # (L, k) + all_idx_row = torch.arange(L, device=device, dtype=torch.float32).expand(L, L) + indices = torch.where(mask, all_idx_row, inf_idx) # sentinel inf if not-forced + indices = indices.sort(dim=1)[0][:, :k] # (L, k) float32 — exact for any practical L # 2. Find k-nn excluding forced indices - D_LL = torch.where(mask, inf, D_LL) + D_LL = torch.where(mask, inf_dist, D_LL) filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices # ... Reverse last axis s.t. best matched indices are last filler_idx = filler_idx.flip(dims=[-1]) # 3. Fill indices - to_fill = indices == inf + to_fill = indices == inf_idx to_fill = to_fill.expand_as(filler_idx) indices = indices.expand_as(filler_idx) indices = torch.where(to_fill, filler_idx, indices)