Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import torch
import os

import torch
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor, Shard

import nvfuser_direct as nvfuser
from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView
from benchmark_utils import get_benchmark_fns


@pytest.mark.mpi
Expand Down Expand Up @@ -101,6 +105,99 @@ def test_row_parallel_linear_forward(multidevice_test):
assert event.input_shapes == [[m, k], [k, n], [m, n]]


def row_parallel_linear_forward_reference(
inp_shard: torch.Tensor, weight_shard: torch.Tensor, num_chunks: int
) -> torch.Tensor:
out = torch.empty(
inp_shard.size(0),
weight_shard.size(0),
device="cuda",
dtype=inp_shard.dtype,
)
inp_chunks = inp_shard.chunk(num_chunks)
out_chunks = out.chunk(num_chunks)

def wait_stream(stream: torch.cuda.Stream) -> None:
event = torch.cuda.Event()
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)

main_stream = torch.cuda.current_stream()
worker_streams = [torch.cuda.Stream() for _ in range(num_chunks)]
for inp_chunk, out_chunk, worker_stream in zip(
inp_chunks, out_chunks, worker_streams
):
with torch.cuda.stream(worker_stream):
wait_stream(main_stream)
torch.matmul(inp_chunk, weight_shard.T, out=out_chunk)
work = dist.all_reduce(out_chunk, op=dist.ReduceOp.SUM, async_op=True)
work.wait()

for worker_stream in worker_streams:
wait_stream(worker_stream)

return out


@pytest.mark.mpi
def test_row_parallel_linear_forward_reference(setup_default_process_group):
Comment on lines +142 to +143
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of separate functions, how about using a parameter to indicate if it is a validation/benchmarking run?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that but ended up preferring DAMP to DRY for this particular case. I found fewer things to share than I expected between the test and the benchmark -- the input tensors have to be created differently, execution needs to be measured for the benchmark, and validation is skipped for the benchmark.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's right, I overlooked the initialization difference.

h, s, t = 2, 3, 6
d = dist.get_world_size()
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

torch.manual_seed(0)
inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to(
torch.bfloat16
)
weight_ref = torch.testing.make_tensor(
h, h * 4, dtype=torch.int32, device="cpu"
).to(torch.bfloat16)
out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu()

mesh = dist.device_mesh.init_device_mesh("cuda", [d])
inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local()
weight_shard = distribute_tensor(
weight_ref, mesh, placements=[Shard(-1)]
).to_local()
out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s)

torch.testing.assert_close(out.cpu(), out_ref)


@pytest.mark.mpi
@pytest.mark.benchmark
def test_row_parallel_linear_forward_reference_benchmark(
setup_default_process_group, benchmark
):
h, s, t = 8192, 2, 8192
d = dist.get_world_size()
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

torch.manual_seed(0)
inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16)
weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16)

mesh = dist.device_mesh.init_device_mesh("cuda", [d])
inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local()
weight_shard = distribute_tensor(
weight_ref, mesh, placements=[Shard(-1)]
).to_local()

warmup_fn, benchmark_fn = get_benchmark_fns(
lambda: row_parallel_linear_forward_reference(inp_shard, weight_shard, s)
)
warmup_fn()
benchmark.pedantic(benchmark_fn, rounds=5)


@pytest.mark.mpi
@pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl])
@pytest.mark.parametrize("s", [1, 8])
Expand Down