diff --git a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py index 11998428..c31349e7 100644 --- a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py +++ b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py @@ -499,7 +499,6 @@ def atom_attention(self, A_I, S_I, Z_II, qbatch=32, kbatch=128): Q_IH = self.to_q(A_I) K_IH = self.to_k(A_I) V_IH = self.to_v(A_I) - B_IIH = self.to_b(self.ln_0(Z_II)) G_IH = self.to_g(A_I) if self.kq_norm: @@ -523,12 +522,18 @@ def atom_attention(self, A_I, S_I, Z_II, qbatch=32, kbatch=128): maskK = (indicesK < 0) | (indicesK > L - 1) indicesK = torch.clamp(indicesK, 0, L - 1) + # Compute pair bias within local windows only, to avoid applying + # LayerNorm to the full [1, L, L, c_pair] tensor which can exceed + # 2^32 elements for large systems and silently corrupt CUDA outputs. + Z_local = Z_II[:, indicesQ[:, :, None], indicesK[:, None, :]] + B_local = self.to_b(self.ln_0(Z_local)) + query_subset = Q_IH[:, indicesQ] key_subset = K_IH[:, indicesK] attn = torch.einsum("...ihd,...jhd->...ijh", query_subset, key_subset) attn = attn / (self.c**0.5) - attn += B_IIH[:, indicesQ[:, :, None], indicesK[:, None, :]] - 1e9 * ( + attn += B_local - 1e9 * ( maskQ[None, :, :, None, None] + maskK[None, :, None, :, None] ) attn = torch.softmax(attn, dim=-2)