Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion csrc/ir/internal_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3440,7 +3440,14 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
}
}
if (attn_bias.defined()) {
attn_bias = flattenBatchDims(attn_bias);
// `attn_bias` is of shape [B, N, H, Q, K]. For triangle attention starting
// nodes, B and N are adjacent in stride order and therefore can be
// flattened with a `view`. For ending nodes, however, `B` and `N` are no
// longer adjacent in stride order due to `mask` being transposed (see
// test_alphafold3.py:test_triangle_attention). `attn_bias` can't be
// `flattenBatchDims`ed with a `view`. Therefore, `contiguous()` is
// required.
attn_bias = flattenBatchDims(attn_bias.contiguous());
}

// 4D SDPA
Expand Down
81 changes: 54 additions & 27 deletions tests/python/direct/test_alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# This file contains certain building blocks of the AlphaFold3 model.

from dataclasses import dataclass

import pytest
import torch
from dataclasses import dataclass
from enum import Enum, auto

from nvfuser_direct import FusionDefinition, DataType

Expand All @@ -16,22 +17,29 @@
class ModelConfig:
c_z: int = 128
c_hidden: int = 32
n_heads: int = 2
n_heads: int = 4


_DEFAULT_CONFIG = ModelConfig()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DejunL, what are the sizes people use in practice?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Boltz reference model typically has:

  • B or batch_size of 1
  • N or token counts of {inference: however many in request but typically ~100 to ~2000, training: {stage1: 256, stage2: 512, stage3: 768}
  • c_z or token_z or hidden dimension of pair representation z is 128
  • num_heads is 4
  • head_dim or c_hidden is 32

But in some other models they vary but probably within same order of magnitude. Typically structure prediction models are small in model weight counts but large in activation so hidden dimensions are typically small as such but I do see some models experiment with larger hidden dimensions.



def test_triangle_updates_outgoing():
pass
class Direction(Enum):
INCOMING = auto() # aka ending node
OUTGOING = auto() # aka starting node


def test_triangle_updates_incoming():
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
def test_triangle_updates(direction):
pass


# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-attention
def test_triangle_attention_starting_node():
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
def test_triangle_attention(direction):
c_z, c_hidden, h = (
_DEFAULT_CONFIG.c_z,
_DEFAULT_CONFIG.c_hidden,
Expand All @@ -40,15 +48,33 @@ def test_triangle_attention_starting_node():

with FusionDefinition() as fd:
z_in = fd.define_tensor(
shape=[-1, -1, -1, c_z], dtype=DataType.BFloat16
shape=[-1, -1, -1, c_z],
dtype=DataType.BFloat16,
contiguity=True,
) # [b, i, j, c_z]
w_q = fd.define_tensor(shape=[h * c_hidden, c_z], dtype=DataType.BFloat16)
w_k = fd.define_tensor(shape=[h * c_hidden, c_z], dtype=DataType.BFloat16)
w_b = fd.define_tensor(shape=[h, c_z], dtype=DataType.BFloat16)
mask = fd.define_tensor(shape=[-1, -1, -1], dtype=DataType.Bool) # [b, i, j]
w_v = fd.define_tensor(shape=[h * c_hidden, c_z], dtype=DataType.BFloat16)
w_g = fd.define_tensor(shape=[h * c_hidden, c_z], dtype=DataType.BFloat16)
w_o = fd.define_tensor(shape=[c_z, h * c_hidden], dtype=DataType.BFloat16)
if direction == Direction.INCOMING:
z_in = fd.ops.permute(z_in, [0, 2, 1, 3])
w_q = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_k = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_b = fd.define_tensor(shape=[h, c_z], dtype=DataType.BFloat16, contiguity=True)
mask = fd.define_tensor(
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
) # [b, i, j]
if direction == Direction.INCOMING:
mask = fd.ops.permute(mask, [0, 2, 1])
w_v = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_o = fd.define_tensor(
shape=[c_z, h * c_hidden], dtype=DataType.BFloat16, contiguity=True
)

batch_size = fd.ops.size(z_in, 0)
n_tokens = fd.ops.size(z_in, 1)
Expand All @@ -62,34 +88,37 @@ def test_triangle_attention_starting_node():
k = fd.ops.linear(z_in, w_k)
k_h = fd.ops.reshape(
k, [batch_size, n_tokens, n_tokens, h, -1]
) # [b, i, j, h, c_hidden]
k_h = fd.ops.permute(k_h, [0, 1, 3, 2, 4]) # [b, i, h, j, c_hidden]
) # [b, i, k, h, c_hidden]
k_h = fd.ops.permute(k_h, [0, 1, 3, 2, 4]) # [b, i, h, k, c_hidden]

b_h = fd.ops.linear(z_in, w_b) # [b, i, j, h]
b_h = fd.ops.permute(b_h, [0, 3, 1, 2]) # [b, h, i, j]
b_h = fd.ops.linear(z_in, w_b) # [b, j, k, h]
b_h = fd.ops.permute(b_h, [0, 3, 1, 2]) # [b, h, j, k]
b_h = fd.ops.broadcast_in_dim(
b_h,
shape=[batch_size, 1, h, n_tokens, n_tokens],
broadcast_dims=[0, 2, 3, 4],
) # [b, 1, h, i, j]
) # [b, 1, h, j, k]

mask = fd.ops.broadcast_in_dim(
mask,
shape=[batch_size, n_tokens, 1, 1, n_tokens],
broadcast_dims=[0, 1, 4],
) # [b, i, 1, 1, j]
) # [b, i, 1, 1, k]

v = fd.ops.linear(z_in, w_v)
v_h = fd.ops.reshape(
v, [batch_size, n_tokens, n_tokens, h, -1]
) # [b, i, j, h, c_hidden]
v_h = fd.ops.permute(v_h, [0, 1, 3, 2, 4]) # [b, i, h, j, c_hidden]
) # [b, i, k, h, c_hidden]
v_h = fd.ops.permute(v_h, [0, 1, 3, 2, 4]) # [b, i, h, k, c_hidden]

# k_h.T: [b, i, h, c_hidden, k]
# attention_matrix: [b, i, h, j, k]
o_h, _, _, _ = fd.ops.sdpfa_fwd(
q_h, k_h, v_h, bias=b_h, mask=mask, is_causal=False
) # [b, i, h, j, c_hidden]

g = fd.ops.linear(z_in, w_g)
g = fd.ops.sigmoid(g)
g_h = fd.ops.reshape(
g, [batch_size, n_tokens, n_tokens, h, -1]
) # [b, i, j, h, c_hidden]
Expand All @@ -104,6 +133,8 @@ def test_triangle_attention_starting_node():
) # [b, i, j, h * c_hidden]

z_out = fd.ops.linear(o, w_o) # [b, i, j, c_z]
if direction == Direction.INCOMING:
z_out = fd.ops.permute(z_out, [0, 2, 1, 3])
fd.add_output(z_out)

batch_size = 3
Expand Down Expand Up @@ -132,7 +163,3 @@ def test_triangle_attention_starting_node():
)
(z_out,) = fd.execute([z_in, w_q, w_k, w_b, mask, w_v, w_g, w_o])
assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)


def test_triangle_attention_ending_node():
pass