From 81175fa1b7613e73ef82a2e46f10665e42dd3579 Mon Sep 17 00:00:00 2001 From: jonfunk21 Date: Tue, 3 Mar 2026 10:23:59 +0100 Subject: [PATCH 1/3] Bug for sequences > 2048 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bug in extend_index_mask_with_neighbours (line 407-412): inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) # ← takes dtype from D_LL all_idx_row = torch.arange(L, device=device).expand(L, L) # int64 indices 0..2535 indices = torch.where(mask, all_idx_row, inf) # ← int64 + bfloat16 → promotes to bfloat16! When your training runs under bfloat16 AMP, D_LL (from torch.cdist(X_L, X_L)) is bfloat16. The torch.where has to reconcile int64 and bfloat16 → promotes to bfloat16. bfloat16 has only 8 significant bits, so it can only exactly represent integers up to 256. At seq_len>2048 indices 2048, 2049, ..., 2063 all become 2048.0 in bfloat16. Guaranteed collisions, which now trigger the new assertion. --- models/rfd3/src/rfd3/model/layers/block_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index aeac08c8..739b9ec3 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -404,22 +404,26 @@ def extend_index_mask_with_neighbours( k = min(k, L) assert mask.shape == (L, L) and D_LL.shape == (B, L, L) device = D_LL.device - inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) + # float32 sentinel for index arithmetic: bfloat16/float16 cannot exactly represent + # large integer indices (e.g. bfloat16 loses precision above ~256), so we keep + # the index tensor in float32 regardless of D_LL.dtype to avoid spurious duplicates. + inf_idx = torch.tensor(float("inf"), dtype=torch.float32, device=device) + inf_dist = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device) # 1. Selection of sequence neighbours - all_idx_row = torch.arange(L, device=device).expand(L, L) - indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced - indices = indices.sort(dim=1)[0][:, :k] # (L, k) + all_idx_row = torch.arange(L, device=device, dtype=torch.float32).expand(L, L) + indices = torch.where(mask, all_idx_row, inf_idx) # sentinel inf if not-forced + indices = indices.sort(dim=1)[0][:, :k] # (L, k) float32 — exact for any practical L # 2. Find k-nn excluding forced indices - D_LL = torch.where(mask, inf, D_LL) + D_LL = torch.where(mask, inf_dist, D_LL) filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices # ... Reverse last axis s.t. best matched indices are last filler_idx = filler_idx.flip(dims=[-1]) # 3. Fill indices - to_fill = indices == inf + to_fill = indices == inf_idx to_fill = to_fill.expand_as(filler_idx) indices = indices.expand_as(filler_idx) indices = torch.where(to_fill, filler_idx, indices) From b37674d03269dbbe751b3e736611f32c710922df Mon Sep 17 00:00:00 2001 From: jonfunk21 Date: Tue, 3 Mar 2026 11:08:45 +0100 Subject: [PATCH 2/3] fix: cast X_L to float32 before cdist in create_attention_indices bfloat16 distances have step size ~16 in the range 2048-4096 (angstroms at high noise timesteps). This creates massive ties in D_LL, causing topk's tie-breaking to behave non-deterministically and potentially return non-unique indices, which then triggers the duplicate assertion in get_sparse_attention_indices. Casting X_L to float32 before cdist ensures stable, unique topk results regardless of the AMP precision used for the model forward pass. --- models/rfd3/src/rfd3/model/layers/block_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index 739b9ec3..436f710e 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -196,7 +196,7 @@ def create_attention_indices( X_L = torch.randn( (1, L, 3), device=device, dtype=torch.float ) # [L, 3] - random - D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances + D_LL = torch.cdist(X_L.float(), X_L.float(), p=2) # [B, L, L] - always float32: bfloat16 quantises distances severely at large t, causing topk ties and duplicate attention indices # Create attention indices using neighbour distances base_mask = ~f["unindexing_pair_mask"][ From c6adeaab7202a6255bd70c8ea283efed97bedb88 Mon Sep 17 00:00:00 2001 From: jonfunk21 Date: Tue, 3 Mar 2026 11:52:24 +0100 Subject: [PATCH 3/3] fix: nan_to_num on D_LL to handle NaN coords at high noise timesteps bfloat16 overflow at t=4608 (maximum diffusion noise) produces NaN coordinates in X_L. These propagate through cdist into D_LL. After torch.where masks forced positions to inf, the non-forced positions remain NaN. torch.topk with NaN inputs has undefined CUDA behaviour and returns duplicate indices, triggering the assertion. Replace NaN distances with a large finite sentinel (1e8) so topk always operates on well-defined values. --- models/rfd3/src/rfd3/model/layers/block_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index 436f710e..d8ff1637 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -197,6 +197,7 @@ def create_attention_indices( (1, L, 3), device=device, dtype=torch.float ) # [L, 3] - random D_LL = torch.cdist(X_L.float(), X_L.float(), p=2) # [B, L, L] - always float32: bfloat16 quantises distances severely at large t, causing topk ties and duplicate attention indices + D_LL = D_LL.nan_to_num(nan=1e8) # bfloat16 overflow at high noise t produces NaN coords → NaN distances → topk undefined behaviour → duplicate indices # Create attention indices using neighbour distances base_mask = ~f["unindexing_pair_mask"][