-
Notifications
You must be signed in to change notification settings - Fork 74
Reference implementation for Linear+Allreduce with overlapping #5742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 776a102 Description
|
| Relevant files | |||
|---|---|---|---|
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Stream Synchronization
wait_stream function uses torch.cuda.current_stream().wait_event(event) which synchronizes the current stream with the worker stream. However, this approach might not properly handle the case where multiple worker streams are involved, as each worker stream should be synchronized individually with the main stream before proceeding. |
Greptile SummaryAdds a PyTorch reference implementation for row-parallel linear with overlapped matmul and allreduce operations. The implementation chunks input tensors and processes them concurrently across multiple CUDA streams to overlap computation and communication.
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Main as Main Stream
participant W1 as Worker Stream 1
participant W2 as Worker Stream 2
participant WN as Worker Stream N
participant GPU as GPU Memory
Note over Main,GPU: Input and weight shards loaded
Main->>GPU: Allocate output tensor
Main->>GPU: Chunk input and output
par Chunk 1
Main->>W1: Launch worker stream
W1->>W1: Wait for main stream (sync)
W1->>W1: matmul(inp_chunk[0], weight_shard.T)
W1->>W1: allreduce(out_chunk[0], async_op=True)
W1->>W1: work.wait()
and Chunk 2
Main->>W2: Launch worker stream
W2->>W2: Wait for main stream (sync)
W2->>W2: matmul(inp_chunk[1], weight_shard.T)
W2->>W2: allreduce(out_chunk[1], async_op=True)
W2->>W2: work.wait()
and Chunk N
Main->>WN: Launch worker stream
WN->>WN: Wait for main stream (sync)
WN->>WN: matmul(inp_chunk[N-1], weight_shard.T)
WN->>WN: allreduce(out_chunk[N-1], async_op=True)
WN->>WN: work.wait()
end
W1->>Main: Record event
W2->>Main: Record event
WN->>Main: Record event
Main->>Main: Wait for all worker streams
Main->>Main: Return output tensor
|
|
!test |
| @pytest.mark.mpi | ||
| def test_row_parallel_linear_forward_reference(setup_default_process_group): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of separate functions, how about using a parameter to indicate if it is a validation/benchmarking run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that but ended up preferring DAMP to DRY for this particular case. I found fewer things to share than I expected between the test and the benchmark -- the input tensors have to be created differently, execution needs to be measured for the benchmark, and validation is skipped for the benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's right, I overlooked the initialization difference.
I am not sure we can call this overlap. It might just be a kernel being launched when the previous is being turned down. Could you produce a larger nsys profile that clearly demonstrate overlap? Also, could you add to the screenshot the different streams to check where the different compute and comms are posted?
might be ! |
|
Looks like streams didn't get reused, leading to cudaMalloc. Let me take a look at that: #5767 |

For #5308
Stream assignments and overlapping are confirmed by nsys.
The overlapping is quite limited likely because gemm and ncclAllReduce compete for SMs.