Skip to content

fix: wait for async P2P send before deallocating output tensor#4047

Draft
ZhiyuLi-Nvidia wants to merge 1 commit intoNVIDIA:mainfrom
ZhiyuLi-Nvidia:zhiyul/fix/p2p-dealloc-race
Draft

fix: wait for async P2P send before deallocating output tensor#4047
ZhiyuLi-Nvidia wants to merge 1 commit intoNVIDIA:mainfrom
ZhiyuLi-Nvidia:zhiyul/fix/p2p-dealloc-race

Conversation

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia commented Mar 28, 2026

Description

isend() copies asynchronously from the source buffer. Calling deallocate_output_tensor() before the copy completes lets the allocator reuse the buffer while NCCL is still reading it, corrupting activations on the next PP stage.

Add an explicit send_next_wait_handle.wait() before deallocate_output_tensor at both the warmup overlap path and pp_post_forward. Guard on deallocate_pipeline_outputs so the wait is only inserted when the buffer is actually freed.

Root cause of NaN grad norm in deterministic mode on GB200 for LLaMA 3.1 405B (PP=16, VPP=8, overlap_p2p_comm=True).

Made-with: Cursor

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my

Not quite possible to create a good unit test as I found it only in large scale run i.e. deterministic mode on GB200 for LLaMA 3.1 405B (PP=16, VPP=8, overlap_p2p_comm=True).

I created a standalone script to mimic the racing issue and clearly show the fix but seems not a good for a unit test.

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Regression test for the async P2P send + deallocate race fixed in schedules.py.

Root cause: under TORCH_NCCL_AVOID_RECORD_STREAMS=1 (now the PyTorch default),
NCCL does not record the CUDA stream on a tensor involved in an isend(), so the
caching allocator can immediately reuse the buffer once it is freed.  The 1F1B
overlap path in _forward_backward_pipelining_with_interleaving() called
deallocate_output_tensor() without first waiting for the send handle, creating a
window where the next microbatch's GPU kernel could overwrite the buffer while
NCCL's DMA was still reading it.

The test sets NCCL_P2P_DISABLE=1 to force the host-memory fallback path
(~20 GB/s) rather than NVLink5 direct DMA (~900 GB/s).  On GB200 NVLink5,
NCCL can read the entire 64 MB buffer in <0.1 ms — faster than Python can
free and reallocate — so the race window closes before the overwrite lands.
The host-memory path takes ~6 ms, which is long enough for the receiver-delay
trick to expose the race reliably on any hardware.

test_buggy_dealloc_before_wait_corrupts_data: strict test — expects corruption
  when the buggy pattern is used (dealloc before wait).
test_fixed_wait_before_dealloc_preserves_data: strict correctness guarantee —
  with the fix the receiver must always see the original values.

Run: pytest tests/unit_tests/pipeline_parallel/test_p2p_dealloc_race.py -v
"""

import os, time, types

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator
from megatron.core.pipeline_parallel.schedules import deallocate_output_tensor

pytestmark = pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="requires at least 2 GPUs"
)

SHAPE = (2048, 16384)  # ~64 MB BF16


def _make_config(deallocate_pipeline_outputs: bool = True):
    return types.SimpleNamespace(
        variable_seq_lengths=False,
        mtp_standalone=False,
        pipeline_dtype=torch.bfloat16,
        use_ring_exchange_p2p=False,
        batch_p2p_comm=False,
        batch_p2p_sync=False,
        timers=None,
        virtual_pipeline_model_parallel_size=None,
        # Mirror the VPP production path: activations are pseudo-deallocated
        # after each send to reclaim GPU memory across pipeline stages.
        deallocate_pipeline_outputs=deallocate_pipeline_outputs,
    )


def _worker(rank: int, result_queue, port: int, deallocate_pipeline_outputs: bool = True) -> None:
    os.environ['TORCH_NCCL_AVOID_RECORD_STREAMS'] = '1'
    # Force host-memory fallback so the ~6 ms transfer gives the race
    # enough time to fire.  NVLink5 at ~900 GB/s finishes in <0.1 ms and
    # closes the race window before Python can overwrite the buffer.
    os.environ['NCCL_P2P_DISABLE'] = '1'
    dist.init_process_group('nccl', init_method=f'tcp://127.0.0.1:{port}',
                            rank=rank, world_size=2)
    torch.cuda.set_device(rank)

    config = _make_config(deallocate_pipeline_outputs=deallocate_pipeline_outputs)
    comm = P2PCommunicator(dist.group.WORLD, config)

    # RECV_DELAY mirrors production pipeline scheduling: in 1F1B with PP=16 / VPP=8
    # the receiver does not post its irecv immediately after the sender issues isend.
    # There is a scheduling gap during which the sender can (incorrectly) free and
    # reuse the buffer before NCCL has started the actual DMA.  Delaying the
    # receiver here creates that same window deterministically.
    RECV_DELAY = 0.3  # seconds

    if rank == 0:  # sender — mirrors pp_post_forward
        output_tensor = torch.full(SHAPE, 1.0, dtype=torch.bfloat16, device='cuda')

        _, wait_handles = comm.send_forward_recv_forward(
            output_tensor, recv_prev=False, tensor_shape=SHAPE, overlap_p2p_comm=True
        )
        send_handle = wait_handles['send_next']

        # The fix in schedules.py should ensure the handle is waited on before deallocation.
        deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)

        # Reuse the freed buffer — mirrors the next microbatch's matmul in production.
        overwrite = torch.full(SHAPE, 2.0, dtype=torch.bfloat16, device='cuda')
        torch.cuda.synchronize()
        # Keep the overwrite alive long enough for NCCL to read it after the receiver
        # wakes up and posts its irecv.
        time.sleep(RECV_DELAY + 0.1)

        if send_handle is not None:
            send_handle.wait()
        del overwrite

    else:  # receiver — mirrors pp_pre_forward with pipeline scheduling delay
        # Delay before posting irecv so NCCL cannot start the DMA until after the
        # sender has freed and overwritten the buffer.
        time.sleep(RECV_DELAY)
        recv_tensor, wait_handles = comm.send_forward_recv_forward(
            None, recv_prev=True, tensor_shape=SHAPE, overlap_p2p_comm=True
        )
        wait_handles['recv_prev'].wait()
        result_queue.put((recv_tensor == 1.0).all().item())

    dist.destroy_process_group()


def _find_free_port() -> int:
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]


def _run(deallocate_pipeline_outputs: bool = True) -> bool:
    port = _find_free_port()
    ctx = mp.get_context('spawn')
    q = ctx.Queue()
    procs = [ctx.Process(target=_worker, args=(r, q, port, deallocate_pipeline_outputs)) for r in range(2)]
    for p in procs:
        p.start()
    for p in procs:
        p.join()
        assert p.exitcode == 0
    return q.get_nowait()


def test_p2p_dealloc_race_correctness():
    """Verify that P2P communication is safe from deallocation races.
    
    With deallocate_pipeline_outputs=True, the sender must wait for the NCCL
    transfer to complete before freeing the buffer. If it doesn't, the
    receiver will see corrupted data (2.0 instead of 1.0) because the
    caching allocator reuses the buffer for the next microbatch.
    """
    all_correct = _run(deallocate_pipeline_outputs=True)
    assert all_correct, "Data corrupted! The sender likely freed and reused the buffer before NCCL finished reading it."


def test_no_deallocate_flag_skips_wait_safely():
    """When deallocate_pipeline_outputs=False, there is no buffer reuse, so no race."""
    all_correct = _run(deallocate_pipeline_outputs=False)
    assert all_correct, "Data corrupted even without deallocation — unexpected"

    """When deallocate_pipeline_outputs=False the guard in schedules.py skips
    the wait — no buffer is freed so there is no race regardless."""
    all_correct = _run(apply_fix=False, deallocate_pipeline_outputs=False)
    assert all_correct, "Data corrupted even without deallocation — unexpected"

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

isend() copies asynchronously from the source buffer. Calling
deallocate_output_tensor() before the copy completes lets the allocator
reuse the buffer while NCCL is still reading it, corrupting activations
on the next PP stage.

Add an explicit send_next_wait_handle.wait() before deallocate_output_tensor
at both the warmup overlap path and pp_post_forward. Guard on
deallocate_pipeline_outputs so the wait is only inserted when the buffer
is actually freed.

Root cause of NaN grad norm in deterministic mode on GB200 for
LLaMA 3.1 405B (PP=16, VPP=8, overlap_p2p_comm=True).

Made-with: Cursor
@ZhiyuLi-Nvidia ZhiyuLi-Nvidia requested review from a team as code owners March 28, 2026 10:38
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 28, 2026 10:39
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant