-
Notifications
You must be signed in to change notification settings - Fork 74
Reference implementation for triangle updates #5732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
!test |
|
Review updated until commit 5502896 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Removed validation in broadcast_in_dim_fn
|
Greptile SummaryThis PR implements a reference implementation for AlphaFold3 triangle updates and enhances the existing triangle attention with layer normalization. Key Changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_triangle_updates
participant LN as layer_norm
participant G as gating
participant Ops as fd.ops
Test->>Ops: define_tensor (z_in, weights, mask)
Test->>LN: layer_norm(z_in, w_norm_in, b_norm_in)
LN->>Ops: cast, var_mean, sub, rsqrt, mul, add
LN-->>Test: normalized z_in
Test->>G: gating(z_in, w_p_in, z_in, w_g_in)
G->>Ops: linear(z, w_p) → p
G->>Ops: linear(z_in, w_g) → g
G->>Ops: sigmoid(g), mul(p, g)
G-->>Test: z [b, i, j, c_z*2]
Test->>Ops: broadcast_in_dim(mask)
Test->>Ops: where(mask, z, 0.0)
Test->>Ops: slice(z) → a, b
Test->>Ops: permute (OUTGOING/INCOMING)
Test->>Ops: matmul(a, b)
Test->>LN: layer_norm(z, w_norm_out, b_norm_out)
Test->>G: gating(z, w_p_out, z_in, w_g_out)
Test->>Ops: add_output(z)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming the layernorm is skipped for simplicity.
Apart from that, the other differences are:
- No mask application -- looks like you are going to add it based on above discussion.
- No output gating -- is this skipped intentionally?
382fd62 to
e1f4f08
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| ------- | ||
| DataType | ||
| The data type of this tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.
| ------- | |
| DataType | |
| The data type of this tensor. | |
| ------- | |
| PrimDataType | |
| The data type of this tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine given
Fuser/python/python_direct/enum.cpp
Line 20 in 21524df
| py::enum_<PrimDataType>(nvfuser, "DataType", py::module_local()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
|
!test |
|
!test |
|
!test |
Other changes:
cc @DejunL