Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions autoparallel/shardings/dtensor_sharding_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,29 @@ def batch_shard_strategy(
return output_strategy


def _try_decomp_sharding(
op: torch._ops.OpOverload, op_schema: OpSchema
) -> Optional[StrategyType]:
"""
Attempt to derive sharding strategies for an op via decomposition tracing.

Uses PyTorch's DecompShardingStrategy: traces the op's decomposition on meta
tensors, propagating placements through sub-ops that already have strategies.
Returns an OpStrategy if successful, None otherwise.
"""
from torch.distributed.tensor._decompositions import DecompShardingStrategy

if not DecompShardingStrategy.has_decomp(op):
return None

decomp_strategy = propagator.decomp_strategy
decomp_strategy.ensure_schema_info(op)
result = decomp_strategy.propagate_strategy(op_schema)
if result is not None:
logger.info(f"derived sharding strategy for `{op}` via decomposition")
return result


def _try_single_dim_strategy(
op: torch._ops.OpOverload, op_schema: OpSchema
) -> Optional[StrategyType]:
Expand Down Expand Up @@ -301,6 +324,15 @@ def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyT
f"Operator {op} does not have a sharding strategy registered."
)
else:
# First, try to derive strategies via decomposition tracing.
# This produces richer strategies than all-Replicate fallback
# by tracing through the op's decomposition into sub-ops that
# already have sharding rules.
decomp_result = _try_decomp_sharding(op, op_schema)
if decomp_result is not None:
return decomp_result

# Fall back to all-Replicate
# Use the current stack if available
if _current_stack is not None:
_current_stack.enter_context(
Expand Down