Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 20, 2025

This PR also fixes the initial index assignment, where we always need to allocate circular-buffer specific indices when circular buffering is used. This usually doesn't matter, but some of the tests have some trivial patterns where the loop extent is just 1 but still uses circular buffering, in that case the current code simply assigns zero. https://github.com/NVIDIA/Fuser/blob/main/csrc/id_model/id_model.cpp#L1354

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 20, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Dec 20, 2025

Review updated until commit 28ce887

Description

  • Fix circular buffer index allocation order in IdModel::allocateLoopIndexVariables()

  • Enable TensorIndexer (IdModel) for matmul scheduler test classes

  • Add error checking for missing circular buffer index variables

  • Ensure circular buffer indexing takes precedence over other index allocation

Changes walkthrough

Relevant files
Bug fix
id_model.cpp
Fix circular buffer index allocation order                             

csrc/id_model/id_model.cpp

  • Move circular buffer index allocation before other index allocation
    logic
  • Add error checking in getLoopIndexVariable() for missing circular
    buffer indices
  • Ensure circular buffer indexing takes precedence as required
  • +24/-16 
    Enhancement
    test_matmul_scheduler.cpp
    Enable TensorIndexer for matmul test classes                         

    tests/cpp/test_matmul_scheduler.cpp

  • Enable IdModel option in MatmulSchedulerTest::SetUp()
  • Enable IdModel option in MatmulSchedulerPluginTest::SetUp()
  • Enable IdModel option in AllocationDomainTest::SetUp()
  • Enable IdModel option in HopperPlusMatmulSchedulerTest::SetUp()
  • +19/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic Reordering

    The circular buffer check has been moved from after the loop_index != nullptr check to before it. This is a significant logic change that could affect behavior in edge cases where loop_index was previously set but circular buffering is also present. The PR description mentions this fixes initial index assignment, but the implications of this reordering should be carefully validated.

    // This needs to be done before assigning zero or parallel indices
    // as circular buffer indexing takes precedence.
    if (GpuLower::current()->circularBufferInfo().isCircularBufferedIterDomain(
            loop_group->front()->as<IterDomain>())) {
      // Allocate index variable for each stage of the circular
      // buffered loop.
      auto indices = std::make_unique<CircularBufferIndices>();
      for (auto i :
           arange(static_cast<int>(CircularBufferLoopStage::EndOfStages))) {
        indices->emplace(
            static_cast<CircularBufferLoopStage>(i),
            IrBuilder::create<Val>(DataType::Index));
      }
      circular_buffered_loop_index_variable_map_[loop_group] =
          std::move(indices);
      continue;
    }
    Error Handling

    New error checking has been added (lines 1435-1440) that will throw an NVF_ERROR if circular buffer index variables are not found. This is good defensive programming, but it means that any code paths that previously worked without circular buffer indices may now fail. The robustness of this change should be verified across all test scenarios.

    NVF_ERROR(
        circular_buffered_loop_index_variable_map_.contains(loop_group),
        "Failed to find circular buffer index var for: ",
        nvfuser::toString(loop_group),
        ", ",
        loop_group->front()->toString());

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 20, 2025

    Greptile Summary

    Fixed a critical ordering issue in circular buffer index allocation and enabled TensorIndexer for matmul scheduler tests.

    Key Changes:

    • Moved circular buffer index allocation to occur before zero/parallel index assignment in allocateLoopIndexVariables(), ensuring circular buffer indices are always allocated when needed
    • Added NVF_ERROR check in getLoopIndexVariable() to catch missing circular buffer indices early
    • Enabled EnableOption::IdModel for all matmul scheduler test classes (MatmulSchedulerTest, MatmulSchedulerPluginTest, AllocationDomainTest, HopperPlusMatmulSchedulerTest)

    Why This Matters:
    The previous code had a bug where loops with extent=1 (trivial loops) that also used circular buffering would be assigned a zero index and skip circular buffer index allocation. Later calls to getLoopIndexVariable() would fail because no circular buffer indices existed. This PR fixes the ordering so circular buffer checks happen first, ensuring indices are properly allocated even for trivial circular-buffered loops.

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The changes are well-contained and logically sound. The circular buffer index allocation fix addresses a clear bug where the order of operations was incorrect. The additional NVF_ERROR check provides better error detection. The test changes simply enable TensorIndexer for matmul tests, which is the stated goal of the PR. All changes follow existing code patterns and don't introduce new logic vulnerabilities.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/id_model/id_model.cpp Fixed circular buffer index allocation order to ensure indices are allocated before zero/parallel index assignment, preventing missing index errors for trivial loops with circular buffering
    tests/cpp/test_matmul_scheduler.cpp Enabled TensorIndexer (IdModel) for all matmul scheduler tests by setting EnableOption::IdModel in SetUp methods

    Sequence Diagram

    sequenceDiagram
        participant Test as Matmul Test
        participant IdModel as IdModel
        participant CircularBuffer as CircularBufferInfo
        participant IndexMap as Index Variable Maps
    
        Test->>IdModel: Enable TensorIndexer (IdModel "all")
        Test->>IdModel: allocateLoopIndexVariables()
        
        loop For each loop_group
            IdModel->>CircularBuffer: isCircularBufferedIterDomain()?
            
            alt Is Circular Buffered (NEW: Check First)
                CircularBuffer-->>IdModel: true
                IdModel->>IndexMap: Allocate CircularBufferIndices
                Note over IdModel,IndexMap: Create index for each stage<br/>(Main, Prolog, Epilog)
                IdModel->>IdModel: continue (skip other checks)
            else Not Circular Buffered
                CircularBuffer-->>IdModel: false
                IdModel->>IdModel: Check shouldUseZeroIndex()
                
                alt Should use zero (extent=1 or broadcast)
                    IdModel->>IndexMap: Assign fusion_->zeroVal()
                    IdModel->>IdModel: continue
                else Should use parallel index
                    IdModel->>IndexMap: Assign NamedScalar::getParallelIndex()
                    IdModel->>IdModel: continue
                else Regular loop index needed
                    IdModel->>IndexMap: Allocate new Val(DataType::Index)
                end
            end
        end
        
        Note over Test,IndexMap: Later during execution...
        
        Test->>IdModel: getLoopIndexVariable(loop_group, stage)
        IdModel->>CircularBuffer: isCircularBufferedIterDomain()?
        CircularBuffer-->>IdModel: true
        IdModel->>IndexMap: Get circular buffer index for stage
        alt Index exists (FIXED)
            IndexMap-->>IdModel: Return index variable
            IdModel-->>Test: Success
        else Index missing (OLD BUG)
            IndexMap-->>IdModel: Not found
            IdModel->>IdModel: NVF_ERROR (catch bug early)
        end
    
    Loading

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 20, 2025

    Greptile's behavior is changing!

    From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

    This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 6, 2026

    !test --diff

    @naoyam naoyam requested a review from rdspring1 January 6, 2026 21:26
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 6, 2026

    !test

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 6, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (1)

    1. tests/cpp/test_matmul_scheduler.cpp, line 3333 (link)

      logic: missing parent class SetUp call - other test classes call NVFuserFixtureParamTest::SetUp() first

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 6, 2026

    !test

    @naoyam naoyam merged commit e583b8a into main Jan 7, 2026
    59 of 60 checks passed
    @naoyam naoyam deleted the tensorindexer_matmul_scheduler branch January 7, 2026 02:12
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants