Reference implementation for triangle updates#5732
Conversation
|
!test |
|
Review updated until commit 5502896 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Removed validation in broadcast_in_dim_fn
|
Greptile SummaryThis PR implements a reference implementation for AlphaFold3 triangle updates and enhances the existing triangle attention with layer normalization. Key Changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_triangle_updates
participant LN as layer_norm
participant G as gating
participant Ops as fd.ops
Test->>Ops: define_tensor (z_in, weights, mask)
Test->>LN: layer_norm(z_in, w_norm_in, b_norm_in)
LN->>Ops: cast, var_mean, sub, rsqrt, mul, add
LN-->>Test: normalized z_in
Test->>G: gating(z_in, w_p_in, z_in, w_g_in)
G->>Ops: linear(z, w_p) → p
G->>Ops: linear(z_in, w_g) → g
G->>Ops: sigmoid(g), mul(p, g)
G-->>Test: z [b, i, j, c_z*2]
Test->>Ops: broadcast_in_dim(mask)
Test->>Ops: where(mask, z, 0.0)
Test->>Ops: slice(z) → a, b
Test->>Ops: permute (OUTGOING/INCOMING)
Test->>Ops: matmul(a, b)
Test->>LN: layer_norm(z, w_norm_out, b_norm_out)
Test->>G: gating(z, w_p_out, z_in, w_g_out)
Test->>Ops: add_output(z)
|
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
382fd62 to
e1f4f08
Compare
|
!test |
| ------- | ||
| DataType | ||
| The data type of this tensor. |
There was a problem hiding this comment.
syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.
| ------- | |
| DataType | |
| The data type of this tensor. | |
| ------- | |
| PrimDataType | |
| The data type of this tensor. |
There was a problem hiding this comment.
This is fine given
Fuser/python/python_direct/enum.cpp
Line 20 in 21524df
|
!test |
|
!test |
|
!test |
Other changes:
cc @DejunL