Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion autoparallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@

from autoparallel.api import AutoParallel, auto_parallel
from autoparallel.api_pp import AutoParallelPP
from autoparallel.collectives import with_sharding_constraint

__all__ = ["auto_parallel", "AutoParallel", "AutoParallelPP"]
__all__ = [
"auto_parallel",
"AutoParallel",
"AutoParallelPP",
"with_sharding_constraint",
]
48 changes: 46 additions & 2 deletions autoparallel/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,57 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional
from typing import Any, Optional, Tuple

import torch
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor.experimental import local_map as _local_map
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.device_mesh import DeviceMesh, _mesh_resources
from torch.distributed.distributed_c10d import GroupName
from torch.distributed.tensor.placement_types import Placement


def with_sharding_constraint(
x: torch.Tensor,
shardings: Tuple[Placement, ...],
device_mesh: Optional[DeviceMesh] = None,
) -> torch.Tensor:
"""Constrain the sharding of an intermediate tensor.

Similar to JAX's with_sharding_constraint, this constrains the sharding
of a tensor to a specific placement. This is useful for controlling
intermediate tensor shardings within a computation.

Args:
x: The tensor to constrain.
shardings: Tuple of placements specifying how the tensor should be
sharded across each mesh dimension.
device_mesh: The device mesh to use. If None, uses the mesh from
the enclosing local_map region.

Returns:
The tensor with the specified sharding constraint applied.

Example:
>>> from torch.distributed.tensor.placement_types import Shard, Replicate
>>> # Inside a local_map region or with explicit mesh:
>>> x = with_sharding_constraint(x, (Shard(0), Replicate()))
"""
if device_mesh is None:
device_mesh = get_mesh_from_global()

@_local_map(
out_placements=(shardings,),
in_placements=(shardings,),
redistribute_inputs=True,
device_mesh=device_mesh,
)
def identity(t):
# clone() is required because local_map HOP doesn't support
# input-to-output aliasing during dynamo tracing
return t.clone()

return identity(x)


def local_map(*args, **kwargs):
Expand Down
257 changes: 257 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,267 @@

import pytest
import torch
from torch import nn
from torch.distributed.tensor.placement_types import Replicate, Shard

from autoparallel import AutoParallel, with_sharding_constraint
from autoparallel.collectives import local_map
from autoparallel.ops import permutation


def get_local_map_nodes(graph, is_backward=False):
nodes = []
for node in graph.nodes:
if "local_map_kwargs" in node.meta:
node_is_backward = node.meta.get("partitioner_tag", "") == "is_backward"
if node_is_backward == is_backward:
nodes.append(node)
return nodes


def verify_local_map_placements(sharding_placement, node, expected_placements):
spec = sharding_placement[node]
if isinstance(spec.output_specs, tuple):
output_spec = spec.output_specs[0]
else:
output_spec = spec.output_specs
assert (
output_spec.placements == expected_placements
), f"Expected placements {expected_placements}, got {output_spec.placements}"


class TestWithShardingConstraint:
"""Tests for the with_sharding_constraint operator."""

def test_with_sharding_constraint_explicit_mesh(self, device_mesh_1d):
"""Test with_sharding_constraint with an explicit device mesh."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear1 = nn.Linear(dim, dim, bias=False)
self.linear2 = nn.Linear(dim, dim, bias=False)

def forward(self, x):
x = self.linear1(x)
# Constrain intermediate result to be sharded
x = with_sharding_constraint(x, (Shard(0),), device_mesh_1d)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thought: it we used 'x.redistribute(Shard(0))` which is valid dtensor code as a way for autop to infer constraints, would that be a way to avoid having autop vs normal dtensor code diverge?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wouldn't work right now when tracing, we'd have to add a hack for autoparallel again, treating the x as a dtensor instead of plain tensor (like local_map)

x = self.linear2(x)
return x

def input_fn():
return torch.rand(512, dim, device="cuda")

with torch.device("meta"):
model = Model(dim)

with AutoParallel(model, input_fn, device_mesh_1d) as autop:
autop.add_input_constraints([(Shard(0),)])
autop.add_output_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()

# Verify the with_sharding_constraint node has correct placement
local_map_nodes = get_local_map_nodes(autop.gm.graph, is_backward=False)
assert len(local_map_nodes) == 1, "Expected 1 forward local_map node"
verify_local_map_placements(
sharding_placement, local_map_nodes[0], (Shard(0),)
)

parallel_mod = autop.apply_placement(sharding_placement)

assert parallel_mod is not None

def test_with_sharding_constraint_between_local_maps(self, device_mesh_1d):
"""Test with_sharding_constraint between local_map regions."""
dim = 128

@local_map(
out_placements=((Shard(0),),),
in_placements=((Shard(0),),),
redistribute_inputs=True,
device_mesh=device_mesh_1d,
)
def compute1(x):
return x + 1

@local_map(
out_placements=((Shard(0),),),
in_placements=((Shard(0),),),
redistribute_inputs=True,
device_mesh=device_mesh_1d,
)
def compute2(x):
return x * 2

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim, bias=False)

def forward(self, x):
x = self.linear(x)
x = compute1(x)
# Constraint applied between local_map regions (at DTensor level)
x = with_sharding_constraint(x, (Shard(0),), device_mesh_1d)
x = compute2(x)
return x

def input_fn():
return torch.rand(512, dim, device="cuda")

with torch.device("meta"):
model = Model(dim)

with AutoParallel(model, input_fn, device_mesh_1d) as autop:
autop.add_input_constraints([(Shard(0),)])
autop.add_output_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()

# Verify all local_map nodes have correct placement
# There are 3 forward local_map nodes: compute1, with_sharding_constraint, compute2
local_map_nodes = get_local_map_nodes(autop.gm.graph, is_backward=False)
assert (
len(local_map_nodes) == 3
), f"Expected 3 forward local_map nodes, got {len(local_map_nodes)}"
for node in local_map_nodes:
verify_local_map_placements(sharding_placement, node, (Shard(0),))

parallel_mod = autop.apply_placement(sharding_placement)

assert parallel_mod is not None

def test_with_sharding_constraint_replicate(self, device_mesh_1d):
"""Test with_sharding_constraint to force replication."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear1 = nn.Linear(dim, dim, bias=False)
self.linear2 = nn.Linear(dim, dim, bias=False)

def forward(self, x):
x = self.linear1(x)
# Force intermediate to be replicated
x = with_sharding_constraint(x, (Replicate(),), device_mesh_1d)
x = self.linear2(x)
return x

def input_fn():
return torch.rand(512, dim, device="cuda")

with torch.device("meta"):
model = Model(dim)

with AutoParallel(model, input_fn, device_mesh_1d) as autop:
autop.add_input_constraints([(Shard(0),)])
autop.add_output_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()

# Verify the with_sharding_constraint node forces Replicate
local_map_nodes = get_local_map_nodes(autop.gm.graph, is_backward=False)
assert len(local_map_nodes) == 1, "Expected 1 forward local_map node"
verify_local_map_placements(
sharding_placement, local_map_nodes[0], (Replicate(),)
)

parallel_mod = autop.apply_placement(sharding_placement)

assert parallel_mod is not None

def test_with_sharding_constraint_2d_mesh(self, device_mesh_2d):
"""Test with_sharding_constraint on a 2D mesh."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear1 = nn.Linear(dim, dim, bias=False)
self.linear2 = nn.Linear(dim, dim, bias=False)

def forward(self, x):
x = self.linear1(x)
# Shard along batch dim on dp, replicate on tp
x = with_sharding_constraint(x, (Shard(0), Replicate()), device_mesh_2d)
x = self.linear2(x)
return x

def input_fn():
return torch.rand(512, dim, device="cuda")

with torch.device("meta"):
model = Model(dim)

with AutoParallel(model, input_fn, device_mesh_2d) as autop:
autop.add_input_constraints([(Shard(0), Replicate())])
autop.add_output_constraints([(Shard(0), Replicate())])
sharding_placement = autop.optimize_placement()

# Verify the with_sharding_constraint node has correct 2D placement
local_map_nodes = get_local_map_nodes(autop.gm.graph, is_backward=False)
assert len(local_map_nodes) == 1, "Expected 1 forward local_map node"
verify_local_map_placements(
sharding_placement, local_map_nodes[0], (Shard(0), Replicate())
)

parallel_mod = autop.apply_placement(sharding_placement)

assert parallel_mod is not None

def test_with_sharding_constraint_multiple(self, device_mesh_1d):
"""Test multiple with_sharding_constraint calls in sequence."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear1 = nn.Linear(dim, dim, bias=False)
self.linear2 = nn.Linear(dim, dim, bias=False)
self.linear3 = nn.Linear(dim, dim, bias=False)

def forward(self, x):
x = self.linear1(x)
x = with_sharding_constraint(x, (Shard(0),), device_mesh_1d)
x = self.linear2(x)
x = with_sharding_constraint(x, (Replicate(),), device_mesh_1d)
x = self.linear3(x)
x = with_sharding_constraint(x, (Shard(0),), device_mesh_1d)
return x

def input_fn():
return torch.rand(512, dim, device="cuda")

with torch.device("meta"):
model = Model(dim)

with AutoParallel(model, input_fn, device_mesh_1d) as autop:
autop.add_input_constraints([(Shard(0),)])
autop.add_output_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()

# Verify all 3 with_sharding_constraint nodes have correct placements
local_map_nodes = get_local_map_nodes(autop.gm.graph, is_backward=False)
assert (
len(local_map_nodes) == 3
), f"Expected 3 forward local_map nodes, got {len(local_map_nodes)}"

# Nodes should be in order: Shard(0), Replicate(), Shard(0)
expected_placements = [(Shard(0),), (Replicate(),), (Shard(0),)]
for node, expected in zip(local_map_nodes, expected_placements):
verify_local_map_placements(sharding_placement, node, expected)

parallel_mod = autop.apply_placement(sharding_placement)

assert parallel_mod is not None

def test_with_sharding_constraint_no_mesh_outside_local_map_raises(self):
"""Test that with_sharding_constraint raises error when no mesh is available."""
x = torch.rand(10, 10)
with pytest.raises(RuntimeError, match="No device mesh is currently active"):
with_sharding_constraint(x, (Shard(0),))


class TestPermutation:
def test_shape_preserved(self):
"""Permutation should preserve tensor shape."""
Expand Down
Loading