Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Dec 21, 2025

Other changes:

  1. Add TensorView.dtype for convenience.
  2. Clean up broadcast_in_dim_fn.
  3. Add layernorm to triangle attention.

cc @DejunL

@wujingyue wujingyue requested a review from Priya2698 December 21, 2025 04:05
@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Dec 21, 2025

Review updated until commit 5502896

Description

  • Add TensorView.dtype() convenience method for accessing tensor data types

  • Refactor broadcast_in_dim_fn with improved parameter naming and validation

  • Implement reference triangle updates (outgoing/incoming) with layer normalization

  • Add comprehensive triangle attention implementation with proper masking

Changes walkthrough

Relevant files
Enhancement
ir.cpp
Add TensorView.dtype convenience method                                   

python/python_direct/ir.cpp

  • Added TensorView.dtype() method that returns PrimDataType
  • Includes proper error checking for non-PrimDataType cases
  • Provides convenient access to tensor data types
  • +18/-0   
    ops.cpp
    Refactor broadcast_in_dim_fn implementation                           

    python/python_direct/ops.cpp

  • Renamed parameters: arg->input, broadcast_dims->nonbroadcast_dims
  • Simplified validation logic with wrapDim for negative indices
  • Improved error messages and bounds checking
  • Restructured function flow for better readability
  • +18/-32 
    Tests
    test_alphafold3.py
    Implement AlphaFold3 triangle updates and attention           

    tests/python/direct/test_alphafold3.py

  • Added layer_norm function implementation with proper casting
  • Added gating function for linear transformations with sigmoid
    activation
  • Implemented triangle updates (outgoing/incoming) with einsum
    operations
  • Enhanced triangle attention with layer normalization and proper
    masking
  • Added comprehensive test cases with realistic tensor dimensions
  • +155/-7 
    opinfo_input_generators.py
    Update broadcast_in_dim error test cases                                 

    tests/python/opinfo/opinfo_input_generators.py

  • Removed descending_broadcast_dimensions error test case
  • Updated test cases to match new broadcast_in_dim_fn behavior
  • Maintains other validation test cases for broadcast operations
  • +3/-11   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Removed validation in broadcast_in_dim_fn

    The broadcast_in_dim_fn function had its validation logic significantly simplified. Specifically, the old code checked that broadcast_dims were in ascending order and performed more comprehensive validation. The new code removes this ascending order check and some other validations. This could lead to unexpected behavior or silent failures when non-ascending broadcast dimensions are provided.

    template <class ShapeType>
    TensorView* broadcast_in_dim_fn(
        TensorView* input,
        ShapeType generic_output_shape,
        const std::vector<int64_t>& nonbroadcast_dims) {
      std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
      NVF_CHECK_GE(output_shape.size(), nonbroadcast_dims.size());
    
      const auto input_ndim = std::ranges::distance(
          input->getLogicalDomain() | TensorDomain::kNoReductions);
      NVF_CHECK_GE(std::ssize(output_shape), input_ndim);
      NVF_CHECK_EQ(input_ndim, std::ssize(nonbroadcast_dims));
    
      std::vector<bool> is_broadcast_dim(output_shape.size(), true);
      for (int64_t nonbroadcast_dim : nonbroadcast_dims) {
        nonbroadcast_dim = wrapDim(nonbroadcast_dim, std::ssize(output_shape));
        NVF_CHECK(
            is_broadcast_dim.at(nonbroadcast_dim),
            "nonbroadcast_dim (",
            nonbroadcast_dim,
            ") is specified more than once.");
        is_broadcast_dim.at(nonbroadcast_dim) = false;
      }
    
      TensorView* output = broadcast(input, is_broadcast_dim);
      output = expand(output, output_shape);
      return output;
    }
    Removed error test case

    The descending_broadcast_dimensions test case was removed from the error test suite. This test case was validating that broadcast dimensions must be in ascending order. Since the main implementation still uses similar logic in some places, removing this test case could leave a gap in error detection coverage.

    # 3. Each broadcast dimension is within the new shape.
    out_of_bounds_broadcast_dimensions = (
        ([2, 2], [2, 2], [0, 2]),
        RuntimeError,
        "Invalid broadcast_dims value.",
    )

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 21, 2025

    Greptile Summary

    This PR implements a reference implementation for AlphaFold3 triangle updates and enhances the existing triangle attention with layer normalization.

    Key Changes:

    • Added TensorView.dtype() method in the Python direct bindings for convenient data type access
    • Refactored broadcast_in_dim_fn to support negative dimension indices via wrapDim and removed the requirement for ascending broadcast dimensions
    • Implemented full test_triangle_updates test with layer_norm and gating helper functions following the AlphaFold paper (Supplementary Methods 1.6.5)
    • Enhanced test_triangle_attention to include layer normalization preprocessing
    • Updated error test generator to remove the now-obsolete ascending broadcast dimensions check

    Confidence Score: 4/5

    • This PR is safe to merge - it adds new test functionality and a minor API addition with proper validation.
    • The changes are well-structured: the new dtype() method has proper type validation, the broadcast_in_dim refactoring maintains existing behavior while adding flexibility, and the new test implementations follow the AlphaFold paper references. No critical issues found.
    • No files require special attention.

    Important Files Changed

    Filename Overview
    python/python_direct/ir.cpp Added dtype() method to TensorView that returns PrimDataType, properly validates the type before returning.
    python/python_direct/ops.cpp Refactored broadcast_in_dim_fn to support negative indices via wrapDim, removed ascending order requirement, improved variable naming.
    tests/python/direct/test_alphafold3.py Implemented triangle updates test with layer_norm and gating functions; added layer normalization to triangle attention test.
    tests/python/opinfo/opinfo_input_generators.py Removed ascending order requirement error case for broadcast_dims, renumbered remaining test cases.

    Sequence Diagram

    sequenceDiagram
        participant Test as test_triangle_updates
        participant LN as layer_norm
        participant G as gating
        participant Ops as fd.ops
    
        Test->>Ops: define_tensor (z_in, weights, mask)
        Test->>LN: layer_norm(z_in, w_norm_in, b_norm_in)
        LN->>Ops: cast, var_mean, sub, rsqrt, mul, add
        LN-->>Test: normalized z_in
        Test->>G: gating(z_in, w_p_in, z_in, w_g_in)
        G->>Ops: linear(z, w_p) → p
        G->>Ops: linear(z_in, w_g) → g
        G->>Ops: sigmoid(g), mul(p, g)
        G-->>Test: z [b, i, j, c_z*2]
        Test->>Ops: broadcast_in_dim(mask)
        Test->>Ops: where(mask, z, 0.0)
        Test->>Ops: slice(z) → a, b
        Test->>Ops: permute (OUTGOING/INCOMING)
        Test->>Ops: matmul(a, b)
        Test->>LN: layer_norm(z, w_norm_out, b_norm_out)
        Test->>G: gating(z, w_p_out, z_in, w_g_out)
        Test->>Ops: add_output(z)
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I am assuming the layernorm is skipped for simplicity.

    Apart from that, the other differences are:

    1. No mask application -- looks like you are going to add it based on above discussion.
    2. No output gating -- is this skipped intentionally?

    @wujingyue wujingyue requested a review from Priya2698 December 24, 2025 05:19
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +249 to +251
    -------
    DataType
    The data type of this tensor.
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.

    Suggested change
    -------
    DataType
    The data type of this tensor.
    -------
    PrimDataType
    The data type of this tensor.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is fine given

    py::enum_<PrimDataType>(nvfuser, "DataType", py::module_local())
    . The Python user expects to use DataType instead of PrimDataType.

    Base automatically changed from wjy/end to main January 5, 2026 20:11
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue merged commit a88bfb1 into main Jan 6, 2026
    61 checks passed
    @wujingyue wujingyue deleted the wjy/update branch January 6, 2026 06:36
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants