Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 1, 2026

For #5308

Stream assignments and overlapping are confirmed by nsys.

$ nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 mpirun -np 2 -x NCCL_NVLS_ENABLE=0 pytest tests/python/multidevice/test_overlap.py -k 'forward_reference_benchmark' --only-mpi -vs

$ nsys stats report3.nsys-rep --report cuda_gpu_trace | grep '(0)'
   35515238           1088     249                                                                                0.000              3.676  Device              NVIDIA H100 80GB HBM3 (0)    1              35  [CUDA memset]
   35564678        1350935     250     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              35  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
   36408640           1760     345                                                                                0.000              2.273  Device              NVIDIA H100 80GB HBM3 (0)    1              39  [CUDA memset]
   36747070        1556022     346     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              39  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
   38141845         299614     278    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              23  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   38443891         247166     374    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              23  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)

The overlapping is quite limited likely because gemm and ncclAllReduce compete for SMs.

image

@github-actions
Copy link

github-actions bot commented Jan 1, 2026

Review updated until commit 776a102

Description

  • Add reference implementation for row-parallel linear forward with all-reduce

  • Implement multi-stream overlapping computation and communication

  • Add test function comparing reference vs torch.nn.functional.linear

  • Add benchmark test for performance measurement

Changes walkthrough

Relevant files
Tests
test_overlap.py
Add reference implementation and tests for row-parallel linear

tests/python/multidevice/test_overlap.py

  • Add torch.distributed imports and benchmark utilities
  • Implement row_parallel_linear_forward_reference function with
    multi-stream execution
  • Add test_row_parallel_linear_forward_reference for correctness
    validation
  • Add test_row_parallel_linear_forward_reference_benchmark for
    performance testing
  • +98/-1   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Stream Synchronization

    The stream synchronization logic in the reference implementation may have potential issues. The wait_stream function uses torch.cuda.current_stream().wait_event(event) which synchronizes the current stream with the worker stream. However, this approach might not properly handle the case where multiple worker streams are involved, as each worker stream should be synchronized individually with the main stream before proceeding.

    def wait_stream(stream: torch.cuda.Stream) -> None:
        event = torch.cuda.Event()
        stream.record_event(event)
        torch.cuda.current_stream().wait_event(event)
    All-reduce Operation

    The all-reduce operation is performed with work.wait() immediately after the async call, which effectively makes it a synchronous operation. This might not be the intended behavior for a reference implementation that should demonstrate overlapping computation and communication. Consider if the wait should be deferred or if this is the intended behavior.

    work = dist.all_reduce(out_chunk, op=dist.ReduceOp.SUM, async_op=True)
    work.wait()

    @wujingyue wujingyue requested a review from Priya2698 January 3, 2026 07:36
    @wujingyue wujingyue marked this pull request as ready for review January 3, 2026 07:36
    @wujingyue wujingyue changed the title Reference implementation for Linear+Allreduce Reference implementation for Linear+Allreduce with overlapping Jan 3, 2026
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 3, 2026

    Greptile Summary

    Adds a PyTorch reference implementation for row-parallel linear with overlapped matmul and allreduce operations. The implementation chunks input tensors and processes them concurrently across multiple CUDA streams to overlap computation and communication.

    • Implements row_parallel_linear_forward_reference() that splits input into chunks, assigns each to a worker stream, and performs matmul followed by allreduce
    • Adds correctness test test_row_parallel_linear_forward_reference() that validates output against PyTorch's linear operation
    • Adds benchmark test test_row_parallel_linear_forward_reference_benchmark() for performance measurement
    • Imports necessary torch.distributed modules and benchmark utilities

    Confidence Score: 4/5

    • This PR is safe to merge with low risk
    • The implementation adds test infrastructure with proper synchronization patterns. The code follows existing patterns in the file, includes both correctness and benchmark tests, and uses appropriate PyTorch distributed primitives. Minor deduction because synchronous work.wait() inside the stream loop may not achieve optimal overlap, but this is acceptable for a reference implementation
    • No files require special attention

    Important Files Changed

    Filename Overview
    tests/python/multidevice/test_overlap.py Adds PyTorch reference implementation for row-parallel linear with matmul+allreduce overlap, plus correctness and benchmark tests

    Sequence Diagram

    sequenceDiagram
        participant Main as Main Stream
        participant W1 as Worker Stream 1
        participant W2 as Worker Stream 2
        participant WN as Worker Stream N
        participant GPU as GPU Memory
        
        Note over Main,GPU: Input and weight shards loaded
        Main->>GPU: Allocate output tensor
        Main->>GPU: Chunk input and output
        
        par Chunk 1
            Main->>W1: Launch worker stream
            W1->>W1: Wait for main stream (sync)
            W1->>W1: matmul(inp_chunk[0], weight_shard.T)
            W1->>W1: allreduce(out_chunk[0], async_op=True)
            W1->>W1: work.wait()
        and Chunk 2
            Main->>W2: Launch worker stream
            W2->>W2: Wait for main stream (sync)
            W2->>W2: matmul(inp_chunk[1], weight_shard.T)
            W2->>W2: allreduce(out_chunk[1], async_op=True)
            W2->>W2: work.wait()
        and Chunk N
            Main->>WN: Launch worker stream
            WN->>WN: Wait for main stream (sync)
            WN->>WN: matmul(inp_chunk[N-1], weight_shard.T)
            WN->>WN: allreduce(out_chunk[N-1], async_op=True)
            WN->>WN: work.wait()
        end
        
        W1->>Main: Record event
        W2->>Main: Record event
        WN->>Main: Record event
        Main->>Main: Wait for all worker streams
        Main->>Main: Return output tensor
    
    Loading

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Comment on lines +142 to +143
    @pytest.mark.mpi
    def test_row_parallel_linear_forward_reference(setup_default_process_group):
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Instead of separate functions, how about using a parameter to indicate if it is a validation/benchmarking run?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I considered that but ended up preferring DAMP to DRY for this particular case. I found fewer things to share than I expected between the test and the benchmark -- the input tensors have to be created differently, execution needs to be measured for the benchmark, and validation is skipped for the benchmark.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Oh that's right, I overlooked the initialization difference.

    @wujingyue wujingyue requested a review from Priya2698 January 5, 2026 23:34
    @wujingyue wujingyue merged commit 8a9ea0b into main Jan 6, 2026
    64 of 65 checks passed
    @wujingyue wujingyue deleted the wjy/ref branch January 6, 2026 07:10
    @samnordmann
    Copy link
    Collaborator

    Stream assignments and overlapping are confirmed by nsys.

    I am not sure we can call this overlap. It might just be a kernel being launched when the previous is being turned down. Could you produce a larger nsys profile that clearly demonstrate overlap?

    Also, could you add to the screenshot the different streams to check where the different compute and comms are posted?

    The overlapping is quite limited likely because gemm and ncclAllReduce compete for SMs.

    might be !

    @wujingyue
    Copy link
    Collaborator Author

    Good questions. s=4 seems to show a slightly better overlapping pattern:

    image

    I vaguely remember that you ran into a similar issue when testing overlapping using nccl as is, and that you managed to see better results with UCC with NCCL TL. Is there a way to do that in Python without nvFuser?

    @wujingyue
    Copy link
    Collaborator Author

    wujingyue commented Jan 6, 2026

    Looks like streams didn't get reused, leading to cudaMalloc. Let me take a look at that: #5767

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants