Skip to content

Optimize attention computation with local window bias#254

Open
kaistroh wants to merge 1 commit intoRosettaCommons:productionfrom
kaistroh:fixlargesystems_error
Open

Optimize attention computation with local window bias#254
kaistroh wants to merge 1 commit intoRosettaCommons:productionfrom
kaistroh:fixlargesystems_error

Conversation

@kaistroh
Copy link
Copy Markdown

Refactor attention mechanism to compute pair bias within local windows, avoiding large tensor operations.

For systems with >16,384 atoms RF3 produces partially unphysical structures.

In AttentionPairBiasDiffusion.atom_attention(), the pair representation tensor is processed as:

B_IIH = self.to_b(self.ln_0(Z_II))  # Z_II shape: [1, L, L, 16]

For large systems the number of elements in Z_II exceeds 2^32. (At least in some versions) PyTorch's CUDA nn.LayerNorm kernel uses 32-bit unsigned integer offsets internally. Which causes corrupt output if the number of elements is > 2^32.

The fix gathers the local Z_II window before applying the LayerNorm, instead of computing self.to_b(self.ln_0(Z_II)) on the full [1, L, L, 16] tensor and then indexing into it. This is mathematically identical (LayerNorm normalizes over the last dimension independently per position), but keeps the tensor size to [1, nqbatch, 32, 128, 16]

This fixes #238 and saves a fair amount of GPU RAM because the full B_IIH is never materialized.

Refactor attention mechanism to compute pair bias within local windows, avoiding large tensor operations.
Copy link
Copy Markdown
Collaborator

@Ubiquinone-dot Ubiquinone-dot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] RF3 generates unphysical structures for large complexes

2 participants