Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 100 additions & 13 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
import threading
import time
import warnings
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing.connection import Connection
from typing import (
Any,
Callable,
cast,
Dict,
Expand Down Expand Up @@ -84,6 +83,11 @@

T = TypeVar("T")

# Default timeout constants
_DEFAULT_TIMEOUT_SECONDS = 60
_DEFAULT_NCCL_TIMEOUT_SECONDS = 60.0
_DEFAULT_XCCL_TIMEOUT_SECONDS = 60.0

TORCH_NCCL_DEBUG_INFO_PIPE_FILE_ENV_VAR = "TORCH_NCCL_DEBUG_INFO_PIPE_FILE"
# Used to trigger flight recorder if we trigger abort on the process group
TORCHFT_TRIGGER_FR_ON_ABORT = "TORCHFT_TRIGGER_FR_ON_ABORT"
Expand Down Expand Up @@ -410,10 +414,12 @@ class ProcessGroupWrapper(ProcessGroup):

def __init__(
self,
timeout: timedelta = timedelta(seconds=60),
timeout: Optional[timedelta] = None,
pg: Optional[ProcessGroup] = None,
) -> None:
super().__init__(0, 1)
if timeout is None:
timeout = timedelta(seconds=_DEFAULT_TIMEOUT_SECONDS)
self._pg: Optional[BaseProcessGroup] = pg
self._timeout = timeout
self._replica_id: str | None = None
Expand Down Expand Up @@ -724,9 +730,8 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
# In newer versions of PyTorch work may not exist if the call was
# not async. In these cases we can just schedule the stream timeout
# and return.
if self._work is not None:
if not self._work.wait():
return False
if self._work is not None and not self._work.wait():
return False

# Always use cuda stream for timeout to avoid ProcessGroupNCCL
# watchdog firing and crashing the process.
Expand Down Expand Up @@ -796,7 +801,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
timeout: the timeout to use for NCCL operations.
"""

def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
def __init__(self, timeout: Optional[timedelta] = None) -> None:
if timeout is None:
timeout = timedelta(seconds=_DEFAULT_NCCL_TIMEOUT_SECONDS)
super().__init__(timeout)
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)

Expand All @@ -807,7 +814,8 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
warnings.warn(
f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
"If any nonblocking NCCL operations have already run this may "
"result in the default timeout of 30 minutes and hangs on error."
"result in the default timeout of 30 minutes and hangs on error.",
stacklevel=2,
)
os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())

Expand Down Expand Up @@ -911,7 +919,9 @@ class ProcessGroupXCCL(ProcessGroupWrapper):
timeout: the timeout to use for XCCL operations.
"""

def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
def __init__(self, timeout: Optional[timedelta] = None) -> None:
if timeout is None:
timeout = timedelta(seconds=_DEFAULT_XCCL_TIMEOUT_SECONDS)
super().__init__(timeout)
# Check if XPU is available and XCCL is supported
self._use_abort: bool = torch.xpu.is_available()
Expand All @@ -923,7 +933,8 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
warnings.warn(
f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
"If any nonblocking XCCL operations have already run this may "
"result in the default timeout of 30 minutes and hangs on error."
"result in the default timeout of 30 minutes and hangs on error.",
stacklevel=2,
)
os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())

Expand Down Expand Up @@ -1344,7 +1355,7 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
if isinstance(opts, AllreduceOptions):
return self._manager.allreduce(tensors[0], reduce_op=opts.reduceOp)

assert False, "unreachable"
raise AssertionError("unreachable")

def size(self) -> int:
return self._manager.num_participants()
Expand Down Expand Up @@ -1665,7 +1676,11 @@ def _worker(
op_id: int = cast(int, op[1])
metadata: _OpMetadata = work[op_id]

def callback(fut: Future[object], metadata: _OpMetadata) -> None:
def callback(
fut: Future[object],
metadata: _OpMetadata,
op_id: int = op_id,
) -> None:
try:
# create an event after the collective has been issued
# to wait on this before we call "future"
Expand All @@ -1682,7 +1697,7 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None))

metadata.work.get_future().add_done_callback(
lambda fut: callback(fut, metadata)
lambda fut, m=metadata, oid=op_id: callback(fut, m, oid)
)
elif cmd == "num_active_work":
req_pipe.send(len(work))
Expand Down Expand Up @@ -2116,3 +2131,75 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou

def getBackendName(self) -> str:
return "torchft-baby-xccl"


class ProcessGroupAccelerator(ProcessGroupWrapper):
"""
Device-agnostic process group that automatically detects the accelerator type
and delegates to the appropriate backend (NCCL for CUDA, XCCL for XPU).

This allows writing device-agnostic code without explicitly choosing between
ProcessGroupNCCL and ProcessGroupXCCL.

Args:
timeout: the timeout to use for operations.
"""

def __init__(self, timeout: Optional[timedelta] = None) -> None:
if timeout is None:
timeout = timedelta(seconds=60.0)
# Detect device type and create appropriate backend
device = torch.device(torch.accelerator.current_accelerator())
backend = dist.get_default_backend_for_device(device)

if backend == "nccl":
pg = ProcessGroupNCCL(timeout=timeout)
elif backend == "xccl":
pg = ProcessGroupXCCL(timeout=timeout)
else:
raise RuntimeError(
f"ProcessGroupAccelerator does not support backend '{backend}' for device '{device}'"
)
super().__init__(timeout=timeout, pg=pg)

def getBackendName(self) -> str:
if self._pg is not None:
return self._pg.getBackendName()
return "torchft-accelerator"


class ProcessGroupBabyAccelerator(ProcessGroupBaby):
"""
Device-agnostic baby process group that automatically detects the accelerator type
and delegates to the appropriate backend (BabyNCCL for CUDA, BabyXCCL for XPU).

This runs the underlying process group in a subprocess and allows writing
device-agnostic code without explicitly choosing between ProcessGroupBabyNCCL
and ProcessGroupBabyXCCL.

Args:
timeout: the timeout to use for operations.
"""

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
"""Create the appropriate backend based on available accelerator."""
device = torch.device(torch.accelerator.current_accelerator())
backend = dist.get_default_backend_for_device(device)

if backend == "nccl":
return ProcessGroupBabyNCCL._create_pg(store, rank, world_size)
elif backend == "xccl":
return ProcessGroupBabyXCCL._create_pg(store, rank, world_size)
else:
raise RuntimeError(
f"ProcessGroupBabyAccelerator does not support backend '{backend}' for device '{device}'"
)

def getBackendName(self) -> str:
try:
device = torch.device(torch.accelerator.current_accelerator())
backend = dist.get_default_backend_for_device(device)
return f"torchft-baby-{backend}"
except Exception:
return "torchft-baby-accelerator"
Loading