Skip to content

Fused all-gather matmul on multi GPUs with PyTorch Symmetric Memory#62

Closed
kwen2501 wants to merge 1 commit intoNVIDIA:mainfrom
kwen2501:ag_matmul
Closed

Fused all-gather matmul on multi GPUs with PyTorch Symmetric Memory#62
kwen2501 wants to merge 1 commit intoNVIDIA:mainfrom
kwen2501:ag_matmul

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Jan 20, 2026

Description

Each rank gathers inputs from (all) peer GPUs, and perform a matrix multiplication with its local weight.

Peer inputs are made visible via PyTorch Symmetric Memory, i.e.

import torch.distributed._symmetric_memor as symm_mem
symm_mem.empty(...)

The fused kernel is equivalent to:

dist.all_gather_into_tensor(ag_out, inp, group)
out = ag_out @ w

The fusion overlaps communication and computation in fine grain.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Argparser

Add test

Signed-off-by: Ke Wen <kwen@nvidia.com>
@haijieg
Copy link
Collaborator

haijieg commented Feb 13, 2026

integrated.

@haijieg haijieg closed this Feb 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants