From 169b7b51bb7291ca7fd54814546eade980e88fc1 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:55:04 +0000 Subject: [PATCH 1/4] Fixes sharded DiT example with FSDP --- physicsnemo/domain_parallel/__init__.py | 11 +- physicsnemo/domain_parallel/shard_tensor.py | 589 ++++++++++-------- .../domain_parallel/shard_utils/__init__.py | 2 +- .../domain_parallel/shard_utils/view_ops.py | 4 +- 4 files changed, 335 insertions(+), 271 deletions(-) diff --git a/physicsnemo/domain_parallel/__init__.py b/physicsnemo/domain_parallel/__init__.py index c3c16d14fa..af43281c2d 100644 --- a/physicsnemo/domain_parallel/__init__.py +++ b/physicsnemo/domain_parallel/__init__.py @@ -47,7 +47,13 @@ # In minumum versions are met, we can import the shard tensor and spec. from ._shard_tensor_spec import ShardTensorSpec - from .shard_tensor import ShardTensor, scatter_tensor + from .shard_tensor import ( + FSDPOutputTensorAdapter, + ShardTensor, + distribute_over_domain_for_fsdp, + scatter_tensor, + wrap_for_fsdp, + ) def register_custom_ops(): # These imports will register the custom ops with the ShardTensor class. @@ -69,3 +75,6 @@ def register_custom_ops(): ShardTensor = None ShardTensorSpec = None scatter_tensor = None + distribute_over_domain_for_fsdp = None + FSDPOutputTensorAdapter = None + wrap_for_fsdp = None diff --git a/physicsnemo/domain_parallel/shard_tensor.py b/physicsnemo/domain_parallel/shard_tensor.py index 0a31ca7253..74d7a84e50 100644 --- a/physicsnemo/domain_parallel/shard_tensor.py +++ b/physicsnemo/domain_parallel/shard_tensor.py @@ -16,14 +16,17 @@ from __future__ import annotations +import threading from collections.abc import Iterable, Mapping +from contextlib import contextmanager from typing import Callable, Sequence, cast from warnings import warn import torch import torch.distributed as dist +from torch import nn from torch.distributed.device_mesh import DeviceMesh, _mesh_resources -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, distribute_module from torch.distributed.tensor._dtensor_spec import ( TensorMeta, ) @@ -42,26 +45,22 @@ _infer_shard_tensor_spec_from_local_chunks, _stride_from_contiguous_shape_C_style, ) -from physicsnemo.utils.profiling import annotate, profile aten = torch.ops.aten -def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: - r"""Convert a ShardTensor to a plain DTensor for dispatch. +# ====================================================================== - Creates a DTensor with the same internal state as the ShardTensor, - which allows DTensor's dispatch to handle it correctly. +# ============================================================================ +# Layer 1 -- Semi-private conversions (no autograd, no spec inference) +# ============================================================================ - Parameters - ---------- - st : ShardTensor - The ShardTensor to convert. - Returns - ------- - DTensor - A DTensor sharing the same ``_local_tensor`` and ``_spec``. +def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: + r"""Convert a ShardTensor to a plain DTensor (no autograd). + + Creates a DTensor sharing the same ``_local_tensor`` and ``_spec``. + Use for dispatch or inside backward when building a DTensor gradient. """ dtensor = torch.Tensor._make_wrapper_subclass( DTensor, @@ -77,31 +76,250 @@ def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: return dtensor -def _convert_args_to_dtensor(arg: object) -> object: - r"""Recursively convert ShardTensors in args to DTensors. +def _dtensor_to_shard_tensor(dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + r"""Promote a DTensor to a ShardTensor (no autograd). - Parameters - ---------- - arg : object - A single argument that may be a ShardTensor, an iterable of - arguments (e.g. list, tuple), a mapping (e.g. dict) whose - values are converted, or any other value. + Callers must supply a resolved ``spec``. Use inside backward (with spec + from ctx) or after resolving a spec via :func:`_resolve_spec_for_dtensor`. + """ + if isinstance(dtensor, ShardTensor): + # Shortcut if we're already a ShardTensor: + return dtensor + st = ShardTensor.__new__( + ShardTensor, + local_tensor=dtensor._local_tensor, + spec=spec, + requires_grad=dtensor.requires_grad, + ) + return st - Returns - ------- - object - The argument with any ShardTensors replaced by DTensors. + +# ============================================================================ +# Layer 2 -- Autograd Functions (use Layer 1 inside fwd / bwd) +# ============================================================================ + + +class _DTensorToShardTensor(torch.autograd.Function): + r"""Differentiable promotion: DTensor -> ShardTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_dtensor_to_shard_tensor`. + Backward: :func:`_shard_tensor_to_dtensor`. + """ + + @staticmethod + def forward(ctx, dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + return _dtensor_to_shard_tensor(dtensor, spec) + + @staticmethod + def backward(ctx, grad_output: "ShardTensor"): + return _shard_tensor_to_dtensor(grad_output), None + + +class _ShardTensorToDTensor(torch.autograd.Function): + r"""Differentiable conversion: ShardTensor -> DTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_shard_tensor_to_dtensor` (caches spec). + Backward: :func:`_dtensor_to_shard_tensor` (reuses cached spec). """ - # ShardTensor is defined later in this module; the isinstance check - # is safe because this function is only called at runtime. - if isinstance(arg, ShardTensor): - return _shard_tensor_to_dtensor(arg) - elif isinstance(arg, Mapping): - return type(arg)({k: _convert_args_to_dtensor(v) for k, v in arg.items()}) - elif isinstance(arg, Iterable) and not isinstance(arg, (str, bytes)): - converted = [_convert_args_to_dtensor(a) for a in arg] - return type(arg)(converted) - return arg + + @staticmethod + def forward(ctx, st: "ShardTensor") -> DTensor: + ctx.shard_tensor_spec = st._spec + return _shard_tensor_to_dtensor(st) + + @staticmethod + def backward(ctx, grad_output: DTensor): + return (_dtensor_to_shard_tensor(grad_output, ctx.shard_tensor_spec),) + + +# ============================================================================ +# Layer 3 -- Smart single-tensor converters (auto-diff when grad_fn present) +# ============================================================================ + + +def _resolve_spec_for_dtensor( + dtensor: DTensor, input_args: tuple = () +) -> ShardTensorSpec: + r"""Resolve a ShardTensorSpec for *dtensor*. + + Tries to reuse a spec from a ShardTensor in *input_args* whose + ``tensor_meta`` and ``placements`` match. Falls back to chunk-based + inference (no communication). + """ + for arg in input_args: + if ( + isinstance(arg, ShardTensor) + and dtensor._spec.tensor_meta == arg._spec.tensor_meta + and dtensor._spec.placements == arg._spec.placements + ): + return arg._spec + return _infer_shard_tensor_spec_from_local_chunks( + dtensor._local_tensor, + dtensor._spec.mesh, + dtensor._spec.placements, + sharding_shapes="chunk", + global_shape=dtensor.shape, + ) + + +# This is a thread-safe reentry guard. +# Goal is to prevent recursion into the fall back conversion paths. +# Here's the scenario we're preventing: +# 1. A ShardTensor needs to use the DTensor path in a torch_function level call. +# This will enter torch_function for ShardTensor and trigger the fallback path. +# 2. Because that builds the autograd graph, the conversion from ShardTensor to DTensor +# must be differentiable. +# 3. The conversion path itself will call _ShardTensorToDTensor, which will enter +# torch_function for ShardTensor. +# 4. There is no overload for converting ShardTensor to DTensor, so it will +# enter the fallback conversion path +# 5. Infinite recursion / profit. +_conversion_guard = threading.local() + + +def _conversion_active() -> bool: + r"""Return whether ShardTensor<->DTensor conversion is currently active.""" + return getattr(_conversion_guard, "depth", 0) > 0 + + +@contextmanager +def _conversion_scope(): + r"""Re-entrant conversion guard for cast-down/cast-up paths.""" + previous_depth = getattr(_conversion_guard, "depth", 0) + _conversion_guard.depth = previous_depth + 1 + try: + yield + finally: + if previous_depth == 0: + delattr(_conversion_guard, "depth") + else: + _conversion_guard.depth = previous_depth + + +def _convert_st_to_dt(st: "ShardTensor") -> DTensor: + r"""ShardTensor -> DTensor; differentiable when *st* is non-leaf.""" + with _conversion_scope(): + if st.requires_grad and st.grad_fn is not None: + return _ShardTensorToDTensor.apply(st) + return _shard_tensor_to_dtensor(st) + + +def _convert_dt_to_st(dtensor: DTensor, input_args: tuple = ()) -> "ShardTensor": + r"""DTensor -> ShardTensor; differentiable when *dtensor* is non-leaf. + + Resolves spec, then uses Layer 2 or Layer 1 depending on whether the + DTensor carries a ``grad_fn``. + """ + if isinstance(dtensor, ShardTensor): + return dtensor + with _conversion_scope(): + spec = _resolve_spec_for_dtensor(dtensor, input_args) + if dtensor.grad_fn is not None: + return _DTensorToShardTensor.apply(dtensor, spec) + res = _dtensor_to_shard_tensor(dtensor, spec) + return res + + +def _dispatch_fallback_via_dtensor( + func: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute an ATen op through DTensor fallback and promote results back.""" + with _conversion_scope(): + converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) + converted_kwargs = { + k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() + } + dispatch_res = DTensor._op_dispatcher.dispatch( + func, converted_args, converted_kwargs + ) + with _conversion_scope(): + return _convert_results_to_shard_tensor(dispatch_res, args) + + +def _torch_function_fallback_via_dtensor( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute a ``__torch_function__`` fallback through DTensor safely. + + The fallback call itself is wrapped in ``DisableTorchFunctionSubclass`` to + avoid re-entering tensor-subclass ``__torch_function__`` while still + allowing autograd to record DTensor ops. + """ + + with _conversion_scope(): + # Here, we take args and kwargs and push all ShardTensors to DTensors. + # Other args are left as is. + converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) + converted_kwargs = { + k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() + } + with torch._C.DisableTorchFunctionSubclass(): + result = func(*converted_args, **converted_kwargs) + with _conversion_scope(): + # The output results promote any DTensor results back to ShardTensors + converted_result = _convert_results_to_shard_tensor(result, args) + return converted_result + + +# ============================================================================ +# Layer 4 -- Recurse utilities (walk args / kwargs / results) +# ============================================================================ + + +def _convert_args_to_dtensor(arg: object) -> object: + r"""Recursively replace ShardTensors with DTensors in a single arg. + + Walks mappings, tuples, and lists. Each ShardTensor is converted via + :func:`_convert_st_to_dt`. + """ + match arg: + case ShardTensor(): + return _convert_st_to_dt(arg) + case DTensor(): + # DTensor can be iterable; exit early deliberatly + return arg + case Mapping(): + return type(arg)({k: _convert_args_to_dtensor(v) for k, v in arg.items()}) + case tuple(): + return tuple(_convert_args_to_dtensor(a) for a in arg) + case list(): + return [_convert_args_to_dtensor(a) for a in arg] + case _: + return arg + + +def _convert_results_to_shard_tensor(result: object, input_args: tuple) -> object: + r"""Recursively replace DTensors with ShardTensors in an op result. + + Walks single tensor, mappings, and iterables (excluding str/bytes). + Each DTensor is converted via :func:`_convert_dt_to_st`. + """ + if isinstance(result, DTensor): + res = _convert_dt_to_st(result, input_args) + return res + if isinstance(result, Mapping): + return type(result)( + { + k: _convert_dt_to_st(v, input_args) if isinstance(v, DTensor) else v + for k, v in result.items() + } + ) + if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): + return type(result)( + _convert_dt_to_st(d, input_args) if isinstance(d, DTensor) else d + for d in result + ) + return result class _ToTorchTensor(torch.autograd.Function): @@ -296,73 +514,6 @@ def backward( return grad_output.to_local(), None, None, None -class _PromoteDTensorToShardTensor(torch.autograd.Function): - r"""Autograd function to promote a DTensor to a ShardTensor while preserving ``grad_fn``. - - When DTensor's ``__torch_function__`` returns a non-leaf DTensor (one that - has a ``grad_fn``), creating a new ShardTensor via ``_make_wrapper_subclass`` - always produces a leaf — disconnecting it from the autograd graph. - - This function bridges that gap: the forward creates the ShardTensor wrapper, - and ``apply`` attaches a ``grad_fn`` that connects it back to the original - DTensor's graph. The backward simply passes gradients through unchanged. - - This is only used at the ``__torch_function__`` level where the DTensor - result already carries autograd state. At the ``__torch_dispatch__`` level, - promotion is safe without this because autograd wraps the result afterwards. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - dtensor: DTensor, - spec: "ShardTensorSpec", - ) -> "ShardTensor": - r"""Create a ShardTensor from a DTensor, preserving autograd via ``apply``. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused — no state needed for backward). - dtensor : DTensor - The DTensor to promote. - spec : ShardTensorSpec - The ShardTensorSpec to use for the new ShardTensor. - - Returns - ------- - ShardTensor - A new ShardTensor wrapping the same local data. - """ - return ShardTensor.__new__( - ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=False, # autograd.Function.apply handles this - ) - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - grad_output: "ShardTensor", - ) -> tuple[DTensor, None]: - r"""Pass gradient through unchanged. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused). - grad_output : ShardTensor - Gradient with respect to the ShardTensor output. - - Returns - ------- - Tuple[DTensor, None] - The gradient for the DTensor input, and ``None`` for the spec. - """ - return grad_output, None - - class ShardTensor(DTensor): r"""A distributed tensor class with support for uneven data sharding. @@ -549,18 +700,28 @@ def __new__( return ret def __repr__(self) -> str: - return f"ShardTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + return ( + "ShardTensor(" + f"local_tensor={repr(self._local_tensor)}, " + f"device_mesh={repr(self._spec.mesh)}, " + f"placements={repr(self._spec.placements)}" + ")" + ) + + def __str__(self) -> str: + # Avoid Tensor/DTensor string formatting paths that can re-enter dispatch. + return self.__repr__() + + def __format__(self, format_spec: str) -> str: + # Format as plain Python string to bypass tensor formatting internals. + return format(str(self), format_spec) @classmethod def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": r"""Convert a DTensor to a ShardTensor. - Assumes the DTensor is properly constructed. Since DTensor is locked - to sharding a tensor according to chunk format, the sharding sizes - can be inferred with no communication. - - If the DTensor is a non-leaf (has a ``grad_fn``), the autograd graph - is preserved via :class:`_PromoteDTensorToShardTensor`. + Differentiable when *dtensor* is non-leaf (has a ``grad_fn``). + Spec is inferred from the DTensor (chunk-based, no communication). Parameters ---------- @@ -572,144 +733,24 @@ def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": ShardTensor Equivalent ShardTensor with the same local tensor and inferred spec. """ - return cls._maybe_promote_dtensor(dtensor, ()) - - @staticmethod - def _maybe_promote_dtensor(dtensor, input_args): - r"""Promote a single DTensor back to ShardTensor if it matches input criteria. - - If ``dtensor`` is already a ShardTensor, it is returned as-is. Otherwise, - determines a ``ShardTensorSpec`` (reusing an input's spec when possible, - otherwise inferring one) and creates a new ShardTensor. - - When the DTensor is a non-leaf (has a ``grad_fn``), the promotion goes - through :class:`_PromoteDTensorToShardTensor` so that the autograd graph - is preserved. For leaf DTensors, direct construction is used since there - is no graph to preserve. - - Parameters - ---------- - dtensor : DTensor - The DTensor result to promote. - input_args : tuple - Original input arguments to search for matching ShardTensors. - - Returns - ------- - ShardTensor - Promoted ShardTensor (or the original if already a ShardTensor). - """ - if isinstance(dtensor, ShardTensor): - return dtensor - - # Determine the ShardTensorSpec — reuse an input's spec when the - # tensor_meta and placements match (avoids communication). - spec = None - for arg in input_args: - if ( - isinstance(arg, ShardTensor) - and dtensor._spec.tensor_meta == arg._spec.tensor_meta - and dtensor._spec.placements == arg._spec.placements - ): - spec = arg._spec - break - - if spec is None: - # Infer from DTensor (no communication for chunk-based sharding). - spec = _infer_shard_tensor_spec_from_local_chunks( - dtensor._local_tensor, - dtensor._spec.mesh, - dtensor._spec.placements, - sharding_shapes="chunk", - global_shape=dtensor.shape, - ) - - # Non-leaf DTensors carry a grad_fn from the operation that produced - # them. Creating a new ShardTensor via _make_wrapper_subclass would - # discard that grad_fn (producing a leaf). Go through the autograd - # function so that apply() connects the new ShardTensor back to the - # original graph. - if dtensor.grad_fn is not None: - return _PromoteDTensorToShardTensor.apply(dtensor, spec) - - # Leaf DTensors (parameters, buffers, detached tensors) can be - # constructed directly — there is no autograd graph to preserve. - return ShardTensor.__new__( - ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=dtensor.requires_grad, - ) - - @staticmethod - def _promote_dtensor_results(result, input_args): - r"""Promote DTensor(s) in a dispatch/function result back to ShardTensor. - - Handles four cases: - - 1. Single DTensor — promoted via :meth:`_maybe_promote_dtensor`. - 2. Mapping (e.g. dict) — each value is promoted if it is a DTensor. - 3. Iterable of results — each DTensor element is promoted individually. - 4. Anything else — returned as-is. - - Parameters - ---------- - result : object - The result returned by DTensor dispatch or ``__torch_function__``. - input_args : tuple - Original input arguments used for matching specs. - - Returns - ------- - object - The result with any DTensors promoted to ShardTensors. - """ - if isinstance(result, DTensor): - return ShardTensor._maybe_promote_dtensor(result, input_args) - - if isinstance(result, Mapping): - return type(result)( - { - k: ShardTensor._maybe_promote_dtensor(v, input_args) - if isinstance(v, DTensor) - else v - for k, v in result.items() - } - ) - - # Exclude str/bytes so we don't iterate over characters. - if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): - return type(result)( - ShardTensor._maybe_promote_dtensor(d, input_args) - if isinstance(d, DTensor) - else d - for d in result - ) - - return result + return _convert_dt_to_st(dtensor) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - with annotate(f"__torch_function___{func.__name__}"): - # Check for overrides: - if func in cls._function_registry and cls._enable_shard_patches: - res = cls._function_registry[func](func, types, args, kwargs) - return res - elif ( - str(func) in cls._named_function_registry and cls._enable_shard_patches - ): - res = cls._named_function_registry[str(func)](func, types, args, kwargs) - return res - # Fall back to the default behavior, but promote any DTensor - # results back to ShardTensor (matching dispatch behavior): - result = super().__torch_function__(func, types, args, kwargs) - return cls._promote_dtensor_results(result, args) + if _conversion_active(): + # When converting shard tensor to dtensor, or dtensor to shard tensor, + # we just skip this function entirely. + return super().__torch_function__(func, types, args, kwargs) + if func in cls._function_registry and cls._enable_shard_patches: + return cls._function_registry[func](func, types, args, kwargs) + if str(func) in cls._named_function_registry and cls._enable_shard_patches: + return cls._named_function_registry[str(func)](func, types, args, kwargs) + res = _torch_function_fallback_via_dtensor(func, args, kwargs) + return res @classmethod - @torch._disable_dynamo - @profile def __torch_dispatch__( cls, func: torch._ops.OpOverload, @@ -717,33 +758,14 @@ def __torch_dispatch__( args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, ) -> "ShardTensor" | Iterable["ShardTensor"] | object: - with annotate(f"__torch_dispatch___{func.__name__}"): - # Leverage DTensor Dispatch as much as possible, but, enable - # the ability to operate on this output in the future: - handler = cls._dispatch_registry.get(func) - if handler is None: - handler = cls._dispatch_registry_by_name.get(str(func)) - if handler is not None: - res = handler(*args, **kwargs) - return res - - # We assume that if we reach this point, the operator has not been - # intercepted by a wrapper or in the registry. So the DTensor - # default behavior is likely to be correct. - - # Convert ShardTensors to DTensors so DTensor's dispatcher - # receives the types it expects. - converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) - converted_kwargs = { - k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() - } - - dispatch_res = DTensor._op_dispatcher.dispatch( - func, converted_args, converted_kwargs - ) - - # Promote any DTensor results back to ShardTensor. - return cls._promote_dtensor_results(dispatch_res, args) + # Use a handler, if we have one: + handler = cls._dispatch_registry.get(func) + if handler is None: + handler = cls._dispatch_registry_by_name.get(str(func)) + if handler is not None: + return handler(*args, **kwargs) + # Otherwise, try the dtensor route: + return _dispatch_fallback_via_dtensor(func, args, kwargs) @staticmethod def from_local( @@ -965,6 +987,37 @@ def backward(self, *args, **kwargs): return self.to_local().backward(*args, **kwargs) +class FSDPOutputTensorAdapter(nn.Module): + """Wrap a module and convert ShardTensor outputs to torch.Tensor.""" + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + out = self.module(*args, **kwargs) + return out.to_local() if isinstance(out, ShardTensor) else out + + +def wrap_for_fsdp(module: nn.Module) -> nn.Module: + """Return a module wrapper that exposes tensor outputs for FSDP hooks.""" + return FSDPOutputTensorAdapter(module) + + +def distribute_over_domain_for_fsdp( + module: nn.Module, + device_mesh: DeviceMesh, + partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None, +) -> nn.Module: + """Distribute a module over a domain mesh and adapt outputs for FSDP.""" + distributed_module = distribute_module( + module, + device_mesh=device_mesh, + partition_fn=partition_fn, + ) + return wrap_for_fsdp(distributed_module) + + def scatter_tensor( tensor: torch.Tensor, global_src: int, diff --git a/physicsnemo/domain_parallel/shard_utils/__init__.py b/physicsnemo/domain_parallel/shard_utils/__init__.py index 16e2bb0cb2..bba5ceb7c5 100644 --- a/physicsnemo/domain_parallel/shard_utils/__init__.py +++ b/physicsnemo/domain_parallel/shard_utils/__init__.py @@ -36,7 +36,7 @@ def register_shard_wrappers(): from .mesh_ops import sharded_signed_distance_field_wrapper # Currently disabled until wrapt is removed - # from .natten_patches import na2d_wrapper + from .natten_patches import na2d_wrapper from .normalization_patches import group_norm_wrapper from .padding import generic_pad_nd_wrapper from .point_cloud_ops import radius_search_wrapper diff --git a/physicsnemo/domain_parallel/shard_utils/view_ops.py b/physicsnemo/domain_parallel/shard_utils/view_ops.py index 0b73a4e762..bb18081871 100644 --- a/physicsnemo/domain_parallel/shard_utils/view_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/view_ops.py @@ -608,7 +608,8 @@ def forward( Viewed ShardTensor. """ ctx.input_global_shape = tuple(tensor.shape) - return _sharded_view_forward(tensor, target_shape) + out = _sharded_view_forward(tensor, target_shape) + return out @staticmethod def backward( @@ -629,6 +630,7 @@ def backward( tuple[ShardTensor, None] Gradient for the input tensor, and ``None`` for ``target_shape``. """ + return ( _sharded_view_forward(grad_output, ctx.input_global_shape), None, From 15118af844d7d259510cec8dd84346a32833caa7 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:18:03 -0700 Subject: [PATCH 2/4] Refactor unbind to implement shard tensor interface only; prevent torch api shifts from breaking for us. Also increase conv image size for better test stability --- physicsnemo/domain_parallel/__init__.py | 2 +- .../domain_parallel/custom_ops/__init__.py | 2 +- .../domain_parallel/custom_ops/_reductions.py | 110 +++-- .../domain_parallel/custom_ops/_tensor_ops.py | 249 ++++++---- physicsnemo/domain_parallel/shard_tensor.py | 444 ++++++++++++------ .../shard_utils/normalization_patches.py | 4 +- test/domain_parallel/ops/test_convolution.py | 12 +- test/domain_parallel/ops/test_unbind.py | 113 +++++ test/domain_parallel/ops/test_view_ops.py | 19 +- test/domain_parallel/ops/utils.py | 2 +- test/domain_parallel/test_grad_sharding.py | 2 +- test/domain_parallel/test_initialization.py | 93 +++- test/domain_parallel/test_reductions.py | 4 + 13 files changed, 773 insertions(+), 283 deletions(-) create mode 100644 test/domain_parallel/ops/test_unbind.py diff --git a/physicsnemo/domain_parallel/__init__.py b/physicsnemo/domain_parallel/__init__.py index af43281c2d..1da99e02f1 100644 --- a/physicsnemo/domain_parallel/__init__.py +++ b/physicsnemo/domain_parallel/__init__.py @@ -61,7 +61,7 @@ def register_custom_ops(): from .custom_ops import ( mean_wrapper, sum_wrapper, - unbind_rules, + unbind_wrapper, ) from .shard_utils import register_shard_wrappers diff --git a/physicsnemo/domain_parallel/custom_ops/__init__.py b/physicsnemo/domain_parallel/custom_ops/__init__.py index 2dc160d6de..bfd7e1c641 100644 --- a/physicsnemo/domain_parallel/custom_ops/__init__.py +++ b/physicsnemo/domain_parallel/custom_ops/__init__.py @@ -21,4 +21,4 @@ if ST_AVAILABLE: from ._reductions import mean_wrapper, sum_wrapper - from ._tensor_ops import unbind_rules + from ._tensor_ops import unbind_wrapper diff --git a/physicsnemo/domain_parallel/custom_ops/_reductions.py b/physicsnemo/domain_parallel/custom_ops/_reductions.py index b673e6bec4..7c509c0eb3 100644 --- a/physicsnemo/domain_parallel/custom_ops/_reductions.py +++ b/physicsnemo/domain_parallel/custom_ops/_reductions.py @@ -44,12 +44,17 @@ ) import torch +from torch.distributed.tensor._dtensor_spec import TensorMeta from torch.distributed.tensor.placement_types import ( Partial, Shard, ) # noqa: E402 +from physicsnemo.domain_parallel._shard_tensor_spec import ( + ShardTensorSpec, + _stride_from_contiguous_shape_C_style, +) from physicsnemo.domain_parallel.shard_tensor import ShardTensor aten = torch.ops.aten @@ -248,6 +253,62 @@ def compute_result_sharding_shapes( return result_sharding_shapes +def build_reduction_result( + local_result: torch.Tensor, + input_tensor: ShardTensor, + placements: list[Partial | Shard], + sharding_shapes: dict[int, list[torch.Size]], +) -> ShardTensor: + r"""Construct a ShardTensor result from a local reduction output. + + Builds the ``ShardTensorSpec`` directly from the already-computed placements + and sharding shapes, avoiding the overhead and autograd side-effects of + ``ShardTensor.from_local``. + + Parameters + ---------- + local_result : torch.Tensor + The locally-computed reduction result. + input_tensor : ShardTensor + The original input ShardTensor (used for device mesh). + placements : List[Union[Partial, Shard]] + Result placements from :func:`compute_result_placements`. + sharding_shapes : Dict[int, List[torch.Size]] + Result sharding shapes from :func:`compute_result_sharding_shapes`. + + Returns + ------- + ShardTensor + Wrapped result with correct sharding metadata. + """ + global_shape = list(local_result.shape) + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + tensor_dim = placement.dim + global_shape[tensor_dim] = sum( + s[tensor_dim] for s in sharding_shapes[mesh_dim] + ) + + stride = _stride_from_contiguous_shape_C_style(global_shape) + spec = ShardTensorSpec( + mesh=input_tensor.device_mesh, + placements=tuple(placements), + tensor_meta=TensorMeta( + shape=tuple(global_shape), + stride=stride, + dtype=local_result.dtype, + ), + _local_shape=local_result.shape, + _sharding_shapes={dim: tuple(s) for dim, s in sharding_shapes.items()}, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_result, + spec=spec, + requires_grad=input_tensor.requires_grad, + ) + + def create_sharded_grad_input( local_grad_input: torch.Tensor, original_spec: Any ) -> ShardTensor: @@ -265,11 +326,15 @@ def create_sharded_grad_input( ShardTensor A distributed tensor with the same sharding as the original input. """ - return ShardTensor.from_local( - local_grad_input, - device_mesh=original_spec.mesh, - placements=original_spec.placements, - sharding_shapes=original_spec.sharding_shapes(), + # In custom autograd backward, return the input gradient directly as a + # ShardTensor value. Avoid ``from_local`` here (which routes through a + # separate autograd Function) so the gradient is attached unambiguously to + # the original ShardTensor input. + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_grad_input, + spec=original_spec, + requires_grad=False, ) @@ -361,24 +426,14 @@ def forward( """ dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) - # Get local tensor - local_tensor = tensor._local_tensor - # Perform local sum - local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) + local_result = aten.sum( + tensor._local_tensor, dim=dim, keepdim=keepdim, dtype=dtype + ) - # Compute placements for the result placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - return result + return build_reduction_result(local_result, tensor, placements, sharding_shapes) @staticmethod def backward( @@ -495,23 +550,14 @@ def forward( for d in reduction_dims: weight *= local_shape[d] / global_shape[d] - # Perform local mean + # Perform local mean and apply weighting for uneven shards local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) - # Apply weighting local_result = local_result * weight placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - return result + return build_reduction_result(local_result, tensor, placements, sharding_shapes) @staticmethod def backward( diff --git a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py index d83a92c757..3f0580b921 100644 --- a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py +++ b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py @@ -16,130 +16,215 @@ r"""Custom tensor operations for ShardTensor dispatch. -This module provides propagation rules for tensor operations that need -special handling when applied to ``ShardTensor`` objects. These rules -are registered with PyTorch's DTensor operation dispatch system. +This module provides dispatch and function handlers for tensor operations +that need special handling when applied to ``ShardTensor`` objects. Handlers +are registered with both ``__torch_dispatch__`` (ATen level) and +``__torch_function__`` (Python level) on :class:`ShardTensor`. """ +from __future__ import annotations + +from typing import Any, Callable + import torch -from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import ( - OpSchema, - OutputSharding, - RuntimeSchemaInfo, -) +from torch.distributed.tensor._dtensor_spec import TensorMeta from torch.distributed.tensor.placement_types import ( Partial, Replicate, Shard, ) -from physicsnemo.core.version_check import check_version_spec +from physicsnemo.domain_parallel import ShardTensor from physicsnemo.domain_parallel._shard_tensor_spec import ( + ShardTensorSpec, _stride_from_contiguous_shape_C_style, ) -if check_version_spec("torch", "2.10.0a"): - from torch.distributed.tensor._ops.registration import ( - register_prop_rule, - ) -else: - from torch.distributed.tensor._ops.utils import ( - register_prop_rule, - ) - aten = torch.ops.aten -@register_prop_rule(aten.unbind.int, schema_info=RuntimeSchemaInfo(1)) -def unbind_rules(op_schema: OpSchema) -> OutputSharding: - r"""Propagation rule for ``torch.unbind`` on ShardTensor. +def _unbind_output_metadata( + input_spec: ShardTensorSpec, dim: int +) -> tuple[int, list, dict[int, list[torch.Size]]]: + r"""Compute the normalized dim, output placements, and sharding shapes for unbind. - Computes the output sharding specification when unbinding a sharded tensor - along a specified dimension. The unbind operation removes one dimension - from the tensor and returns a tuple of tensors. + Validates that the unbind dimension is not sharded and does not use + ``Partial`` placement, then returns the metadata needed to construct + the output ``ShardTensor`` objects. Parameters ---------- - op_schema : OpSchema - The operation schema containing input specifications and arguments. - Expected to contain: - - - ``args_schema[0]``: Input tensor specification (DTensorSpec) - - ``args_schema[1]``: Dimension to unbind along (int), defaults to 0 + input_spec : ShardTensorSpec + Specification of the input sharded tensor. + dim : int + Dimension along which to unbind (may be negative). Returns ------- - OutputSharding - Output sharding specification containing a list of DTensorSpec objects, - one for each tensor in the unbind result. + tuple[int, list, dict[int, list[torch.Size]]] + - Normalized (non-negative) ``dim``. + - Output placements (shard dims above ``dim`` shifted down by 1). + - Output sharding shapes with the unbind dimension removed. Raises ------ - Exception + RuntimeError If attempting to unbind along a sharded dimension (not yet implemented). - If attempting to unbind with Partial placement (not yet supported). + If attempting to unbind with ``Partial`` placement (not yet supported). + """ + ndim = len(input_spec.shape) + if dim < 0: + dim = dim % ndim + + # if the unbind dimension is along a dimension that is sharded, we have to handle that. + # If it's along an unsharded dimension, there is nearly nothing to do. + input_placements = input_spec.placements + shards = [s for s in input_placements if isinstance(s, Shard)] + + if dim in [i.dim for i in shards]: + raise RuntimeError( + "No implementation for unbinding along sharding axis yet." + ) + + new_placements: list = [] + for p in input_placements: + if p.is_replicate(): + new_placements.append(p) + elif p.is_shard(): + if p.dim > dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + elif p.is_partial(): + raise RuntimeError("Partial placement not supported yet for unbind") + + out_sharding_shapes: dict[int, list[torch.Size]] = { + mesh_dim: [ + torch.Size(list(cs[:dim]) + list(cs[dim + 1 :])) + for cs in shard_shapes + ] + for mesh_dim, shard_shapes in input_spec.sharding_shapes().items() + } + + return dim, new_placements, out_sharding_shapes + + +def _unbind_dispatch( + tensor: ShardTensor, dim: int = 0 +) -> tuple[ShardTensor, ...]: + r"""Dispatch handler for ``aten.unbind.int`` on :class:`ShardTensor`. + + Called at the ``__torch_dispatch__`` level (below autograd). Operates + directly on the local tensor and constructs output ``ShardTensor`` + objects with the correct metadata; the autograd engine above handles + gradient tracking. + + Parameters + ---------- + tensor : ShardTensor + Input sharded tensor. + dim : int, default=0 + Dimension along which to unbind. + + Returns + ------- + tuple[ShardTensor, ...] + Tuple of ShardTensors, one per slice along ``dim``. Note ---- - This rule is needed for operations like attention in Stormcast and other + This handler is needed for operations like attention in Stormcast and other models that unbind tensors along non-sharded dimensions. """ + input_spec = tensor._spec + dim, new_placements, out_sharding_shapes = _unbind_output_metadata( + input_spec, dim + ) - # We need to get the dimension of the slice. 0 is default. + # We are reducing tensor rank and returning one tensor per slice + original_shape = list(input_spec.shape) + original_shape.pop(dim) - args_schema = op_schema.args_schema + output_spec = ShardTensorSpec( + mesh=input_spec.mesh, + placements=tuple(new_placements), + tensor_meta=TensorMeta( + torch.Size(tuple(original_shape)), + stride=_stride_from_contiguous_shape_C_style(original_shape), + dtype=input_spec.tensor_meta.dtype, + ), + _sharding_shapes={ + k: tuple(v) for k, v in out_sharding_shapes.items() + }, + ) - if len(args_schema) > 1: - dim = args_schema[-1] - else: - dim = 0 + local_results = aten.unbind.int(tensor._local_tensor, dim) - # if the chunking dimension is along a dimension that is sharded, we have to handle that. - # If it's along an unsharded dimension, there is nearly nothing to do. + return tuple( + ShardTensor( + local_result, + output_spec, + requires_grad=False, # Adjusted after the dispatcher + ) + for local_result in local_results + ) - input_spec = args_schema[0] - input_placements = input_spec.placements +def unbind_wrapper( + func: Callable, + types: tuple[Any, ...], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[ShardTensor, ...]: + r"""Functional-level wrapper for ``torch.unbind`` on ShardTensor. - shards = [s for s in input_placements if isinstance(s, Shard)] + This is a ``__torch_function__``-level intercept (above autograd). It + uses ``to_local()`` / ``from_local()`` so that the autograd graph is + preserved through the unbind operation. - if dim in [i.dim for i in shards]: - raise Exception("No implementation for unbinding along sharding axis yet.") + Parameters + ---------- + func : Callable + The original function being wrapped (``torch.unbind`` or + ``torch.Tensor.unbind``). + types : tuple[Any, ...] + Types of the input arguments (unused). + args : tuple[Any, ...] + Positional arguments. Expected ``(input,)`` or ``(input, dim)``. + kwargs : dict[str, Any] + Keyword arguments (may contain ``dim``). - else: - # We are reducing tensor rank and returning one sharding per tensor: - original_shape = list(input_spec.shape) - unbind_dim_shape = original_shape.pop(dim) + Returns + ------- + tuple[ShardTensor, ...] + Tuple of ShardTensors, one per slice along the unbind dimension. + """ + input_tensor: ShardTensor = args[0] + dim: int = args[1] if len(args) > 1 else kwargs.get("dim", 0) - output_stride = _stride_from_contiguous_shape_C_style(original_shape) + input_spec = input_tensor._spec + dim, new_placements, out_sharding_shapes = _unbind_output_metadata( + input_spec, dim + ) - # Need to create a new global meta: - new_meta = TensorMeta( - torch.Size(tuple(original_shape)), - stride=output_stride, - dtype=input_spec.tensor_meta.dtype, + # to_local() / from_local() preserve the autograd graph + local_input = input_tensor.to_local() + local_results = torch.unbind(local_input, dim) + + return tuple( + ShardTensor.from_local( + local_result, + input_spec.mesh, + new_placements, + out_sharding_shapes, ) + for local_result in local_results + ) - # The placements get adjusted too - new_placements = [] - for p in input_spec.placements: - if isinstance(p, Replicate): - new_placements.append(p) - elif isinstance(p, Shard): - if p.dim > dim: - new_placements.append(Shard(p.dim - 1)) - else: - new_placements.append(p) - elif isinstance(p, Partial): - raise Exception("Partial placement not supported yet for unbind") - - output_spec_list = [ - DTensorSpec( - mesh=input_spec.mesh, - placements=tuple(new_placements), - tensor_meta=new_meta, - ) - for _ in range(unbind_dim_shape) - ] - return OutputSharding(output_spec_list) + +# Python-level function handlers (__torch_function__). +ShardTensor.register_function_handler(torch.unbind, unbind_wrapper) +ShardTensor.register_function_handler(torch.Tensor.unbind, unbind_wrapper) + +# ATen-level dispatch handler (__torch_dispatch__). +ShardTensor.register_dispatch_handler(aten.unbind.int, _unbind_dispatch) diff --git a/physicsnemo/domain_parallel/shard_tensor.py b/physicsnemo/domain_parallel/shard_tensor.py index 74d7a84e50..b1bea3d1cc 100644 --- a/physicsnemo/domain_parallel/shard_tensor.py +++ b/physicsnemo/domain_parallel/shard_tensor.py @@ -62,15 +62,20 @@ def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: Creates a DTensor sharing the same ``_local_tensor`` and ``_spec``. Use for dispatch or inside backward when building a DTensor gradient. """ - dtensor = torch.Tensor._make_wrapper_subclass( - DTensor, - st._spec.tensor_meta.shape, - strides=st._spec.tensor_meta.stride, - dtype=st.dtype, - device=st.device, - layout=st.layout, - requires_grad=st.requires_grad, - ) + if hasattr(torch.Tensor, "_dtensor__new__"): + dtensor = torch.Tensor._dtensor__new__( + DTensor, st._local_tensor, st._spec, requires_grad=st.requires_grad + ) + else: + dtensor = torch.Tensor._make_wrapper_subclass( + DTensor, + st._spec.tensor_meta.shape, + strides=st._spec.tensor_meta.stride, + dtype=st.dtype, + device=st.device, + layout=st.layout, + requires_grad=st.requires_grad, + ) dtensor._local_tensor = st._local_tensor dtensor._spec = st._spec return dtensor @@ -137,12 +142,10 @@ def forward(ctx, st: "ShardTensor") -> DTensor: def backward(ctx, grad_output: DTensor): return (_dtensor_to_shard_tensor(grad_output, ctx.shard_tensor_spec),) - # ============================================================================ # Layer 3 -- Smart single-tensor converters (auto-diff when grad_fn present) # ============================================================================ - def _resolve_spec_for_dtensor( dtensor: DTensor, input_args: tuple = () ) -> ShardTensorSpec: @@ -169,25 +172,13 @@ def _resolve_spec_for_dtensor( # This is a thread-safe reentry guard. -# Goal is to prevent recursion into the fall back conversion paths. -# Here's the scenario we're preventing: -# 1. A ShardTensor needs to use the DTensor path in a torch_function level call. -# This will enter torch_function for ShardTensor and trigger the fallback path. -# 2. Because that builds the autograd graph, the conversion from ShardTensor to DTensor -# must be differentiable. -# 3. The conversion path itself will call _ShardTensorToDTensor, which will enter -# torch_function for ShardTensor. -# 4. There is no overload for converting ShardTensor to DTensor, so it will -# enter the fallback conversion path -# 5. Infinite recursion / profit. +# Goal is to prevent recursion into the fallback conversion paths. _conversion_guard = threading.local() - def _conversion_active() -> bool: r"""Return whether ShardTensor<->DTensor conversion is currently active.""" return getattr(_conversion_guard, "depth", 0) > 0 - @contextmanager def _conversion_scope(): r"""Re-entrant conversion guard for cast-down/cast-up paths.""" @@ -202,46 +193,26 @@ def _conversion_scope(): _conversion_guard.depth = previous_depth -def _convert_st_to_dt(st: "ShardTensor") -> DTensor: - r"""ShardTensor -> DTensor; differentiable when *st* is non-leaf.""" - with _conversion_scope(): - if st.requires_grad and st.grad_fn is not None: - return _ShardTensorToDTensor.apply(st) - return _shard_tensor_to_dtensor(st) - - -def _convert_dt_to_st(dtensor: DTensor, input_args: tuple = ()) -> "ShardTensor": - r"""DTensor -> ShardTensor; differentiable when *dtensor* is non-leaf. - - Resolves spec, then uses Layer 2 or Layer 1 depending on whether the - DTensor carries a ``grad_fn``. - """ - if isinstance(dtensor, ShardTensor): - return dtensor - with _conversion_scope(): - spec = _resolve_spec_for_dtensor(dtensor, input_args) - if dtensor.grad_fn is not None: - return _DTensorToShardTensor.apply(dtensor, spec) - res = _dtensor_to_shard_tensor(dtensor, spec) - return res - - def _dispatch_fallback_via_dtensor( func: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object] | None = None, ) -> object: - r"""Execute an ATen op through DTensor fallback and promote results back.""" + r"""Execute an ATen op through DTensor fallback using PURE data conversion. + + Native Autograd wraps this hook, so we must NOT build an internal graph + using .apply(). We just do the math and let PyTorch track the outer graph. + """ with _conversion_scope(): - converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) + converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=False) for arg in args) converted_kwargs = { - k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() + k: _convert_args_to_dtensor(v, use_autograd=False) for k, v in (kwargs or {}).items() } - dispatch_res = DTensor._op_dispatcher.dispatch( - func, converted_args, converted_kwargs - ) + + dispatch_res = func(*converted_args, **(converted_kwargs or {})) + with _conversion_scope(): - return _convert_results_to_shard_tensor(dispatch_res, args) + return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False) def _torch_function_fallback_via_dtensor( @@ -249,79 +220,88 @@ def _torch_function_fallback_via_dtensor( args: tuple[object, ...], kwargs: dict[str, object] | None = None, ) -> object: - r"""Execute a ``__torch_function__`` fallback through DTensor safely. + r"""Execute a __torch_function__ fallback through DTensor safely. - The fallback call itself is wrapped in ``DisableTorchFunctionSubclass`` to - avoid re-entering tensor-subclass ``__torch_function__`` while still - allowing autograd to record DTensor ops. + Because this executes at the Python API level (above Autograd), we MUST + use autograd functions (.apply) to bridge the tracking manually. """ - with _conversion_scope(): - # Here, we take args and kwargs and push all ShardTensors to DTensors. - # Other args are left as is. - converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) + converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=True) for arg in args) converted_kwargs = { - k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() + k: _convert_args_to_dtensor(v, use_autograd=True) for k, v in (kwargs or {}).items() } + with torch._C.DisableTorchFunctionSubclass(): result = func(*converted_args, **converted_kwargs) + with _conversion_scope(): - # The output results promote any DTensor results back to ShardTensors - converted_result = _convert_results_to_shard_tensor(result, args) - return converted_result + return _convert_results_to_shard_tensor(result, args, use_autograd=True) # ============================================================================ # Layer 4 -- Recurse utilities (walk args / kwargs / results) # ============================================================================ - -def _convert_args_to_dtensor(arg: object) -> object: - r"""Recursively replace ShardTensors with DTensors in a single arg. - - Walks mappings, tuples, and lists. Each ShardTensor is converted via - :func:`_convert_st_to_dt`. +def _convert_args_to_dtensor(arg: object, use_autograd: bool = False) -> object: + r"""Recursively replace ShardTensors with DTensors. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. """ match arg: case ShardTensor(): - return _convert_st_to_dt(arg) + if use_autograd and arg.requires_grad and torch.is_grad_enabled(): + return _ShardTensorToDTensor.apply(arg) + return _shard_tensor_to_dtensor(arg) case DTensor(): - # DTensor can be iterable; exit early deliberatly + # DTensor can be iterable; exit early deliberately return arg case Mapping(): - return type(arg)({k: _convert_args_to_dtensor(v) for k, v in arg.items()}) + return type(arg)({k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()}) case tuple(): - return tuple(_convert_args_to_dtensor(a) for a in arg) + return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg) case list(): - return [_convert_args_to_dtensor(a) for a in arg] + return [_convert_args_to_dtensor(a, use_autograd) for a in arg] case _: return arg -def _convert_results_to_shard_tensor(result: object, input_args: tuple) -> object: +def _convert_results_to_shard_tensor( + result: object, input_args: tuple, use_autograd: bool = False +) -> object: r"""Recursively replace DTensors with ShardTensors in an op result. - - Walks single tensor, mappings, and iterables (excluding str/bytes). - Each DTensor is converted via :func:`_convert_dt_to_st`. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. + Handles None returns gracefully for inplace ATen operations. """ + if result is None: + return None + if isinstance(result, DTensor): - res = _convert_dt_to_st(result, input_args) - return res + spec = _resolve_spec_for_dtensor(result, input_args) + + # If autograd graph connection is requested AND the DTensor actually + # requires tracking (it has a grad_fn or requires_grad is active) + if use_autograd and torch.is_grad_enabled() and (result.grad_fn is not None or result.requires_grad): + return _DTensorToShardTensor.apply(result, spec) + + return _dtensor_to_shard_tensor(result, spec) + if isinstance(result, Mapping): return type(result)( { - k: _convert_dt_to_st(v, input_args) if isinstance(v, DTensor) else v + k: _convert_results_to_shard_tensor(v, input_args, use_autograd) for k, v in result.items() } ) + if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): return type(result)( - _convert_dt_to_st(d, input_args) if isinstance(d, DTensor) else d + _convert_results_to_shard_tensor(d, input_args, use_autograd) for d in result ) + return result - class _ToTorchTensor(torch.autograd.Function): r"""Autograd function to convert a ShardTensor to a regular PyTorch tensor. @@ -356,11 +336,17 @@ def forward( ctx.grad_placements = grad_placements local_tensor = input._local_tensor - # JUST LIKE DTENSOR: - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this ShardTensor. - return local_tensor.view_as(local_tensor) + # # JUST LIKE DTENSOR: + # # We need to return a fresh Tensor object there as autograd metadata + # # will be inplaced into it. So we don't want to pollute the Tensor + # # object stored in the _local_tensor of this ShardTensor. + # return local_tensor.view_as(local_tensor) + + # Force the local view to inherit the requires_grad state of the ShardTensor + local_tensor = input._local_tensor + res = local_tensor.view_as(local_tensor) + res.requires_grad_(input.requires_grad) + return res @staticmethod def backward( @@ -514,7 +500,7 @@ def backward( return grad_output.to_local(), None, None, None -class ShardTensor(DTensor): +class ShardTensor(torch.Tensor): r"""A distributed tensor class with support for uneven data sharding. Similar to PyTorch's native ``DTensor`` but with more flexibility for @@ -639,6 +625,71 @@ def register_named_function_handler(cls, func_name: str, handler: Callable) -> N """ cls._named_function_registry[func_name] = handler + # @staticmethod + # def __new__( + # cls, + # local_tensor: torch.Tensor, + # spec: ShardTensorSpec, + # *, + # requires_grad: bool, + # ) -> "ShardTensor": + # r"""Construct a new ShardTensor from a local tensor and specification. + + # Note that unlike ``DTensor``, ShardTensor will automatically collect + # the shard size information from all participating devices. This enables + # uneven and dynamic sharding. + + # Parameters + # ---------- + # local_tensor : torch.Tensor + # Local tensor to use as the data. + # spec : ShardTensorSpec + # ShardTensorSpec defining the sharding scheme. + # requires_grad : bool + # Whether the tensor requires gradients. + + # Returns + # ------- + # ShardTensor + # A new ShardTensor instance. + + # Note + # ---- + # This implementation is heavily derived from ``torch.distributed.tensor.DTensor``. + # """ + # if local_tensor.requires_grad and not requires_grad: + # warn( + # "To construct a new ShardTensor from torch.Tensor, " + # "it's recommended to use local_tensor.detach() and " + # "make requires_grad consistent." + # ) + + # if spec.tensor_meta is None: + # raise ValueError("TensorMeta should not be None!") + + # ret = torch.Tensor._make_wrapper_subclass( + # cls, + # spec.tensor_meta.shape, + # strides=spec.tensor_meta.stride, + # dtype=local_tensor.dtype, + # device=local_tensor.device, + # layout=local_tensor.layout, + # requires_grad=False, + # ) + + # ret._spec = spec + # ret._local_tensor = local_tensor + + # # Set requires_grad AFTER _spec/_local_tensor are assigned, using + # # the C-level setter directly (bypassing our Python property + # # override) so the autograd engine sees the correct flag. + # if requires_grad: + # torch.Tensor.requires_grad.__set__(ret, True) + + # cls._enable_shard_patches = True + + # return ret + @staticmethod def __new__( cls, @@ -647,41 +698,6 @@ def __new__( *, requires_grad: bool, ) -> "ShardTensor": - r"""Construct a new ShardTensor from a local tensor and specification. - - Note that unlike ``DTensor``, ShardTensor will automatically collect - the shard size information from all participating devices. This enables - uneven and dynamic sharding. - - Parameters - ---------- - local_tensor : torch.Tensor - Local tensor to use as the data. - spec : ShardTensorSpec - ShardTensorSpec defining the sharding scheme. - requires_grad : bool - Whether the tensor requires gradients. - - Returns - ------- - ShardTensor - A new ShardTensor instance. - - Note - ---- - This implementation is heavily derived from ``torch.distributed.tensor.DTensor``. - """ - if local_tensor.requires_grad and not requires_grad: - warn( - "To construct a new ShardTensor from torch.Tensor, " - "it's recommended to use local_tensor.detach() and " - "make requires_grad consistent." - ) - - if spec.tensor_meta is None: - raise ValueError("TensorMeta should not be None!") - - # Check the sharding information is known: ret = torch.Tensor._make_wrapper_subclass( cls, spec.tensor_meta.shape, @@ -689,14 +705,20 @@ def __new__( dtype=local_tensor.dtype, device=local_tensor.device, layout=local_tensor.layout, - requires_grad=requires_grad, + requires_grad=False, ) ret._spec = spec ret._local_tensor = local_tensor - cls._enable_shard_patches = True + # Set requires_grad AFTER _spec/_local_tensor are assigned, using + # the C-level setter directly (bypassing __torch_function__ which + # would convert to DTensor and set on a temporary). + if requires_grad: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(ret, True) + cls._enable_shard_patches = True return ret def __repr__(self) -> str: @@ -716,6 +738,109 @@ def __format__(self, format_spec: str) -> str: # Format as plain Python string to bypass tensor formatting internals. return format(str(self), format_spec) + @property + def device_mesh(self) -> DeviceMesh: + return self._spec.mesh + + @property + def placements(self) -> tuple[Placement, ...]: + return self._spec.placements + + def __tensor_flatten__(self): + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + spec, requires_grad = flatten_spec + local_tensor = inner_tensors["_local_tensor"] + unflatten_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = ShardTensorSpec( + mesh=spec.mesh, + placements=spec.placements, + tensor_meta=unflatten_meta, + _local_shape=local_tensor.shape, + _sharding_shapes=spec._sharding_shapes, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_tensor.requires_grad_(requires_grad), + spec=unflatten_spec, + requires_grad=requires_grad, + ) + + # -- Autograd property overrides ------------------------------------------- + # The C-level requires_grad is authoritative for autograd engine + # decisions; we read it first and fall back to _local_tensor for the + # case where _make_wrapper_subclass didn't propagate it correctly. + # For grad, the autograd engine accumulates at the C level, so we + # check there first then fall back to _local_tensor.grad. + + @property # type: ignore[override] + def requires_grad(self) -> bool: # type: ignore[override] + with torch._C.DisableTorchFunctionSubclass(): + if torch.Tensor.requires_grad.__get__(self): + return True + return self._local_tensor.requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, value) + self._local_tensor.requires_grad = value + + def requires_grad_(self, requires_grad: bool = True) -> "ShardTensor": + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, requires_grad) + self._local_tensor.requires_grad_(requires_grad) + return self + + @property # type: ignore[override] + def is_leaf(self) -> bool: # type: ignore[override] + with torch._C.DisableTorchFunctionSubclass(): + return torch.Tensor.is_leaf.__get__(self) + + @property # type: ignore[override] + def grad(self) -> "ShardTensor | None": # type: ignore[override] + with torch._C.DisableTorchFunctionSubclass(): + c_grad = torch.Tensor.grad.__get__(self) + if c_grad is not None: + if isinstance(c_grad, ShardTensor): + return c_grad + return ShardTensor.__new__( + ShardTensor, + local_tensor=c_grad._local_tensor if isinstance(c_grad, DTensor) else c_grad, + spec=self._spec, + requires_grad=False, + ) + local_grad = self._local_tensor.grad + if local_grad is None: + return None + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_grad, + spec=self._spec, + requires_grad=False, + ) + + @grad.setter + def grad(self, value: "ShardTensor | torch.Tensor | None") -> None: + if value is None: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, None) + self._local_tensor.grad = None + elif isinstance(value, ShardTensor): + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value._local_tensor + else: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value + @classmethod def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": r"""Convert a DTensor to a ShardTensor. @@ -733,7 +858,12 @@ def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": ShardTensor Equivalent ShardTensor with the same local tensor and inferred spec. """ - return _convert_dt_to_st(dtensor) + if isinstance(dtensor, ShardTensor): + return dtensor + spec = _resolve_spec_for_dtensor(dtensor) + if dtensor.grad_fn is not None: + return _DTensorToShardTensor.apply(dtensor, spec) + return _dtensor_to_shard_tensor(dtensor, spec) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -741,8 +871,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if _conversion_active(): # When converting shard tensor to dtensor, or dtensor to shard tensor, - # we just skip this function entirely. - return super().__torch_function__(func, types, args, kwargs) + # we just run the function without ShardTensor dispatch. + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) if func in cls._function_registry and cls._enable_shard_patches: return cls._function_registry[func](func, types, args, kwargs) if str(func) in cls._named_function_registry and cls._enable_shard_patches: @@ -984,6 +1115,9 @@ def backward(self, *args, **kwargs): if needs_redistribute: self = self.redistribute(placements=new_placements) + if self.grad_fn is not None: + return torch.Tensor.backward(self, *args, **kwargs) + return self.to_local().backward(*args, **kwargs) @@ -1097,12 +1231,15 @@ def scatter_tensor( # scatter along Shard dimensions. BUT, the focus is on performance of full applications # and this is a once-per-iteration cost. - # Broadcast the tensor to all ranks + # Broadcast the tensor to all ranks. + # scatter_tensor is an input-boundary utility; keep internal collectives/layout + # transforms out of autograd and construct the requested leaf explicitly. if tensor is None and not is_src: # Tensor is allowed to be none if not on the root rank tensor = torch.empty(local_meta.shape, dtype=local_meta.dtype, device=dm.device) - dist.broadcast(tensor, src=global_src, group=mesh_group) + with torch.no_grad(): + dist.broadcast(tensor, src=global_src, group=mesh_group) # Create a fully-replicated spec: spec = ShardTensorSpec( @@ -1112,18 +1249,31 @@ def scatter_tensor( _sharding_shapes={}, ) - # Make a "fully-replicated" tensor on all ranks: - st = ShardTensor.__new__( - ShardTensor, - local_tensor=tensor, - spec=spec, - requires_grad=requires_grad, - ) + with torch.no_grad(): + # Build a replicated ShardTensor and redistribute to the requested + # placements without recording autograd history. + st = ShardTensor.__new__( + ShardTensor, + local_tensor=tensor, + spec=spec, + requires_grad=False, + ) + st = st.redistribute(mesh, placements, async_op=False) - # Redistribute the tensor to the desired placements: - st = st.redistribute(mesh, placements, async_op=False) - # This is an unoptimal step but is functional: if requires_grad: - st = st.detach() - st.requires_grad = True + + # 1. Ensure the local data is a clean leaf + local_leaf = st._local_tensor.detach().requires_grad_(True) + + # 2. Create the ShardTensor wrapper + st = ShardTensor.__new__( + ShardTensor, + local_tensor=local_leaf, + spec=st._spec, + requires_grad=True, + ) + + # 3. CRITICAL: Force the wrapper itself to be a leaf in the autograd graph + st = st.detach().requires_grad_(True) + return st diff --git a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py index 4685a9a49f..5593a3d1ad 100644 --- a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py @@ -282,13 +282,13 @@ def backward( grad_weight = None grad_bias = None - if weight is not None and weight.requires_grad: + if weight is not None and ctx.needs_input_grad[3]: # grad_weight_c = sum_{n, spatial} grad_output * y (per-channel) y_c = y.view(N, C, HxW_local) grad_out_c = local_grad_output.view(N, C, HxW_local) grad_weight = (grad_out_c * y_c).sum(dim=(0, 2)) # (C,) - if bias is not None and bias.requires_grad: + if bias is not None and ctx.needs_input_grad[4]: grad_out_c = local_grad_output.view(N, C, HxW_local) grad_bias = grad_out_c.sum(dim=(0, 2)) # (C,) diff --git a/test/domain_parallel/ops/test_convolution.py b/test/domain_parallel/ops/test_convolution.py index a20ea6efe6..7b7af223c7 100644 --- a/test/domain_parallel/ops/test_convolution.py +++ b/test/domain_parallel/ops/test_convolution.py @@ -167,7 +167,7 @@ def test_conv_transpose_1d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -198,7 +198,7 @@ def test_conv2d_1dmesh( dm = DistributedManager() - image = generate_image_like_data(2, C_in, (H, H)).to(dm.device) + image = generate_image_like_data(2, C_in, ( H, H)).to(dm.device) placements = (Shard(2),) @@ -265,7 +265,7 @@ def test_conv_transpose_2d_1dmesh( 2, C_in, ( - H, + 2 * H, H, ), device=dm.device, @@ -293,7 +293,7 @@ def test_conv_transpose_2d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -405,8 +405,8 @@ def test_conv_transpose_2d_2dmesh( 2, C_in, ( - H, - H, + 2 * H, + 2 * H, ), device=dm.device, ) diff --git a/test/domain_parallel/ops/test_unbind.py b/test/domain_parallel/ops/test_unbind.py new file mode 100644 index 0000000000..b658806575 --- /dev/null +++ b/test/domain_parallel/ops/test_unbind.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test unbind operations on ShardTensor. We use a 3D tensor sharded along +dim 2 and test unbinding along non-sharded dimensions. Both forward +correctness and backward gradient flow are verified. +""" + +import pytest +import torch +from torch.distributed.tensor.placement_types import Shard + +from physicsnemo.distributed import DistributedManager +from physicsnemo.domain_parallel import scatter_tensor + +from .utils import numerical_shard_tensor_check + + +class UnbindSelectWrapper(torch.nn.Module): + """ + Wrapper that unbinds a tensor and returns a single element from the + result tuple. This allows reuse of ``numerical_shard_tensor_check`` + which expects a single tensor output. + """ + + def __init__(self, dim: int, index: int): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, tensor: torch.Tensor): + pieces = torch.unbind(tensor, self.dim) + return pieces[self.index] + + +@pytest.mark.multigpu_static +@pytest.mark.parametrize("backward", [False, True]) +@pytest.mark.parametrize( + "unbind_dim,index", [(0, 0), (0, 2), (1, 3), (-3, 0), (-2, 3)] +) +def test_unbind(distributed_mesh, backward, unbind_dim, index): + """Verify forward and backward via ``numerical_shard_tensor_check``.""" + + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 6, 128) + placements = (Shard(2),) + + original_tensor = torch.rand( + shape, device=dm.device, requires_grad=backward + ) + + sharded_tensor = scatter_tensor( + original_tensor, + global_src=0, + mesh=distributed_mesh, + placements=placements, + requires_grad=True, + ) + + module = UnbindSelectWrapper(dim=unbind_dim, index=index) + + numerical_shard_tensor_check( + distributed_mesh, + module, + [sharded_tensor], + {}, + check_grads=backward, + ) + + +# -- Error tests -------------------------------------------------------------- + + +@pytest.mark.multigpu_static +def test_unbind_along_sharded_dim(distributed_mesh): + """Unbinding along the sharded dimension should raise.""" + + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 6, 128) + placements = (Shard(2),) + + original_tensor = torch.rand(shape, device=dm.device) + + sharded_tensor = scatter_tensor( + original_tensor, + global_src=0, + mesh=distributed_mesh, + placements=placements, + requires_grad=False, + ) + + with pytest.raises(RuntimeError, match="unbinding along sharding axis"): + torch.unbind(sharded_tensor, 2) diff --git a/test/domain_parallel/ops/test_view_ops.py b/test/domain_parallel/ops/test_view_ops.py index 7094074f2c..dd1a2a7d8d 100644 --- a/test/domain_parallel/ops/test_view_ops.py +++ b/test/domain_parallel/ops/test_view_ops.py @@ -534,20 +534,23 @@ def test_view_trailing_dims_1d_to_3d( distributed_mesh, backward, ): - """Test view (6,) -> (2, 3, 1) with Shard(0): trailing dim must stay in group. + """Test view (48,) -> (8, 6, 1) with Shard(0): trailing singleton in target. - With the shard on dim 0, each rank has a contiguous chunk of the 1D tensor. - The target shape has a trailing singleton (2, 3, 1). The trailing dimension - must be included in the same dimension group so that the local element - count is correct (product of local shape equals chunk_size). Without that, - the old code produced wrong local shapes (e.g. product 4 instead of 2 or 3). + The 1D tensor is sharded on dim 0. The target shape has a trailing + singleton ``(8, 6, 1)`` that falls outside the dimension group matched + by ``_match_view_dim_groups`` (which pairs ``(48,)`` with ``(8, 6)``). + The trailing ``1`` must be carried through unchanged in the local shape + so that ``product(local_shape) == chunk_size``. + + We use a tensor size (48) that divides cleanly across 2-, 4-, and 8-GPU + meshes so that every rank's chunk aligns to a row boundary in ``(8, 6)``. """ if not torch.cuda.is_available(): pytest.skip("CUDA is not available") dm = DistributedManager() - shape = (6,) - target_shape = (2, 3, 1) + shape = (48,) + target_shape = (8, 6, 1) original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) diff --git a/test/domain_parallel/ops/utils.py b/test/domain_parallel/ops/utils.py index de8052c93c..b925041fe7 100644 --- a/test/domain_parallel/ops/utils.py +++ b/test/domain_parallel/ops/utils.py @@ -169,7 +169,7 @@ def unparallelize_module(module): This function is for testing purposes only. Do not use in production code. """ for name, param in list(module._parameters.items()): - if isinstance(param, torch.nn.Parameter) and isinstance(param.data, DTensor): + if isinstance(param, torch.nn.Parameter) and isinstance(param.data, (ShardTensor, DTensor)): # gather to replicated then unwrap local_tensor = param.data.full_tensor() # replace with a normal Parameter diff --git a/test/domain_parallel/test_grad_sharding.py b/test/domain_parallel/test_grad_sharding.py index 871782459b..167044159d 100644 --- a/test/domain_parallel/test_grad_sharding.py +++ b/test/domain_parallel/test_grad_sharding.py @@ -275,7 +275,7 @@ def run_dtensor_to_shard_tensor_non_leaf_gradient(mesh): loss_ref.backward() assert dt.grad is not None - assert isinstance(dt.grad, DTensor) + assert isinstance(dt.grad, (ShardTensor, DTensor)) assert torch.allclose(dt.grad.full_tensor(), ref.grad) diff --git a/test/domain_parallel/test_initialization.py b/test/domain_parallel/test_initialization.py index 5c5c8fbf02..d687872f95 100644 --- a/test/domain_parallel/test_initialization.py +++ b/test/domain_parallel/test_initialization.py @@ -121,6 +121,97 @@ def init_from_data_rank_worker(mesh): assert dim == local_data.shape[i] +def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool): + r"""Validate scatter_tensor construction contract for requires_grad modes.""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}")) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=requires_grad, + ) + + assert st.requires_grad is requires_grad + if requires_grad: + assert st.is_leaf + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_1d(distributed_mesh, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh, requires_grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_2d(distributed_mesh_2d, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh_2d, requires_grad) + + +def scatter_tensor_grad_population_worker(mesh): + r"""Validate that gradients populate for scatter_tensor(..., requires_grad=True).""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}")) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=True, + ) + assert st.is_leaf + assert st.requires_grad + + reference = st.full_tensor().detach().requires_grad_(True) + reference_loss = (reference**2).sum() + reference_loss.backward() + + st2 = st**2 + sharded_loss = st2.sum() + sharded_loss.backward() + + assert st.grad is not None + assert st.grad._spec.placements == st._spec.placements + assert st.grad._spec.sharding_shapes() == st._spec.sharding_shapes() + assert torch.allclose(st.grad.full_tensor(), reference.grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_1d(distributed_mesh): + scatter_tensor_grad_population_worker(distributed_mesh) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_2d(distributed_mesh_2d): + scatter_tensor_grad_population_worker(distributed_mesh_2d) + + @pytest.mark.timeout(10) @pytest.mark.multigpu_static def test_shard_tensor_initialization_from_data_rank_1d(distributed_mesh, verbose=False): @@ -162,8 +253,6 @@ def shard_tensor_initialization_from_all_dtensor_worker(mesh): st = ShardTensor.from_dtensor(dt) - print(f"Rank {dm.rank} made shard tensors.") - dt_full = dt.full_tensor() st_full = st.full_tensor() diff --git a/test/domain_parallel/test_reductions.py b/test/domain_parallel/test_reductions.py index 8cb8931e45..2145f48af1 100644 --- a/test/domain_parallel/test_reductions.py +++ b/test/domain_parallel/test_reductions.py @@ -118,6 +118,10 @@ def test_shard_tensor_reduction( requires_grad=backward, ) + # if backward: + # assert shard_tensor.is_leaf + # assert shard_tensor.requires_grad + if verbose: print( f"Shard tensor global shape: {shard_tensor.shape} and local shape: {shard_tensor._local_tensor.shape}" From 8879fec8a2e20b2356c8ccb30c34de8b3e63c579 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:16:24 -0500 Subject: [PATCH 3/4] Fix ruff and interrogate issues --- physicsnemo/domain_parallel/__init__.py | 5 + .../domain_parallel/custom_ops/_tensor_ops.py | 25 +-- physicsnemo/domain_parallel/shard_tensor.py | 161 ++++++++---------- .../domain_parallel/shard_utils/__init__.py | 5 + test/domain_parallel/ops/test_convolution.py | 2 +- test/domain_parallel/ops/test_unbind.py | 8 +- test/domain_parallel/ops/utils.py | 4 +- test/domain_parallel/test_initialization.py | 8 +- 8 files changed, 97 insertions(+), 121 deletions(-) diff --git a/physicsnemo/domain_parallel/__init__.py b/physicsnemo/domain_parallel/__init__.py index 1da99e02f1..270ce3f6df 100644 --- a/physicsnemo/domain_parallel/__init__.py +++ b/physicsnemo/domain_parallel/__init__.py @@ -56,6 +56,11 @@ ) def register_custom_ops(): + """Register all custom ShardTensor ops and shard-aware wrappers. + + Imports are deferred to this function to avoid an import cycle between + ``shard_tensor`` and the individual op modules. + """ # These imports will register the custom ops with the ShardTensor class. # It's done here to avoid an import cycle. from .custom_ops import ( diff --git a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py index 3f0580b921..eca70a0a29 100644 --- a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py +++ b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py @@ -29,8 +29,6 @@ import torch from torch.distributed.tensor._dtensor_spec import TensorMeta from torch.distributed.tensor.placement_types import ( - Partial, - Replicate, Shard, ) @@ -82,9 +80,7 @@ def _unbind_output_metadata( shards = [s for s in input_placements if isinstance(s, Shard)] if dim in [i.dim for i in shards]: - raise RuntimeError( - "No implementation for unbinding along sharding axis yet." - ) + raise RuntimeError("No implementation for unbinding along sharding axis yet.") new_placements: list = [] for p in input_placements: @@ -100,8 +96,7 @@ def _unbind_output_metadata( out_sharding_shapes: dict[int, list[torch.Size]] = { mesh_dim: [ - torch.Size(list(cs[:dim]) + list(cs[dim + 1 :])) - for cs in shard_shapes + torch.Size(list(cs[:dim]) + list(cs[dim + 1 :])) for cs in shard_shapes ] for mesh_dim, shard_shapes in input_spec.sharding_shapes().items() } @@ -109,9 +104,7 @@ def _unbind_output_metadata( return dim, new_placements, out_sharding_shapes -def _unbind_dispatch( - tensor: ShardTensor, dim: int = 0 -) -> tuple[ShardTensor, ...]: +def _unbind_dispatch(tensor: ShardTensor, dim: int = 0) -> tuple[ShardTensor, ...]: r"""Dispatch handler for ``aten.unbind.int`` on :class:`ShardTensor`. Called at the ``__torch_dispatch__`` level (below autograd). Operates @@ -137,9 +130,7 @@ def _unbind_dispatch( models that unbind tensors along non-sharded dimensions. """ input_spec = tensor._spec - dim, new_placements, out_sharding_shapes = _unbind_output_metadata( - input_spec, dim - ) + dim, new_placements, out_sharding_shapes = _unbind_output_metadata(input_spec, dim) # We are reducing tensor rank and returning one tensor per slice original_shape = list(input_spec.shape) @@ -153,9 +144,7 @@ def _unbind_dispatch( stride=_stride_from_contiguous_shape_C_style(original_shape), dtype=input_spec.tensor_meta.dtype, ), - _sharding_shapes={ - k: tuple(v) for k, v in out_sharding_shapes.items() - }, + _sharding_shapes={k: tuple(v) for k, v in out_sharding_shapes.items()}, ) local_results = aten.unbind.int(tensor._local_tensor, dim) @@ -203,9 +192,7 @@ def unbind_wrapper( dim: int = args[1] if len(args) > 1 else kwargs.get("dim", 0) input_spec = input_tensor._spec - dim, new_placements, out_sharding_shapes = _unbind_output_metadata( - input_spec, dim - ) + dim, new_placements, out_sharding_shapes = _unbind_output_metadata(input_spec, dim) # to_local() / from_local() preserve the autograd graph local_input = input_tensor.to_local() diff --git a/physicsnemo/domain_parallel/shard_tensor.py b/physicsnemo/domain_parallel/shard_tensor.py index b1bea3d1cc..8ab22f67e1 100644 --- a/physicsnemo/domain_parallel/shard_tensor.py +++ b/physicsnemo/domain_parallel/shard_tensor.py @@ -20,7 +20,6 @@ from collections.abc import Iterable, Mapping from contextlib import contextmanager from typing import Callable, Sequence, cast -from warnings import warn import torch import torch.distributed as dist @@ -142,10 +141,12 @@ def forward(ctx, st: "ShardTensor") -> DTensor: def backward(ctx, grad_output: DTensor): return (_dtensor_to_shard_tensor(grad_output, ctx.shard_tensor_spec),) + # ============================================================================ # Layer 3 -- Smart single-tensor converters (auto-diff when grad_fn present) # ============================================================================ + def _resolve_spec_for_dtensor( dtensor: DTensor, input_args: tuple = () ) -> ShardTensorSpec: @@ -175,10 +176,12 @@ def _resolve_spec_for_dtensor( # Goal is to prevent recursion into the fallback conversion paths. _conversion_guard = threading.local() + def _conversion_active() -> bool: r"""Return whether ShardTensor<->DTensor conversion is currently active.""" return getattr(_conversion_guard, "depth", 0) > 0 + @contextmanager def _conversion_scope(): r"""Re-entrant conversion guard for cast-down/cast-up paths.""" @@ -199,18 +202,21 @@ def _dispatch_fallback_via_dtensor( kwargs: dict[str, object] | None = None, ) -> object: r"""Execute an ATen op through DTensor fallback using PURE data conversion. - - Native Autograd wraps this hook, so we must NOT build an internal graph + + Native Autograd wraps this hook, so we must NOT build an internal graph using .apply(). We just do the math and let PyTorch track the outer graph. """ with _conversion_scope(): - converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=False) for arg in args) + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=False) for arg in args + ) converted_kwargs = { - k: _convert_args_to_dtensor(v, use_autograd=False) for k, v in (kwargs or {}).items() + k: _convert_args_to_dtensor(v, use_autograd=False) + for k, v in (kwargs or {}).items() } - + dispatch_res = func(*converted_args, **(converted_kwargs or {})) - + with _conversion_scope(): return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False) @@ -222,18 +228,21 @@ def _torch_function_fallback_via_dtensor( ) -> object: r"""Execute a __torch_function__ fallback through DTensor safely. - Because this executes at the Python API level (above Autograd), we MUST + Because this executes at the Python API level (above Autograd), we MUST use autograd functions (.apply) to bridge the tracking manually. """ with _conversion_scope(): - converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=True) for arg in args) + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=True) for arg in args + ) converted_kwargs = { - k: _convert_args_to_dtensor(v, use_autograd=True) for k, v in (kwargs or {}).items() + k: _convert_args_to_dtensor(v, use_autograd=True) + for k, v in (kwargs or {}).items() } - + with torch._C.DisableTorchFunctionSubclass(): result = func(*converted_args, **converted_kwargs) - + with _conversion_scope(): return _convert_results_to_shard_tensor(result, args, use_autograd=True) @@ -242,9 +251,10 @@ def _torch_function_fallback_via_dtensor( # Layer 4 -- Recurse utilities (walk args / kwargs / results) # ============================================================================ + def _convert_args_to_dtensor(arg: object, use_autograd: bool = False) -> object: r"""Recursively replace ShardTensors with DTensors. - + If use_autograd is True, uses Layer 2 to preserve the graph connection. """ match arg: @@ -256,7 +266,9 @@ def _convert_args_to_dtensor(arg: object, use_autograd: bool = False) -> object: # DTensor can be iterable; exit early deliberately return arg case Mapping(): - return type(arg)({k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()}) + return type(arg)( + {k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()} + ) case tuple(): return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg) case list(): @@ -269,7 +281,7 @@ def _convert_results_to_shard_tensor( result: object, input_args: tuple, use_autograd: bool = False ) -> object: r"""Recursively replace DTensors with ShardTensors in an op result. - + If use_autograd is True, uses Layer 2 to preserve the graph connection. Handles None returns gracefully for inplace ATen operations. """ @@ -278,12 +290,16 @@ def _convert_results_to_shard_tensor( if isinstance(result, DTensor): spec = _resolve_spec_for_dtensor(result, input_args) - + # If autograd graph connection is requested AND the DTensor actually # requires tracking (it has a grad_fn or requires_grad is active) - if use_autograd and torch.is_grad_enabled() and (result.grad_fn is not None or result.requires_grad): + if ( + use_autograd + and torch.is_grad_enabled() + and (result.grad_fn is not None or result.requires_grad) + ): return _DTensorToShardTensor.apply(result, spec) - + return _dtensor_to_shard_tensor(result, spec) if isinstance(result, Mapping): @@ -293,15 +309,16 @@ def _convert_results_to_shard_tensor( for k, v in result.items() } ) - + if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): return type(result)( _convert_results_to_shard_tensor(d, input_args, use_autograd) for d in result ) - + return result + class _ToTorchTensor(torch.autograd.Function): r"""Autograd function to convert a ShardTensor to a regular PyTorch tensor. @@ -334,14 +351,12 @@ def forward( """ ctx.shard_tensor_spec = input._spec ctx.grad_placements = grad_placements - local_tensor = input._local_tensor - # # JUST LIKE DTENSOR: # # We need to return a fresh Tensor object there as autograd metadata # # will be inplaced into it. So we don't want to pollute the Tensor # # object stored in the _local_tensor of this ShardTensor. # return local_tensor.view_as(local_tensor) - + # Force the local view to inherit the requires_grad state of the ShardTensor local_tensor = input._local_tensor res = local_tensor.view_as(local_tensor) @@ -625,71 +640,6 @@ def register_named_function_handler(cls, func_name: str, handler: Callable) -> N """ cls._named_function_registry[func_name] = handler - # @staticmethod - # def __new__( - # cls, - # local_tensor: torch.Tensor, - # spec: ShardTensorSpec, - # *, - # requires_grad: bool, - # ) -> "ShardTensor": - # r"""Construct a new ShardTensor from a local tensor and specification. - - # Note that unlike ``DTensor``, ShardTensor will automatically collect - # the shard size information from all participating devices. This enables - # uneven and dynamic sharding. - - # Parameters - # ---------- - # local_tensor : torch.Tensor - # Local tensor to use as the data. - # spec : ShardTensorSpec - # ShardTensorSpec defining the sharding scheme. - # requires_grad : bool - # Whether the tensor requires gradients. - - # Returns - # ------- - # ShardTensor - # A new ShardTensor instance. - - # Note - # ---- - # This implementation is heavily derived from ``torch.distributed.tensor.DTensor``. - # """ - # if local_tensor.requires_grad and not requires_grad: - # warn( - # "To construct a new ShardTensor from torch.Tensor, " - # "it's recommended to use local_tensor.detach() and " - # "make requires_grad consistent." - # ) - - # if spec.tensor_meta is None: - # raise ValueError("TensorMeta should not be None!") - - # ret = torch.Tensor._make_wrapper_subclass( - # cls, - # spec.tensor_meta.shape, - # strides=spec.tensor_meta.stride, - # dtype=local_tensor.dtype, - # device=local_tensor.device, - # layout=local_tensor.layout, - # requires_grad=False, - # ) - - # ret._spec = spec - # ret._local_tensor = local_tensor - - # # Set requires_grad AFTER _spec/_local_tensor are assigned, using - # # the C-level setter directly (bypassing our Python property - # # override) so the autograd engine sees the correct flag. - # if requires_grad: - # torch.Tensor.requires_grad.__set__(ret, True) - - # cls._enable_shard_patches = True - - # return ret - @staticmethod def __new__( cls, @@ -740,10 +690,12 @@ def __format__(self, format_spec: str) -> str: @property def device_mesh(self) -> DeviceMesh: + """Return the :class:`DeviceMesh` that this tensor is distributed over.""" return self._spec.mesh @property def placements(self) -> tuple[Placement, ...]: + """Return the placement strategy for each mesh dimension.""" return self._spec.placements def __tensor_flatten__(self): @@ -781,6 +733,11 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): @property # type: ignore[override] def requires_grad(self) -> bool: # type: ignore[override] + """Whether this tensor requires gradient computation. + + Returns ``True`` if either the wrapper tensor or the underlying local + tensor has ``requires_grad`` set. + """ with torch._C.DisableTorchFunctionSubclass(): if torch.Tensor.requires_grad.__get__(self): return True @@ -788,11 +745,24 @@ def requires_grad(self) -> bool: # type: ignore[override] @requires_grad.setter def requires_grad(self, value: bool) -> None: + """Set ``requires_grad`` on both the wrapper and the local tensor.""" with torch._C.DisableTorchFunctionSubclass(): torch.Tensor.requires_grad.__set__(self, value) self._local_tensor.requires_grad = value def requires_grad_(self, requires_grad: bool = True) -> "ShardTensor": + """Set ``requires_grad`` in-place on both the wrapper and local tensor. + + Parameters + ---------- + requires_grad : bool, optional + Whether to enable gradient tracking. Default is ``True``. + + Returns + ------- + ShardTensor + ``self``, for method chaining. + """ with torch._C.DisableTorchFunctionSubclass(): torch.Tensor.requires_grad.__set__(self, requires_grad) self._local_tensor.requires_grad_(requires_grad) @@ -800,11 +770,16 @@ def requires_grad_(self, requires_grad: bool = True) -> "ShardTensor": @property # type: ignore[override] def is_leaf(self) -> bool: # type: ignore[override] + """Whether this tensor is a leaf in the autograd graph.""" with torch._C.DisableTorchFunctionSubclass(): return torch.Tensor.is_leaf.__get__(self) @property # type: ignore[override] def grad(self) -> "ShardTensor | None": # type: ignore[override] + """Return the accumulated gradient, wrapped as a :class:`ShardTensor`. + + If no gradient has been accumulated yet, returns ``None``. + """ with torch._C.DisableTorchFunctionSubclass(): c_grad = torch.Tensor.grad.__get__(self) if c_grad is not None: @@ -812,7 +787,9 @@ def grad(self) -> "ShardTensor | None": # type: ignore[override] return c_grad return ShardTensor.__new__( ShardTensor, - local_tensor=c_grad._local_tensor if isinstance(c_grad, DTensor) else c_grad, + local_tensor=c_grad._local_tensor + if isinstance(c_grad, DTensor) + else c_grad, spec=self._spec, requires_grad=False, ) @@ -828,6 +805,7 @@ def grad(self) -> "ShardTensor | None": # type: ignore[override] @grad.setter def grad(self, value: "ShardTensor | torch.Tensor | None") -> None: + """Set or clear the gradient on both the wrapper and local tensor.""" if value is None: with torch._C.DisableTorchFunctionSubclass(): torch.Tensor.grad.__set__(self, None) @@ -1261,10 +1239,9 @@ def scatter_tensor( st = st.redistribute(mesh, placements, async_op=False) if requires_grad: - # 1. Ensure the local data is a clean leaf local_leaf = st._local_tensor.detach().requires_grad_(True) - + # 2. Create the ShardTensor wrapper st = ShardTensor.__new__( ShardTensor, @@ -1272,7 +1249,7 @@ def scatter_tensor( spec=st._spec, requires_grad=True, ) - + # 3. CRITICAL: Force the wrapper itself to be a leaf in the autograd graph st = st.detach().requires_grad_(True) diff --git a/physicsnemo/domain_parallel/shard_utils/__init__.py b/physicsnemo/domain_parallel/shard_utils/__init__.py index 4ce1bfc714..69b7370cb2 100644 --- a/physicsnemo/domain_parallel/shard_utils/__init__.py +++ b/physicsnemo/domain_parallel/shard_utils/__init__.py @@ -25,6 +25,11 @@ from physicsnemo.domain_parallel.shard_tensor import ShardTensor def register_shard_wrappers(): + """Import and register all shard-aware operation wrappers with ShardTensor. + + Each imported module registers its wrapper via + :meth:`ShardTensor.register_op` at import time. + """ from .attention_patches import sdpa_wrapper from .conv_patches import generic_conv_nd_wrapper from .index_ops import ( diff --git a/test/domain_parallel/ops/test_convolution.py b/test/domain_parallel/ops/test_convolution.py index 7b7af223c7..410e6b9134 100644 --- a/test/domain_parallel/ops/test_convolution.py +++ b/test/domain_parallel/ops/test_convolution.py @@ -198,7 +198,7 @@ def test_conv2d_1dmesh( dm = DistributedManager() - image = generate_image_like_data(2, C_in, ( H, H)).to(dm.device) + image = generate_image_like_data(2, C_in, (H, H)).to(dm.device) placements = (Shard(2),) diff --git a/test/domain_parallel/ops/test_unbind.py b/test/domain_parallel/ops/test_unbind.py index b658806575..f7fb45f4c5 100644 --- a/test/domain_parallel/ops/test_unbind.py +++ b/test/domain_parallel/ops/test_unbind.py @@ -49,9 +49,7 @@ def forward(self, tensor: torch.Tensor): @pytest.mark.multigpu_static @pytest.mark.parametrize("backward", [False, True]) -@pytest.mark.parametrize( - "unbind_dim,index", [(0, 0), (0, 2), (1, 3), (-3, 0), (-2, 3)] -) +@pytest.mark.parametrize("unbind_dim,index", [(0, 0), (0, 2), (1, 3), (-3, 0), (-2, 3)]) def test_unbind(distributed_mesh, backward, unbind_dim, index): """Verify forward and backward via ``numerical_shard_tensor_check``.""" @@ -62,9 +60,7 @@ def test_unbind(distributed_mesh, backward, unbind_dim, index): shape = (4, 6, 128) placements = (Shard(2),) - original_tensor = torch.rand( - shape, device=dm.device, requires_grad=backward - ) + original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) sharded_tensor = scatter_tensor( original_tensor, diff --git a/test/domain_parallel/ops/utils.py b/test/domain_parallel/ops/utils.py index b925041fe7..fcf04f9c92 100644 --- a/test/domain_parallel/ops/utils.py +++ b/test/domain_parallel/ops/utils.py @@ -169,7 +169,9 @@ def unparallelize_module(module): This function is for testing purposes only. Do not use in production code. """ for name, param in list(module._parameters.items()): - if isinstance(param, torch.nn.Parameter) and isinstance(param.data, (ShardTensor, DTensor)): + if isinstance(param, torch.nn.Parameter) and isinstance( + param.data, (ShardTensor, DTensor) + ): # gather to replicated then unwrap local_tensor = param.data.full_tensor() # replace with a normal Parameter diff --git a/test/domain_parallel/test_initialization.py b/test/domain_parallel/test_initialization.py index d687872f95..d6cd7d054b 100644 --- a/test/domain_parallel/test_initialization.py +++ b/test/domain_parallel/test_initialization.py @@ -129,7 +129,9 @@ def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool): source = 0 if rank == source: - raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}")) + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) else: raw_data = None @@ -170,7 +172,9 @@ def scatter_tensor_grad_population_worker(mesh): source = 0 if rank == source: - raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}")) + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) else: raw_data = None From 4b23cb8534dc780d1d687b463db608c5f99d706c Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 17 Apr 2026 09:17:37 -0500 Subject: [PATCH 4/4] Remove unbind tensor ops refactor (moved to separate PR) --- .../domain_parallel/custom_ops/_tensor_ops.py | 242 ++++++------------ test/domain_parallel/ops/test_unbind.py | 109 -------- 2 files changed, 85 insertions(+), 266 deletions(-) delete mode 100644 test/domain_parallel/ops/test_unbind.py diff --git a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py index eca70a0a29..d83a92c757 100644 --- a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py +++ b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py @@ -16,202 +16,130 @@ r"""Custom tensor operations for ShardTensor dispatch. -This module provides dispatch and function handlers for tensor operations -that need special handling when applied to ``ShardTensor`` objects. Handlers -are registered with both ``__torch_dispatch__`` (ATen level) and -``__torch_function__`` (Python level) on :class:`ShardTensor`. +This module provides propagation rules for tensor operations that need +special handling when applied to ``ShardTensor`` objects. These rules +are registered with PyTorch's DTensor operation dispatch system. """ -from __future__ import annotations - -from typing import Any, Callable - import torch -from torch.distributed.tensor._dtensor_spec import TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, +) from torch.distributed.tensor.placement_types import ( + Partial, + Replicate, Shard, ) -from physicsnemo.domain_parallel import ShardTensor +from physicsnemo.core.version_check import check_version_spec from physicsnemo.domain_parallel._shard_tensor_spec import ( - ShardTensorSpec, _stride_from_contiguous_shape_C_style, ) +if check_version_spec("torch", "2.10.0a"): + from torch.distributed.tensor._ops.registration import ( + register_prop_rule, + ) +else: + from torch.distributed.tensor._ops.utils import ( + register_prop_rule, + ) + aten = torch.ops.aten -def _unbind_output_metadata( - input_spec: ShardTensorSpec, dim: int -) -> tuple[int, list, dict[int, list[torch.Size]]]: - r"""Compute the normalized dim, output placements, and sharding shapes for unbind. +@register_prop_rule(aten.unbind.int, schema_info=RuntimeSchemaInfo(1)) +def unbind_rules(op_schema: OpSchema) -> OutputSharding: + r"""Propagation rule for ``torch.unbind`` on ShardTensor. - Validates that the unbind dimension is not sharded and does not use - ``Partial`` placement, then returns the metadata needed to construct - the output ``ShardTensor`` objects. + Computes the output sharding specification when unbinding a sharded tensor + along a specified dimension. The unbind operation removes one dimension + from the tensor and returns a tuple of tensors. Parameters ---------- - input_spec : ShardTensorSpec - Specification of the input sharded tensor. - dim : int - Dimension along which to unbind (may be negative). + op_schema : OpSchema + The operation schema containing input specifications and arguments. + Expected to contain: + + - ``args_schema[0]``: Input tensor specification (DTensorSpec) + - ``args_schema[1]``: Dimension to unbind along (int), defaults to 0 Returns ------- - tuple[int, list, dict[int, list[torch.Size]]] - - Normalized (non-negative) ``dim``. - - Output placements (shard dims above ``dim`` shifted down by 1). - - Output sharding shapes with the unbind dimension removed. + OutputSharding + Output sharding specification containing a list of DTensorSpec objects, + one for each tensor in the unbind result. Raises ------ - RuntimeError + Exception If attempting to unbind along a sharded dimension (not yet implemented). - If attempting to unbind with ``Partial`` placement (not yet supported). - """ - ndim = len(input_spec.shape) - if dim < 0: - dim = dim % ndim - - # if the unbind dimension is along a dimension that is sharded, we have to handle that. - # If it's along an unsharded dimension, there is nearly nothing to do. - input_placements = input_spec.placements - shards = [s for s in input_placements if isinstance(s, Shard)] - - if dim in [i.dim for i in shards]: - raise RuntimeError("No implementation for unbinding along sharding axis yet.") - - new_placements: list = [] - for p in input_placements: - if p.is_replicate(): - new_placements.append(p) - elif p.is_shard(): - if p.dim > dim: - new_placements.append(Shard(p.dim - 1)) - else: - new_placements.append(p) - elif p.is_partial(): - raise RuntimeError("Partial placement not supported yet for unbind") - - out_sharding_shapes: dict[int, list[torch.Size]] = { - mesh_dim: [ - torch.Size(list(cs[:dim]) + list(cs[dim + 1 :])) for cs in shard_shapes - ] - for mesh_dim, shard_shapes in input_spec.sharding_shapes().items() - } - - return dim, new_placements, out_sharding_shapes - - -def _unbind_dispatch(tensor: ShardTensor, dim: int = 0) -> tuple[ShardTensor, ...]: - r"""Dispatch handler for ``aten.unbind.int`` on :class:`ShardTensor`. - - Called at the ``__torch_dispatch__`` level (below autograd). Operates - directly on the local tensor and constructs output ``ShardTensor`` - objects with the correct metadata; the autograd engine above handles - gradient tracking. - - Parameters - ---------- - tensor : ShardTensor - Input sharded tensor. - dim : int, default=0 - Dimension along which to unbind. - - Returns - ------- - tuple[ShardTensor, ...] - Tuple of ShardTensors, one per slice along ``dim``. + If attempting to unbind with Partial placement (not yet supported). Note ---- - This handler is needed for operations like attention in Stormcast and other + This rule is needed for operations like attention in Stormcast and other models that unbind tensors along non-sharded dimensions. """ - input_spec = tensor._spec - dim, new_placements, out_sharding_shapes = _unbind_output_metadata(input_spec, dim) - # We are reducing tensor rank and returning one tensor per slice - original_shape = list(input_spec.shape) - original_shape.pop(dim) + # We need to get the dimension of the slice. 0 is default. - output_spec = ShardTensorSpec( - mesh=input_spec.mesh, - placements=tuple(new_placements), - tensor_meta=TensorMeta( - torch.Size(tuple(original_shape)), - stride=_stride_from_contiguous_shape_C_style(original_shape), - dtype=input_spec.tensor_meta.dtype, - ), - _sharding_shapes={k: tuple(v) for k, v in out_sharding_shapes.items()}, - ) + args_schema = op_schema.args_schema - local_results = aten.unbind.int(tensor._local_tensor, dim) + if len(args_schema) > 1: + dim = args_schema[-1] + else: + dim = 0 - return tuple( - ShardTensor( - local_result, - output_spec, - requires_grad=False, # Adjusted after the dispatcher - ) - for local_result in local_results - ) + # if the chunking dimension is along a dimension that is sharded, we have to handle that. + # If it's along an unsharded dimension, there is nearly nothing to do. + input_spec = args_schema[0] -def unbind_wrapper( - func: Callable, - types: tuple[Any, ...], - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> tuple[ShardTensor, ...]: - r"""Functional-level wrapper for ``torch.unbind`` on ShardTensor. + input_placements = input_spec.placements - This is a ``__torch_function__``-level intercept (above autograd). It - uses ``to_local()`` / ``from_local()`` so that the autograd graph is - preserved through the unbind operation. + shards = [s for s in input_placements if isinstance(s, Shard)] - Parameters - ---------- - func : Callable - The original function being wrapped (``torch.unbind`` or - ``torch.Tensor.unbind``). - types : tuple[Any, ...] - Types of the input arguments (unused). - args : tuple[Any, ...] - Positional arguments. Expected ``(input,)`` or ``(input, dim)``. - kwargs : dict[str, Any] - Keyword arguments (may contain ``dim``). + if dim in [i.dim for i in shards]: + raise Exception("No implementation for unbinding along sharding axis yet.") - Returns - ------- - tuple[ShardTensor, ...] - Tuple of ShardTensors, one per slice along the unbind dimension. - """ - input_tensor: ShardTensor = args[0] - dim: int = args[1] if len(args) > 1 else kwargs.get("dim", 0) - - input_spec = input_tensor._spec - dim, new_placements, out_sharding_shapes = _unbind_output_metadata(input_spec, dim) - - # to_local() / from_local() preserve the autograd graph - local_input = input_tensor.to_local() - local_results = torch.unbind(local_input, dim) - - return tuple( - ShardTensor.from_local( - local_result, - input_spec.mesh, - new_placements, - out_sharding_shapes, - ) - for local_result in local_results - ) + else: + # We are reducing tensor rank and returning one sharding per tensor: + original_shape = list(input_spec.shape) + unbind_dim_shape = original_shape.pop(dim) + output_stride = _stride_from_contiguous_shape_C_style(original_shape) -# Python-level function handlers (__torch_function__). -ShardTensor.register_function_handler(torch.unbind, unbind_wrapper) -ShardTensor.register_function_handler(torch.Tensor.unbind, unbind_wrapper) + # Need to create a new global meta: + new_meta = TensorMeta( + torch.Size(tuple(original_shape)), + stride=output_stride, + dtype=input_spec.tensor_meta.dtype, + ) -# ATen-level dispatch handler (__torch_dispatch__). -ShardTensor.register_dispatch_handler(aten.unbind.int, _unbind_dispatch) + # The placements get adjusted too + new_placements = [] + for p in input_spec.placements: + if isinstance(p, Replicate): + new_placements.append(p) + elif isinstance(p, Shard): + if p.dim > dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + elif isinstance(p, Partial): + raise Exception("Partial placement not supported yet for unbind") + + output_spec_list = [ + DTensorSpec( + mesh=input_spec.mesh, + placements=tuple(new_placements), + tensor_meta=new_meta, + ) + for _ in range(unbind_dim_shape) + ] + return OutputSharding(output_spec_list) diff --git a/test/domain_parallel/ops/test_unbind.py b/test/domain_parallel/ops/test_unbind.py deleted file mode 100644 index f7fb45f4c5..0000000000 --- a/test/domain_parallel/ops/test_unbind.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test unbind operations on ShardTensor. We use a 3D tensor sharded along -dim 2 and test unbinding along non-sharded dimensions. Both forward -correctness and backward gradient flow are verified. -""" - -import pytest -import torch -from torch.distributed.tensor.placement_types import Shard - -from physicsnemo.distributed import DistributedManager -from physicsnemo.domain_parallel import scatter_tensor - -from .utils import numerical_shard_tensor_check - - -class UnbindSelectWrapper(torch.nn.Module): - """ - Wrapper that unbinds a tensor and returns a single element from the - result tuple. This allows reuse of ``numerical_shard_tensor_check`` - which expects a single tensor output. - """ - - def __init__(self, dim: int, index: int): - super().__init__() - self.dim = dim - self.index = index - - def forward(self, tensor: torch.Tensor): - pieces = torch.unbind(tensor, self.dim) - return pieces[self.index] - - -@pytest.mark.multigpu_static -@pytest.mark.parametrize("backward", [False, True]) -@pytest.mark.parametrize("unbind_dim,index", [(0, 0), (0, 2), (1, 3), (-3, 0), (-2, 3)]) -def test_unbind(distributed_mesh, backward, unbind_dim, index): - """Verify forward and backward via ``numerical_shard_tensor_check``.""" - - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - dm = DistributedManager() - shape = (4, 6, 128) - placements = (Shard(2),) - - original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) - - sharded_tensor = scatter_tensor( - original_tensor, - global_src=0, - mesh=distributed_mesh, - placements=placements, - requires_grad=True, - ) - - module = UnbindSelectWrapper(dim=unbind_dim, index=index) - - numerical_shard_tensor_check( - distributed_mesh, - module, - [sharded_tensor], - {}, - check_grads=backward, - ) - - -# -- Error tests -------------------------------------------------------------- - - -@pytest.mark.multigpu_static -def test_unbind_along_sharded_dim(distributed_mesh): - """Unbinding along the sharded dimension should raise.""" - - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - dm = DistributedManager() - shape = (4, 6, 128) - placements = (Shard(2),) - - original_tensor = torch.rand(shape, device=dm.device) - - sharded_tensor = scatter_tensor( - original_tensor, - global_src=0, - mesh=distributed_mesh, - placements=placements, - requires_grad=False, - ) - - with pytest.raises(RuntimeError, match="unbinding along sharding axis"): - torch.unbind(sharded_tensor, 2)