Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
boxed_nop_preserve_node_meta,
)
from torch._inductor.compile_fx import compile_fx_inner
from torch._logging import trace_structured
Expand Down Expand Up @@ -62,6 +61,29 @@
logger = logging.getLogger(__name__)


def _boxed_nop_preserve_node_meta(fx_g, example_inputs):
if torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling:
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing_from_inductor_configs,
)

# disable flags which are inductor-specific
with torch._inductor.config.patch(
{
"aten_distributed_optimizations.insert_overlap_deps": False,
"aten_distributed_optimizations.enable_fusion_regions": False,
}
):
schedule_overlap_bucketing_from_inductor_configs(fx_g)

def run(args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(fx_g).boxed_run(args)

run._boxed_call = True
return run


@contextmanager
def _suppress_wait_tensor_side_effect():
"""Temporarily remove wait_tensor from the side-effectful set.
Expand Down Expand Up @@ -200,7 +222,7 @@ def __init__(
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
)
else:
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
self.compiler_fn = _boxed_nop_preserve_node_meta # type: ignore[assignment]
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
self.reshard_after_forward = reshard_after_forward
Expand Down
126 changes: 126 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
import torch._inductor.config
import torch.fx.traceback as fx_traceback
from torch import nn
from torch.distributed.tensor import DTensor
Expand Down Expand Up @@ -757,3 +758,128 @@ def forward(self, x):
assert isinstance(parallel_mod, Model)
assert isinstance(parallel_mod, BaseModel)
assert parallel_mod.get_num_params() > 0


# Tests for overlap scheduling in compile=False path


def test_overlap_scheduling_called_when_enabled(device_mesh_1d):
"""Test that schedule_overlap_bucketing_from_inductor_configs is called
when compile=False and enable_overlap_scheduling=True.
"""
from unittest.mock import patch

dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)

def forward(self, x):
return self.linear(x)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

overlap_bucketing_called = []

with patch(
"torch._inductor.fx_passes.overlap_scheduling.schedule_overlap_bucketing_from_inductor_configs",
side_effect=lambda fx_g: overlap_bucketing_called.append(fx_g) or fx_g,
), torch._inductor.config.patch(
{"aten_distributed_optimizations.enable_overlap_scheduling": True}
):
with AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=False,
) as autop:
autop.add_input_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()
_ = autop.apply_placement(sharding_placement)

assert (
len(overlap_bucketing_called) > 0
), "schedule_overlap_bucketing_from_inductor_configs should be called"

for fx_g in overlap_bucketing_called:
assert isinstance(fx_g, torch.fx.GraphModule)


def test_overlap_scheduling_not_called_when_disabled(device_mesh_1d):
"""Test that overlap scheduling is skipped when enable_overlap_scheduling=False (default)."""
from unittest.mock import patch

dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)

def forward(self, x):
return self.linear(x)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

overlap_bucketing_called = []

with patch(
"torch._inductor.fx_passes.overlap_scheduling.schedule_overlap_bucketing_from_inductor_configs",
side_effect=lambda fx_g: overlap_bucketing_called.append(fx_g) or fx_g,
):
with AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=False,
) as autop:
autop.add_input_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()
_ = autop.apply_placement(sharding_placement)

assert (
len(overlap_bucketing_called) == 0
), "schedule_overlap_bucketing_from_inductor_configs should NOT be called by default"


def test_compile_true_uses_compile_fx_inner(device_mesh_1d):
"""When compile=True, the compiler_fn should be compile_fx_inner."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)

def forward(self, x):
return self.linear(x)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

from torch._inductor.compile_fx import compile_fx_inner

autop = AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=True,
)

assert autop.compiler_fn is compile_fx_inner
Loading