From efee3ef7822b0b16e7cf348e947acc9a67ad2373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 15 Dec 2025 11:34:13 +0800 Subject: [PATCH 01/23] Introduce Megatron-style parallel state management Signed-off-by: Jikang Mo Signed-off-by: Junjie Mao --- deepspeed/utils/parallel_state.py | 1037 ++++++++++++ deepspeed/utils/parallel_state_deepspeed.py | 555 ++++++ tests/unit/utils/test_mpu.py | 1692 +++++++++++++++++++ 3 files changed, 3284 insertions(+) create mode 100644 deepspeed/utils/parallel_state.py create mode 100644 deepspeed/utils/parallel_state_deepspeed.py create mode 100644 tests/unit/utils/test_mpu.py diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py new file mode 100644 index 000000000000..df9906d2fcee --- /dev/null +++ b/deepspeed/utils/parallel_state.py @@ -0,0 +1,1037 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Refactored Model and data parallel groups with class-based design.""" + +import logging +from datetime import timedelta +from typing import Callable, List, Optional + +import numpy as np +import torch + +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +logger = logging.getLogger(__name__) + +try: + import einops + HAVE_EINOPS = True +except ImportError: + HAVE_EINOPS = False + + +def is_torch_min_version(version: str, check_equality: bool = True) -> bool: + """Check if PyTorch version meets minimum requirement. + + Args: + version: Version string to check (e.g., "2.4.0") + check_equality: If True, also check for equality + + Returns: + True if version requirement is met + """ + try: + from packaging.version import Version as PkgVersion + torch_version = PkgVersion(torch.__version__) + required_version = PkgVersion(version) + if check_equality: + return torch_version >= required_version + return torch_version > required_version + except Exception: + return False + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name, mem_alloc_context=None): + """Returns a sub-tensor from the buffer for the given shape.""" + from functools import reduce + import operator + + required_len = reduce(operator.mul, tensor_shape, 1) + if (self.buffer.get((name, dtype), None) is None or self.buffer[(name, dtype)].numel() < required_len): + from contextlib import nullcontext + mem_alloc_context = mem_alloc_context if mem_alloc_context else nullcontext + with mem_alloc_context(): + self.buffer[(name, dtype)] = torch.empty( + required_len, + dtype=dtype, + device=get_accelerator().current_device(), + requires_grad=False, + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], + mask: List[bool]) -> List[List[int]]: + r"""Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """Solve: index = sum(idx[i] * stride[i])""" + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + assert (sum([x * y for x, y in zip(idx, stride[:-1])]) == index), f"idx {index} with shape {shape} mismatch" + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride)) + ranks.append(rank) + return ranks + + +class RankGenerator: + """A class for generating rank groups for different modes of parallelism.""" + + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0) -> None: + assert (ep == 1 or cp == 1), "Both EP and CP > 1 is not allowed in one rank generator." + + self.tp = tp + self.ep = ep + self.dp = dp + self.pp = pp + self.cp = cp + self.rank_offset = rank_offset + self.world_size = tp * dp * pp * cp * ep + + self.name_to_size = { + "tp": self.tp, + "pp": self.pp, + "dp": self.dp, + "ep": self.ep, + "cp": self.cp, + } + self.order = order + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError(f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't" + f"specified the order ({self.order}).") + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + """Create a mask for the specified tokens based on the given order.""" + ordered_token = order.split("-") + token_list = token.split("-") + mask = [False] * len(ordered_token) + for t in token_list: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Args: + token (str): Specify the ranks type (e.g., 'tp-dp') + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks + + +class ParallelState: + """Encapsulates all parallel state and operations. + + This class replaces the global variables and functions from the original + parallel_state.py, providing a cleaner, more maintainable interface. + """ + + def __init__(self): + # Process groups + self.tensor_model_parallel_group = None + self.pipeline_model_parallel_group = None + self.model_parallel_group = None + self.embedding_group = None + self.position_embedding_group = None + self.data_parallel_group = None + self.data_parallel_group_gloo = None + self.tensor_and_data_parallel_group = None + self.context_parallel_group = None + self.tensor_and_context_parallel_group = None + self.tensor_and_data_parallel_group_with_cp = None + self.data_parallel_group_with_cp = None + self.data_parallel_group_with_cp_gloo = None + + # Expert-related groups + self.expert_model_parallel_group = None + self.expert_tensor_parallel_group = None + self.expert_tensor_and_model_parallel_group = None + self.expert_tensor_model_pipeline_parallel_group = None + self.expert_data_parallel_group = None + self.expert_data_parallel_group_gloo = None + self.intra_partial_expert_data_parallel_group = None + self.intra_partial_expert_data_parallel_group_gloo = None + self.inter_partial_expert_data_parallel_group = None + + # Global ranks lists + self.embedding_global_ranks = None + self.position_embedding_global_ranks = None + self.pipeline_global_ranks = None + self.data_parallel_global_ranks = None + self.tensor_model_parallel_global_ranks = None + self.model_parallel_global_ranks = None + self.context_parallel_global_ranks = None + self.data_parallel_global_ranks_with_cp = None + self.hierarchical_context_parallel_groups = None + + # Parallel state values + self.virtual_pipeline_model_parallel_rank = None + self.virtual_pipeline_model_parallel_world_size = None + self.mpu_tensor_model_parallel_world_size = None + self.mpu_pipeline_model_parallel_world_size = None + self.mpu_data_parallel_world_size = None + self.mpu_data_parallel_rank = None + self.mpu_tensor_model_parallel_rank = None + self.mpu_pipeline_model_parallel_rank = None + + # Expert parallel state values + self.mpu_expert_model_parallel_world_size = None + self.mpu_expert_model_parallel_rank = None + self.mpu_expert_tensor_parallel_world_size = None + self.mpu_expert_tensor_parallel_rank = None + + # Other + self.global_memory_buffer = None + self.global_process_group_list = None + self.intra_partial_data_parallel_group_with_cp = None + self.intra_partial_data_parallel_group_with_cp_gloo = None + self.intra_distributed_optimizer_instance_group = None + + # Rank generators + self.decoder_rank_generator = None + self.expert_decoder_rank_generator = None + + def _get_nccl_options(self, pg_name: str, nccl_comm_cfgs: dict): + """Set the NCCL process group options.""" + if pg_name in nccl_comm_cfgs: + # FIXME: deepspeed.comm does not provide a way to set NCCL options yet. + nccl_options = torch.distributed.ProcessGroupNCCL.Options( + is_high_priority_stream=nccl_comm_cfgs[pg_name].get("is_high_priority_stream", False)) + if "cga_cluster_size" in nccl_comm_cfgs[pg_name]: + nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name]["cga_cluster_size"] + if "max_ctas" in nccl_comm_cfgs[pg_name]: + nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name]["max_ctas"] + if "min_ctas" in nccl_comm_cfgs[pg_name]: + nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name]["min_ctas"] + if "net_name" in nccl_comm_cfgs[pg_name]: + nccl_options.config.net_name = nccl_comm_cfgs[pg_name]["net_name"] + if nccl_options.config.net_name.lower() not in ["ib", "socket"]: + raise RuntimeError(f"net_name ({nccl_options.config.net_name}) is not supported." + f"Accepted values: 'IB' or 'socket'.") + return nccl_options + return None + + def _create_group( + self, + ranks, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, + ): + """Creates a ProcessGroup.""" + kwargs = { + "ranks": ranks, + "timeout": timeout, + "backend": backend, + "pg_options": pg_options, + "use_local_synchronization": use_local_synchronization, + "group_desc": group_desc, + } + if not is_torch_min_version("2.4.0"): + kwargs.pop("group_desc") + if timeout is None: + kwargs.pop("timeout") + + group = dist.new_group(**kwargs) + if self.global_process_group_list is None: + self.global_process_group_list = [None] + if dist.get_rank() in ranks: + self.global_process_group_list.append(group) + return group + + def _create_hierarchical_groups( + self, + rank, + ranks, + hierarchical_group_sizes, + create_gloo_process_groups=False, + pg_options=None, + timeout=None, + group_desc=None, + ): + """Create hierarchical groups for a set of ranks.""" + if not HAVE_EINOPS: + raise ImportError("einops is not installed. Please install it with `pip install einops`.") + + hierarchical_groups = [] + hierarchical_groups_gloo = [] + if not isinstance(pg_options, list): + pg_options = [pg_options] * len(hierarchical_group_sizes) + + for level in range(len(hierarchical_group_sizes)): + rearranged_ranks = einops.rearrange( + np.array(ranks), + "(l s u) -> (l u) s", + u=int(np.prod(hierarchical_group_sizes[:level])), + s=hierarchical_group_sizes[level], + l=int(np.prod(hierarchical_group_sizes[level + 1:])), + ).tolist() + for sub_ranks in rearranged_ranks: + sub_group = self._create_group( + sub_ranks, + timeout=timeout, + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_L{level}", + ) + if create_gloo_process_groups: + sub_group_gloo = self._create_group( + sub_ranks, + timeout=timeout, + backend="gloo", + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_GLOO_L{level}", + ) + else: + sub_group_gloo = None + if rank in sub_ranks: + hierarchical_groups.append(sub_group) + hierarchical_groups_gloo.append(sub_group_gloo) + + assert rank not in ranks or len(hierarchical_groups) == len(hierarchical_group_sizes) + assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) + return hierarchical_groups, hierarchical_groups_gloo + + def initialize_model_parallel( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: int = 1, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: int = 1, + num_distributed_optimizer_instances: int = 1, + expert_tensor_parallel_size: Optional[int] = None, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: int = 30, + order: str = "tp-cp-ep-dp-pp", + get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + create_gloo_process_groups: bool = True, + high_priority_stream_groups: Optional[List[str]] = None, + ) -> None: + """Initialize model data parallel groups. + + This is the main initialization method that sets up all parallel groups. + """ + + def default_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the word embeddings live.""" + if len(pp_ranks) == 1: + return [pp_ranks[0]] + else: + return [pp_ranks[0], pp_ranks[-1]] + + def default_position_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the position embeddings live.""" + return [pp_ranks[0]] + + if get_embedding_ranks is None: + get_embedding_ranks = default_embedding_ranks + if get_position_embedding_ranks is None: + get_position_embedding_ranks = default_position_embedding_ranks + + # Get world size and rank + assert dist.is_initialized() + world_size: int = dist.get_world_size() + rank = dist.get_rank() + + model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + if world_size % model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") + + data_parallel_size: int = world_size // model_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 1: + raise RuntimeError("pipeline-model-parallel size should be greater than 1 with interleaved schedule") + self.virtual_pipeline_model_parallel_rank = 0 + self.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size + + # Load NCCL configs + nccl_comm_cfgs = {} + if nccl_communicator_config_path is not None: + try: + import yaml + except ImportError: + raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " + "requires the yaml package.") + with open(nccl_communicator_config_path, "r") as stream: + nccl_comm_cfgs = yaml.safe_load(stream) + + # Set high priority stream groups + high_priority_stream_groups = high_priority_stream_groups or [] + for pg_name in high_priority_stream_groups: + if pg_name not in nccl_comm_cfgs: + nccl_comm_cfgs[pg_name] = {} + nccl_comm_cfgs[pg_name]["is_high_priority_stream"] = True + + # Create rank generators + self.decoder_rank_generator = RankGenerator( + tp=tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=pipeline_model_parallel_size, + cp=context_parallel_size, + order=order, + rank_offset=0, + ) + + # Build expert rank generator + if expert_tensor_parallel_size is None: + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = (expert_tensor_parallel_size * expert_model_parallel_size * + pipeline_model_parallel_size) + expert_data_parallel_size = world_size // expert_tensor_model_pipeline_parallel_size + if world_size % expert_tensor_model_pipeline_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + + self.expert_decoder_rank_generator = RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order=order, + rank_offset=0, + ) + + timeout = timedelta(minutes=distributed_timeout_minutes) + + # Build data-parallel groups with context parallel + assert self.data_parallel_group is None, "data parallel group is already initialized" + assert (data_parallel_size * context_parallel_size) % num_distributed_optimizer_instances == 0, ( + "Data parallel size should be divisible by partial DistOpt shard factor") + intra_partial_data_parallel_size = (data_parallel_size * + context_parallel_size) // num_distributed_optimizer_instances + + for ranks_with_cp in self.decoder_rank_generator.get_ranks('dp-cp'): + group_with_cp = self._create_group( + ranks_with_cp, + timeout=timeout, + pg_options=self._get_nccl_options("dp_cp", nccl_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + group_with_cp_gloo = self._create_group( + ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + group_with_cp_gloo = None + if rank in ranks_with_cp: + self.data_parallel_group_with_cp = group_with_cp + self.data_parallel_group_with_cp_gloo = group_with_cp_gloo + self.data_parallel_global_ranks_with_cp = ranks_with_cp + + if num_distributed_optimizer_instances > 1: + for i in range(num_distributed_optimizer_instances): + intra_partial_dp_ranks_with_cp = ranks_with_cp[( + i * intra_partial_data_parallel_size):((i + 1) * intra_partial_data_parallel_size)] + intra_partial_dp_group_with_cp = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + pg_options=self._get_nccl_options("intra_dp_cp", nccl_comm_cfgs), + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + intra_partial_dp_group_with_cp_gloo = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + intra_partial_dp_group_with_cp_gloo = None + if rank in intra_partial_dp_ranks_with_cp: + self.intra_partial_data_parallel_group_with_cp = intra_partial_dp_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = (intra_partial_dp_group_with_cp_gloo) + else: + self.intra_partial_data_parallel_group_with_cp = self.data_parallel_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = self.data_parallel_group_with_cp_gloo + + # Build data-parallel groups + for ranks in self.decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("dp", nccl_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.data_parallel_group = group + self.data_parallel_group_gloo = group_gloo + self.data_parallel_global_ranks = ranks + + # Build context-parallel groups + assert self.context_parallel_group is None, 'context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("cp", nccl_comm_cfgs), + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.context_parallel_group = group + self.context_parallel_global_ranks = ranks + if hierarchical_context_parallel_sizes: + assert np.prod(hierarchical_context_parallel_sizes) == context_parallel_size + hierarchical_groups, _ = self._create_hierarchical_groups( + rank, + ranks, + hierarchical_context_parallel_sizes, + create_gloo_process_groups=False, + pg_options=self._get_nccl_options("hcp", nccl_comm_cfgs), + timeout=timeout, + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.hierarchical_context_parallel_groups = hierarchical_groups + + # Build model-parallel groups + assert self.model_parallel_group is None, 'model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("mp", nccl_comm_cfgs), + group_desc="MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.model_parallel_group = group + self.model_parallel_global_ranks = ranks + + # Build tensor model-parallel groups + assert self.tensor_model_parallel_group is None, 'tensor model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp", nccl_comm_cfgs), + group_desc="TENSOR_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_model_parallel_group = group + self.tensor_model_parallel_global_ranks = ranks + + # Build pipeline model-parallel groups and embedding groups + assert self.pipeline_model_parallel_group is None, "pipeline model parallel group is already initialized" + assert self.embedding_group is None, "embedding group is already initialized" + assert self.position_embedding_group is None, "position embedding group is already initialized" + + for ranks in self.decoder_rank_generator.get_ranks('pp'): + group = self._create_group( + ranks, + timeout=timeout, + backend=pipeline_model_parallel_comm_backend, + pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_nccl_options( + "pp", nccl_comm_cfgs)), + group_desc="PIPELINE_MODEL_PARALLEL_GROUP", + ) + assert ( + pipeline_model_parallel_comm_backend == None or pipeline_model_parallel_comm_backend == "nccl" + or pipeline_model_parallel_comm_backend == "ucc" + ), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported' + + if rank in ranks: + if self.pipeline_model_parallel_group is None: + self.pipeline_model_parallel_group = group + self.pipeline_global_ranks = ranks + elif isinstance(self.pipeline_global_ranks[0], list): + if not isinstance(self.pipeline_model_parallel_group, list): + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group] + self.pipeline_model_parallel_group.append(group) + self.pipeline_global_ranks.append(ranks) + else: + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group, group] + self.pipeline_global_ranks = [self.pipeline_global_ranks, ranks] + + embedding_ranks = get_embedding_ranks(ranks) + group = self._create_group( + embedding_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("embd", nccl_comm_cfgs), + group_desc="EMBEDDING_GROUP", + ) + if rank in embedding_ranks: + self.embedding_group = group + self.embedding_global_ranks = embedding_ranks + + position_embedding_ranks = get_position_embedding_ranks(ranks) + group = self._create_group( + position_embedding_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("pos_embd", nccl_comm_cfgs), + group_desc="POSITION_EMBEDDING_GROUP", + ) + if rank in position_embedding_ranks: + self.position_embedding_group = group + self.position_embedding_global_ranks = position_embedding_ranks + + # Build tensor + data parallel groups + assert self.tensor_and_data_parallel_group is None, 'Tensor + data parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-dp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_dp_cp", nccl_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group_with_cp = group + for ranks in self.decoder_rank_generator.get_ranks('tp-dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_dp", nccl_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group = group + + assert self.tensor_and_context_parallel_group is None, 'Tensor + context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_cp", nccl_comm_cfgs), + group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_context_parallel_group = group + + # Build expert-related parallel groups + assert self.expert_model_parallel_group is None, 'Expert parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('ep'): + group = self._create_group( + ranks, + pg_options=self._get_nccl_options("ep", nccl_comm_cfgs), + group_desc="EXPERT_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_model_parallel_group = group + + assert self.expert_tensor_parallel_group is None, 'Expert tensor model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("ep_tp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_parallel_group = group + + assert self.expert_tensor_and_model_parallel_group is None, 'Expert tensor + model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_ep_mp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_and_model_parallel_group = group + + assert self.expert_tensor_model_pipeline_parallel_group is None, 'The expert_tensor_model_pipeline parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_ep_pp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_model_pipeline_parallel_group = group + + assert self.expert_data_parallel_group is None, "Expert data group is already initialized" + assert self.expert_data_parallel_group_gloo is None, "Expert data group-gloo is already initialized" + assert self.intra_partial_expert_data_parallel_group is None, "Intra partial expert data group is already initialized" + assert self.intra_partial_expert_data_parallel_group_gloo is None, "Intra partial expert data group-gloo is already initialized" + assert self.inter_partial_expert_data_parallel_group is None, "Inter partial expert data group is already initialized" + + assert (expert_data_parallel_size % num_distributed_optimizer_instances == 0 + ), "Expert data parallel size should be divisible by partial DistOpt shard factor" + intra_partial_expert_data_parallel_size = (expert_data_parallel_size // num_distributed_optimizer_instances) + + for ranks in self.expert_decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("ep_dp", nccl_comm_cfgs), + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, backend="gloo", group_desc="EXPERT_DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.expert_data_parallel_group = group + self.expert_data_parallel_group_gloo = group_gloo + + if num_distributed_optimizer_instances > 1: + hierarchical_groups, hierarchical_groups_gloo = self._create_hierarchical_groups( + rank, + ranks, + [intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances], + create_gloo_process_groups=create_gloo_process_groups, + pg_options=[ + self._get_nccl_options("intra_ep_dp", nccl_comm_cfgs), + self._get_nccl_options("inter_ep_dp", nccl_comm_cfgs), + ], + timeout=timeout, + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.intra_partial_expert_data_parallel_group = hierarchical_groups[0] + self.intra_partial_expert_data_parallel_group_gloo = hierarchical_groups_gloo[0] + self.inter_partial_expert_data_parallel_group = hierarchical_groups[1] + else: + self.intra_partial_expert_data_parallel_group = self.expert_data_parallel_group + self.intra_partial_expert_data_parallel_group_gloo = self.expert_data_parallel_group_gloo + + # Build intra distributed optimizer instance group + assert self.intra_distributed_optimizer_instance_group is None, "Intra distributed optimizer instance group is already initialized" + model_parallel_group_id = 0 + intra_dist_opt_ranks = [] + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + model_parallel_group_id += 1 + intra_dist_opt_ranks.extend(ranks) + if model_parallel_group_id % intra_partial_expert_data_parallel_size == 0: + intra_dist_opt_instance_group = self._create_group( + intra_dist_opt_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("intra_dist_opt_instance", nccl_comm_cfgs), + group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP", + ) + if rank in intra_dist_opt_ranks: + self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group + intra_dist_opt_ranks = [] + + # Initialize global memory buffer + self._set_global_memory_buffer() + + def _set_global_memory_buffer(self): + """Initialize global buffer.""" + assert self.global_memory_buffer is None, "global memory buffer is already initialized" + self.global_memory_buffer = GlobalMemoryBuffer() + + # Getter methods for process groups + def get_model_parallel_group(self, check_initialized=True): + """Get the model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.model_parallel_group is not None, "model parallel group is not initialized" + return self.model_parallel_group + + def get_tensor_model_parallel_group(self, check_initialized=True): + """Get the tensor-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_model_parallel_group is not None, "tensor model parallel group is not initialized" + return self.tensor_model_parallel_group + + def get_pipeline_model_parallel_group(self, check_initialized=True): + """Get the pipeline-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.pipeline_model_parallel_group is not None, "pipeline_model parallel group is not initialized" + return self.pipeline_model_parallel_group + + def get_data_parallel_group(self, with_context_parallel=False, partial_data_parallel=False): + """Get the data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if partial_data_parallel: + assert self.intra_partial_data_parallel_group_with_cp is not None, "Intra partial data parallel group is not initialized" + return self.intra_partial_data_parallel_group_with_cp + assert self.data_parallel_group_with_cp is not None, "data parallel group with context parallel combined is not initialized" + return self.data_parallel_group_with_cp + else: + assert self.data_parallel_group is not None, "data parallel group is not initialized" + assert partial_data_parallel == False, "Partial DP for Optimizer needs to include CP" + return self.data_parallel_group + + def get_context_parallel_group(self, check_initialized=True): + """Get the context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.context_parallel_group is not None, "context parallel group is not initialized" + return self.context_parallel_group + + def get_embedding_group(self, check_initialized=True): + """Get the embedding group the caller rank belongs to.""" + if check_initialized: + assert self.embedding_group is not None, "embedding group is not initialized" + return self.embedding_group + + def get_tensor_and_data_parallel_group(self, check_initialized=True, with_context_parallel=False): + """Get the tensor- and data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if check_initialized: + assert self.tensor_and_data_parallel_group_with_cp is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group_with_cp + else: + if check_initialized: + assert self.tensor_and_data_parallel_group is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group + + def get_tensor_and_context_parallel_group(self, check_initialized=True): + """Get the tensor- and context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_and_context_parallel_group is not None, "tensor and context parallel group is not initialized" + return self.tensor_and_context_parallel_group + + # Getter methods for world sizes and ranks + def get_tensor_model_parallel_world_size(self): + """Return world size for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_world_size is not None: + return self.mpu_tensor_model_parallel_world_size + return self.get_tensor_model_parallel_group().size() + + def get_pipeline_model_parallel_world_size(self): + """Return world size for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_world_size is not None: + return self.mpu_pipeline_model_parallel_world_size + return self.get_pipeline_model_parallel_group().size() + + def get_tensor_model_parallel_rank(self): + """Return caller's rank for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_rank is not None: + return self.mpu_tensor_model_parallel_rank + return self.get_tensor_model_parallel_group().rank() + + def get_pipeline_model_parallel_rank(self): + """Return caller's rank for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_rank is not None: + return self.mpu_pipeline_model_parallel_rank + return dist.get_rank(group=self.get_pipeline_model_parallel_group()) + + def get_data_parallel_world_size(self, with_context_parallel=False, partial_data_parallel=False): + """Return world size for the data parallel group.""" + if self.mpu_data_parallel_world_size is not None: + return self.mpu_data_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).size() + else: + return 0 + + def get_data_parallel_rank(self, with_context_parallel=False, partial_data_parallel=False): + """Return caller's rank in the data-parallel group.""" + if self.mpu_data_parallel_rank is not None: + return self.mpu_data_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).rank() + else: + return 0 + + def get_context_parallel_world_size(self): + """Return world size for the context parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().size() + else: + return 0 + + def get_context_parallel_rank(self): + """Return caller's rank in the context-parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().rank() + else: + return 0 + + def is_initialized(self): + """Check if parallel state has been initialized""" + return self.data_parallel_group is not None + + def get_global_memory_buffer(self): + """Return the global GlobalMemoryBuffer object""" + assert self.global_memory_buffer is not None, "global memory buffer is not initialized" + return self.global_memory_buffer + + # Expert-related getter methods + def get_expert_model_parallel_group(self, check_initialized=True): + """Get the expert-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_model_parallel_group is not None, "expert model parallel group is not initialized" + return self.expert_model_parallel_group + + def get_expert_model_parallel_world_size(self): + """Return world size for the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_world_size is not None: + return self.mpu_expert_model_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().size() + else: + return 0 + + def get_expert_model_parallel_rank(self): + """Return caller's rank in the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_rank is not None: + return self.mpu_expert_model_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().rank() + else: + return 0 + + def get_expert_tensor_parallel_group(self, check_initialized=True): + """Get the expert-tensor-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_tensor_parallel_group is not None, "Expert tensor parallel group is not initialized" + return self.expert_tensor_parallel_group + + def get_expert_tensor_parallel_world_size(self): + """Return world size for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_world_size is not None: + return self.mpu_expert_tensor_parallel_world_size + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_world_size + else: + return self.get_expert_tensor_parallel_group().size() + + def get_expert_tensor_parallel_rank(self): + """Return my rank for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_rank is not None: + return self.mpu_expert_tensor_parallel_rank + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_rank + else: + return self.get_expert_tensor_parallel_group().rank() + + def get_expert_data_parallel_group(self, check_initialized=True, partial_expert_data_parallel=False): + """Get expert data parallel group.""" + if partial_expert_data_parallel: + if check_initialized: + assert self.intra_partial_expert_data_parallel_group is not None, "Intra partial expert data parallel group is not initialized" + return self.intra_partial_expert_data_parallel_group + else: + if check_initialized: + assert self.expert_data_parallel_group is not None, "Expert data parallel group is not initialized" + return self.expert_data_parallel_group + + def get_expert_data_parallel_rank(self, partial_expert_data_parallel=False): + """Return caller's rank in the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).rank() + else: + return 0 + + def get_expert_data_parallel_world_size(self, partial_expert_data_parallel=False): + """Return world size for the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).size() + else: + return 0 + + +# Convenience function to create a singleton instance +_parallel_state_instance = None + + +def get_parallel_state() -> ParallelState: + """Get or create the global ParallelState instance.""" + global _parallel_state_instance + if _parallel_state_instance is None: + _parallel_state_instance = ParallelState() + return _parallel_state_instance diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py new file mode 100644 index 000000000000..bf3a346de194 --- /dev/null +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepSpeed Compatibility Layer for parallel_state. + +This module provides module-level functions compatible with DeepSpeed's +groups.py API, allowing code written for DeepSpeed to work with the +refactored parallel_state module. + +Key Features: +- Supports multiple parallel state instances (for RL scenarios with different models) +- Backward compatible with single global instance +- Context manager for switching between different parallel configurations + +Usage: + # Basic usage (single global instance): + from parallel_state_deepspeed import get_data_parallel_group + dp_group = get_data_parallel_group() + + # Multi-instance usage (for RL scenarios): + from parallel_state_deepspeed import ( + get_parallel_state_instance, + set_current_parallel_state, + get_data_parallel_group, + ) + + # Create different instances for different models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + + # Initialize with different DP sizes + actor_state.initialize_model_parallel(tensor_model_parallel_size=2, data_parallel_size=4) + critic_state.initialize_model_parallel(tensor_model_parallel_size=1, data_parallel_size=8) + + # Use context manager to switch + with set_current_parallel_state("actor"): + actor_dp_group = get_data_parallel_group() # Uses actor's DP group + + with set_current_parallel_state("critic"): + critic_dp_group = get_data_parallel_group() # Uses critic's DP group +""" + +from contextlib import contextmanager +from typing import Optional +from parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state + +# Registry for multiple parallel state instances +_parallel_state_registry = {} +_default_instance_name = "__default__" + +# Current active instance name (thread-local would be better, but using global for simplicity) +_current_instance_name = _default_instance_name + + +def get_parallel_state_instance(name: Optional[str] = None) -> ParallelState: + """Get or create a named ParallelState instance. + + Args: + name: Name of the instance. If None, returns the default global instance. + Use different names for different models in RL scenarios. + + Returns: + ParallelState instance + + Example: + # For RL with actor and critic models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + """ + if name is None: + return _get_default_parallel_state() + + if name not in _parallel_state_registry: + _parallel_state_registry[name] = ParallelState() + + return _parallel_state_registry[name] + + +def set_current_parallel_state(name: Optional[str] = None): + """Set the current active parallel state instance. + + Args: + name: Name of the instance to activate. If None, uses the default instance. + + Returns: + Context manager for temporarily switching the active instance + + Example: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + + @contextmanager + def _context(): + global _current_instance_name + old_name = _current_instance_name + _current_instance_name = name if name is not None else _default_instance_name + try: + yield + finally: + _current_instance_name = old_name + + return _context() + + +def get_current_parallel_state() -> ParallelState: + """Get the currently active parallel state instance. + + Returns: + The currently active ParallelState instance + """ + return get_parallel_state_instance(_current_instance_name) + + +def get_parallel_state(name: Optional[str] = None) -> ParallelState: + """Get parallel state instance (backward compatible). + + If name is provided, returns the named instance. + Otherwise, returns the currently active instance. + + Args: + name: Optional name of the instance. If None, returns current active instance. + + Returns: + ParallelState instance + """ + if name is not None: + return get_parallel_state_instance(name) + return get_current_parallel_state() + + +# ============================================================================ +# Core Tensor/Model/Data Parallel Functions +# ============================================================================ + + +def get_tensor_model_parallel_group(name: Optional[str] = None): + """Get the tensor model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's parallel groups to use. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_group() + + +def get_model_parallel_group(name: Optional[str] = None): + """Get the model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_model_parallel_group() + + +def get_data_parallel_group(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Get the data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's DP group to use. + For example, "actor" vs "critic" may have different DP sizes. + with_context_parallel: Whether to include context parallel in the group. + partial_data_parallel: Whether to use partial data parallel group. + + DeepSpeed-compatible interface. + + Example: + # In RL scenario with different DP sizes: + actor_dp = get_data_parallel_group("actor") # Actor's DP group + critic_dp = get_data_parallel_group("critic") # Critic's DP group + + # Or use context manager: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + return get_parallel_state(name).get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_tensor_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the tensor-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_data_parallel_world_size(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return world size for the data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_world_size(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_data_parallel_rank(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return caller's rank in the data-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_rank(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_src_rank(name: Optional[str] = None): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + import torch.distributed as dist + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size(name) + return (global_rank // local_world_size) * local_world_size + + +def set_tensor_model_parallel_world_size(world_size, name: Optional[str] = None): + """Set the tensor model parallel size. + + Args: + world_size: World size to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_world_size = world_size + + +def set_tensor_model_parallel_rank(rank, name: Optional[str] = None): + """Set tensor model parallel rank. + + Args: + rank: Rank to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_rank = rank + + +# ============================================================================ +# Pipeline Parallel Functions +# ============================================================================ + + +def get_pipeline_model_parallel_group(name: Optional[str] = None): + """Get the pipeline-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_group() + + +def get_pipeline_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_world_size() + + +def get_pipeline_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_rank() + + +# ============================================================================ +# Context Parallel Functions +# ============================================================================ + + +def get_context_parallel_group(name: Optional[str] = None): + """Get the context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_group() + + +def get_context_parallel_world_size(name: Optional[str] = None): + """Return world size for the context parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_world_size() + + +def get_context_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the context-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_rank() + + +# ============================================================================ +# Expert Parallel Functions +# ============================================================================ + + +def get_expert_model_parallel_group(name: Optional[str] = None): + """Get the expert-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_group() + + +def get_expert_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_world_size() + + +def get_expert_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_rank() + + +def get_expert_tensor_parallel_group(name: Optional[str] = None): + """Get the expert-tensor-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_group() + + +def get_expert_tensor_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_world_size() + + +def get_expert_tensor_parallel_rank(name: Optional[str] = None): + """Return my rank for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_rank() + + +def get_expert_data_parallel_group(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Get expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_world_size(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return world size for the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_world_size( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_rank(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return caller's rank in the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_rank( + partial_expert_data_parallel=partial_expert_data_parallel) + + +# ============================================================================ +# Additional Helper Functions +# ============================================================================ + + +def get_embedding_group(name: Optional[str] = None): + """Get the embedding group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_embedding_group() + + +def get_tensor_and_data_parallel_group(name: Optional[str] = None, with_context_parallel: bool = False): + """Get the tensor- and data-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_data_parallel_group(with_context_parallel=with_context_parallel) + + +def get_tensor_and_context_parallel_group(name: Optional[str] = None): + """Get the tensor- and context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_context_parallel_group() + + +def is_initialized(name: Optional[str] = None): + """Check if parallel state has been initialized. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).is_initialized() diff --git a/tests/unit/utils/test_mpu.py b/tests/unit/utils/test_mpu.py new file mode 100644 index 000000000000..11ed585c92b3 --- /dev/null +++ b/tests/unit/utils/test_mpu.py @@ -0,0 +1,1692 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team +""" +Automated testing of parallel strategy combinations using random configurations. + +This test automatically generates random parallel configurations and tests +both parallel_state_refactored and DeepSpeed to see if they produce compatible results. +""" + +import pytest +import random +from typing import Dict, List, Tuple, Optional +from collections import defaultdict + +# Try to import both libraries +try: + from deepspeed.utils.parallel_state import RankGenerator + PARALLEL_STATE_AVAILABLE = True +except ImportError as e: + PARALLEL_STATE_AVAILABLE = False + print(f"Warning: Could not import Megatron parallel_state_refactored: {e}") + +try: + from deepspeed.utils import groups as ds_groups + from deepspeed.runtime.sequence_parallel import parallel_state_sp as ds_sp + DEEPSPEED_AVAILABLE = True +except ImportError as e: + DEEPSPEED_AVAILABLE = False + print(f"Warning: Could not import DeepSpeed: {e}") + + +class ParallelConfigGenerator: + """Generate random parallel configurations for testing.""" + + def __init__(self, seed=None): + if seed is not None: + random.seed(seed) + self.tested_configs = [] + self.failed_configs = [] + + def generate_random_config(self, max_size=1024, min_parallel_size=1, max_parallel_size=32): + """Generate a random parallel configuration. + + Args: + max_size: Maximum world size to consider + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + Dict with tp, dp, pp, cp, ep values and order + """ + # Generate random sizes for each dimension + # Don't filter invalid configurations - we want to test and report all cases + tp = random.randint(min_parallel_size, max_parallel_size) + dp = random.randint(min_parallel_size, max_parallel_size) + pp = random.randint(min_parallel_size, max_parallel_size) + cp = random.randint(min_parallel_size, max_parallel_size) + ep = random.randint(min_parallel_size, max_parallel_size) + + # Calculate world size + world_size = tp * dp * pp * cp * ep + + # If world size is too large, scale down proportionally + # But try to keep at least one dimension > 1 + if world_size > max_size: + # Scale down proportionally + scale_factor = (max_size / world_size)**0.25 + tp = max(1, int(tp * scale_factor)) + dp = max(1, int(dp * scale_factor)) + pp = max(1, int(pp * scale_factor)) + cp = max(1, int(cp * scale_factor)) + ep = max(1, int(ep * scale_factor)) + world_size = tp * dp * pp * cp * ep + + # Ensure at least one dimension is > 1 + if world_size == 1: + tp = 2 + world_size = 2 + + # Generate random order (but must include all non-1 dimensions) + dimensions = [] + if tp > 1: + dimensions.append('tp') + if dp > 1: + dimensions.append('dp') + if pp > 1: + dimensions.append('pp') + if cp > 1: + dimensions.append('cp') + if ep > 1: + dimensions.append('ep') + + # Shuffle to get random order + random.shuffle(dimensions) + order = '-'.join(dimensions) if dimensions else 'tp' + + # If no dimensions > 1, use default + if not dimensions: + order = 'tp-dp' + tp = 2 + dp = 2 + + config = { + "tp": tp, + "dp": dp, + "pp": pp, + "cp": cp, + "ep": ep, + "order": order, + "world_size": tp * dp * pp * cp * ep, + } + + return config + + def generate_systematic_configs(self, max_world_size=512): + """Generate systematic configurations covering common cases. + + Args: + max_world_size: Maximum world size to consider + + Returns: + List of configurations + """ + configs = [] + + # Single parallelism - test larger sizes + for size in [2, 4, 8, 16, 32, 64, 128, 256]: + if size <= max_world_size: + configs.append({"tp": size, "dp": 1, "pp": 1, "cp": 1, "ep": 1, "order": "tp", "world_size": size}) + configs.append({"tp": 1, "dp": size, "pp": 1, "cp": 1, "ep": 1, "order": "dp", "world_size": size}) + configs.append({"tp": 1, "dp": 1, "pp": size, "cp": 1, "ep": 1, "order": "pp", "world_size": size}) + + # Two-way combinations - more variations + for tp, dp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4), (2, 16), (16, 2), (4, 8), (8, 4)]: + if tp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": tp * dp + }) + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": tp * dp + }) + + for tp, pp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4)]: + if tp * pp <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-pp", + "world_size": tp * pp + }) + + for tp, cp in [(2, 2), (2, 4), (4, 2), (2, 8)]: + if tp * cp <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": 1, + "cp": cp, + "ep": 1, + "order": "tp-cp", + "world_size": tp * cp + }) + + for tp, ep in [(2, 2), (2, 4), (4, 2), (2, 8)]: + if tp * ep <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": ep, + "order": "tp-ep", + "world_size": tp * ep + }) + + # Three-way combinations - more variations + for tp, pp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2), (4, 2, 2), (2, 2, 8), (2, 4, 4), (4, 4, 2)]: + if tp * pp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": tp * pp * dp + }) + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": tp * pp * dp + }) + + for tp, cp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: + if tp * cp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": cp, + "ep": 1, + "order": "tp-cp-dp", + "world_size": tp * cp * dp + }) + + for tp, ep, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: + if tp * ep * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": ep, + "order": "tp-ep-dp", + "world_size": tp * ep * dp + }) + + # Four-way combinations - more variations + for tp, pp, dp, cp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2), (2, 4, 2, 2)]: + if tp * pp * dp * cp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": cp, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": tp * pp * dp * cp + }) + + for tp, ep, pp, dp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2)]: + if tp * ep * pp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": ep, + "order": "tp-ep-pp-dp", + "world_size": tp * ep * pp * dp + }) + + return configs + + def generate_random_configs(self, count=1000, max_size=1024): + """Generate multiple random configurations. + + Args: + count: Number of random configurations to generate + max_size: Maximum world size + + Returns: + List of configurations + """ + configs = [] + seen = set() + + for _ in range(count): + config = self.generate_random_config(max_size=max_size) + # Create a unique key for this configuration + key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) + if key not in seen: + seen.add(key) + configs.append(config) + + return configs + + def generate_random_config_by_dimension(self, + dimension_count: int, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32): + """Generate a random configuration with exactly the specified number of dimensions > 1. + + Args: + dimension_count: Number of dimensions that should be > 1 (1-5) + max_size: Maximum world size + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + Dict with tp, dp, pp, cp, ep values and order + """ + # All possible dimensions + all_dims = ['tp', 'dp', 'pp', 'cp', 'ep'] + + # Randomly select which dimensions to activate + active_dims = random.sample(all_dims, min(dimension_count, len(all_dims))) + + # Initialize all dimensions to 1 + config = { + "tp": 1, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + } + + # Set active dimensions to random values + for dim in active_dims: + config[dim] = random.randint(min_parallel_size, max_parallel_size) + + # Calculate world size + world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] + + # If world size is too large, scale down proportionally + if world_size > max_size: + scale_factor = (max_size / world_size)**(1.0 / dimension_count) + for dim in active_dims: + config[dim] = max(min_parallel_size, int(config[dim] * scale_factor)) + world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] + + # Generate random order from active dimensions + random.shuffle(active_dims) + order = '-'.join(active_dims) + + config["order"] = order + config["world_size"] = world_size + + return config + + def generate_random_configs_by_dimension(self, + counts_by_dimension: Dict[int, int], + max_size=1024, + min_parallel_size=2, + max_parallel_size=32): + """Generate random configurations for each dimension separately. + + Args: + counts_by_dimension: Dict mapping dimension count (1-5) to number of configs to generate + e.g., {1: 100, 2: 200, 3: 150, 4: 100, 5: 50} + max_size: Maximum world size + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + List of configurations grouped by dimension count + """ + all_configs = [] + seen = set() + + for dim_count, count in counts_by_dimension.items(): + if dim_count < 1 or dim_count > 5: + continue + + dim_configs = [] + attempts = 0 + # Increased max_attempts for larger test sets (20x more configs) + max_attempts = count * 20 # Prevent infinite loops, allow more attempts for uniqueness + + while len(dim_configs) < count and attempts < max_attempts: + attempts += 1 + config = self.generate_random_config_by_dimension(dim_count, max_size, min_parallel_size, + max_parallel_size) + + # Create a unique key for this configuration + key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) + + if key not in seen: + seen.add(key) + dim_configs.append(config) + all_configs.append(config) + + if len(dim_configs) < count: + print( + f"Warning: Only generated {len(dim_configs)}/{count} configs for {dim_count}D combinations (attempted {attempts} times)" + ) + + return all_configs + + +class ErrorCategorizer: + """Categorize and aggregate errors by type.""" + + def __init__(self): + self.error_categories = defaultdict(list) + self.combination_stats = defaultdict(int) + + def categorize_error(self, error_msg: str, config: Dict) -> str: + """Categorize an error message into a category.""" + error_lower = error_msg.lower() + + if "ep and cp cannot both be > 1" in error_lower: + return "EP_CP_CONFLICT" + elif "cp not supported" in error_lower: + return "CP_NOT_SUPPORTED" + elif "pp requires" in error_lower or "pipeline" in error_lower: + return "PP_REQUIRES_MPU" + elif "not divisible" in error_lower: + return "DIVISIBILITY_ERROR" + elif "order" in error_lower and "specified" in error_lower: + return "ORDER_MISMATCH" + elif "not available" in error_lower: + return "FEATURE_NOT_AVAILABLE" + else: + return "OTHER_ERROR" + + def get_combination_type(self, config: Dict) -> str: + """Get the combination type string for a configuration.""" + dims = [] + if config["tp"] > 1: + dims.append("TP") + if config["dp"] > 1: + dims.append("DP") + if config["pp"] > 1: + dims.append("PP") + if config["cp"] > 1: + dims.append("CP") + if config["ep"] > 1: + dims.append("EP") + + if not dims: + return "NONE" + + return "+".join(sorted(dims)) + + def record_error(self, error_msg: str, config: Dict, library: str): + """Record an error with categorization.""" + category = self.categorize_error(error_msg, config) + combo_type = self.get_combination_type(config) + + self.error_categories[category].append({ + "error": error_msg, + "config": config, + "library": library, + "combination": combo_type, + }) + + self.combination_stats[combo_type] += 1 + + def get_error_summary(self) -> Dict: + """Get summary of errors by category.""" + summary = {} + for category, errors in self.error_categories.items(): + summary[category] = { + "count": len(errors), + "examples": errors[:5], # First 5 examples + "unique_combinations": len(set(e["combination"] for e in errors)), + } + return summary + + +class ParallelCompatibilityTester: + """Test compatibility between Megatron and DeepSpeed for parallel configurations.""" + + def __init__(self): + self.results = { + "megatron_success": [], + "megatron_failures": [], + "deepspeed_success": [], + "deepspeed_failures": [], + "compatible": [], + "incompatible": [], + "megatron_only": [], + "deepspeed_only": [], + } + self.error_categorizer = ErrorCategorizer() + self.combination_stats = defaultdict( + lambda: { + "total": 0, + "megatron_success": 0, + "megatron_failures": 0, + "deepspeed_success": 0, + "deepspeed_failures": 0, + "compatible": 0, + "megatron_only": 0, + "deepspeed_only": 0, + "incompatible": 0, + }) + + def test_megatron_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: + """Test if a configuration works with Megatron. + + Returns: + (success, error_message, result_data) + """ + if not PARALLEL_STATE_AVAILABLE: + return False, "Megatron not available", None + + try: + # Check EP and CP constraint + if config["ep"] > 1 and config["cp"] > 1: + return False, "EP and CP cannot both be > 1 in Megatron", None + + # Create RankGenerator + rg = RankGenerator(tp=config["tp"], + ep=config["ep"], + dp=config["dp"], + pp=config["pp"], + cp=config["cp"], + order=config["order"]) + + # Test getting ranks for each dimension + result_data = { + "world_size": rg.world_size, + "tp_groups": rg.get_ranks("tp") if config["tp"] > 1 else [], + "dp_groups": rg.get_ranks("dp") if config["dp"] > 1 else [], + "pp_groups": rg.get_ranks("pp") if config["pp"] > 1 else [], + "cp_groups": rg.get_ranks("cp") if config["cp"] > 1 else [], + "ep_groups": rg.get_ranks("ep") if config["ep"] > 1 else [], + } + + # Test combined groups + if len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) > 1: + combined_token = config["order"] + result_data["combined_groups"] = rg.get_ranks(combined_token) + + return True, None, result_data + + except Exception as e: + return False, str(e), None + + def test_deepspeed_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: + """Test if a configuration is supported by DeepSpeed. + + Returns: + (supported, error_message, support_info) + """ + if not DEEPSPEED_AVAILABLE: + return False, "DeepSpeed not available", None + + support_info = { + "tp_supported": False, + "dp_supported": False, + "pp_supported": False, + "cp_supported": False, + "ep_supported": False, + "sp_supported": False, + "notes": [], + } + + # Check TP support + if config["tp"] > 1: + support_info["tp_supported"] = hasattr(ds_groups, 'get_tensor_model_parallel_group') + + # Check DP support + if config["dp"] > 1: + support_info["dp_supported"] = hasattr(ds_groups, 'get_data_parallel_group') + + # Check PP support + if config["pp"] > 1: + # DeepSpeed supports PP via mpu or pipe module + support_info["pp_supported"] = (hasattr(ds_groups, 'bwc_pipeline_parallel_world_size') + or self._check_module_exists('deepspeed.pipe')) + if not support_info["pp_supported"]: + support_info["notes"].append("PP requires mpu object or deepspeed.pipe module") + + # Check CP support + if config["cp"] > 1: + support_info["cp_supported"] = hasattr(ds_groups, 'get_context_parallel_group') + if not support_info["cp_supported"]: + support_info["notes"].append("CP not supported in DeepSpeed") + + # Check EP support + if config["ep"] > 1: + support_info["ep_supported"] = (hasattr(ds_groups, '_create_expert_and_data_parallel') + or hasattr(ds_groups, '_create_expert_data_and_model_parallel')) + + # Check SP support (DeepSpeed-specific) + support_info["sp_supported"] = hasattr(ds_sp, 'initialize_sequence_parallel') + + # Determine overall support + required_dims = [d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1] + supported_dims = [] + if config["tp"] > 1 and support_info["tp_supported"]: + supported_dims.append("tp") + if config["dp"] > 1 and support_info["dp_supported"]: + supported_dims.append("dp") + if config["pp"] > 1 and support_info["pp_supported"]: + supported_dims.append("pp") + if config["cp"] > 1 and support_info["cp_supported"]: + supported_dims.append("cp") + if config["ep"] > 1 and support_info["ep_supported"]: + supported_dims.append("ep") + + fully_supported = len(supported_dims) == len(required_dims) + + return fully_supported, None, support_info + + def _check_module_exists(self, module_name): + """Check if a module exists.""" + try: + __import__(module_name) + return True + except ImportError: + return False + + def _simulate_deepspeed_rank_generation(self, config: Dict) -> Optional[Dict]: + """Simulate DeepSpeed's rank generation logic based on code analysis. + + This attempts to replicate DeepSpeed's rank assignment logic for comparison. + """ + if not DEEPSPEED_AVAILABLE: + return None + + try: + world_size = config["world_size"] + result = {} + + # For TP+DP: DeepSpeed uses mesh_device which creates groups in a specific way + if config["tp"] > 1 and config["dp"] > 1 and config["pp"] == 1 and config["cp"] == 1 and config["ep"] == 1: + # DeepSpeed's _init_tp_mesh_device creates: + # TP groups: [0,1], [2,3], [4,5], ... (consecutive) + # DP groups: [0,2,4,...], [1,3,5,...] (strided) + tp_size = config["tp"] + dp_size = config["dp"] + + tp_groups = [] + for i in range(world_size // tp_size): + group = list(range(i * tp_size, (i + 1) * tp_size)) + tp_groups.append(group) + + dp_groups = [] + for i in range(tp_size): + group = list(range(i, world_size, tp_size)) + dp_groups.append(group) + + result["tp_groups"] = tp_groups + result["dp_groups"] = dp_groups + result["world_size"] = world_size + return result + + # For other combinations, we can't easily simulate without actual distributed setup + # But we can note that DeepSpeed supports it + return {"supported": True, "note": "Rank generation requires actual distributed setup"} + + except Exception as e: + return {"error": str(e)} + + def _compare_rank_groups(self, megatron_groups: List[List[int]], deepspeed_groups: List[List[int]]) -> Dict: + """Compare rank groups from Megatron and DeepSpeed. + + Returns: + Dict with comparison results + """ + comparison = {"same_structure": False, "same_ranks": False, "differences": []} + + if not megatron_groups or not deepspeed_groups: + return comparison + + # Check if same number of groups + if len(megatron_groups) != len(deepspeed_groups): + comparison["differences"].append( + f"Group count mismatch: Megatron={len(megatron_groups)}, DeepSpeed={len(deepspeed_groups)}") + return comparison + + # Check if same group sizes + megatron_sizes = [len(g) for g in megatron_groups] + deepspeed_sizes = [len(g) for g in deepspeed_groups] + if megatron_sizes != deepspeed_sizes: + comparison["differences"].append( + f"Group size mismatch: Megatron={megatron_sizes}, DeepSpeed={deepspeed_sizes}") + return comparison + + # Check if same ranks (order may differ) + megatron_ranks = set() + for group in megatron_groups: + megatron_ranks.update(group) + + deepspeed_ranks = set() + for group in deepspeed_groups: + deepspeed_ranks.update(group) + + if megatron_ranks != deepspeed_ranks: + comparison["differences"].append( + f"Rank set mismatch: Megatron={sorted(megatron_ranks)}, DeepSpeed={sorted(deepspeed_ranks)}") + return comparison + + # Check if same structure (same groups, possibly different order) + megatron_sets = [set(g) for g in megatron_groups] + deepspeed_sets = [set(g) for g in deepspeed_groups] + + if sorted(megatron_sets, key=lambda x: min(x)) == sorted(deepspeed_sets, key=lambda x: min(x)): + comparison["same_structure"] = True + comparison["same_ranks"] = True + else: + comparison["differences"].append("Group structure differs (same ranks but different grouping)") + + return comparison + + def test_config_compatibility(self, config: Dict): + """Test compatibility of a configuration between both libraries.""" + # Get combination type for statistics + combo_type = self.error_categorizer.get_combination_type(config) + self.combination_stats[combo_type]["total"] += 1 + + # Test Megatron + megatron_success, megatron_error, megatron_result = self.test_megatron_config(config) + + # Test DeepSpeed + deepspeed_success, deepspeed_error, deepspeed_support = self.test_deepspeed_config(config) + + # Record errors in categorizer + if not megatron_success and megatron_error: + self.error_categorizer.record_error(megatron_error, config, "Megatron") + self.combination_stats[combo_type]["megatron_failures"] += 1 + else: + self.combination_stats[combo_type]["megatron_success"] += 1 + + if not deepspeed_success: + # Get error message from support_info notes + error_msg = deepspeed_support.get("notes", ["Not supported"])[0] if deepspeed_support else "Not supported" + self.error_categorizer.record_error(error_msg, config, "DeepSpeed") + self.combination_stats[combo_type]["deepspeed_failures"] += 1 + else: + self.combination_stats[combo_type]["deepspeed_success"] += 1 + + # Try to simulate DeepSpeed rank generation for comparison + deepspeed_simulated = None + if megatron_success and deepspeed_success: + deepspeed_simulated = self._simulate_deepspeed_rank_generation(config) + + # Compare rank generation if both succeeded and we have simulated results + rank_comparison = None + if megatron_success and deepspeed_success and deepspeed_simulated and "tp_groups" in deepspeed_simulated: + # Compare TP groups + if config["tp"] > 1 and "tp_groups" in megatron_result: + rank_comparison = self._compare_rank_groups(megatron_result["tp_groups"], + deepspeed_simulated.get("tp_groups", [])) + # Compare DP groups + if config["dp"] > 1 and "dp_groups" in megatron_result and not rank_comparison: + rank_comparison = self._compare_rank_groups(megatron_result["dp_groups"], + deepspeed_simulated.get("dp_groups", [])) + + # Record results + config_key = f"tp={config['tp']},dp={config['dp']},pp={config['pp']},cp={config['cp']},ep={config['ep']},order={config['order']}" + + if megatron_success: + self.results["megatron_success"].append(config_key) + else: + self.results["megatron_failures"].append({ + "config": config_key, + "error": megatron_error, + "combination": combo_type, + }) + + if deepspeed_success: + self.results["deepspeed_success"].append(config_key) + else: + self.results["deepspeed_failures"].append({ + "config": config_key, + "error": deepspeed_error, + "support_info": deepspeed_support, + "combination": combo_type, + }) + + # Determine compatibility and update stats + if megatron_success and deepspeed_success: + compat_entry = { + "config": config_key, + "megatron_result": megatron_result, + "deepspeed_support": deepspeed_support, + "combination": combo_type, + } + if rank_comparison: + compat_entry["rank_comparison"] = rank_comparison + if rank_comparison.get("same_structure"): + compat_entry["rank_match"] = True + else: + compat_entry["rank_match"] = False + compat_entry["rank_differences"] = rank_comparison.get("differences", []) + self.results["compatible"].append(compat_entry) + self.combination_stats[combo_type]["compatible"] += 1 + elif megatron_success and not deepspeed_success: + self.results["megatron_only"].append({ + "config": + config_key, + "megatron_result": + megatron_result, + "deepspeed_issue": + deepspeed_support.get("notes", []) if deepspeed_support else [], + "combination": + combo_type, + }) + self.combination_stats[combo_type]["megatron_only"] += 1 + elif not megatron_success and deepspeed_success: + self.results["deepspeed_only"].append({ + "config": config_key, + "megatron_error": megatron_error, + "deepspeed_support": deepspeed_support, + "combination": combo_type, + }) + self.combination_stats[combo_type]["deepspeed_only"] += 1 + else: + self.results["incompatible"].append({ + "config": + config_key, + "megatron_error": + megatron_error, + "deepspeed_issue": + deepspeed_support.get("notes", []) if deepspeed_support else [], + "combination": + combo_type, + }) + self.combination_stats[combo_type]["incompatible"] += 1 + + +class TestAutomatedParallelCombinations: + """Automated tests for parallel strategy combinations.""" + + def test_systematic_configurations(self): + """Test systematic configurations covering common cases.""" + generator = ParallelConfigGenerator(seed=42) + tester = ParallelCompatibilityTester() + + configs = generator.generate_systematic_configs(max_world_size=16) + + print("\n" + "=" * 80) + print("SYSTEMATIC CONFIGURATION TESTING") + print("=" * 80) + print(f"\nTesting {len(configs)} systematic configurations...") + + for i, config in enumerate(configs, 1): + print(f"\n[{i}/{len(configs)}] Testing: {config}") + tester.test_config_compatibility(config) + + self._print_results(tester, "Systematic") + self._generate_comprehensive_report(tester, "Systematic") + + def test_random_configurations(self): + """Test random configurations.""" + generator = ParallelConfigGenerator(seed=123) + tester = ParallelCompatibilityTester() + + configs = generator.generate_random_configs(count=1000, max_size=1024) + + print("\n" + "=" * 80) + print("RANDOM CONFIGURATION TESTING") + print("=" * 80) + print(f"\nTesting {len(configs)} random configurations...") + print(f"Max world size: 1024, Max parallel size per dimension: 32") + + for i, config in enumerate(configs, 1): + if i % 100 == 0: + print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") + tester.test_config_compatibility(config) + + self._print_results(tester, "Random") + self._generate_comprehensive_report(tester, "Random") + + def test_random_configurations_by_dimension(self): + """Test random configurations generated separately for each dimension.""" + generator = ParallelConfigGenerator(seed=789) + tester = ParallelCompatibilityTester() + + # Generate configs for each dimension separately + # This ensures balanced coverage across all dimensions + # Increased by 20x for comprehensive testing + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs (200 * 20) + 2: 6000, # 2D: 6000 configs (300 * 20) - more because there are more 2D combinations + 3: 5000, # 3D: 5000 configs (250 * 20) + 4: 3000, # 4D: 3000 configs (150 * 20) + 5: 2000, # 5D: 2000 configs (100 * 20) + } + + print("\n" + "=" * 80) + print("RANDOM CONFIGURATION TESTING BY DIMENSION") + print("=" * 80) + print(f"\nGenerating configurations by dimension:") + for dim, count in counts_by_dimension.items(): + print(f" {dim}D: {count} configurations") + + configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + + print(f"\nTotal unique configurations generated: {len(configs)}") + print(f"Max world size: 1024, Parallel size range: 2-32") + + # Count configs by dimension + dim_counts = defaultdict(int) + for config in configs: + dim_count = len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) + dim_counts[dim_count] += 1 + + print("\nActual distribution:") + for dim in sorted(dim_counts.keys()): + print(f" {dim}D: {dim_counts[dim]} configurations") + + print(f"\nTesting {len(configs)} configurations...") + + for i, config in enumerate(configs, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs): + print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") + tester.test_config_compatibility(config) + + self._print_results(tester, "Random by Dimension") + self._generate_comprehensive_report(tester, "Random by Dimension") + + def test_edge_cases(self): + """Test edge cases and boundary conditions.""" + generator = ParallelConfigGenerator(seed=456) + tester = ParallelCompatibilityTester() + + # Edge cases - including larger sizes + edge_configs = [ + # Maximum dimensions - larger sizes + { + "tp": 8, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 512 + }, + { + "tp": 16, + "dp": 16, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 1024 + }, + # EP and CP conflict + { + "tp": 2, + "dp": 2, + "pp": 1, + "cp": 2, + "ep": 2, + "order": "tp-ep-dp", + "world_size": 8 + }, + { + "tp": 4, + "dp": 4, + "pp": 1, + "cp": 4, + "ep": 4, + "order": "tp-ep-dp", + "world_size": 64 + }, + # Single dimension - larger sizes + { + "tp": 1, + "dp": 1, + "pp": 64, + "cp": 1, + "ep": 1, + "order": "pp", + "world_size": 64 + }, + { + "tp": 128, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp", + "world_size": 128 + }, + { + "tp": 1, + "dp": 256, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp", + "world_size": 256 + }, + # All dimensions - larger sizes + { + "tp": 2, + "dp": 2, + "pp": 2, + "cp": 2, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 16 + }, + { + "tp": 4, + "dp": 4, + "pp": 4, + "cp": 4, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 256 + }, + # Different orders + { + "tp": 2, + "dp": 4, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": 8 + }, + { + "tp": 2, + "dp": 4, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": 8 + }, + { + "tp": 8, + "dp": 16, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": 128 + }, + { + "tp": 8, + "dp": 16, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": 128 + }, + # Large multi-dimensional + { + "tp": 8, + "dp": 8, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": 256 + }, + { + "tp": 4, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": 256 + }, + ] + + print("\n" + "=" * 80) + print("EDGE CASE TESTING") + print("=" * 80) + print(f"\nTesting {len(edge_configs)} edge case configurations...") + + for i, config in enumerate(edge_configs, 1): + print(f"\n[{i}/{len(edge_configs)}] Testing: {config}") + tester.test_config_compatibility(config) + + self._print_results(tester, "Edge Cases") + self._generate_comprehensive_report(tester, "Edge Cases") + + def _print_results(self, tester: ParallelCompatibilityTester, test_type: str): + """Print test results.""" + results = tester.results + + print("\n" + "=" * 80) + print(f"{test_type} TEST RESULTS") + print("=" * 80) + + print(f"\n✓ Megatron Success: {len(results['megatron_success'])}") + print(f"✗ Megatron Failures: {len(results['megatron_failures'])}") + if results['megatron_failures']: + print("\nMegatron Failures:") + for failure in results['megatron_failures'][:10]: # Show first 10 + print(f" - {failure['config']}: {failure['error']}") + if len(results['megatron_failures']) > 10: + print(f" ... and {len(results['megatron_failures']) - 10} more") + + print(f"\n✓ DeepSpeed Success: {len(results['deepspeed_success'])}") + print(f"✗ DeepSpeed Failures: {len(results['deepspeed_failures'])}") + if results['deepspeed_failures']: + print("\nDeepSpeed Failures:") + for failure in results['deepspeed_failures'][:10]: # Show first 10 + print(f" - {failure['config']}") + if failure.get('support_info'): + notes = failure['support_info'].get('notes', []) + if notes: + print(f" Notes: {', '.join(notes)}") + if len(results['deepspeed_failures']) > 10: + print(f" ... and {len(results['deepspeed_failures']) - 10} more") + + print(f"\n✓ Compatible (Both Support): {len(results['compatible'])}") + if results['compatible']: + print(" Examples:") + rank_matches = 0 + rank_mismatches = 0 + for item in results['compatible'][:10]: + if isinstance(item, dict): + config = item.get('config', 'Unknown') + rank_comp = item.get('rank_comparison') + if rank_comp: + if rank_comp.get('same_structure'): + print(f" - {config} ✓ Rank groups match") + rank_matches += 1 + else: + print(f" - {config} ⚠ Rank groups differ") + rank_mismatches += 1 + if rank_comp.get('differences'): + for diff in rank_comp['differences'][:2]: + print(f" {diff}") + else: + print(f" - {config}") + else: + print(f" - {item}") + if len(results['compatible']) > 10: + print(f" ... and {len(results['compatible']) - 10} more") + + if rank_matches > 0 or rank_mismatches > 0: + print(f"\n Rank Comparison Summary:") + print(f" Matches: {rank_matches}") + print(f" Mismatches: {rank_mismatches}") + print(f" (Note: Comparison only available for TP+DP combinations)") + + print(f"\n⚠ Megatron Only: {len(results['megatron_only'])}") + if results['megatron_only']: + print(" Examples:") + for item in results['megatron_only'][:5]: + print(f" - {item['config']}") + if item.get('deepspeed_issue'): + print(f" DeepSpeed issue: {', '.join(item['deepspeed_issue'])}") + if len(results['megatron_only']) > 5: + print(f" ... and {len(results['megatron_only']) - 5} more") + + print(f"\n→ DeepSpeed Only: {len(results['deepspeed_only'])}") + if results['deepspeed_only']: + print(" Examples:") + for item in results['deepspeed_only'][:5]: + print(f" - {item['config']}") + print(f" Megatron error: {item['megatron_error']}") + if len(results['deepspeed_only']) > 5: + print(f" ... and {len(results['deepspeed_only']) - 5} more") + + print(f"\n✗ Incompatible (Neither Support): {len(results['incompatible'])}") + if results['incompatible']: + print(" Examples:") + for item in results['incompatible'][:5]: + print(f" - {item['config']}") + print(f" Megatron: {item['megatron_error']}") + if len(results['incompatible']) > 5: + print(f" ... and {len(results['incompatible']) - 5} more") + + print("\n" + "=" * 80) + + def _generate_comprehensive_report(self, tester: ParallelCompatibilityTester, test_type: str): + """Generate comprehensive test report with error categorization and combination statistics.""" + results = tester.results + error_summary = tester.error_categorizer.get_error_summary() + combo_stats = tester.combination_stats + + print("\n" + "=" * 80) + print(f"{test_type} COMPREHENSIVE TEST REPORT") + print("=" * 80) + + # Overall statistics + print("\n" + "-" * 80) + print("OVERALL STATISTICS") + print("-" * 80) + total_tested = (len(results['megatron_success']) + len(results['megatron_failures']) + + len(results['deepspeed_success']) + len(results['deepspeed_failures'])) + print(f"Total Configurations Tested: {total_tested}") + print( + f" Megatron Success: {len(results['megatron_success'])} ({len(results['megatron_success'])/total_tested*100:.1f}%)" + ) + print( + f" Megatron Failures: {len(results['megatron_failures'])} ({len(results['megatron_failures'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Success: {len(results['deepspeed_success'])} ({len(results['deepspeed_success'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Failures: {len(results['deepspeed_failures'])} ({len(results['deepspeed_failures'])/total_tested*100:.1f}%)" + ) + print(f" Compatible: {len(results['compatible'])} ({len(results['compatible'])/total_tested*100:.1f}%)") + print( + f" Megatron Only: {len(results['megatron_only'])} ({len(results['megatron_only'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Only: {len(results['deepspeed_only'])} ({len(results['deepspeed_only'])/total_tested*100:.1f}%)" + ) + print(f" Incompatible: {len(results['incompatible'])} ({len(results['incompatible'])/total_tested*100:.1f}%)") + + # Error categorization + print("\n" + "-" * 80) + print("ERROR CATEGORIZATION (Aggregated by Type)") + print("-" * 80) + for category, summary in sorted(error_summary.items(), key=lambda x: x[1]['count'], reverse=True): + print(f"\n{category}: {summary['count']} occurrences") + print(f" Affects {summary['unique_combinations']} unique combination types") + print(f" Examples:") + for example in summary['examples'][:3]: + combo = example.get('combination', 'Unknown') + lib = example.get('library', 'Unknown') + print(f" - {combo} ({lib}): {example['error'][:80]}") + if len(summary['examples']) > 3: + print(f" ... and {len(summary['examples']) - 3} more examples") + + # Combination type statistics + print("\n" + "-" * 80) + print("COMBINATION TYPE STATISTICS") + print("-" * 80) + print( + f"{'Combination':<20} {'Total':<8} {'M-Succ':<8} {'M-Fail':<8} {'DS-Succ':<8} {'DS-Fail':<8} {'Compat':<8} {'M-Only':<8} {'DS-Only':<8} {'Incomp':<8}" + ) + print("-" * 100) + + # Sort by total count + sorted_combos = sorted(combo_stats.items(), key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in sorted_combos: + if stats['total'] > 0: + print(f"{combo_type:<20} {stats['total']:<8} {stats['megatron_success']:<8} " + f"{stats['megatron_failures']:<8} {stats['deepspeed_success']:<8} " + f"{stats['deepspeed_failures']:<8} {stats['compatible']:<8} " + f"{stats['megatron_only']:<8} {stats['deepspeed_only']:<8} " + f"{stats['incompatible']:<8}") + + # Detailed combination analysis + print("\n" + "-" * 80) + print("DETAILED COMBINATION ANALYSIS") + print("-" * 80) + + # Group by number of dimensions + by_dimension_count = defaultdict(list) + for combo_type, stats in combo_stats.items(): + dim_count = len([c for c in combo_type.split('+') if c != 'NONE']) + by_dimension_count[dim_count].append((combo_type, stats)) + + for dim_count in sorted(by_dimension_count.keys()): + print(f"\n{dim_count}-Dimensional Combinations:") + combos = sorted(by_dimension_count[dim_count], key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in combos[:10]: # Show top 10 + if stats['total'] > 0: + compat_rate = (stats['compatible'] / stats['total'] * 100) if stats['total'] > 0 else 0 + print(f" {combo_type}:") + print(f" Total: {stats['total']}, Compatible: {stats['compatible']} ({compat_rate:.1f}%)") + print(f" Megatron: {stats['megatron_success']} success, {stats['megatron_failures']} failures") + print( + f" DeepSpeed: {stats['deepspeed_success']} success, {stats['deepspeed_failures']} failures") + if len(combos) > 10: + print(f" ... and {len(combos) - 10} more {dim_count}-dimensional combinations") + + print("\n" + "=" * 80) + + def test_cp_vs_sp_compatibility_by_dimension(self): + """Test CP vs SP compatibility using the same config generation as test_random_configurations_by_dimension. + + This test: + 1. Uses parallel_state_refactored with CP + 2. Uses DeepSpeed with SP + 3. Compares CP rank groups with SP rank groups to see if they match + """ + generator = ParallelConfigGenerator(seed=789) + + # Use the same configuration generation as test_random_configurations_by_dimension + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs + 2: 6000, # 2D: 6000 configs + 3: 5000, # 3D: 5000 configs + 4: 3000, # 4D: 3000 configs + 5: 2000, # 5D: 2000 configs + } + + print("\n" + "=" * 80) + print("CP vs SP COMPATIBILITY TESTING BY DIMENSION") + print("=" * 80) + print(f"\nGenerating configurations by dimension:") + for dim, count in counts_by_dimension.items(): + print(f" {dim}D: {count} configurations") + + configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + + # Filter to only include configs with CP > 1 and EP == 1 (EP and CP cannot both be > 1) + configs_with_cp = [c for c in configs if c["cp"] > 1 and c["ep"] == 1] + + print(f"\nTotal unique configurations generated: {len(configs)}") + print(f"Configurations with CP > 1 and EP == 1: {len(configs_with_cp)}") + print(f"Max world size: 1024, Parallel size range: 2-32") + + # Test CP vs SP compatibility + results = { + "total_tested": 0, + "cp_groups_generated": 0, + "sp_groups_generated": 0, + "rank_groups_match": 0, + "rank_groups_differ": 0, + "errors": 0, + "match_details": [], + "differ_details": [], + } + + combination_stats = defaultdict(lambda: { + "total": 0, + "match": 0, + "differ": 0, + "errors": 0, + }) + + print(f"\nTesting {len(configs_with_cp)} configurations for CP vs SP compatibility...") + + for i, config in enumerate(configs_with_cp, 1): + if i % 1000 == 0 or i == len(configs_with_cp): + print(f"Progress: {i}/{len(configs_with_cp)} ({(i/len(configs_with_cp)*100):.1f}%)") + + results["total_tested"] += 1 + + # Get combination type + combo_type = self._get_combination_type_for_cp_sp(config) + combination_stats[combo_type]["total"] += 1 + + try: + # Get CP rank groups from Megatron + if not PARALLEL_STATE_AVAILABLE: + results["errors"] += 1 + combination_stats[combo_type]["errors"] += 1 + continue + + rg = RankGenerator(tp=config["tp"], + ep=config["ep"], + dp=config["dp"], + pp=config["pp"], + cp=config["cp"], + order=config["order"]) + + cp_groups = rg.get_ranks("cp") + if cp_groups: + results["cp_groups_generated"] += 1 + + # Simulate SP rank groups from DeepSpeed + # DeepSpeed SP creates consecutive rank groups + sp_groups = self._simulate_deepspeed_sp_groups(config["world_size"], config["cp"]) + if sp_groups: + results["sp_groups_generated"] += 1 + + # Compare CP and SP groups + if self._compare_cp_sp_groups(cp_groups, sp_groups): + results["rank_groups_match"] += 1 + combination_stats[combo_type]["match"] += 1 + results["match_details"].append(config) + else: + results["rank_groups_differ"] += 1 + combination_stats[combo_type]["differ"] += 1 + results["differ_details"].append({ + "config": config, + "cp_groups": cp_groups, + "sp_groups": sp_groups, + }) + + except Exception as e: + results["errors"] += 1 + combination_stats[combo_type]["errors"] += 1 + + # Generate report + self._generate_cp_vs_sp_report(results, combination_stats) + + def _simulate_deepspeed_sp_groups(self, world_size: int, sp_size: int) -> List[List[int]]: + """Simulate DeepSpeed's SP rank group generation. + + DeepSpeed SP creates groups as consecutive ranks: + - Group 0: [0, 1, ..., sp_size-1] + - Group 1: [sp_size, sp_size+1, ..., 2*sp_size-1] + - etc. + """ + if sp_size <= 1 or world_size % sp_size != 0: + return [] + + num_groups = world_size // sp_size + groups = [] + for i in range(num_groups): + group = list(range(i * sp_size, (i + 1) * sp_size)) + groups.append(group) + + return groups + + def _compare_cp_sp_groups(self, cp_groups: List[List[int]], sp_groups: List[List[int]]) -> bool: + """Compare CP and SP rank groups to see if they match.""" + if not cp_groups and not sp_groups: + return True + + if not cp_groups or not sp_groups: + return False + + if len(cp_groups) != len(sp_groups): + return False + + # Check if all CP groups have a matching SP group (order may differ) + cp_sets = [set(g) for g in cp_groups] + sp_sets = [set(g) for g in sp_groups] + + # Check if all CP groups match SP groups + for cp_set in cp_sets: + found = False + for sp_set in sp_sets: + if cp_set == sp_set: + found = True + break + if not found: + return False + + # Check if all SP groups match CP groups + for sp_set in sp_sets: + found = False + for cp_set in cp_sets: + if sp_set == cp_set: + found = True + break + if not found: + return False + + return True + + def _get_combination_type_for_cp_sp(self, config: Dict) -> str: + """Get combination type string for CP vs SP testing.""" + dims = [] + if config["tp"] > 1: + dims.append("TP") + if config["dp"] > 1: + dims.append("DP") + if config["pp"] > 1: + dims.append("PP") + if config["cp"] > 1: + dims.append("CP") + # Note: EP is always 1 in this test + + if not dims: + return "NONE" + + return "+".join(sorted(dims)) + + def _generate_cp_vs_sp_report(self, results: Dict, combination_stats: Dict): + """Generate comprehensive CP vs SP compatibility report.""" + print("\n" + "=" * 80) + print("CP vs SP COMPATIBILITY TEST REPORT") + print("=" * 80) + + # Overall statistics + print("\n" + "-" * 80) + print("OVERALL STATISTICS") + print("-" * 80) + print(f"Total Configurations Tested: {results['total_tested']}") + print(f" CP Groups Generated: {results['cp_groups_generated']}") + print(f" SP Groups Generated: {results['sp_groups_generated']}") + print(f" Rank Groups Match: {results['rank_groups_match']}") + print(f" Rank Groups Differ: {results['rank_groups_differ']}") + print(f" Errors: {results['errors']}") + + if results['total_tested'] > 0: + match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 + print(f"\n Match Rate: {match_rate:.2f}%") + print(f" CP can replace SP in {match_rate:.2f}% of tested configurations") + + # Combination type statistics + print("\n" + "-" * 80) + print("COMBINATION TYPE STATISTICS") + print("-" * 80) + print(f"{'Combination':<20} {'Total':<8} {'Match':<8} {'Differ':<8} {'Errors':<8} {'Match Rate':<12}") + print("-" * 80) + + sorted_combos = sorted(combination_stats.items(), key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in sorted_combos: + if stats['total'] > 0: + match_rate = (stats['match'] / stats['total'] * 100) if stats['total'] > 0 else 0 + print(f"{combo_type:<20} {stats['total']:<8} {stats['match']:<8} " + f"{stats['differ']:<8} {stats['errors']:<8} {match_rate:.1f}%") + + # Examples of matching configurations + print("\n" + "-" * 80) + print("EXAMPLES OF MATCHING CONFIGURATIONS (CP can replace SP)") + print("-" * 80) + for i, config in enumerate(results['match_details'][:10], 1): + print(f"{i}. {config}") + print(f" CP size: {config['cp']}, Order: {config['order']}") + + if len(results['match_details']) > 10: + print(f"\n... and {len(results['match_details']) - 10} more matching configurations") + + # Examples of differing configurations + if results['differ_details']: + print("\n" + "-" * 80) + print("EXAMPLES OF DIFFERING CONFIGURATIONS (CP cannot replace SP)") + print("-" * 80) + for i, item in enumerate(results['differ_details'][:10], 1): + config = item['config'] + cp_groups = item['cp_groups'] + sp_groups = item['sp_groups'] + print(f"{i}. {config}") + print(f" CP size: {config['cp']}, Order: {config['order']}") + print(f" CP groups count: {len(cp_groups)}, SP groups count: {len(sp_groups)}") + if cp_groups and sp_groups: + print(f" CP first group: {cp_groups[0]}") + print(f" SP first group: {sp_groups[0]}") + + if len(results['differ_details']) > 10: + print(f"\n... and {len(results['differ_details']) - 10} more differing configurations") + + # Conclusion + print("\n" + "=" * 80) + print("CONCLUSION") + print("=" * 80) + if results['rank_groups_match'] > 0: + match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 + print(f"\n✓ CP can replace SP in {match_rate:.2f}% of tested configurations") + print( + f" - {results['rank_groups_match']} out of {results['total_tested']} configurations have matching rank groups" + ) + else: + print("\n✗ CP cannot replace SP in any of the tested configurations") + + if results['rank_groups_differ'] > 0: + print(f"\n⚠ {results['rank_groups_differ']} configurations have different rank groups") + print(" - These configurations may require special handling when migrating from CP to SP") + + print("\n" + "=" * 80) + + def test_comprehensive_automated_testing(self): + """Comprehensive automated testing with all test types.""" + print("\n" + "=" * 80) + print("COMPREHENSIVE AUTOMATED PARALLEL COMBINATION TESTING") + print("=" * 80) + + # Create a combined tester for overall report + combined_tester = ParallelCompatibilityTester() + + # Run all test types and accumulate results + print("\n[1/3] Running systematic configurations...") + generator1 = ParallelConfigGenerator(seed=42) + configs1 = generator1.generate_systematic_configs(max_world_size=512) + print(f"Testing {len(configs1)} systematic configurations...") + for i, config in enumerate(configs1, 1): + if i % 50 == 0 or i == len(configs1): + print(f" Progress: {i}/{len(configs1)}") + combined_tester.test_config_compatibility(config) + + print("\n[2/4] Running random configurations by dimension...") + generator2 = ParallelConfigGenerator(seed=789) + # Increased by 20x for comprehensive testing + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs (200 * 20) + 2: 6000, # 2D: 6000 configs (300 * 20) + 3: 5000, # 3D: 5000 configs (250 * 20) + 4: 3000, # 4D: 3000 configs (150 * 20) + 5: 2000, # 5D: 2000 configs (100 * 20) + } + configs2 = generator2.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + print(f"Testing {len(configs2)} random configurations (balanced by dimension)...") + print(f"Max world size: 1024, Parallel size range: 2-32") + for i, config in enumerate(configs2, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs2): + print(f" Progress: {i}/{len(configs2)} ({(i/len(configs2)*100):.1f}%)") + combined_tester.test_config_compatibility(config) + + print("\n[3/4] Running additional random configurations...") + generator3 = ParallelConfigGenerator(seed=123) + # Increased by 20x: 500 * 20 = 10000 + configs3 = generator3.generate_random_configs(count=10000, max_size=1024) + print(f"Testing {len(configs3)} additional random configurations...") + for i, config in enumerate(configs3, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs3): + print(f" Progress: {i}/{len(configs3)} ({(i/len(configs3)*100):.1f}%)") + combined_tester.test_config_compatibility(config) + + print("\n[4/4] Running edge cases...") + edge_configs = [ + { + "tp": 8, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 512 + }, + { + "tp": 16, + "dp": 16, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 1024 + }, + { + "tp": 2, + "dp": 2, + "pp": 1, + "cp": 2, + "ep": 2, + "order": "tp-ep-dp", + "world_size": 8 + }, + { + "tp": 4, + "dp": 4, + "pp": 1, + "cp": 4, + "ep": 4, + "order": "tp-ep-dp", + "world_size": 64 + }, + { + "tp": 1, + "dp": 1, + "pp": 64, + "cp": 1, + "ep": 1, + "order": "pp", + "world_size": 64 + }, + { + "tp": 128, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp", + "world_size": 128 + }, + { + "tp": 1, + "dp": 256, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp", + "world_size": 256 + }, + { + "tp": 2, + "dp": 2, + "pp": 2, + "cp": 2, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 16 + }, + { + "tp": 4, + "dp": 4, + "pp": 4, + "cp": 4, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 256 + }, + ] + print(f"Testing {len(edge_configs)} edge case configurations...") + for config in edge_configs: + combined_tester.test_config_compatibility(config) + + # Generate comprehensive report + print("\n" + "=" * 80) + print("COMPREHENSIVE FINAL REPORT") + print("=" * 80) + self._generate_comprehensive_report(combined_tester, "COMPREHENSIVE") + + print("\n" + "=" * 80) + print("ALL TESTS COMPLETED") + print("=" * 80) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From e319923f2807e70cef16da35dce4f827407a17b4 Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Wed, 7 Jan 2026 11:43:37 +0800 Subject: [PATCH 02/23] parallel_state: Cleanup dependency on ProcessGroupNCCL.Options Signed-off-by: Junjie Mao --- deepspeed/utils/parallel_state.py | 98 ++++++++++++------------------- 1 file changed, 39 insertions(+), 59 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index df9906d2fcee..495241daa523 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -280,24 +280,19 @@ def __init__(self): self.decoder_rank_generator = None self.expert_decoder_rank_generator = None - def _get_nccl_options(self, pg_name: str, nccl_comm_cfgs: dict): - """Set the NCCL process group options.""" - if pg_name in nccl_comm_cfgs: - # FIXME: deepspeed.comm does not provide a way to set NCCL options yet. - nccl_options = torch.distributed.ProcessGroupNCCL.Options( - is_high_priority_stream=nccl_comm_cfgs[pg_name].get("is_high_priority_stream", False)) - if "cga_cluster_size" in nccl_comm_cfgs[pg_name]: - nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name]["cga_cluster_size"] - if "max_ctas" in nccl_comm_cfgs[pg_name]: - nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name]["max_ctas"] - if "min_ctas" in nccl_comm_cfgs[pg_name]: - nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name]["min_ctas"] - if "net_name" in nccl_comm_cfgs[pg_name]: - nccl_options.config.net_name = nccl_comm_cfgs[pg_name]["net_name"] - if nccl_options.config.net_name.lower() not in ["ib", "socket"]: - raise RuntimeError(f"net_name ({nccl_options.config.net_name}) is not supported." - f"Accepted values: 'IB' or 'socket'.") - return nccl_options + def _get_pg_options(self, pg_name: str, pg_comm_cfgs: dict): + """Get the options for a specific process group.""" + # TODO: construct process group options from json config + # + # As of PyTorch 2.9, the only backend that supports pg options is nccl, + # and a nccl-specific class, namely ProcessGroupNCCL.Options, is + # required to construct the options. + # + # To enable configuring such options in DeepSpeed, we need to define the + # interface for users to specify them and also figure out whether we + # want to export ProcessGroupNCCL.Options in deepspeed.comm or allow + # using torch distributed for this specific case in check-torchdist.py. + # Those are left as future work. return None def _create_group( @@ -393,13 +388,11 @@ def initialize_model_parallel( expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, - nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: int = 30, order: str = "tp-cp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, create_gloo_process_groups: bool = True, - high_priority_stream_groups: Optional[List[str]] = None, ) -> None: """Initialize model data parallel groups. @@ -439,23 +432,10 @@ def default_position_embedding_ranks(pp_ranks): self.virtual_pipeline_model_parallel_rank = 0 self.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size - # Load NCCL configs - nccl_comm_cfgs = {} - if nccl_communicator_config_path is not None: - try: - import yaml - except ImportError: - raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " - "requires the yaml package.") - with open(nccl_communicator_config_path, "r") as stream: - nccl_comm_cfgs = yaml.safe_load(stream) - - # Set high priority stream groups - high_priority_stream_groups = high_priority_stream_groups or [] - for pg_name in high_priority_stream_groups: - if pg_name not in nccl_comm_cfgs: - nccl_comm_cfgs[pg_name] = {} - nccl_comm_cfgs[pg_name]["is_high_priority_stream"] = True + # TODO: Collect process group options from configs + # + # Check _get_pg_options for details. + pg_comm_cfgs = {} # Create rank generators self.decoder_rank_generator = RankGenerator( @@ -502,7 +482,7 @@ def default_position_embedding_ranks(pp_ranks): group_with_cp = self._create_group( ranks_with_cp, timeout=timeout, - pg_options=self._get_nccl_options("dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("dp_cp", pg_comm_cfgs), group_desc="DATA_PARALLEL_GROUP_WITH_CP", ) if create_gloo_process_groups: @@ -526,7 +506,7 @@ def default_position_embedding_ranks(pp_ranks): intra_partial_dp_group_with_cp = self._create_group( intra_partial_dp_ranks_with_cp, timeout=timeout, - pg_options=self._get_nccl_options("intra_dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("intra_dp_cp", pg_comm_cfgs), group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP", ) if create_gloo_process_groups: @@ -550,7 +530,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("dp", pg_comm_cfgs), group_desc="DATA_PARALLEL_GROUP", ) if create_gloo_process_groups: @@ -571,7 +551,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("cp", pg_comm_cfgs), group_desc="CONTEXT_PARALLEL_GROUP", ) if rank in ranks: @@ -584,7 +564,7 @@ def default_position_embedding_ranks(pp_ranks): ranks, hierarchical_context_parallel_sizes, create_gloo_process_groups=False, - pg_options=self._get_nccl_options("hcp", nccl_comm_cfgs), + pg_options=self._get_pg_options("hcp", pg_comm_cfgs), timeout=timeout, group_desc="CONTEXT_PARALLEL_GROUP", ) @@ -597,7 +577,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("mp", nccl_comm_cfgs), + pg_options=self._get_pg_options("mp", pg_comm_cfgs), group_desc="MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -610,7 +590,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp", pg_comm_cfgs), group_desc="TENSOR_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -627,8 +607,8 @@ def default_position_embedding_ranks(pp_ranks): ranks, timeout=timeout, backend=pipeline_model_parallel_comm_backend, - pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_nccl_options( - "pp", nccl_comm_cfgs)), + pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_pg_options( + "pp", pg_comm_cfgs)), group_desc="PIPELINE_MODEL_PARALLEL_GROUP", ) assert ( @@ -653,7 +633,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( embedding_ranks, timeout=timeout, - pg_options=self._get_nccl_options("embd", nccl_comm_cfgs), + pg_options=self._get_pg_options("embd", pg_comm_cfgs), group_desc="EMBEDDING_GROUP", ) if rank in embedding_ranks: @@ -664,7 +644,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( position_embedding_ranks, timeout=timeout, - pg_options=self._get_nccl_options("pos_embd", nccl_comm_cfgs), + pg_options=self._get_pg_options("pos_embd", pg_comm_cfgs), group_desc="POSITION_EMBEDDING_GROUP", ) if rank in position_embedding_ranks: @@ -677,7 +657,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_dp_cp", pg_comm_cfgs), group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP", ) if rank in ranks: @@ -686,7 +666,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_dp", pg_comm_cfgs), group_desc="TENSOR_AND_DATA_PARALLEL_GROUP", ) if rank in ranks: @@ -697,7 +677,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_cp", pg_comm_cfgs), group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP", ) if rank in ranks: @@ -708,7 +688,7 @@ def default_position_embedding_ranks(pp_ranks): for ranks in self.expert_decoder_rank_generator.get_ranks('ep'): group = self._create_group( ranks, - pg_options=self._get_nccl_options("ep", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep", pg_comm_cfgs), group_desc="EXPERT_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -719,7 +699,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("ep_tp", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep_tp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_PARALLEL_GROUP", ) if rank in ranks: @@ -730,7 +710,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_ep_mp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_ep_mp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -741,7 +721,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_ep_pp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_ep_pp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP", ) if rank in ranks: @@ -761,7 +741,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("ep_dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep_dp", pg_comm_cfgs), group_desc="EXPERT_DATA_PARALLEL_GROUP", ) if create_gloo_process_groups: @@ -779,8 +759,8 @@ def default_position_embedding_ranks(pp_ranks): [intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances], create_gloo_process_groups=create_gloo_process_groups, pg_options=[ - self._get_nccl_options("intra_ep_dp", nccl_comm_cfgs), - self._get_nccl_options("inter_ep_dp", nccl_comm_cfgs), + self._get_pg_options("intra_ep_dp", pg_comm_cfgs), + self._get_pg_options("inter_ep_dp", pg_comm_cfgs), ], timeout=timeout, group_desc="EXPERT_DATA_PARALLEL_GROUP", @@ -804,7 +784,7 @@ def default_position_embedding_ranks(pp_ranks): intra_dist_opt_instance_group = self._create_group( intra_dist_opt_ranks, timeout=timeout, - pg_options=self._get_nccl_options("intra_dist_opt_instance", nccl_comm_cfgs), + pg_options=self._get_pg_options("intra_dist_opt_instance", pg_comm_cfgs), group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP", ) if rank in intra_dist_opt_ranks: From 684f09642f565b679d1b8f8fb641002c433cad35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Tue, 23 Dec 2025 20:41:37 +0800 Subject: [PATCH 03/23] feat: add config-based parallel state initialization with validation Add comprehensive config.json support for parallelism configuration with smart priority handling and context parallel validation. Key features: - Support all parallelism dimensions via config.json - Config priority: config file > function params > defaults - Conflict detection with warning logs - Context parallel validation (CP must be 1) - Backward compatible with existing code Changes: - Add 14 optional parameters to initialize_parallel_state_from_config() - Implement 3-tier priority system with conflict detection - Add CP validation: raise NotImplementedError if CP > 1 - Update default order from "tp-cp-ep-dp-pp" to "tp-ep-dp-pp" - Add detailed docstrings and usage examples This allows users to configure all parallel dimensions in config.json instead of reading documentation and manually calling initialize_model_parallel. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 252 +++++++++++++++++++- 1 file changed, 249 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index bf3a346de194..2d4cbf93915a 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -29,6 +29,7 @@ - Supports multiple parallel state instances (for RL scenarios with different models) - Backward compatible with single global instance - Context manager for switching between different parallel configurations +- Configuration-based initialization from config.json Usage: # Basic usage (single global instance): @@ -56,11 +57,16 @@ with set_current_parallel_state("critic"): critic_dp_group = get_data_parallel_group() # Uses critic's DP group + + # Initialize from config.json: + from deepspeed import DeepSpeedConfig + ds_config = DeepSpeedConfig("config.json") + initialize_parallel_state_from_config(ds_config) """ from contextlib import contextmanager -from typing import Optional -from parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state +from typing import Optional, Union, Dict, Any, List +from .parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state # Registry for multiple parallel state instances _parallel_state_registry = {} @@ -287,7 +293,7 @@ def get_tensor_model_parallel_src_rank(name: Optional[str] = None): DeepSpeed-compatible interface. """ - import torch.distributed as dist + import deepspeed.comm as dist global_rank = dist.get_rank() local_world_size = get_tensor_model_parallel_world_size(name) return (global_rank // local_world_size) * local_world_size @@ -553,3 +559,243 @@ def is_initialized(name: Optional[str] = None): DeepSpeed-compatible interface. """ return get_parallel_state(name).is_initialized() + + +# ============================================================================ +# Configuration-based Initialization +# ============================================================================ + + +def initialize_parallel_state_from_config( + config: Union[Dict[str, Any], Any], + name: Optional[str] = None, + config_key: str = "parallelism", + # Optional parameters to override config values + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: Optional[int] = None, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: Optional[int] = None, + num_distributed_optimizer_instances: Optional[int] = None, + expert_tensor_parallel_size: Optional[int] = None, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: Optional[int] = None, + order: Optional[str] = None, + create_gloo_process_groups: Optional[bool] = None, + high_priority_stream_groups: Optional[List[str]] = None, +) -> None: + """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. + + This function reads parallelism configuration from the DeepSpeed config file + and automatically initializes the ParallelState instance. This allows users + to configure all parallelism dimensions in a single place (config.json) + rather than having to read documentation and manually call initialize_model_parallel. + + Configuration priority: config file (if explicitly set) > function parameters > default values + + Note: If a value is explicitly set in config file, it takes precedence over function + parameters. A warning will be logged if there's a conflict. To override config file + values, remove them from the config file first. + + Args: + config: Either a DeepSpeedConfig object or a config dictionary. + If DeepSpeedConfig, will access its _param_dict attribute. + If dict, will use it directly. + name: Optional name of the parallel state instance to initialize. + If None, initializes the default global instance. + config_key: Key in the config dictionary where parallelism config is stored. + Default is "parallelism". + + # Parallelism dimension parameters (override config if provided): + tensor_model_parallel_size: Size of tensor model parallel group. Default: 1 + pipeline_model_parallel_size: Size of pipeline model parallel group. Default: 1 + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. Default: None + pipeline_model_parallel_comm_backend: Communication backend for pipeline. Default: None + context_parallel_size: Size of context parallel group. Default: 1 (MUST be 1, CP not supported) + hierarchical_context_parallel_sizes: Hierarchical context parallel sizes. Default: None (NOT supported) + expert_model_parallel_size: Size of expert model parallel group. Default: 1 + num_distributed_optimizer_instances: Number of distributed optimizer instances. Default: 1 + expert_tensor_parallel_size: Size of expert tensor parallel group. Default: None + nccl_communicator_config_path: Path to NCCL communicator config. Default: None + distributed_timeout_minutes: Timeout for distributed operations. Default: 30 + order: Order of parallelism dimensions. Default: "tp-ep-dp-pp" + create_gloo_process_groups: Whether to create Gloo process groups. Default: True + high_priority_stream_groups: High priority stream groups. Default: None + + Example config.json: + { + "parallelism": { + "tensor_model_parallel_size": 2, + "pipeline_model_parallel_size": 1, + "expert_model_parallel_size": 1, + "expert_tensor_parallel_size": 1, + "virtual_pipeline_model_parallel_size": null, + "pipeline_model_parallel_comm_backend": null, ##不要加入config中,保留加载逻辑 + "num_distributed_optimizer_instances": 1, + "nccl_communicator_config_path": null, + "distributed_timeout_minutes": 30, + "order": "tp-ep-dp-pp", + "create_gloo_process_groups": true, + "high_priority_stream_groups": null + }, + + // Note: The following parameters are NOT supported in DeepSpeed: + // - "context_parallel_size": must be 1 (default) + // - "hierarchical_context_parallel_sizes": not supported + "train_batch_size": 8, + ... + } + + Example usage: + # Basic usage from config file: + from deepspeed import DeepSpeedConfig + ds_config = DeepSpeedConfig("config.json") + initialize_parallel_state_from_config(ds_config) + + # Override specific parameters: + initialize_parallel_state_from_config( + ds_config, + tensor_model_parallel_size=4, # Override config value + expert_model_parallel_size=2 + ) + + # From config dictionary: + import json + with open("config.json") as f: + config_dict = json.load(f) + initialize_parallel_state_from_config(config_dict) + + # For named instances (RL scenarios): + initialize_parallel_state_from_config(ds_config, name="actor") + initialize_parallel_state_from_config( + ds_config, + name="critic", + tensor_model_parallel_size=2 # Override for critic + ) + """ + # Extract config dictionary + if hasattr(config, '_param_dict'): + # DeepSpeedConfig object + config_dict = config._param_dict + elif isinstance(config, dict): + # Already a dictionary + config_dict = config + else: + raise ValueError(f"config must be a DeepSpeedConfig object or a dict, got {type(config)}") + + # Check if parallelism config exists in config file + parallelism_config = config_dict.get(config_key, {}) + if parallelism_config and not isinstance(parallelism_config, dict): + raise ValueError(f"'{config_key}' in config must be a dictionary, got {type(parallelism_config)}") + + # Get the parallel state instance + ps = get_parallel_state_instance(name) + + # Check if already initialized + if ps.is_initialized(): + # Already initialized, skip + return + + # Import logging + import logging + logger = logging.getLogger(__name__) + + # Helper function to get value with proper priority handling + # Priority: config file (if explicitly set) > function parameter > default + def get_value(param_name, param_value, config_key, default_value): + """ + Get value with priority handling and conflict detection. + + Priority: + 1. If config file explicitly sets the value -> use config value (warn if param differs) + 2. If config file doesn't have the value -> use function parameter + 3. If both are None -> use default value + """ + config_has_key = config_key in parallelism_config + config_value = parallelism_config.get(config_key) + + # Case 1: Config file explicitly sets the value + if config_has_key: + # If function parameter is also provided and differs, warn and use config + if param_value is not None and param_value != config_value: + logger.warning(f"Parameter '{param_name}' conflict detected: " + f"config file specifies {config_value}, but function parameter is {param_value}. " + f"Using config file value ({config_value}). " + f"To override config, remove '{config_key}' from config file.") + return config_value + + # Case 2: Config file doesn't have the key, use function parameter if provided + if param_value is not None: + return param_value + + # Case 3: Neither config nor parameter provided, use default + return default_value + + # Extract parameters with proper priority: config (if set) > function param > default + init_kwargs = { + "tensor_model_parallel_size": + get_value("tensor_model_parallel_size", tensor_model_parallel_size, "tensor_model_parallel_size", 1), + "pipeline_model_parallel_size": + get_value("pipeline_model_parallel_size", pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), + "virtual_pipeline_model_parallel_size": + get_value("virtual_pipeline_model_parallel_size", virtual_pipeline_model_parallel_size, + "virtual_pipeline_model_parallel_size", None), + "pipeline_model_parallel_comm_backend": + get_value("pipeline_model_parallel_comm_backend", pipeline_model_parallel_comm_backend, + "pipeline_model_parallel_comm_backend", None), + "context_parallel_size": + get_value("context_parallel_size", context_parallel_size, "context_parallel_size", 1), + "hierarchical_context_parallel_sizes": + get_value("hierarchical_context_parallel_sizes", hierarchical_context_parallel_sizes, + "hierarchical_context_parallel_sizes", None), + "expert_model_parallel_size": + get_value("expert_model_parallel_size", expert_model_parallel_size, "expert_model_parallel_size", 1), + "num_distributed_optimizer_instances": + get_value("num_distributed_optimizer_instances", num_distributed_optimizer_instances, + "num_distributed_optimizer_instances", 1), + "expert_tensor_parallel_size": + get_value("expert_tensor_parallel_size", expert_tensor_parallel_size, "expert_tensor_parallel_size", None), + "nccl_communicator_config_path": + get_value("nccl_communicator_config_path", nccl_communicator_config_path, "nccl_communicator_config_path", + None), + "distributed_timeout_minutes": + get_value("distributed_timeout_minutes", distributed_timeout_minutes, "distributed_timeout_minutes", 30), + "order": + get_value("order", order, "order", "tp-ep-dp-pp"), + "create_gloo_process_groups": + get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", True), + "high_priority_stream_groups": + get_value("high_priority_stream_groups", high_priority_stream_groups, "high_priority_stream_groups", None), + } + + # Validate context_parallel_size + cp_size = init_kwargs["context_parallel_size"] + if cp_size != 1: + raise NotImplementedError( + f"DeepSpeed currently does not support context_parallel_size > 1. " + f"Got context_parallel_size={cp_size}. Please set context_parallel_size=1 in your config.") + + # Validate hierarchical_context_parallel_sizes + hcp_sizes = init_kwargs["hierarchical_context_parallel_sizes"] + if hcp_sizes is not None: + raise NotImplementedError( + f"DeepSpeed currently does not support hierarchical_context_parallel_sizes. " + f"Got hierarchical_context_parallel_sizes={hcp_sizes}. Please remove this configuration.") + + # Remove None values for optional parameters (except those that can be None) + # Keep None for: virtual_pipeline_model_parallel_size, pipeline_model_parallel_comm_backend, + # hierarchical_context_parallel_sizes, expert_tensor_parallel_size, nccl_communicator_config_path, + # high_priority_stream_groups + filtered_kwargs = {} + for key, value in init_kwargs.items(): + if value is not None or key in [ + "virtual_pipeline_model_parallel_size", "pipeline_model_parallel_comm_backend", + "hierarchical_context_parallel_sizes", "expert_tensor_parallel_size", "nccl_communicator_config_path", + "high_priority_stream_groups" + ]: + filtered_kwargs[key] = value + + # Initialize parallel state + ps.initialize_model_parallel(**filtered_kwargs) From 4b362e48c8d26af72b0e381d6e7b4890e06db8a1 Mon Sep 17 00:00:00 2001 From: yunqing Date: Thu, 15 Jan 2026 10:51:43 +0800 Subject: [PATCH 04/23] Add sequence parallel support to refactored parallel state - Extend RankGenerator to include SP dimension and enforce TP/PP/EP compatibility - Initialize sequence parallel and sequence+data parallel process groups in ParallelState.initialize_model_parallel - Add sequence-parallel accessor stubs in parallel_state_deepspeed for future unified SP interfaces Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 112 +++++++++++++++++++- deepspeed/utils/parallel_state_deepspeed.py | 81 +++++++++++++- 2 files changed, 189 insertions(+), 4 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 495241daa523..0e4793e97a86 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -150,16 +150,32 @@ def decompose(index, shape, stride=None): class RankGenerator: """A class for generating rank groups for different modes of parallelism.""" - def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0) -> None: + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, sp: int, order: str, rank_offset: int = 0) -> None: assert (ep == 1 or cp == 1), "Both EP and CP > 1 is not allowed in one rank generator." + # Check SP compatibility: SP cannot be used with TP, PP, or EP + if sp > 1: + if tp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Tensor Parallel (TP). " + f"SP size: {sp}, TP size: {tp}. " + "Please set tp=1 when using SP.") + if pp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Pipeline Parallel (PP). " + f"SP size: {sp}, PP size: {pp}. " + "Please set pp=1 when using SP.") + if ep > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Expert Parallel (EP). " + f"SP size: {sp}, EP size: {ep}. " + "Please set ep=1 when using SP.") + self.tp = tp self.ep = ep self.dp = dp self.pp = pp self.cp = cp + self.sp = sp self.rank_offset = rank_offset - self.world_size = tp * dp * pp * cp * ep + self.world_size = tp * dp * pp * cp * ep * sp self.name_to_size = { "tp": self.tp, @@ -167,6 +183,7 @@ def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank "dp": self.dp, "ep": self.ep, "cp": self.cp, + "sp": self.sp, } self.order = order order = order.lower() @@ -231,6 +248,10 @@ def __init__(self): self.data_parallel_group_with_cp = None self.data_parallel_group_with_cp_gloo = None + # Sequence parallel groups + self.sequence_parallel_group = None + self.sequence_and_data_parallel_group = None + # Expert-related groups self.expert_model_parallel_group = None self.expert_tensor_parallel_group = None @@ -384,12 +405,13 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size: Optional[int] = None, pipeline_model_parallel_comm_backend: Optional[str] = None, context_parallel_size: int = 1, + sequence_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, distributed_timeout_minutes: int = 30, - order: str = "tp-cp-ep-dp-pp", + order: str = "tp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, create_gloo_process_groups: bool = True, @@ -446,6 +468,7 @@ def default_position_embedding_ranks(pp_ranks): cp=context_parallel_size, order=order, rank_offset=0, + sp=1, ) # Build expert rank generator @@ -467,6 +490,7 @@ def default_position_embedding_ranks(pp_ranks): cp=1, order=order, rank_offset=0, + sp=1, ) timeout = timedelta(minutes=distributed_timeout_minutes) @@ -791,6 +815,48 @@ def default_position_embedding_ranks(pp_ranks): self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group intra_dist_opt_ranks = [] + # Build sequence parallel groups + if sequence_parallel_size > 1: + assert self.sequence_parallel_group is None, "sequence parallel group is already initialized" + assert self.sequence_and_data_parallel_group is None, "sequence and data parallel group is already initialized" + + if world_size < sequence_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than sequence_parallel_size ({sequence_parallel_size})") + + if world_size % sequence_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") + + sp_data_parallel_size = world_size // sequence_parallel_size + sequence_and_data_parallel_size = sequence_parallel_size * sp_data_parallel_size + num_sequence_parallel_groups = world_size // sequence_parallel_size + num_sequence_and_data_parallel_groups = world_size // sequence_and_data_parallel_size + + # Build the sequence parallel groups + for i in range(num_sequence_parallel_groups): + ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp", pg_comm_cfgs), + group_desc="SEQUENCE_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_parallel_group = group + + # Build the sequence and data parallel groups + for i in range(num_sequence_and_data_parallel_groups): + ranks = list(range(i * sequence_and_data_parallel_size, (i + 1) * sequence_and_data_parallel_size)) + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp_dp", pg_comm_cfgs), + group_desc="SEQUENCE_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_and_data_parallel_group = group + # Initialize global memory buffer self._set_global_memory_buffer() @@ -837,6 +903,18 @@ def get_context_parallel_group(self, check_initialized=True): assert self.context_parallel_group is not None, "context parallel group is not initialized" return self.context_parallel_group + def get_sequence_parallel_group(self, check_initialized=True): + """Get the sequence-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_parallel_group is not None, "sequence parallel group is not initialized" + return self.sequence_parallel_group + + def get_sequence_and_data_parallel_group(self, check_initialized=True): + """Get the sequence and data parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_and_data_parallel_group is not None, "sequence and data parallel group is not initialized" + return self.sequence_and_data_parallel_group + def get_embedding_group(self, check_initialized=True): """Get the embedding group the caller rank belongs to.""" if check_initialized: @@ -919,6 +997,34 @@ def get_context_parallel_rank(self): else: return 0 + def get_sequence_parallel_world_size(self): + """Return world size for the sequence parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().size() + return 1 + + def get_sequence_parallel_rank(self): + """Return caller's rank in the sequence-parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().rank() + return 0 + + def get_sequence_and_data_parallel_world_size(self): + """Return world size for the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().size() + return 0 + + def get_sequence_and_data_parallel_rank(self): + """Return caller's rank in the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().rank() + return 0 + def is_initialized(self): """Check if parallel state has been initialized""" return self.data_parallel_group is not None diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 2d4cbf93915a..d0eb21acfa19 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -401,6 +401,77 @@ def get_context_parallel_rank(name: Optional[str] = None): return get_parallel_state(name).get_context_parallel_rank() +# ============================================================================ +# Sequence Parallel Functions +# ============================================================================ + + +def get_sequence_parallel_group(name: Optional[str] = None): + """Get the sequence-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_group() + + +def get_sequence_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_world_size() + + +def get_sequence_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_rank() + + +def get_sequence_and_data_parallel_group(name: Optional[str] = None): + """Get the sequence and data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_group() + + +def get_sequence_and_data_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_world_size() + + +def get_sequence_and_data_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_rank() + + # ============================================================================ # Expert Parallel Functions # ============================================================================ @@ -638,12 +709,20 @@ def initialize_parallel_state_from_config( "distributed_timeout_minutes": 30, "order": "tp-ep-dp-pp", "create_gloo_process_groups": true, - "high_priority_stream_groups": null + "high_priority_stream_groups": null, + "sequence_parallel_size": 1 }, // Note: The following parameters are NOT supported in DeepSpeed: // - "context_parallel_size": must be 1 (default) // - "hierarchical_context_parallel_sizes": not supported + + // Sequence Parallel (SP) usage notes: + // - SP cannot be used together with TP, PP, or EP + // - When using SP, set tp=1, pp=1, ep=1 + // - Example SP config: {"sequence_parallel_size": 4, "order": "sp-dp"} + // - SP can be combined with DP: {"sequence_parallel_size": 4, "data_parallel_size": 2, "order": "sp-dp"} + "train_batch_size": 8, ... } From 703c1fe8ca618f4dcbda2f5336593847ab0ab2b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 19 Jan 2026 10:02:23 +0800 Subject: [PATCH 05/23] fix: remove Chinese comment from config example Remove Chinese inline comment from the example config.json docstring to comply with DeepSpeed community coding standards. This ensures all comments and documentation are in English only. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index d0eb21acfa19..a3e46a062519 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -703,7 +703,7 @@ def initialize_parallel_state_from_config( "expert_model_parallel_size": 1, "expert_tensor_parallel_size": 1, "virtual_pipeline_model_parallel_size": null, - "pipeline_model_parallel_comm_backend": null, ##不要加入config中,保留加载逻辑 + "pipeline_model_parallel_comm_backend": null, "num_distributed_optimizer_instances": 1, "nccl_communicator_config_path": null, "distributed_timeout_minutes": 30, From 04f56a9a9212605e4891a60106315f67289c71ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 19 Jan 2026 15:08:35 +0800 Subject: [PATCH 06/23] fix: use torch.distributed.new_group directly in _create_group The deepspeed.comm.new_group() wrapper only accepts 'ranks' parameter, but _create_group() needs to pass additional parameters like timeout, backend, pg_options, etc. to support advanced process group configuration. This fix uses torch.distributed.new_group() directly to support all parameters while still using deepspeed.comm for other operations. Fixes TypeError: new_group() got an unexpected keyword argument 'timeout' Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 0e4793e97a86..12d711f4e64f 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -326,6 +326,9 @@ def _create_group( group_desc=None, ): """Creates a ProcessGroup.""" + # Use torch.distributed directly for new_group as deepspeed.comm.new_group only supports ranks parameter + import torch.distributed as torch_dist + kwargs = { "ranks": ranks, "timeout": timeout, @@ -339,7 +342,7 @@ def _create_group( if timeout is None: kwargs.pop("timeout") - group = dist.new_group(**kwargs) + group = torch_dist.new_group(**kwargs) if self.global_process_group_list is None: self.global_process_group_list = [None] if dist.get_rank() in ranks: From 845bd817a98f968437f5e2de0a60699bcaf99eb2 Mon Sep 17 00:00:00 2001 From: yunqing Date: Mon, 19 Jan 2026 16:05:00 +0800 Subject: [PATCH 07/23] fix: correct SP parallel group creation logic in parallel_state - Include sequence_parallel_size in model_size calculation - Fix SP group count: num_sequence_parallel_groups = data_parallel_size - Use consecutive rank grouping for SP (not RankGenerator) - SP uses different parallelism model than TP/PP/CP/EP Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 12d711f4e64f..c2757587f3b8 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -445,7 +445,7 @@ def default_position_embedding_ranks(pp_ranks): world_size: int = dist.get_world_size() rank = dist.get_rank() - model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * sequence_parallel_size if world_size % model_size != 0: raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") @@ -471,7 +471,7 @@ def default_position_embedding_ranks(pp_ranks): cp=context_parallel_size, order=order, rank_offset=0, - sp=1, + sp=sequence_parallel_size, ) # Build expert rank generator @@ -831,12 +831,14 @@ def default_position_embedding_ranks(pp_ranks): raise RuntimeError( f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") - sp_data_parallel_size = world_size // sequence_parallel_size - sequence_and_data_parallel_size = sequence_parallel_size * sp_data_parallel_size - num_sequence_parallel_groups = world_size // sequence_parallel_size - num_sequence_and_data_parallel_groups = world_size // sequence_and_data_parallel_size + # SP groups use consecutive ranks + # Number of SP groups = data_parallel_size (each DP rank has its own SP group) + num_sequence_parallel_groups = data_parallel_size + sequence_and_data_parallel_size = world_size + num_sequence_and_data_parallel_groups = 1 - # Build the sequence parallel groups + # Build the sequence parallel groups using consecutive ranks + # SP uses consecutive rank grouping, not orthogonal grouping like TP/PP/CP for i in range(num_sequence_parallel_groups): ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) group = self._create_group( From 2ec5c69c9881340b00a91c2e0e2371f107b4fa9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Wed, 21 Jan 2026 10:59:01 +0800 Subject: [PATCH 08/23] refactor: simplify _create_group to use deepspeed.comm interface Updated _create_group() to use deepspeed.comm.new_group() which currently only supports 'ranks' parameter. Other parameters (timeout, backend, pg_options, etc.) are commented out and documented in TODO comments. For non-nccl backends, the function returns None with a warning, as these are not yet supported by the deepspeed.comm interface. These parameters will be enabled once DeepSpeed's comm interface is enhanced to support them. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index c2757587f3b8..8033ebf72a68 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -29,6 +29,8 @@ from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist +from deepspeed.utils.torch import required_torch_version + logger = logging.getLogger(__name__) @@ -324,25 +326,30 @@ def _create_group( pg_options=None, use_local_synchronization=False, group_desc=None, - ): + ): """Creates a ProcessGroup.""" - # Use torch.distributed directly for new_group as deepspeed.comm.new_group only supports ranks parameter - import torch.distributed as torch_dist - + if backend is not None and backend != "nccl": + logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") + return None + + # TODO: Currently using deepspeed.comm.new_group() which only supports 'ranks' parameter. + # The following parameters are commented out and will be enabled once DeepSpeed's + # comm interface supports them: + # - timeout: Timeout for process group operations + # - backend: Communication backend (e.g., 'nccl', 'gloo') + # - pg_options: Process group options + # - use_local_synchronization: Enable local synchronization + # - group_desc: Group description for debugging (requires PyTorch >= 2.4) kwargs = { "ranks": ranks, - "timeout": timeout, - "backend": backend, - "pg_options": pg_options, - "use_local_synchronization": use_local_synchronization, - "group_desc": group_desc, + # "timeout": timeout, + # "backend": backend, + # "pg_options": pg_options, + # "use_local_synchronization": use_local_synchronization, + # "group_desc": group_desc, } - if not is_torch_min_version("2.4.0"): - kwargs.pop("group_desc") - if timeout is None: - kwargs.pop("timeout") - group = torch_dist.new_group(**kwargs) + group = dist.new_group(**kwargs) if self.global_process_group_list is None: self.global_process_group_list = [None] if dist.get_rank() in ranks: From 33a3ca8f8eaf45c5cf9f7a2ec67025f956367788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Wed, 21 Jan 2026 11:58:15 +0800 Subject: [PATCH 09/23] feat: migrate All-to-All groups to parallel_state architecture Migrate _get_local_all_to_all_group functionality from groups.py to the new parallel_state architecture to support ZeRO++ quantized gradients. Changes in parallel_state.py: - Add all_to_all_groups and all_to_all_initialized to ParallelState class - Implement initialize_all_to_all_groups() method to create local and global All-to-All groups based on node topology - Implement get_all_to_all_groups() method to retrieve initialized groups Changes in parallel_state_deepspeed.py: - Add initialize_all_to_all_groups() wrapper function - Add get_all_to_all_groups() wrapper function - Add _get_local_all_to_all_group() for backward compatibility with groups.py Benefits: - Supports multi-instance scenarios (e.g., RL with actor/critic models) - Consistent with the new parallel_state architecture - Maintains backward compatibility with existing groups.py interface - Enables future config-based initialization of All-to-All groups Note: This does not remove the implementation from groups.py yet to maintain backward compatibility during the transition period. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 68 ++++++++++++++++++++- deepspeed/utils/parallel_state_deepspeed.py | 64 +++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 8033ebf72a68..a8ef43091096 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -265,6 +265,10 @@ def __init__(self): self.intra_partial_expert_data_parallel_group_gloo = None self.inter_partial_expert_data_parallel_group = None + # All-to-All groups for ZeRO++ quantized gradients + self.all_to_all_groups = {} + self.all_to_all_initialized = False + # Global ranks lists self.embedding_global_ranks = None self.position_embedding_global_ranks = None @@ -326,7 +330,7 @@ def _create_group( pg_options=None, use_local_synchronization=False, group_desc=None, - ): + ): """Creates a ProcessGroup.""" if backend is not None and backend != "nccl": logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") @@ -1041,6 +1045,68 @@ def is_initialized(self): """Check if parallel state has been initialized""" return self.data_parallel_group is not None + def initialize_all_to_all_groups(self): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology: + - Local groups: intra-node communication (NVLink/NVSwitch) + - Global groups: inter-node communication (cross-node) + + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Returns: + Dictionary of All-to-All groups + """ + if self.all_to_all_initialized: + return self.all_to_all_groups + + assert dist.is_initialized(), 'dist is not initialized' + + device_per_node = get_accelerator().device_count() + world_size = dist.get_world_size() + num_nodes = world_size // device_per_node + + if num_nodes == 0 and world_size > 0: + # Single incomplete node + assert world_size >= 1, 'num_gpus must >=1, cannot initialize All-To-All' + ranks = list(range(world_size)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + elif num_nodes == 1: + # Exactly one node + assert world_size == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All' + ranks = list(range(device_per_node)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + else: + # Multiple nodes: create both local and global groups + assert world_size > device_per_node, 'num_nodes<2 cannot initialize All-To-All' + + # Local groups (intra-node) + for node_id in range(num_nodes): + local_ranks = [j + device_per_node * node_id for j in range(device_per_node)] + self.all_to_all_groups[f"local_{node_id}"] = self._create_group(local_ranks) + + # Global groups (inter-node) + for device_id in range(device_per_node): + global_ranks = [device_id + j * device_per_node for j in range(num_nodes)] + self.all_to_all_groups[f"global_{device_id}"] = self._create_group(global_ranks) + + self.all_to_all_initialized = True + return self.all_to_all_groups + + def get_all_to_all_groups(self): + """Get All-to-All groups dictionary. + + Initializes the groups if not already initialized. + + Returns: + Dictionary of All-to-All groups + """ + if not self.all_to_all_initialized: + self.initialize_all_to_all_groups() + return self.all_to_all_groups + def get_global_memory_buffer(self): """Return the global GlobalMemoryBuffer object""" assert self.global_memory_buffer is not None, "global memory buffer is not initialized" diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index a3e46a062519..3d65a8b70e9b 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -632,6 +632,70 @@ def is_initialized(name: Optional[str] = None): return get_parallel_state(name).is_initialized() +# ============================================================================ +# All-to-All Groups for ZeRO++ Quantized Gradients +# ============================================================================ + + +def initialize_all_to_all_groups(name: Optional[str] = None): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology. + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Example: + # Initialize for default instance + all_to_all_groups = initialize_all_to_all_groups() + + # Initialize for named instance (RL scenario) + actor_groups = initialize_all_to_all_groups("actor") + critic_groups = initialize_all_to_all_groups("critic") + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).initialize_all_to_all_groups() + + +def get_all_to_all_groups(name: Optional[str] = None): + """Get All-to-All groups dictionary. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + +def _get_local_all_to_all_group(name: Optional[str] = None): + """Get All-to-All groups for current rank (backward compatible with groups.py). + + This function provides backward compatibility with the groups.py interface. + It returns all All-to-All groups (both local and global). + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Note: + This is a compatibility wrapper. New code should use get_all_to_all_groups() instead. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + # ============================================================================ # Configuration-based Initialization # ============================================================================ From 138fbebdb6f22fac17a42c9f645bf21bee9df847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Fri, 23 Jan 2026 15:21:48 +0800 Subject: [PATCH 10/23] fix: disable gloo process groups by default DeepSpeed's comm interface does not support gloo backend, so set create_gloo_process_groups default to False. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index a8ef43091096..106155f3c846 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -428,7 +428,7 @@ def initialize_model_parallel( order: str = "tp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, - create_gloo_process_groups: bool = True, + create_gloo_process_groups: bool = False, ) -> None: """Initialize model data parallel groups. From 2aab5b85bf90b16082ca3ee3fc15142690ac8fac Mon Sep 17 00:00:00 2001 From: yunqing Date: Fri, 23 Jan 2026 15:31:36 +0800 Subject: [PATCH 11/23] refactor: simplify SP group creation using RankGenerator - Replace manual consecutive rank grouping with RankGenerator.get_ranks('sp') - Remove redundant world_size validation logic (handled by RankGenerator) - Reduce SP group creation code from 41 lines to 26 lines - Maintain same SP group topology: consecutive ranks [0,1], [2,3] for sp_size=2 - Fix code style issues: remove unused import, update warning message This change unifies process group creation by leveraging RankGenerator's orthogonal parallelism algorithm, which naturally produces consecutive rank grouping when order='sp-dp'. Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 106155f3c846..52719a807986 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -29,8 +29,6 @@ from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist -from deepspeed.utils.torch import required_torch_version - logger = logging.getLogger(__name__) @@ -333,7 +331,7 @@ def _create_group( ): """Creates a ProcessGroup.""" if backend is not None and backend != "nccl": - logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") + logger.warning(f"{backend} backend is not supported for new_group. Using deepspeed.comm directly.") return None # TODO: Currently using deepspeed.comm.new_group() which only supports 'ranks' parameter. @@ -829,29 +827,13 @@ def default_position_embedding_ranks(pp_ranks): self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group intra_dist_opt_ranks = [] - # Build sequence parallel groups + # Build sequence parallel groups using RankGenerator if sequence_parallel_size > 1: assert self.sequence_parallel_group is None, "sequence parallel group is already initialized" assert self.sequence_and_data_parallel_group is None, "sequence and data parallel group is already initialized" - if world_size < sequence_parallel_size: - raise RuntimeError( - f"world_size ({world_size}) is less than sequence_parallel_size ({sequence_parallel_size})") - - if world_size % sequence_parallel_size != 0: - raise RuntimeError( - f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") - - # SP groups use consecutive ranks - # Number of SP groups = data_parallel_size (each DP rank has its own SP group) - num_sequence_parallel_groups = data_parallel_size - sequence_and_data_parallel_size = world_size - num_sequence_and_data_parallel_groups = 1 - - # Build the sequence parallel groups using consecutive ranks - # SP uses consecutive rank grouping, not orthogonal grouping like TP/PP/CP - for i in range(num_sequence_parallel_groups): - ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) + # Build SP groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp'): group = self._create_group( ranks, timeout=timeout, @@ -861,9 +843,8 @@ def default_position_embedding_ranks(pp_ranks): if rank in ranks: self.sequence_parallel_group = group - # Build the sequence and data parallel groups - for i in range(num_sequence_and_data_parallel_groups): - ranks = list(range(i * sequence_and_data_parallel_size, (i + 1) * sequence_and_data_parallel_size)) + # Build SP+DP combined groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp-dp'): group = self._create_group( ranks, timeout=timeout, From a0975c7e53644d7389ec35f19e31269d6c37ce3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 09:48:36 +0800 Subject: [PATCH 12/23] docs: fix config example and SP usage notes 1. Change create_gloo_process_groups from true to false - Aligns with default value change in previous commit - DeepSpeed comm interface does not support gloo backend 2. Correct Sequence Parallel usage description - SP is included in model_size calculation (tp * pp * cp * sp) - SP can be used together with TP/PP/EP - Number of SP groups equals data_parallel_size - SP uses consecutive rank grouping (not orthogonal like TP/PP/CP/EP) Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 3d65a8b70e9b..23bbe61b0b6d 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -772,7 +772,7 @@ def initialize_parallel_state_from_config( "nccl_communicator_config_path": null, "distributed_timeout_minutes": 30, "order": "tp-ep-dp-pp", - "create_gloo_process_groups": true, + "create_gloo_process_groups": false, "high_priority_stream_groups": null, "sequence_parallel_size": 1 }, @@ -782,10 +782,10 @@ def initialize_parallel_state_from_config( // - "hierarchical_context_parallel_sizes": not supported // Sequence Parallel (SP) usage notes: - // - SP cannot be used together with TP, PP, or EP - // - When using SP, set tp=1, pp=1, ep=1 - // - Example SP config: {"sequence_parallel_size": 4, "order": "sp-dp"} - // - SP can be combined with DP: {"sequence_parallel_size": 4, "data_parallel_size": 2, "order": "sp-dp"} + // - SP dimension is included in model_size calculation: model_size = tp * pp * cp * sp + // - Number of SP groups = data_parallel_size (each DP rank has its own SP group) + // - SP uses consecutive rank grouping, different from TP/PP/CP/EP orthogonal grouping + // - Example: world_size=16, tp=2, sp=2, pp=1, ep=1 => dp=4, and 4 SP groups "train_batch_size": 8, ... From 58f7f51addeed15f69138d0c3f12026bc8aa7b7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 09:59:14 +0800 Subject: [PATCH 13/23] refactor: remove unused is_torch_min_version function Remove is_torch_min_version function that is never called in the codebase. This reduces code complexity and removes unnecessary dependencies. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 52719a807986..93bf297d0ba0 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -39,27 +39,6 @@ HAVE_EINOPS = False -def is_torch_min_version(version: str, check_equality: bool = True) -> bool: - """Check if PyTorch version meets minimum requirement. - - Args: - version: Version string to check (e.g., "2.4.0") - check_equality: If True, also check for equality - - Returns: - True if version requirement is met - """ - try: - from packaging.version import Version as PkgVersion - torch_version = PkgVersion(torch.__version__) - required_version = PkgVersion(version) - if check_equality: - return torch_version >= required_version - return torch_version > required_version - except Exception: - return False - - class GlobalMemoryBuffer: """Global buffer to avoid dynamic memory allocations.""" From ff67ec59246bedf6b19dcf44c5fc271842a8366b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 14:52:14 +0800 Subject: [PATCH 14/23] refactor: simplify config-based initialization to use top-level fields Remove the 'parallelism' config block concept and directly read from existing DeepSpeed config fields. This avoids adding new top-level config structure which requires changes throughout the codebase. Changes: - Remove 'parallelism' nested config block from examples - Read 'sequence_parallel_size' directly from top-level config - Change priority: function params > config values > defaults (was: config > params > defaults) - Update create_gloo_process_groups default from True to False - Simplify documentation to reflect current implementation This makes the config-based initialization fully backward compatible without requiring any new config schema validation or parsing logic. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 99 +++++++-------------- 1 file changed, 32 insertions(+), 67 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 23bbe61b0b6d..0a2d71b112c6 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -704,7 +704,6 @@ def _get_local_all_to_all_group(name: Optional[str] = None): def initialize_parallel_state_from_config( config: Union[Dict[str, Any], Any], name: Optional[str] = None, - config_key: str = "parallelism", # Optional parameters to override config values tensor_model_parallel_size: Optional[int] = None, pipeline_model_parallel_size: Optional[int] = None, @@ -715,6 +714,7 @@ def initialize_parallel_state_from_config( expert_model_parallel_size: Optional[int] = None, num_distributed_optimizer_instances: Optional[int] = None, expert_tensor_parallel_size: Optional[int] = None, + sequence_parallel_size: Optional[int] = None, nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: Optional[int] = None, order: Optional[str] = None, @@ -724,15 +724,10 @@ def initialize_parallel_state_from_config( """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. This function reads parallelism configuration from the DeepSpeed config file - and automatically initializes the ParallelState instance. This allows users - to configure all parallelism dimensions in a single place (config.json) - rather than having to read documentation and manually call initialize_model_parallel. + (top-level fields) and automatically initializes the ParallelState instance. + This allows code to work with both explicit initialization and config-based initialization. - Configuration priority: config file (if explicitly set) > function parameters > default values - - Note: If a value is explicitly set in config file, it takes precedence over function - parameters. A warning will be logged if there's a conflict. To override config file - values, remove them from the config file first. + Configuration priority: function parameters > config file values > default values (1) Args: config: Either a DeepSpeedConfig object or a config dictionary. @@ -740,8 +735,6 @@ def initialize_parallel_state_from_config( If dict, will use it directly. name: Optional name of the parallel state instance to initialize. If None, initializes the default global instance. - config_key: Key in the config dictionary where parallelism config is stored. - Default is "parallelism". # Parallelism dimension parameters (override config if provided): tensor_model_parallel_size: Size of tensor model parallel group. Default: 1 @@ -753,44 +746,27 @@ def initialize_parallel_state_from_config( expert_model_parallel_size: Size of expert model parallel group. Default: 1 num_distributed_optimizer_instances: Number of distributed optimizer instances. Default: 1 expert_tensor_parallel_size: Size of expert tensor parallel group. Default: None + sequence_parallel_size: Size of sequence parallel group. Default: 1 nccl_communicator_config_path: Path to NCCL communicator config. Default: None distributed_timeout_minutes: Timeout for distributed operations. Default: 30 order: Order of parallelism dimensions. Default: "tp-ep-dp-pp" - create_gloo_process_groups: Whether to create Gloo process groups. Default: True + create_gloo_process_groups: Whether to create Gloo process groups. Default: False high_priority_stream_groups: High priority stream groups. Default: None - Example config.json: + Example config.json (using existing DeepSpeed config fields): { - "parallelism": { - "tensor_model_parallel_size": 2, - "pipeline_model_parallel_size": 1, - "expert_model_parallel_size": 1, - "expert_tensor_parallel_size": 1, - "virtual_pipeline_model_parallel_size": null, - "pipeline_model_parallel_comm_backend": null, - "num_distributed_optimizer_instances": 1, - "nccl_communicator_config_path": null, - "distributed_timeout_minutes": 30, - "order": "tp-ep-dp-pp", - "create_gloo_process_groups": false, - "high_priority_stream_groups": null, - "sequence_parallel_size": 1 - }, - - // Note: The following parameters are NOT supported in DeepSpeed: - // - "context_parallel_size": must be 1 (default) - // - "hierarchical_context_parallel_sizes": not supported - - // Sequence Parallel (SP) usage notes: - // - SP dimension is included in model_size calculation: model_size = tp * pp * cp * sp - // - Number of SP groups = data_parallel_size (each DP rank has its own SP group) - // - SP uses consecutive rank grouping, different from TP/PP/CP/EP orthogonal grouping - // - Example: world_size=16, tp=2, sp=2, pp=1, ep=1 => dp=4, and 4 SP groups - "train_batch_size": 8, - ... + "sequence_parallel_size": 1, + "zero_optimization": { + "stage": 1 + } } + Note: + - Currently only "sequence_parallel_size" can be read from config (existing field) + - Other parallelism parameters must be passed via function parameters or use defaults + - Context Parallel is NOT supported (cp must be 1) + Example usage: # Basic usage from config file: from deepspeed import DeepSpeedConfig @@ -828,11 +804,6 @@ def initialize_parallel_state_from_config( else: raise ValueError(f"config must be a DeepSpeedConfig object or a dict, got {type(config)}") - # Check if parallelism config exists in config file - parallelism_config = config_dict.get(config_key, {}) - if parallelism_config and not isinstance(parallelism_config, dict): - raise ValueError(f"'{config_key}' in config must be a dictionary, got {type(parallelism_config)}") - # Get the parallel state instance ps = get_parallel_state_instance(name) @@ -846,37 +817,29 @@ def initialize_parallel_state_from_config( logger = logging.getLogger(__name__) # Helper function to get value with proper priority handling - # Priority: config file (if explicitly set) > function parameter > default + # Priority: function parameter > config file value > default value def get_value(param_name, param_value, config_key, default_value): """ - Get value with priority handling and conflict detection. + Get value with priority handling. Priority: - 1. If config file explicitly sets the value -> use config value (warn if param differs) - 2. If config file doesn't have the value -> use function parameter - 3. If both are None -> use default value + 1. If function parameter is provided -> use parameter value + 2. If config file has the value -> use config value + 3. Otherwise -> use default value """ - config_has_key = config_key in parallelism_config - config_value = parallelism_config.get(config_key) - - # Case 1: Config file explicitly sets the value - if config_has_key: - # If function parameter is also provided and differs, warn and use config - if param_value is not None and param_value != config_value: - logger.warning(f"Parameter '{param_name}' conflict detected: " - f"config file specifies {config_value}, but function parameter is {param_value}. " - f"Using config file value ({config_value}). " - f"To override config, remove '{config_key}' from config file.") - return config_value - - # Case 2: Config file doesn't have the key, use function parameter if provided + # Case 1: Function parameter provided if param_value is not None: return param_value - # Case 3: Neither config nor parameter provided, use default + # Case 2: Config file has the key + if config_key in config_dict: + config_value = config_dict[config_key] + return config_value + + # Case 3: Use default return default_value - # Extract parameters with proper priority: config (if set) > function param > default + # Extract parameters with proper priority: function param > config value > default init_kwargs = { "tensor_model_parallel_size": get_value("tensor_model_parallel_size", tensor_model_parallel_size, "tensor_model_parallel_size", 1), @@ -890,6 +853,8 @@ def get_value(param_name, param_value, config_key, default_value): "pipeline_model_parallel_comm_backend", None), "context_parallel_size": get_value("context_parallel_size", context_parallel_size, "context_parallel_size", 1), + "sequence_parallel_size": + get_value("sequence_parallel_size", sequence_parallel_size, "sequence_parallel_size", 1), "hierarchical_context_parallel_sizes": get_value("hierarchical_context_parallel_sizes", hierarchical_context_parallel_sizes, "hierarchical_context_parallel_sizes", None), @@ -908,7 +873,7 @@ def get_value(param_name, param_value, config_key, default_value): "order": get_value("order", order, "order", "tp-ep-dp-pp"), "create_gloo_process_groups": - get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", True), + get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", False), "high_priority_stream_groups": get_value("high_priority_stream_groups", high_priority_stream_groups, "high_priority_stream_groups", None), } From 11e5a2c1b016a6996fc1efd6348f5e3749d1569a Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Mon, 26 Jan 2026 18:08:41 +0800 Subject: [PATCH 15/23] tests: Drop test_mpu.py from the PR The test_mpu.py script is used to verify the equivalence between existing process group management facilities and the proposed, unified ParallelState. It is meant to be an temporary helper and will not be useful after we switch existing implementations to the new interfaces. Thus remove it from the current PR. The test is still available at https://gist.github.com/eternalNight/b76c72216b4be84832b615b76465396f. Signed-off-by: Junjie Mao --- tests/unit/utils/test_mpu.py | 1692 ---------------------------------- 1 file changed, 1692 deletions(-) delete mode 100644 tests/unit/utils/test_mpu.py diff --git a/tests/unit/utils/test_mpu.py b/tests/unit/utils/test_mpu.py deleted file mode 100644 index 11ed585c92b3..000000000000 --- a/tests/unit/utils/test_mpu.py +++ /dev/null @@ -1,1692 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) DeepSpeed Team - -# DeepSpeed Team -""" -Automated testing of parallel strategy combinations using random configurations. - -This test automatically generates random parallel configurations and tests -both parallel_state_refactored and DeepSpeed to see if they produce compatible results. -""" - -import pytest -import random -from typing import Dict, List, Tuple, Optional -from collections import defaultdict - -# Try to import both libraries -try: - from deepspeed.utils.parallel_state import RankGenerator - PARALLEL_STATE_AVAILABLE = True -except ImportError as e: - PARALLEL_STATE_AVAILABLE = False - print(f"Warning: Could not import Megatron parallel_state_refactored: {e}") - -try: - from deepspeed.utils import groups as ds_groups - from deepspeed.runtime.sequence_parallel import parallel_state_sp as ds_sp - DEEPSPEED_AVAILABLE = True -except ImportError as e: - DEEPSPEED_AVAILABLE = False - print(f"Warning: Could not import DeepSpeed: {e}") - - -class ParallelConfigGenerator: - """Generate random parallel configurations for testing.""" - - def __init__(self, seed=None): - if seed is not None: - random.seed(seed) - self.tested_configs = [] - self.failed_configs = [] - - def generate_random_config(self, max_size=1024, min_parallel_size=1, max_parallel_size=32): - """Generate a random parallel configuration. - - Args: - max_size: Maximum world size to consider - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - Dict with tp, dp, pp, cp, ep values and order - """ - # Generate random sizes for each dimension - # Don't filter invalid configurations - we want to test and report all cases - tp = random.randint(min_parallel_size, max_parallel_size) - dp = random.randint(min_parallel_size, max_parallel_size) - pp = random.randint(min_parallel_size, max_parallel_size) - cp = random.randint(min_parallel_size, max_parallel_size) - ep = random.randint(min_parallel_size, max_parallel_size) - - # Calculate world size - world_size = tp * dp * pp * cp * ep - - # If world size is too large, scale down proportionally - # But try to keep at least one dimension > 1 - if world_size > max_size: - # Scale down proportionally - scale_factor = (max_size / world_size)**0.25 - tp = max(1, int(tp * scale_factor)) - dp = max(1, int(dp * scale_factor)) - pp = max(1, int(pp * scale_factor)) - cp = max(1, int(cp * scale_factor)) - ep = max(1, int(ep * scale_factor)) - world_size = tp * dp * pp * cp * ep - - # Ensure at least one dimension is > 1 - if world_size == 1: - tp = 2 - world_size = 2 - - # Generate random order (but must include all non-1 dimensions) - dimensions = [] - if tp > 1: - dimensions.append('tp') - if dp > 1: - dimensions.append('dp') - if pp > 1: - dimensions.append('pp') - if cp > 1: - dimensions.append('cp') - if ep > 1: - dimensions.append('ep') - - # Shuffle to get random order - random.shuffle(dimensions) - order = '-'.join(dimensions) if dimensions else 'tp' - - # If no dimensions > 1, use default - if not dimensions: - order = 'tp-dp' - tp = 2 - dp = 2 - - config = { - "tp": tp, - "dp": dp, - "pp": pp, - "cp": cp, - "ep": ep, - "order": order, - "world_size": tp * dp * pp * cp * ep, - } - - return config - - def generate_systematic_configs(self, max_world_size=512): - """Generate systematic configurations covering common cases. - - Args: - max_world_size: Maximum world size to consider - - Returns: - List of configurations - """ - configs = [] - - # Single parallelism - test larger sizes - for size in [2, 4, 8, 16, 32, 64, 128, 256]: - if size <= max_world_size: - configs.append({"tp": size, "dp": 1, "pp": 1, "cp": 1, "ep": 1, "order": "tp", "world_size": size}) - configs.append({"tp": 1, "dp": size, "pp": 1, "cp": 1, "ep": 1, "order": "dp", "world_size": size}) - configs.append({"tp": 1, "dp": 1, "pp": size, "cp": 1, "ep": 1, "order": "pp", "world_size": size}) - - # Two-way combinations - more variations - for tp, dp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4), (2, 16), (16, 2), (4, 8), (8, 4)]: - if tp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": tp * dp - }) - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": tp * dp - }) - - for tp, pp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4)]: - if tp * pp <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-pp", - "world_size": tp * pp - }) - - for tp, cp in [(2, 2), (2, 4), (4, 2), (2, 8)]: - if tp * cp <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": 1, - "cp": cp, - "ep": 1, - "order": "tp-cp", - "world_size": tp * cp - }) - - for tp, ep in [(2, 2), (2, 4), (4, 2), (2, 8)]: - if tp * ep <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": ep, - "order": "tp-ep", - "world_size": tp * ep - }) - - # Three-way combinations - more variations - for tp, pp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2), (4, 2, 2), (2, 2, 8), (2, 4, 4), (4, 4, 2)]: - if tp * pp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": tp * pp * dp - }) - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": tp * pp * dp - }) - - for tp, cp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: - if tp * cp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": cp, - "ep": 1, - "order": "tp-cp-dp", - "world_size": tp * cp * dp - }) - - for tp, ep, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: - if tp * ep * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": ep, - "order": "tp-ep-dp", - "world_size": tp * ep * dp - }) - - # Four-way combinations - more variations - for tp, pp, dp, cp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2), (2, 4, 2, 2)]: - if tp * pp * dp * cp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": cp, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": tp * pp * dp * cp - }) - - for tp, ep, pp, dp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2)]: - if tp * ep * pp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": ep, - "order": "tp-ep-pp-dp", - "world_size": tp * ep * pp * dp - }) - - return configs - - def generate_random_configs(self, count=1000, max_size=1024): - """Generate multiple random configurations. - - Args: - count: Number of random configurations to generate - max_size: Maximum world size - - Returns: - List of configurations - """ - configs = [] - seen = set() - - for _ in range(count): - config = self.generate_random_config(max_size=max_size) - # Create a unique key for this configuration - key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) - if key not in seen: - seen.add(key) - configs.append(config) - - return configs - - def generate_random_config_by_dimension(self, - dimension_count: int, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32): - """Generate a random configuration with exactly the specified number of dimensions > 1. - - Args: - dimension_count: Number of dimensions that should be > 1 (1-5) - max_size: Maximum world size - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - Dict with tp, dp, pp, cp, ep values and order - """ - # All possible dimensions - all_dims = ['tp', 'dp', 'pp', 'cp', 'ep'] - - # Randomly select which dimensions to activate - active_dims = random.sample(all_dims, min(dimension_count, len(all_dims))) - - # Initialize all dimensions to 1 - config = { - "tp": 1, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - } - - # Set active dimensions to random values - for dim in active_dims: - config[dim] = random.randint(min_parallel_size, max_parallel_size) - - # Calculate world size - world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] - - # If world size is too large, scale down proportionally - if world_size > max_size: - scale_factor = (max_size / world_size)**(1.0 / dimension_count) - for dim in active_dims: - config[dim] = max(min_parallel_size, int(config[dim] * scale_factor)) - world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] - - # Generate random order from active dimensions - random.shuffle(active_dims) - order = '-'.join(active_dims) - - config["order"] = order - config["world_size"] = world_size - - return config - - def generate_random_configs_by_dimension(self, - counts_by_dimension: Dict[int, int], - max_size=1024, - min_parallel_size=2, - max_parallel_size=32): - """Generate random configurations for each dimension separately. - - Args: - counts_by_dimension: Dict mapping dimension count (1-5) to number of configs to generate - e.g., {1: 100, 2: 200, 3: 150, 4: 100, 5: 50} - max_size: Maximum world size - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - List of configurations grouped by dimension count - """ - all_configs = [] - seen = set() - - for dim_count, count in counts_by_dimension.items(): - if dim_count < 1 or dim_count > 5: - continue - - dim_configs = [] - attempts = 0 - # Increased max_attempts for larger test sets (20x more configs) - max_attempts = count * 20 # Prevent infinite loops, allow more attempts for uniqueness - - while len(dim_configs) < count and attempts < max_attempts: - attempts += 1 - config = self.generate_random_config_by_dimension(dim_count, max_size, min_parallel_size, - max_parallel_size) - - # Create a unique key for this configuration - key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) - - if key not in seen: - seen.add(key) - dim_configs.append(config) - all_configs.append(config) - - if len(dim_configs) < count: - print( - f"Warning: Only generated {len(dim_configs)}/{count} configs for {dim_count}D combinations (attempted {attempts} times)" - ) - - return all_configs - - -class ErrorCategorizer: - """Categorize and aggregate errors by type.""" - - def __init__(self): - self.error_categories = defaultdict(list) - self.combination_stats = defaultdict(int) - - def categorize_error(self, error_msg: str, config: Dict) -> str: - """Categorize an error message into a category.""" - error_lower = error_msg.lower() - - if "ep and cp cannot both be > 1" in error_lower: - return "EP_CP_CONFLICT" - elif "cp not supported" in error_lower: - return "CP_NOT_SUPPORTED" - elif "pp requires" in error_lower or "pipeline" in error_lower: - return "PP_REQUIRES_MPU" - elif "not divisible" in error_lower: - return "DIVISIBILITY_ERROR" - elif "order" in error_lower and "specified" in error_lower: - return "ORDER_MISMATCH" - elif "not available" in error_lower: - return "FEATURE_NOT_AVAILABLE" - else: - return "OTHER_ERROR" - - def get_combination_type(self, config: Dict) -> str: - """Get the combination type string for a configuration.""" - dims = [] - if config["tp"] > 1: - dims.append("TP") - if config["dp"] > 1: - dims.append("DP") - if config["pp"] > 1: - dims.append("PP") - if config["cp"] > 1: - dims.append("CP") - if config["ep"] > 1: - dims.append("EP") - - if not dims: - return "NONE" - - return "+".join(sorted(dims)) - - def record_error(self, error_msg: str, config: Dict, library: str): - """Record an error with categorization.""" - category = self.categorize_error(error_msg, config) - combo_type = self.get_combination_type(config) - - self.error_categories[category].append({ - "error": error_msg, - "config": config, - "library": library, - "combination": combo_type, - }) - - self.combination_stats[combo_type] += 1 - - def get_error_summary(self) -> Dict: - """Get summary of errors by category.""" - summary = {} - for category, errors in self.error_categories.items(): - summary[category] = { - "count": len(errors), - "examples": errors[:5], # First 5 examples - "unique_combinations": len(set(e["combination"] for e in errors)), - } - return summary - - -class ParallelCompatibilityTester: - """Test compatibility between Megatron and DeepSpeed for parallel configurations.""" - - def __init__(self): - self.results = { - "megatron_success": [], - "megatron_failures": [], - "deepspeed_success": [], - "deepspeed_failures": [], - "compatible": [], - "incompatible": [], - "megatron_only": [], - "deepspeed_only": [], - } - self.error_categorizer = ErrorCategorizer() - self.combination_stats = defaultdict( - lambda: { - "total": 0, - "megatron_success": 0, - "megatron_failures": 0, - "deepspeed_success": 0, - "deepspeed_failures": 0, - "compatible": 0, - "megatron_only": 0, - "deepspeed_only": 0, - "incompatible": 0, - }) - - def test_megatron_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: - """Test if a configuration works with Megatron. - - Returns: - (success, error_message, result_data) - """ - if not PARALLEL_STATE_AVAILABLE: - return False, "Megatron not available", None - - try: - # Check EP and CP constraint - if config["ep"] > 1 and config["cp"] > 1: - return False, "EP and CP cannot both be > 1 in Megatron", None - - # Create RankGenerator - rg = RankGenerator(tp=config["tp"], - ep=config["ep"], - dp=config["dp"], - pp=config["pp"], - cp=config["cp"], - order=config["order"]) - - # Test getting ranks for each dimension - result_data = { - "world_size": rg.world_size, - "tp_groups": rg.get_ranks("tp") if config["tp"] > 1 else [], - "dp_groups": rg.get_ranks("dp") if config["dp"] > 1 else [], - "pp_groups": rg.get_ranks("pp") if config["pp"] > 1 else [], - "cp_groups": rg.get_ranks("cp") if config["cp"] > 1 else [], - "ep_groups": rg.get_ranks("ep") if config["ep"] > 1 else [], - } - - # Test combined groups - if len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) > 1: - combined_token = config["order"] - result_data["combined_groups"] = rg.get_ranks(combined_token) - - return True, None, result_data - - except Exception as e: - return False, str(e), None - - def test_deepspeed_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: - """Test if a configuration is supported by DeepSpeed. - - Returns: - (supported, error_message, support_info) - """ - if not DEEPSPEED_AVAILABLE: - return False, "DeepSpeed not available", None - - support_info = { - "tp_supported": False, - "dp_supported": False, - "pp_supported": False, - "cp_supported": False, - "ep_supported": False, - "sp_supported": False, - "notes": [], - } - - # Check TP support - if config["tp"] > 1: - support_info["tp_supported"] = hasattr(ds_groups, 'get_tensor_model_parallel_group') - - # Check DP support - if config["dp"] > 1: - support_info["dp_supported"] = hasattr(ds_groups, 'get_data_parallel_group') - - # Check PP support - if config["pp"] > 1: - # DeepSpeed supports PP via mpu or pipe module - support_info["pp_supported"] = (hasattr(ds_groups, 'bwc_pipeline_parallel_world_size') - or self._check_module_exists('deepspeed.pipe')) - if not support_info["pp_supported"]: - support_info["notes"].append("PP requires mpu object or deepspeed.pipe module") - - # Check CP support - if config["cp"] > 1: - support_info["cp_supported"] = hasattr(ds_groups, 'get_context_parallel_group') - if not support_info["cp_supported"]: - support_info["notes"].append("CP not supported in DeepSpeed") - - # Check EP support - if config["ep"] > 1: - support_info["ep_supported"] = (hasattr(ds_groups, '_create_expert_and_data_parallel') - or hasattr(ds_groups, '_create_expert_data_and_model_parallel')) - - # Check SP support (DeepSpeed-specific) - support_info["sp_supported"] = hasattr(ds_sp, 'initialize_sequence_parallel') - - # Determine overall support - required_dims = [d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1] - supported_dims = [] - if config["tp"] > 1 and support_info["tp_supported"]: - supported_dims.append("tp") - if config["dp"] > 1 and support_info["dp_supported"]: - supported_dims.append("dp") - if config["pp"] > 1 and support_info["pp_supported"]: - supported_dims.append("pp") - if config["cp"] > 1 and support_info["cp_supported"]: - supported_dims.append("cp") - if config["ep"] > 1 and support_info["ep_supported"]: - supported_dims.append("ep") - - fully_supported = len(supported_dims) == len(required_dims) - - return fully_supported, None, support_info - - def _check_module_exists(self, module_name): - """Check if a module exists.""" - try: - __import__(module_name) - return True - except ImportError: - return False - - def _simulate_deepspeed_rank_generation(self, config: Dict) -> Optional[Dict]: - """Simulate DeepSpeed's rank generation logic based on code analysis. - - This attempts to replicate DeepSpeed's rank assignment logic for comparison. - """ - if not DEEPSPEED_AVAILABLE: - return None - - try: - world_size = config["world_size"] - result = {} - - # For TP+DP: DeepSpeed uses mesh_device which creates groups in a specific way - if config["tp"] > 1 and config["dp"] > 1 and config["pp"] == 1 and config["cp"] == 1 and config["ep"] == 1: - # DeepSpeed's _init_tp_mesh_device creates: - # TP groups: [0,1], [2,3], [4,5], ... (consecutive) - # DP groups: [0,2,4,...], [1,3,5,...] (strided) - tp_size = config["tp"] - dp_size = config["dp"] - - tp_groups = [] - for i in range(world_size // tp_size): - group = list(range(i * tp_size, (i + 1) * tp_size)) - tp_groups.append(group) - - dp_groups = [] - for i in range(tp_size): - group = list(range(i, world_size, tp_size)) - dp_groups.append(group) - - result["tp_groups"] = tp_groups - result["dp_groups"] = dp_groups - result["world_size"] = world_size - return result - - # For other combinations, we can't easily simulate without actual distributed setup - # But we can note that DeepSpeed supports it - return {"supported": True, "note": "Rank generation requires actual distributed setup"} - - except Exception as e: - return {"error": str(e)} - - def _compare_rank_groups(self, megatron_groups: List[List[int]], deepspeed_groups: List[List[int]]) -> Dict: - """Compare rank groups from Megatron and DeepSpeed. - - Returns: - Dict with comparison results - """ - comparison = {"same_structure": False, "same_ranks": False, "differences": []} - - if not megatron_groups or not deepspeed_groups: - return comparison - - # Check if same number of groups - if len(megatron_groups) != len(deepspeed_groups): - comparison["differences"].append( - f"Group count mismatch: Megatron={len(megatron_groups)}, DeepSpeed={len(deepspeed_groups)}") - return comparison - - # Check if same group sizes - megatron_sizes = [len(g) for g in megatron_groups] - deepspeed_sizes = [len(g) for g in deepspeed_groups] - if megatron_sizes != deepspeed_sizes: - comparison["differences"].append( - f"Group size mismatch: Megatron={megatron_sizes}, DeepSpeed={deepspeed_sizes}") - return comparison - - # Check if same ranks (order may differ) - megatron_ranks = set() - for group in megatron_groups: - megatron_ranks.update(group) - - deepspeed_ranks = set() - for group in deepspeed_groups: - deepspeed_ranks.update(group) - - if megatron_ranks != deepspeed_ranks: - comparison["differences"].append( - f"Rank set mismatch: Megatron={sorted(megatron_ranks)}, DeepSpeed={sorted(deepspeed_ranks)}") - return comparison - - # Check if same structure (same groups, possibly different order) - megatron_sets = [set(g) for g in megatron_groups] - deepspeed_sets = [set(g) for g in deepspeed_groups] - - if sorted(megatron_sets, key=lambda x: min(x)) == sorted(deepspeed_sets, key=lambda x: min(x)): - comparison["same_structure"] = True - comparison["same_ranks"] = True - else: - comparison["differences"].append("Group structure differs (same ranks but different grouping)") - - return comparison - - def test_config_compatibility(self, config: Dict): - """Test compatibility of a configuration between both libraries.""" - # Get combination type for statistics - combo_type = self.error_categorizer.get_combination_type(config) - self.combination_stats[combo_type]["total"] += 1 - - # Test Megatron - megatron_success, megatron_error, megatron_result = self.test_megatron_config(config) - - # Test DeepSpeed - deepspeed_success, deepspeed_error, deepspeed_support = self.test_deepspeed_config(config) - - # Record errors in categorizer - if not megatron_success and megatron_error: - self.error_categorizer.record_error(megatron_error, config, "Megatron") - self.combination_stats[combo_type]["megatron_failures"] += 1 - else: - self.combination_stats[combo_type]["megatron_success"] += 1 - - if not deepspeed_success: - # Get error message from support_info notes - error_msg = deepspeed_support.get("notes", ["Not supported"])[0] if deepspeed_support else "Not supported" - self.error_categorizer.record_error(error_msg, config, "DeepSpeed") - self.combination_stats[combo_type]["deepspeed_failures"] += 1 - else: - self.combination_stats[combo_type]["deepspeed_success"] += 1 - - # Try to simulate DeepSpeed rank generation for comparison - deepspeed_simulated = None - if megatron_success and deepspeed_success: - deepspeed_simulated = self._simulate_deepspeed_rank_generation(config) - - # Compare rank generation if both succeeded and we have simulated results - rank_comparison = None - if megatron_success and deepspeed_success and deepspeed_simulated and "tp_groups" in deepspeed_simulated: - # Compare TP groups - if config["tp"] > 1 and "tp_groups" in megatron_result: - rank_comparison = self._compare_rank_groups(megatron_result["tp_groups"], - deepspeed_simulated.get("tp_groups", [])) - # Compare DP groups - if config["dp"] > 1 and "dp_groups" in megatron_result and not rank_comparison: - rank_comparison = self._compare_rank_groups(megatron_result["dp_groups"], - deepspeed_simulated.get("dp_groups", [])) - - # Record results - config_key = f"tp={config['tp']},dp={config['dp']},pp={config['pp']},cp={config['cp']},ep={config['ep']},order={config['order']}" - - if megatron_success: - self.results["megatron_success"].append(config_key) - else: - self.results["megatron_failures"].append({ - "config": config_key, - "error": megatron_error, - "combination": combo_type, - }) - - if deepspeed_success: - self.results["deepspeed_success"].append(config_key) - else: - self.results["deepspeed_failures"].append({ - "config": config_key, - "error": deepspeed_error, - "support_info": deepspeed_support, - "combination": combo_type, - }) - - # Determine compatibility and update stats - if megatron_success and deepspeed_success: - compat_entry = { - "config": config_key, - "megatron_result": megatron_result, - "deepspeed_support": deepspeed_support, - "combination": combo_type, - } - if rank_comparison: - compat_entry["rank_comparison"] = rank_comparison - if rank_comparison.get("same_structure"): - compat_entry["rank_match"] = True - else: - compat_entry["rank_match"] = False - compat_entry["rank_differences"] = rank_comparison.get("differences", []) - self.results["compatible"].append(compat_entry) - self.combination_stats[combo_type]["compatible"] += 1 - elif megatron_success and not deepspeed_success: - self.results["megatron_only"].append({ - "config": - config_key, - "megatron_result": - megatron_result, - "deepspeed_issue": - deepspeed_support.get("notes", []) if deepspeed_support else [], - "combination": - combo_type, - }) - self.combination_stats[combo_type]["megatron_only"] += 1 - elif not megatron_success and deepspeed_success: - self.results["deepspeed_only"].append({ - "config": config_key, - "megatron_error": megatron_error, - "deepspeed_support": deepspeed_support, - "combination": combo_type, - }) - self.combination_stats[combo_type]["deepspeed_only"] += 1 - else: - self.results["incompatible"].append({ - "config": - config_key, - "megatron_error": - megatron_error, - "deepspeed_issue": - deepspeed_support.get("notes", []) if deepspeed_support else [], - "combination": - combo_type, - }) - self.combination_stats[combo_type]["incompatible"] += 1 - - -class TestAutomatedParallelCombinations: - """Automated tests for parallel strategy combinations.""" - - def test_systematic_configurations(self): - """Test systematic configurations covering common cases.""" - generator = ParallelConfigGenerator(seed=42) - tester = ParallelCompatibilityTester() - - configs = generator.generate_systematic_configs(max_world_size=16) - - print("\n" + "=" * 80) - print("SYSTEMATIC CONFIGURATION TESTING") - print("=" * 80) - print(f"\nTesting {len(configs)} systematic configurations...") - - for i, config in enumerate(configs, 1): - print(f"\n[{i}/{len(configs)}] Testing: {config}") - tester.test_config_compatibility(config) - - self._print_results(tester, "Systematic") - self._generate_comprehensive_report(tester, "Systematic") - - def test_random_configurations(self): - """Test random configurations.""" - generator = ParallelConfigGenerator(seed=123) - tester = ParallelCompatibilityTester() - - configs = generator.generate_random_configs(count=1000, max_size=1024) - - print("\n" + "=" * 80) - print("RANDOM CONFIGURATION TESTING") - print("=" * 80) - print(f"\nTesting {len(configs)} random configurations...") - print(f"Max world size: 1024, Max parallel size per dimension: 32") - - for i, config in enumerate(configs, 1): - if i % 100 == 0: - print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") - tester.test_config_compatibility(config) - - self._print_results(tester, "Random") - self._generate_comprehensive_report(tester, "Random") - - def test_random_configurations_by_dimension(self): - """Test random configurations generated separately for each dimension.""" - generator = ParallelConfigGenerator(seed=789) - tester = ParallelCompatibilityTester() - - # Generate configs for each dimension separately - # This ensures balanced coverage across all dimensions - # Increased by 20x for comprehensive testing - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs (200 * 20) - 2: 6000, # 2D: 6000 configs (300 * 20) - more because there are more 2D combinations - 3: 5000, # 3D: 5000 configs (250 * 20) - 4: 3000, # 4D: 3000 configs (150 * 20) - 5: 2000, # 5D: 2000 configs (100 * 20) - } - - print("\n" + "=" * 80) - print("RANDOM CONFIGURATION TESTING BY DIMENSION") - print("=" * 80) - print(f"\nGenerating configurations by dimension:") - for dim, count in counts_by_dimension.items(): - print(f" {dim}D: {count} configurations") - - configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - - print(f"\nTotal unique configurations generated: {len(configs)}") - print(f"Max world size: 1024, Parallel size range: 2-32") - - # Count configs by dimension - dim_counts = defaultdict(int) - for config in configs: - dim_count = len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) - dim_counts[dim_count] += 1 - - print("\nActual distribution:") - for dim in sorted(dim_counts.keys()): - print(f" {dim}D: {dim_counts[dim]} configurations") - - print(f"\nTesting {len(configs)} configurations...") - - for i, config in enumerate(configs, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs): - print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") - tester.test_config_compatibility(config) - - self._print_results(tester, "Random by Dimension") - self._generate_comprehensive_report(tester, "Random by Dimension") - - def test_edge_cases(self): - """Test edge cases and boundary conditions.""" - generator = ParallelConfigGenerator(seed=456) - tester = ParallelCompatibilityTester() - - # Edge cases - including larger sizes - edge_configs = [ - # Maximum dimensions - larger sizes - { - "tp": 8, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 512 - }, - { - "tp": 16, - "dp": 16, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 1024 - }, - # EP and CP conflict - { - "tp": 2, - "dp": 2, - "pp": 1, - "cp": 2, - "ep": 2, - "order": "tp-ep-dp", - "world_size": 8 - }, - { - "tp": 4, - "dp": 4, - "pp": 1, - "cp": 4, - "ep": 4, - "order": "tp-ep-dp", - "world_size": 64 - }, - # Single dimension - larger sizes - { - "tp": 1, - "dp": 1, - "pp": 64, - "cp": 1, - "ep": 1, - "order": "pp", - "world_size": 64 - }, - { - "tp": 128, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp", - "world_size": 128 - }, - { - "tp": 1, - "dp": 256, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp", - "world_size": 256 - }, - # All dimensions - larger sizes - { - "tp": 2, - "dp": 2, - "pp": 2, - "cp": 2, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 16 - }, - { - "tp": 4, - "dp": 4, - "pp": 4, - "cp": 4, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 256 - }, - # Different orders - { - "tp": 2, - "dp": 4, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": 8 - }, - { - "tp": 2, - "dp": 4, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": 8 - }, - { - "tp": 8, - "dp": 16, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": 128 - }, - { - "tp": 8, - "dp": 16, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": 128 - }, - # Large multi-dimensional - { - "tp": 8, - "dp": 8, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": 256 - }, - { - "tp": 4, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": 256 - }, - ] - - print("\n" + "=" * 80) - print("EDGE CASE TESTING") - print("=" * 80) - print(f"\nTesting {len(edge_configs)} edge case configurations...") - - for i, config in enumerate(edge_configs, 1): - print(f"\n[{i}/{len(edge_configs)}] Testing: {config}") - tester.test_config_compatibility(config) - - self._print_results(tester, "Edge Cases") - self._generate_comprehensive_report(tester, "Edge Cases") - - def _print_results(self, tester: ParallelCompatibilityTester, test_type: str): - """Print test results.""" - results = tester.results - - print("\n" + "=" * 80) - print(f"{test_type} TEST RESULTS") - print("=" * 80) - - print(f"\n✓ Megatron Success: {len(results['megatron_success'])}") - print(f"✗ Megatron Failures: {len(results['megatron_failures'])}") - if results['megatron_failures']: - print("\nMegatron Failures:") - for failure in results['megatron_failures'][:10]: # Show first 10 - print(f" - {failure['config']}: {failure['error']}") - if len(results['megatron_failures']) > 10: - print(f" ... and {len(results['megatron_failures']) - 10} more") - - print(f"\n✓ DeepSpeed Success: {len(results['deepspeed_success'])}") - print(f"✗ DeepSpeed Failures: {len(results['deepspeed_failures'])}") - if results['deepspeed_failures']: - print("\nDeepSpeed Failures:") - for failure in results['deepspeed_failures'][:10]: # Show first 10 - print(f" - {failure['config']}") - if failure.get('support_info'): - notes = failure['support_info'].get('notes', []) - if notes: - print(f" Notes: {', '.join(notes)}") - if len(results['deepspeed_failures']) > 10: - print(f" ... and {len(results['deepspeed_failures']) - 10} more") - - print(f"\n✓ Compatible (Both Support): {len(results['compatible'])}") - if results['compatible']: - print(" Examples:") - rank_matches = 0 - rank_mismatches = 0 - for item in results['compatible'][:10]: - if isinstance(item, dict): - config = item.get('config', 'Unknown') - rank_comp = item.get('rank_comparison') - if rank_comp: - if rank_comp.get('same_structure'): - print(f" - {config} ✓ Rank groups match") - rank_matches += 1 - else: - print(f" - {config} ⚠ Rank groups differ") - rank_mismatches += 1 - if rank_comp.get('differences'): - for diff in rank_comp['differences'][:2]: - print(f" {diff}") - else: - print(f" - {config}") - else: - print(f" - {item}") - if len(results['compatible']) > 10: - print(f" ... and {len(results['compatible']) - 10} more") - - if rank_matches > 0 or rank_mismatches > 0: - print(f"\n Rank Comparison Summary:") - print(f" Matches: {rank_matches}") - print(f" Mismatches: {rank_mismatches}") - print(f" (Note: Comparison only available for TP+DP combinations)") - - print(f"\n⚠ Megatron Only: {len(results['megatron_only'])}") - if results['megatron_only']: - print(" Examples:") - for item in results['megatron_only'][:5]: - print(f" - {item['config']}") - if item.get('deepspeed_issue'): - print(f" DeepSpeed issue: {', '.join(item['deepspeed_issue'])}") - if len(results['megatron_only']) > 5: - print(f" ... and {len(results['megatron_only']) - 5} more") - - print(f"\n→ DeepSpeed Only: {len(results['deepspeed_only'])}") - if results['deepspeed_only']: - print(" Examples:") - for item in results['deepspeed_only'][:5]: - print(f" - {item['config']}") - print(f" Megatron error: {item['megatron_error']}") - if len(results['deepspeed_only']) > 5: - print(f" ... and {len(results['deepspeed_only']) - 5} more") - - print(f"\n✗ Incompatible (Neither Support): {len(results['incompatible'])}") - if results['incompatible']: - print(" Examples:") - for item in results['incompatible'][:5]: - print(f" - {item['config']}") - print(f" Megatron: {item['megatron_error']}") - if len(results['incompatible']) > 5: - print(f" ... and {len(results['incompatible']) - 5} more") - - print("\n" + "=" * 80) - - def _generate_comprehensive_report(self, tester: ParallelCompatibilityTester, test_type: str): - """Generate comprehensive test report with error categorization and combination statistics.""" - results = tester.results - error_summary = tester.error_categorizer.get_error_summary() - combo_stats = tester.combination_stats - - print("\n" + "=" * 80) - print(f"{test_type} COMPREHENSIVE TEST REPORT") - print("=" * 80) - - # Overall statistics - print("\n" + "-" * 80) - print("OVERALL STATISTICS") - print("-" * 80) - total_tested = (len(results['megatron_success']) + len(results['megatron_failures']) + - len(results['deepspeed_success']) + len(results['deepspeed_failures'])) - print(f"Total Configurations Tested: {total_tested}") - print( - f" Megatron Success: {len(results['megatron_success'])} ({len(results['megatron_success'])/total_tested*100:.1f}%)" - ) - print( - f" Megatron Failures: {len(results['megatron_failures'])} ({len(results['megatron_failures'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Success: {len(results['deepspeed_success'])} ({len(results['deepspeed_success'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Failures: {len(results['deepspeed_failures'])} ({len(results['deepspeed_failures'])/total_tested*100:.1f}%)" - ) - print(f" Compatible: {len(results['compatible'])} ({len(results['compatible'])/total_tested*100:.1f}%)") - print( - f" Megatron Only: {len(results['megatron_only'])} ({len(results['megatron_only'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Only: {len(results['deepspeed_only'])} ({len(results['deepspeed_only'])/total_tested*100:.1f}%)" - ) - print(f" Incompatible: {len(results['incompatible'])} ({len(results['incompatible'])/total_tested*100:.1f}%)") - - # Error categorization - print("\n" + "-" * 80) - print("ERROR CATEGORIZATION (Aggregated by Type)") - print("-" * 80) - for category, summary in sorted(error_summary.items(), key=lambda x: x[1]['count'], reverse=True): - print(f"\n{category}: {summary['count']} occurrences") - print(f" Affects {summary['unique_combinations']} unique combination types") - print(f" Examples:") - for example in summary['examples'][:3]: - combo = example.get('combination', 'Unknown') - lib = example.get('library', 'Unknown') - print(f" - {combo} ({lib}): {example['error'][:80]}") - if len(summary['examples']) > 3: - print(f" ... and {len(summary['examples']) - 3} more examples") - - # Combination type statistics - print("\n" + "-" * 80) - print("COMBINATION TYPE STATISTICS") - print("-" * 80) - print( - f"{'Combination':<20} {'Total':<8} {'M-Succ':<8} {'M-Fail':<8} {'DS-Succ':<8} {'DS-Fail':<8} {'Compat':<8} {'M-Only':<8} {'DS-Only':<8} {'Incomp':<8}" - ) - print("-" * 100) - - # Sort by total count - sorted_combos = sorted(combo_stats.items(), key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in sorted_combos: - if stats['total'] > 0: - print(f"{combo_type:<20} {stats['total']:<8} {stats['megatron_success']:<8} " - f"{stats['megatron_failures']:<8} {stats['deepspeed_success']:<8} " - f"{stats['deepspeed_failures']:<8} {stats['compatible']:<8} " - f"{stats['megatron_only']:<8} {stats['deepspeed_only']:<8} " - f"{stats['incompatible']:<8}") - - # Detailed combination analysis - print("\n" + "-" * 80) - print("DETAILED COMBINATION ANALYSIS") - print("-" * 80) - - # Group by number of dimensions - by_dimension_count = defaultdict(list) - for combo_type, stats in combo_stats.items(): - dim_count = len([c for c in combo_type.split('+') if c != 'NONE']) - by_dimension_count[dim_count].append((combo_type, stats)) - - for dim_count in sorted(by_dimension_count.keys()): - print(f"\n{dim_count}-Dimensional Combinations:") - combos = sorted(by_dimension_count[dim_count], key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in combos[:10]: # Show top 10 - if stats['total'] > 0: - compat_rate = (stats['compatible'] / stats['total'] * 100) if stats['total'] > 0 else 0 - print(f" {combo_type}:") - print(f" Total: {stats['total']}, Compatible: {stats['compatible']} ({compat_rate:.1f}%)") - print(f" Megatron: {stats['megatron_success']} success, {stats['megatron_failures']} failures") - print( - f" DeepSpeed: {stats['deepspeed_success']} success, {stats['deepspeed_failures']} failures") - if len(combos) > 10: - print(f" ... and {len(combos) - 10} more {dim_count}-dimensional combinations") - - print("\n" + "=" * 80) - - def test_cp_vs_sp_compatibility_by_dimension(self): - """Test CP vs SP compatibility using the same config generation as test_random_configurations_by_dimension. - - This test: - 1. Uses parallel_state_refactored with CP - 2. Uses DeepSpeed with SP - 3. Compares CP rank groups with SP rank groups to see if they match - """ - generator = ParallelConfigGenerator(seed=789) - - # Use the same configuration generation as test_random_configurations_by_dimension - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs - 2: 6000, # 2D: 6000 configs - 3: 5000, # 3D: 5000 configs - 4: 3000, # 4D: 3000 configs - 5: 2000, # 5D: 2000 configs - } - - print("\n" + "=" * 80) - print("CP vs SP COMPATIBILITY TESTING BY DIMENSION") - print("=" * 80) - print(f"\nGenerating configurations by dimension:") - for dim, count in counts_by_dimension.items(): - print(f" {dim}D: {count} configurations") - - configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - - # Filter to only include configs with CP > 1 and EP == 1 (EP and CP cannot both be > 1) - configs_with_cp = [c for c in configs if c["cp"] > 1 and c["ep"] == 1] - - print(f"\nTotal unique configurations generated: {len(configs)}") - print(f"Configurations with CP > 1 and EP == 1: {len(configs_with_cp)}") - print(f"Max world size: 1024, Parallel size range: 2-32") - - # Test CP vs SP compatibility - results = { - "total_tested": 0, - "cp_groups_generated": 0, - "sp_groups_generated": 0, - "rank_groups_match": 0, - "rank_groups_differ": 0, - "errors": 0, - "match_details": [], - "differ_details": [], - } - - combination_stats = defaultdict(lambda: { - "total": 0, - "match": 0, - "differ": 0, - "errors": 0, - }) - - print(f"\nTesting {len(configs_with_cp)} configurations for CP vs SP compatibility...") - - for i, config in enumerate(configs_with_cp, 1): - if i % 1000 == 0 or i == len(configs_with_cp): - print(f"Progress: {i}/{len(configs_with_cp)} ({(i/len(configs_with_cp)*100):.1f}%)") - - results["total_tested"] += 1 - - # Get combination type - combo_type = self._get_combination_type_for_cp_sp(config) - combination_stats[combo_type]["total"] += 1 - - try: - # Get CP rank groups from Megatron - if not PARALLEL_STATE_AVAILABLE: - results["errors"] += 1 - combination_stats[combo_type]["errors"] += 1 - continue - - rg = RankGenerator(tp=config["tp"], - ep=config["ep"], - dp=config["dp"], - pp=config["pp"], - cp=config["cp"], - order=config["order"]) - - cp_groups = rg.get_ranks("cp") - if cp_groups: - results["cp_groups_generated"] += 1 - - # Simulate SP rank groups from DeepSpeed - # DeepSpeed SP creates consecutive rank groups - sp_groups = self._simulate_deepspeed_sp_groups(config["world_size"], config["cp"]) - if sp_groups: - results["sp_groups_generated"] += 1 - - # Compare CP and SP groups - if self._compare_cp_sp_groups(cp_groups, sp_groups): - results["rank_groups_match"] += 1 - combination_stats[combo_type]["match"] += 1 - results["match_details"].append(config) - else: - results["rank_groups_differ"] += 1 - combination_stats[combo_type]["differ"] += 1 - results["differ_details"].append({ - "config": config, - "cp_groups": cp_groups, - "sp_groups": sp_groups, - }) - - except Exception as e: - results["errors"] += 1 - combination_stats[combo_type]["errors"] += 1 - - # Generate report - self._generate_cp_vs_sp_report(results, combination_stats) - - def _simulate_deepspeed_sp_groups(self, world_size: int, sp_size: int) -> List[List[int]]: - """Simulate DeepSpeed's SP rank group generation. - - DeepSpeed SP creates groups as consecutive ranks: - - Group 0: [0, 1, ..., sp_size-1] - - Group 1: [sp_size, sp_size+1, ..., 2*sp_size-1] - - etc. - """ - if sp_size <= 1 or world_size % sp_size != 0: - return [] - - num_groups = world_size // sp_size - groups = [] - for i in range(num_groups): - group = list(range(i * sp_size, (i + 1) * sp_size)) - groups.append(group) - - return groups - - def _compare_cp_sp_groups(self, cp_groups: List[List[int]], sp_groups: List[List[int]]) -> bool: - """Compare CP and SP rank groups to see if they match.""" - if not cp_groups and not sp_groups: - return True - - if not cp_groups or not sp_groups: - return False - - if len(cp_groups) != len(sp_groups): - return False - - # Check if all CP groups have a matching SP group (order may differ) - cp_sets = [set(g) for g in cp_groups] - sp_sets = [set(g) for g in sp_groups] - - # Check if all CP groups match SP groups - for cp_set in cp_sets: - found = False - for sp_set in sp_sets: - if cp_set == sp_set: - found = True - break - if not found: - return False - - # Check if all SP groups match CP groups - for sp_set in sp_sets: - found = False - for cp_set in cp_sets: - if sp_set == cp_set: - found = True - break - if not found: - return False - - return True - - def _get_combination_type_for_cp_sp(self, config: Dict) -> str: - """Get combination type string for CP vs SP testing.""" - dims = [] - if config["tp"] > 1: - dims.append("TP") - if config["dp"] > 1: - dims.append("DP") - if config["pp"] > 1: - dims.append("PP") - if config["cp"] > 1: - dims.append("CP") - # Note: EP is always 1 in this test - - if not dims: - return "NONE" - - return "+".join(sorted(dims)) - - def _generate_cp_vs_sp_report(self, results: Dict, combination_stats: Dict): - """Generate comprehensive CP vs SP compatibility report.""" - print("\n" + "=" * 80) - print("CP vs SP COMPATIBILITY TEST REPORT") - print("=" * 80) - - # Overall statistics - print("\n" + "-" * 80) - print("OVERALL STATISTICS") - print("-" * 80) - print(f"Total Configurations Tested: {results['total_tested']}") - print(f" CP Groups Generated: {results['cp_groups_generated']}") - print(f" SP Groups Generated: {results['sp_groups_generated']}") - print(f" Rank Groups Match: {results['rank_groups_match']}") - print(f" Rank Groups Differ: {results['rank_groups_differ']}") - print(f" Errors: {results['errors']}") - - if results['total_tested'] > 0: - match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 - print(f"\n Match Rate: {match_rate:.2f}%") - print(f" CP can replace SP in {match_rate:.2f}% of tested configurations") - - # Combination type statistics - print("\n" + "-" * 80) - print("COMBINATION TYPE STATISTICS") - print("-" * 80) - print(f"{'Combination':<20} {'Total':<8} {'Match':<8} {'Differ':<8} {'Errors':<8} {'Match Rate':<12}") - print("-" * 80) - - sorted_combos = sorted(combination_stats.items(), key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in sorted_combos: - if stats['total'] > 0: - match_rate = (stats['match'] / stats['total'] * 100) if stats['total'] > 0 else 0 - print(f"{combo_type:<20} {stats['total']:<8} {stats['match']:<8} " - f"{stats['differ']:<8} {stats['errors']:<8} {match_rate:.1f}%") - - # Examples of matching configurations - print("\n" + "-" * 80) - print("EXAMPLES OF MATCHING CONFIGURATIONS (CP can replace SP)") - print("-" * 80) - for i, config in enumerate(results['match_details'][:10], 1): - print(f"{i}. {config}") - print(f" CP size: {config['cp']}, Order: {config['order']}") - - if len(results['match_details']) > 10: - print(f"\n... and {len(results['match_details']) - 10} more matching configurations") - - # Examples of differing configurations - if results['differ_details']: - print("\n" + "-" * 80) - print("EXAMPLES OF DIFFERING CONFIGURATIONS (CP cannot replace SP)") - print("-" * 80) - for i, item in enumerate(results['differ_details'][:10], 1): - config = item['config'] - cp_groups = item['cp_groups'] - sp_groups = item['sp_groups'] - print(f"{i}. {config}") - print(f" CP size: {config['cp']}, Order: {config['order']}") - print(f" CP groups count: {len(cp_groups)}, SP groups count: {len(sp_groups)}") - if cp_groups and sp_groups: - print(f" CP first group: {cp_groups[0]}") - print(f" SP first group: {sp_groups[0]}") - - if len(results['differ_details']) > 10: - print(f"\n... and {len(results['differ_details']) - 10} more differing configurations") - - # Conclusion - print("\n" + "=" * 80) - print("CONCLUSION") - print("=" * 80) - if results['rank_groups_match'] > 0: - match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 - print(f"\n✓ CP can replace SP in {match_rate:.2f}% of tested configurations") - print( - f" - {results['rank_groups_match']} out of {results['total_tested']} configurations have matching rank groups" - ) - else: - print("\n✗ CP cannot replace SP in any of the tested configurations") - - if results['rank_groups_differ'] > 0: - print(f"\n⚠ {results['rank_groups_differ']} configurations have different rank groups") - print(" - These configurations may require special handling when migrating from CP to SP") - - print("\n" + "=" * 80) - - def test_comprehensive_automated_testing(self): - """Comprehensive automated testing with all test types.""" - print("\n" + "=" * 80) - print("COMPREHENSIVE AUTOMATED PARALLEL COMBINATION TESTING") - print("=" * 80) - - # Create a combined tester for overall report - combined_tester = ParallelCompatibilityTester() - - # Run all test types and accumulate results - print("\n[1/3] Running systematic configurations...") - generator1 = ParallelConfigGenerator(seed=42) - configs1 = generator1.generate_systematic_configs(max_world_size=512) - print(f"Testing {len(configs1)} systematic configurations...") - for i, config in enumerate(configs1, 1): - if i % 50 == 0 or i == len(configs1): - print(f" Progress: {i}/{len(configs1)}") - combined_tester.test_config_compatibility(config) - - print("\n[2/4] Running random configurations by dimension...") - generator2 = ParallelConfigGenerator(seed=789) - # Increased by 20x for comprehensive testing - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs (200 * 20) - 2: 6000, # 2D: 6000 configs (300 * 20) - 3: 5000, # 3D: 5000 configs (250 * 20) - 4: 3000, # 4D: 3000 configs (150 * 20) - 5: 2000, # 5D: 2000 configs (100 * 20) - } - configs2 = generator2.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - print(f"Testing {len(configs2)} random configurations (balanced by dimension)...") - print(f"Max world size: 1024, Parallel size range: 2-32") - for i, config in enumerate(configs2, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs2): - print(f" Progress: {i}/{len(configs2)} ({(i/len(configs2)*100):.1f}%)") - combined_tester.test_config_compatibility(config) - - print("\n[3/4] Running additional random configurations...") - generator3 = ParallelConfigGenerator(seed=123) - # Increased by 20x: 500 * 20 = 10000 - configs3 = generator3.generate_random_configs(count=10000, max_size=1024) - print(f"Testing {len(configs3)} additional random configurations...") - for i, config in enumerate(configs3, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs3): - print(f" Progress: {i}/{len(configs3)} ({(i/len(configs3)*100):.1f}%)") - combined_tester.test_config_compatibility(config) - - print("\n[4/4] Running edge cases...") - edge_configs = [ - { - "tp": 8, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 512 - }, - { - "tp": 16, - "dp": 16, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 1024 - }, - { - "tp": 2, - "dp": 2, - "pp": 1, - "cp": 2, - "ep": 2, - "order": "tp-ep-dp", - "world_size": 8 - }, - { - "tp": 4, - "dp": 4, - "pp": 1, - "cp": 4, - "ep": 4, - "order": "tp-ep-dp", - "world_size": 64 - }, - { - "tp": 1, - "dp": 1, - "pp": 64, - "cp": 1, - "ep": 1, - "order": "pp", - "world_size": 64 - }, - { - "tp": 128, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp", - "world_size": 128 - }, - { - "tp": 1, - "dp": 256, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp", - "world_size": 256 - }, - { - "tp": 2, - "dp": 2, - "pp": 2, - "cp": 2, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 16 - }, - { - "tp": 4, - "dp": 4, - "pp": 4, - "cp": 4, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 256 - }, - ] - print(f"Testing {len(edge_configs)} edge case configurations...") - for config in edge_configs: - combined_tester.test_config_compatibility(config) - - # Generate comprehensive report - print("\n" + "=" * 80) - print("COMPREHENSIVE FINAL REPORT") - print("=" * 80) - self._generate_comprehensive_report(combined_tester, "COMPREHENSIVE") - - print("\n" + "=" * 80) - print("ALL TESTS COMPLETED") - print("=" * 80) - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) From f7bd2dcde4224cac8145229cee605f75764fc661 Mon Sep 17 00:00:00 2001 From: yunqing Date: Wed, 11 Feb 2026 11:55:07 +0800 Subject: [PATCH 16/23] fix: filter unsupported params in initialize_parallel_state_from_config and add integration tests Add a supported_params whitelist to prevent unsupported parameters (nccl_communicator_config_path, high_priority_stream_groups) from being passed to initialize_model_parallel. Also add comprehensive integration tests for ParallelState as mpu with 5-batch training loops. Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state_deepspeed.py | 18 +- .../utils/test_parallel_state_deepspeed.py | 459 ++++++++++++++++++ 2 files changed, 473 insertions(+), 4 deletions(-) create mode 100644 tests/unit/utils/test_parallel_state_deepspeed.py diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 0a2d71b112c6..e1342021ee5c 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -894,14 +894,24 @@ def get_value(param_name, param_value, config_key, default_value): # Remove None values for optional parameters (except those that can be None) # Keep None for: virtual_pipeline_model_parallel_size, pipeline_model_parallel_comm_backend, - # hierarchical_context_parallel_sizes, expert_tensor_parallel_size, nccl_communicator_config_path, - # high_priority_stream_groups + # hierarchical_context_parallel_sizes, expert_tensor_parallel_size + # Note: nccl_communicator_config_path and high_priority_stream_groups are not supported by initialize_model_parallel filtered_kwargs = {} + supported_params = { + "tensor_model_parallel_size", "pipeline_model_parallel_size", "virtual_pipeline_model_parallel_size", + "pipeline_model_parallel_comm_backend", "context_parallel_size", "sequence_parallel_size", + "hierarchical_context_parallel_sizes", "expert_model_parallel_size", "num_distributed_optimizer_instances", + "expert_tensor_parallel_size", "distributed_timeout_minutes", "order", "create_gloo_process_groups" + } + for key, value in init_kwargs.items(): + # Skip unsupported parameters + if key not in supported_params: + continue + # Keep None for parameters that can be None if value is not None or key in [ "virtual_pipeline_model_parallel_size", "pipeline_model_parallel_comm_backend", - "hierarchical_context_parallel_sizes", "expert_tensor_parallel_size", "nccl_communicator_config_path", - "high_priority_stream_groups" + "hierarchical_context_parallel_sizes", "expert_tensor_parallel_size" ]: filtered_kwargs[key] = value diff --git a/tests/unit/utils/test_parallel_state_deepspeed.py b/tests/unit/utils/test_parallel_state_deepspeed.py new file mode 100644 index 000000000000..d77793716382 --- /dev/null +++ b/tests/unit/utils/test_parallel_state_deepspeed.py @@ -0,0 +1,459 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Integration tests for using ParallelState as mpu in deepspeed.initialize() + +Tests the full workflow: +1. Initialize parallel_state_deepspeed with parallel configurations +2. Pass it as mpu parameter to deepspeed.initialize() +3. Verify DeepSpeed Engine correctly uses the parallel state +""" + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +class TestParallelStateAsMPU(DistributedTest): + """Test parallel_state_deepspeed as mpu parameter in deepspeed.initialize()""" + + world_size = 8 + + def _get_base_config(self): + """Get base DeepSpeed config""" + return {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + + def _verify_mpu_integration(self, engine, mpu, expected_tp=1, expected_pp=1, expected_sp=1): + """Verify mpu is correctly integrated in engine""" + # 1. Engine holds mpu reference + assert engine.mpu == mpu + + # 2. Parallel configuration is correct + assert mpu.get_tensor_model_parallel_world_size() == expected_tp + assert mpu.get_pipeline_model_parallel_world_size() == expected_pp + + # 3. Data parallel world size is correctly calculated + world_size = dist.get_world_size() + expected_dp = world_size // (expected_tp * expected_pp * expected_sp) + assert mpu.get_data_parallel_world_size() == expected_dp + + # 4. Config uses mpu for world_size calculation + assert engine.config.world_size == expected_dp + + return expected_dp + + def test_basic_mpu_usage(self): + """Test basic mpu parameter usage with TP and PP""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_basic") + state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + # Pass parallel_state module as mpu (the module provides compatibility layer) + with ps.set_current_parallel_state("test_basic"): + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify integration + with ps.set_current_parallel_state("test_basic"): + self._verify_mpu_integration(engine, ps, expected_tp=2, expected_pp=2) + + # Verify optimizer is created + assert optimizer is not None + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + assert loss is not None + engine.backward(loss) + engine.step() + + def test_config_driven_mpu(self): + """Test mpu initialized from config with sequence_parallel_size""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 8, + "sequence_parallel_size": 2, + "order": "tp-sp-dp-pp", # Need to specify order when using sp + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Initialize from config + ps.initialize_parallel_state_from_config(config, name="config_driven_test") + + model = SimpleModel(hidden_dim=16) + + # Set current instance + ps.set_current_parallel_state("config_driven_test") + + engine, _, _, _ = deepspeed.initialize(model=model, config=config, mpu=ps, model_parameters=model.parameters()) + + # Verify SP group is created + with ps.set_current_parallel_state("config_driven_test"): + sp_world_size = ps.get_sequence_parallel_world_size() + assert sp_world_size == 2 + + # Verify integration + with ps.set_current_parallel_state("config_driven_test"): + self._verify_mpu_integration(engine, ps, expected_sp=2) + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_multi_instance_mpu(self): + """Test multiple named instances as mpu (Actor-Critic scenario)""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Initialize Actor with TP=2 + actor_state = ps.get_parallel_state_instance("actor") + actor_state.initialize_model_parallel(tensor_model_parallel_size=2) + + # Initialize Critic with TP=1 (no parallelism) + critic_state = ps.get_parallel_state_instance("critic") + critic_state.initialize_model_parallel(tensor_model_parallel_size=1) + + config = self._get_base_config() + + # Create Actor engine + actor_model = SimpleModel(hidden_dim=16) + with ps.set_current_parallel_state("actor"): + actor_engine, _, _, _ = deepspeed.initialize(model=actor_model, + config=config, + mpu=ps, + model_parameters=actor_model.parameters()) + + # Create Critic engine + critic_model = SimpleModel(hidden_dim=16) + with ps.set_current_parallel_state("critic"): + critic_engine, _, _, _ = deepspeed.initialize(model=critic_model, + config=config, + mpu=ps, + model_parameters=critic_model.parameters()) + + # Verify Actor uses TP=2 + with ps.set_current_parallel_state("actor"): + assert ps.get_tensor_model_parallel_world_size() == 2 + assert actor_engine.mpu == ps + + # Verify Critic uses TP=1 + with ps.set_current_parallel_state("critic"): + assert ps.get_tensor_model_parallel_world_size() == 1 + assert critic_engine.mpu == ps + + # Test training for 5 batches on both engines + actor_loader = random_dataloader(model=actor_engine.module, total_samples=20, hidden_dim=16, device=actor_engine.device) + critic_loader = random_dataloader(model=critic_engine.module, total_samples=20, hidden_dim=16, device=critic_engine.device) + for i, (actor_batch, critic_batch) in enumerate(zip(actor_loader, critic_loader)): + if i >= 5: + break + actor_loss = actor_engine(actor_batch[0], actor_batch[1]) + assert actor_loss is not None + actor_engine.backward(actor_loss) + actor_engine.step() + + critic_loss = critic_engine(critic_batch[0], critic_batch[1]) + assert critic_loss is not None + critic_engine.backward(critic_loss) + critic_engine.step() + + def test_mpu_with_zero_stage1(self): + """Test mpu integration with ZeRO Stage 1""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_zero") + state.initialize_model_parallel(tensor_model_parallel_size=2) + + config = { + "train_batch_size": 8, + "zero_optimization": { + "stage": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_zero"): + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify ZeRO optimizer is created + assert optimizer is not None + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + assert isinstance(optimizer, DeepSpeedZeroOptimizer) + + # Verify mpu integration + with ps.set_current_parallel_state("test_zero"): + self._verify_mpu_integration(engine, ps, expected_tp=2) + + # Verify optimizer uses correct DP group + assert optimizer.mpu == ps + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_deepspeed_config_uses_mpu(self): + """Test DeepSpeedConfig correctly uses mpu for world_size calculation""" + from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.runtime.config import DeepSpeedConfig + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_config") + state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + + config_dict = self._get_base_config() + + # Create DeepSpeedConfig with mpu + with ps.set_current_parallel_state("test_config"): + ds_config = DeepSpeedConfig(config_dict, mpu=ps) + + # Verify world_size calculation uses mpu + expected_dp = dist.get_world_size() // (2 * 2) + assert ds_config.world_size == expected_dp + + # Verify it matches mpu's calculation + with ps.set_current_parallel_state("test_config"): + assert ds_config.world_size == ps.get_data_parallel_world_size() + + def test_mpu_without_parallelism(self): + """Test mpu with all parallelism dimensions = 1 (no parallelism)""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_no_parallel") + state.initialize_model_parallel() + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_no_parallel"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify all dimensions are 1 + with ps.set_current_parallel_state("test_no_parallel"): + assert ps.get_tensor_model_parallel_world_size() == 1 + assert ps.get_pipeline_model_parallel_world_size() == 1 + + # DP should equal world_size + assert ps.get_data_parallel_world_size() == dist.get_world_size() + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + + def test_mpu_with_different_orders(self): + """Test mpu with different parallel dimension orders""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_order") + state.initialize_model_parallel(tensor_model_parallel_size=2, + expert_model_parallel_size=2, + order="tp-ep-dp-pp") + + config = self._get_base_config() + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_order"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify parallel configuration + with ps.set_current_parallel_state("test_order"): + assert ps.get_tensor_model_parallel_world_size() == 2 + assert ps.get_expert_model_parallel_world_size() == 2 + + # Verify DP world_size: world_size / (tp * ep) + expected_dp = dist.get_world_size() // (2 * 2) + assert ps.get_data_parallel_world_size() == expected_dp + + +class TestParallelStateConfigPriority(DistributedTest): + """Test configuration priority: params > config > defaults""" + + world_size = 4 + + def test_param_overrides_config(self): + """Function parameter should override config value""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 4, + "sequence_parallel_size": 2, # Config says 2 + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Override with param: sp=1 + ps.initialize_parallel_state_from_config( + config, + name="param_override_test", + sequence_parallel_size=1 # Parameter overrides config + ) + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("param_override_test"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # With sp=1, SP group should not have special effect + assert engine is not None + assert engine.mpu == ps + + def test_config_overrides_default(self): + """Config value should override default value""" + from deepspeed.utils import parallel_state_deepspeed as ps + + config = { + "train_batch_size": 4, + "sequence_parallel_size": 2, # Override default (1) + "order": "tp-sp-dp-pp", # Need to specify order when using sp + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001 + } + } + } + + # Don't pass sequence_parallel_size parameter + ps.initialize_parallel_state_from_config(config, name="config_override_test") + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("config_override_test"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Verify SP is configured from config + # Since sp_size = 2, SP group should be initialized + with ps.set_current_parallel_state("config_override_test"): + sp_world_size = ps.get_sequence_parallel_world_size() + assert sp_world_size == 2 + + +class TestParallelStateValidation(DistributedTest): + """Test validation and error handling""" + + world_size = 4 + + def test_context_parallel_not_supported(self): + """Test that CP > 1 raises NotImplementedError""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # CP > 1 should raise error via initialize_parallel_state_from_config + with pytest.raises(NotImplementedError, match="does not support context_parallel_size"): + ps.initialize_parallel_state_from_config({"context_parallel_size": 2}, name="cp_test") + + def test_hierarchical_cp_not_supported(self): + """Test that hierarchical CP raises NotImplementedError""" + from deepspeed.utils import parallel_state_deepspeed as ps + + with pytest.raises(NotImplementedError, match="does not support hierarchical_context_parallel_sizes"): + ps.initialize_parallel_state_from_config({"hierarchical_context_parallel_sizes": [2, 2]}, name="hcp_test") + + +class TestAllToAllGroupsWithMPU(DistributedTest): + """Test All-to-All groups initialization with mpu""" + + world_size = 8 + + def test_all_to_all_groups_with_mpu(self): + """Test All-to-All groups work with mpu in initialize""" + from deepspeed.utils import parallel_state_deepspeed as ps + + # Use named instance to avoid test interference + state = ps.get_parallel_state_instance("test_all_to_all") + state.initialize_model_parallel() + + config = {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + + model = SimpleModel(hidden_dim=16) + + with ps.set_current_parallel_state("test_all_to_all"): + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=ps, + model_parameters=model.parameters()) + + # Initialize All-to-All groups + with ps.set_current_parallel_state("test_all_to_all"): + all_to_all_groups = ps.initialize_all_to_all_groups() + + # Verify groups are created + assert isinstance(all_to_all_groups, dict) + assert len(all_to_all_groups) > 0 + + # Test backward compatibility interface + with ps.set_current_parallel_state("test_all_to_all"): + compat_groups = ps._get_local_all_to_all_group() + assert compat_groups == all_to_all_groups + + # Test training for 5 batches + data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + for i, batch in enumerate(data_loader): + if i >= 5: + break + loss = engine(batch[0], batch[1]) + assert loss is not None + engine.backward(loss) + engine.step() From 537a8993e8a5b72b8b033c630f69ac7c767245fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 2 Mar 2026 19:21:54 +0800 Subject: [PATCH 17/23] fix: improve initialize_parallel_state_from_config for nested config and usability - Support nested dict config keys via dot-separated paths (e.g. "tensor_parallel.autotp_size") so autotp tp_size can be resolved from config automatically - Allow config_key to be a list of candidates tried in order - Remove unused param_name argument from get_value helper - Return the ParallelState instance so callers can use it as mpu directly, e.g. ps = initialize_parallel_state_from_config(config) Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 151 +++++++++----------- 1 file changed, 70 insertions(+), 81 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index e1342021ee5c..eb768fd83815 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -720,24 +720,25 @@ def initialize_parallel_state_from_config( order: Optional[str] = None, create_gloo_process_groups: Optional[bool] = None, high_priority_stream_groups: Optional[List[str]] = None, -) -> None: +) -> ParallelState: """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. - This function reads parallelism configuration from the DeepSpeed config file - (top-level fields) and automatically initializes the ParallelState instance. - This allows code to work with both explicit initialization and config-based initialization. + Reads parallelism configuration from the DeepSpeed config (including nested dicts) + and initializes the ParallelState instance. Returns the instance so it can be used + directly as the ``mpu`` argument to ``deepspeed.initialize``. - Configuration priority: function parameters > config file values > default values (1) + Configuration priority: function parameters > config file values > default values + + Config keys support dot-separated paths for nested dicts. For example, + ``tensor_model_parallel_size`` is resolved from either the top-level key + ``"tensor_model_parallel_size"`` or the nested key ``"tensor_parallel.autotp_size"``. Args: config: Either a DeepSpeedConfig object or a config dictionary. - If DeepSpeedConfig, will access its _param_dict attribute. - If dict, will use it directly. name: Optional name of the parallel state instance to initialize. If None, initializes the default global instance. - - # Parallelism dimension parameters (override config if provided): - tensor_model_parallel_size: Size of tensor model parallel group. Default: 1 + tensor_model_parallel_size: Size of tensor model parallel group. Default: 1. + Also read from ``tensor_parallel.autotp_size`` in config. pipeline_model_parallel_size: Size of pipeline model parallel group. Default: 1 virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. Default: None pipeline_model_parallel_comm_backend: Communication backend for pipeline. Default: None @@ -753,46 +754,35 @@ def initialize_parallel_state_from_config( create_gloo_process_groups: Whether to create Gloo process groups. Default: False high_priority_stream_groups: High priority stream groups. Default: None - Example config.json (using existing DeepSpeed config fields): - { - "train_batch_size": 8, - "sequence_parallel_size": 1, - "zero_optimization": { - "stage": 1 - } - } + Returns: + The initialized (or already-initialized) ParallelState instance. - Note: - - Currently only "sequence_parallel_size" can be read from config (existing field) - - Other parallelism parameters must be passed via function parameters or use defaults - - Context Parallel is NOT supported (cp must be 1) + Example usage:: - Example usage: - # Basic usage from config file: - from deepspeed import DeepSpeedConfig - ds_config = DeepSpeedConfig("config.json") - initialize_parallel_state_from_config(ds_config) + # Use return value as mpu: + ps = initialize_parallel_state_from_config(config_dict) + model, optimizer, _, _ = deepspeed.initialize( + model=model, model_parameters=model.parameters(), + config=config_dict, mpu=ps) + + # AutoTP config (nested dict): + config_dict = { + "tensor_parallel": {"autotp_size": 4}, + ... + } + ps = initialize_parallel_state_from_config(config_dict) # Override specific parameters: - initialize_parallel_state_from_config( + ps = initialize_parallel_state_from_config( ds_config, - tensor_model_parallel_size=4, # Override config value + tensor_model_parallel_size=4, expert_model_parallel_size=2 ) - # From config dictionary: - import json - with open("config.json") as f: - config_dict = json.load(f) - initialize_parallel_state_from_config(config_dict) - - # For named instances (RL scenarios): - initialize_parallel_state_from_config(ds_config, name="actor") - initialize_parallel_state_from_config( - ds_config, - name="critic", - tensor_model_parallel_size=2 # Override for critic - ) + # Named instances (RL scenarios): + actor_ps = initialize_parallel_state_from_config(ds_config, name="actor") + critic_ps = initialize_parallel_state_from_config( + ds_config, name="critic", tensor_model_parallel_size=2) """ # Extract config dictionary if hasattr(config, '_param_dict'): @@ -807,75 +797,74 @@ def initialize_parallel_state_from_config( # Get the parallel state instance ps = get_parallel_state_instance(name) - # Check if already initialized if ps.is_initialized(): - # Already initialized, skip - return + return ps # Import logging import logging logger = logging.getLogger(__name__) - # Helper function to get value with proper priority handling - # Priority: function parameter > config file value > default value - def get_value(param_name, param_value, config_key, default_value): + def _resolve_nested_key(d, dotted_key): + """Resolve a dot-separated key path in a nested dict. Returns (found, value).""" + keys = dotted_key.split(".") + cur = d + for k in keys: + if isinstance(cur, dict) and k in cur: + cur = cur[k] + else: + return False, None + return True, cur + + def get_value(param_value, config_key, default_value): """ - Get value with priority handling. + Get value with priority: function parameter > config value > default. - Priority: - 1. If function parameter is provided -> use parameter value - 2. If config file has the value -> use config value - 3. Otherwise -> use default value + config_key can be a single dot-separated string (e.g. "tensor_parallel.autotp_size") + or a list of candidate keys tried in order. """ - # Case 1: Function parameter provided if param_value is not None: return param_value - # Case 2: Config file has the key - if config_key in config_dict: - config_value = config_dict[config_key] - return config_value + candidates = config_key if isinstance(config_key, (list, tuple)) else [config_key] + for key in candidates: + found, value = _resolve_nested_key(config_dict, key) + if found: + return value - # Case 3: Use default return default_value - # Extract parameters with proper priority: function param > config value > default init_kwargs = { "tensor_model_parallel_size": - get_value("tensor_model_parallel_size", tensor_model_parallel_size, "tensor_model_parallel_size", 1), + get_value(tensor_model_parallel_size, + ["tensor_model_parallel_size", "tensor_parallel.autotp_size"], 1), "pipeline_model_parallel_size": - get_value("pipeline_model_parallel_size", pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), + get_value(pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), "virtual_pipeline_model_parallel_size": - get_value("virtual_pipeline_model_parallel_size", virtual_pipeline_model_parallel_size, - "virtual_pipeline_model_parallel_size", None), + get_value(virtual_pipeline_model_parallel_size, "virtual_pipeline_model_parallel_size", None), "pipeline_model_parallel_comm_backend": - get_value("pipeline_model_parallel_comm_backend", pipeline_model_parallel_comm_backend, - "pipeline_model_parallel_comm_backend", None), + get_value(pipeline_model_parallel_comm_backend, "pipeline_model_parallel_comm_backend", None), "context_parallel_size": - get_value("context_parallel_size", context_parallel_size, "context_parallel_size", 1), + get_value(context_parallel_size, "context_parallel_size", 1), "sequence_parallel_size": - get_value("sequence_parallel_size", sequence_parallel_size, "sequence_parallel_size", 1), + get_value(sequence_parallel_size, "sequence_parallel_size", 1), "hierarchical_context_parallel_sizes": - get_value("hierarchical_context_parallel_sizes", hierarchical_context_parallel_sizes, - "hierarchical_context_parallel_sizes", None), + get_value(hierarchical_context_parallel_sizes, "hierarchical_context_parallel_sizes", None), "expert_model_parallel_size": - get_value("expert_model_parallel_size", expert_model_parallel_size, "expert_model_parallel_size", 1), + get_value(expert_model_parallel_size, "expert_model_parallel_size", 1), "num_distributed_optimizer_instances": - get_value("num_distributed_optimizer_instances", num_distributed_optimizer_instances, - "num_distributed_optimizer_instances", 1), + get_value(num_distributed_optimizer_instances, "num_distributed_optimizer_instances", 1), "expert_tensor_parallel_size": - get_value("expert_tensor_parallel_size", expert_tensor_parallel_size, "expert_tensor_parallel_size", None), + get_value(expert_tensor_parallel_size, "expert_tensor_parallel_size", None), "nccl_communicator_config_path": - get_value("nccl_communicator_config_path", nccl_communicator_config_path, "nccl_communicator_config_path", - None), + get_value(nccl_communicator_config_path, "nccl_communicator_config_path", None), "distributed_timeout_minutes": - get_value("distributed_timeout_minutes", distributed_timeout_minutes, "distributed_timeout_minutes", 30), + get_value(distributed_timeout_minutes, "distributed_timeout_minutes", 30), "order": - get_value("order", order, "order", "tp-ep-dp-pp"), + get_value(order, "order", "tp-ep-dp-pp"), "create_gloo_process_groups": - get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", False), + get_value(create_gloo_process_groups, "create_gloo_process_groups", False), "high_priority_stream_groups": - get_value("high_priority_stream_groups", high_priority_stream_groups, "high_priority_stream_groups", None), + get_value(high_priority_stream_groups, "high_priority_stream_groups", None), } # Validate context_parallel_size @@ -915,5 +904,5 @@ def get_value(param_name, param_value, config_key, default_value): ]: filtered_kwargs[key] = value - # Initialize parallel state ps.initialize_model_parallel(**filtered_kwargs) + return ps From 856ed67cb87eb9da69352e928c590a87465e3099 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 2 Mar 2026 19:30:54 +0800 Subject: [PATCH 18/23] docs: add parallel state management documentation - Add "Built-in Parallel State Management" section to training.md covering basic usage, config-based initialization, multi-instance support for RL scenarios, and backward compatibility - Add "Parallel State Initialization" section to initialize.rst with API references for initialize_parallel_state_from_config and ParallelState class Signed-off-by: Jikang Mo --- docs/_pages/training.md | 94 ++++++++++++++++++++++++++++ docs/code-docs/source/initialize.rst | 38 +++++++++++ 2 files changed, 132 insertions(+) diff --git a/docs/_pages/training.md b/docs/_pages/training.md index e31651cc487a..bdae8b563807 100644 --- a/docs/_pages/training.md +++ b/docs/_pages/training.md @@ -244,6 +244,100 @@ mpu.get_data_parallel_group() mpu.get_data_parallel_world_size() ``` +### Built-in Parallel State Management + +DeepSpeed provides a built-in `ParallelState` class that implements the `mpu` interface +with Megatron-style process group management. It supports tensor parallelism (TP), +pipeline parallelism (PP), data parallelism (DP), sequence parallelism (SP), +context parallelism (CP), and expert parallelism (EP). + +#### Basic Usage + +You can initialize the parallel state either explicitly or from a DeepSpeed config: + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Option 1: Initialize from config dict (also works with DeepSpeedConfig objects) +config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": {"autotp_size": 4}, + "zero_optimization": {"stage": 1} +} +parallel_state = ps.initialize_parallel_state_from_config(config_dict) + +# The returned ParallelState can be passed directly as mpu +model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=parallel_state +) +``` + +```python +# Option 2: Initialize explicitly with parallelism dimensions +parallel_state = ps.get_parallel_state_instance() +parallel_state.initialize_model_parallel( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=2, + sequence_parallel_size=1, +) +``` + +#### Configuration-based Initialization + +`initialize_parallel_state_from_config` resolves parallelism parameters with +the following priority: **function parameters > config values > defaults**. + +Config keys support dot-separated paths for nested dictionaries. For example, +`tensor_model_parallel_size` can be read from `"tensor_model_parallel_size"` at +the top level or `"tensor_parallel.autotp_size"` in a nested config. + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Override specific parameters while reading others from config +parallel_state = ps.initialize_parallel_state_from_config( + config_dict, + tensor_model_parallel_size=4, # Override config value + expert_model_parallel_size=2, # Override config value +) +``` + +#### Multiple Instances (RL Scenarios) + +In reinforcement learning scenarios where multiple models (e.g., actor and critic) +require different parallelism configurations, you can create named instances: + +```python +from deepspeed.utils import parallel_state_deepspeed as ps + +# Create separate parallel state instances +actor_ps = ps.initialize_parallel_state_from_config( + actor_config, name="actor", + tensor_model_parallel_size=4, +) +critic_ps = ps.initialize_parallel_state_from_config( + critic_config, name="critic", + tensor_model_parallel_size=2, +) + +# Use context manager to switch between instances +with ps.set_current_parallel_state("actor"): + dp_group = ps.get_data_parallel_group() # Uses actor's groups + +with ps.set_current_parallel_state("critic"): + dp_group = ps.get_data_parallel_group() # Uses critic's groups +``` + +#### Compatibility with Existing Code + +The module-level functions in `parallel_state_deepspeed` (such as +`get_data_parallel_group()`, `get_tensor_model_parallel_world_size()`, etc.) +operate on the current active `ParallelState` instance, preserving backward +compatibility with code written against the previous `groups.py` API. + ### Integration with Megatron-LM DeepSpeed is fully compatible with [Megatron](https://github.com/NVIDIA/Megatron-LM). Please see the [Megatron-LM tutorial](/tutorials/megatron/) for details. diff --git a/docs/code-docs/source/initialize.rst b/docs/code-docs/source/initialize.rst index dd69a5dec4d2..172376043229 100644 --- a/docs/code-docs/source/initialize.rst +++ b/docs/code-docs/source/initialize.rst @@ -42,3 +42,41 @@ Distributed Initialization Optional distributed backend initialization separate from ``deepspeed.initialize()``. Useful in scenarios where the user wants to use torch distributed calls before calling ``deepspeed.initialize()``, such as when using model parallelism, pipeline parallelism, or certain data loader scenarios. .. autofunction:: deepspeed.init_distributed + + +.. _parallel-state-init: + +Parallel State Initialization +----------------------------- +DeepSpeed provides a built-in ``ParallelState`` class for Megatron-style process group management +covering tensor, pipeline, data, sequence, context, and expert parallelism. + +Use ``initialize_parallel_state_from_config`` to create and initialize a ``ParallelState`` from +a DeepSpeed config dictionary (or ``DeepSpeedConfig`` object). The returned instance implements +the ``mpu`` interface and can be passed directly to ``deepspeed.initialize(mpu=...)``. + +Example usage: + +.. code-block:: python + + from deepspeed.utils import parallel_state_deepspeed as ps + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": {"autotp_size": 4}, + } + + # Initialize and use as mpu + parallel_state = ps.initialize_parallel_state_from_config(config_dict) + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=parallel_state, + ) + +.. autofunction:: deepspeed.utils.parallel_state_deepspeed.initialize_parallel_state_from_config + +.. autoclass:: deepspeed.utils.parallel_state.ParallelState + :members: initialize_model_parallel, is_initialized, get_tensor_model_parallel_group, get_data_parallel_group, get_pipeline_model_parallel_group, get_sequence_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_data_parallel_world_size, get_data_parallel_rank, get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_rank + :noindex: From 1084757458731cbe507db268ec20c2358e503e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Wed, 4 Mar 2026 14:47:01 +0800 Subject: [PATCH 19/23] Fix ParallelState mpu compatibility and refactor unit tests Add DeepSpeed mpu compatibility aliases to ParallelState class: - get_model_parallel_world_size/rank: alias for tensor model parallel - get_tensor_model_parallel_src_rank: compute first global rank in TP group - get_data_parallel_group_ranks: expose DP global ranks - get_sequence_data_parallel_group/world_size/rank: fall back to DP group when sequence parallelism is not initialized, fixing the assertion error 'sequence and data parallel group is not initialized' Refactor test_parallel_state_deepspeed.py for CI compatibility: - Reduce world_size from 8 to 4 to match upstream CI hardware - Pass ParallelState instance directly as mpu instead of using the parallel_state_deepspeed module with set_current_parallel_state - Replace non-standard config keys (sequence_parallel_size, order) with standard ones (tensor_parallel.autotp_size) - Use train_micro_batch_size_per_gpu instead of train_batch_size - Extract common training loop into _train_steps helper method Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 41 +- .../utils/test_parallel_state_deepspeed.py | 382 ++++++------------ 2 files changed, 159 insertions(+), 264 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 93bf297d0ba0..5d213c369787 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -991,16 +991,53 @@ def get_sequence_and_data_parallel_world_size(self): """Return world size for the sequence and data parallel group.""" if dist.is_available() and dist.is_initialized(): if self.sequence_and_data_parallel_group is not None: - return self.get_sequence_and_data_parallel_group().size() + return self.sequence_and_data_parallel_group.size() + return self.get_data_parallel_world_size() return 0 def get_sequence_and_data_parallel_rank(self): """Return caller's rank in the sequence and data parallel group.""" if dist.is_available() and dist.is_initialized(): if self.sequence_and_data_parallel_group is not None: - return self.get_sequence_and_data_parallel_group().rank() + return self.sequence_and_data_parallel_group.rank() + return self.get_data_parallel_rank() return 0 + # ---- DeepSpeed mpu compatibility aliases ---- + # groups.py and engine.py call these methods on the mpu object. + + def get_model_parallel_world_size(self): + return self.get_tensor_model_parallel_world_size() + + def get_model_parallel_rank(self): + return self.get_tensor_model_parallel_rank() + + def get_tensor_model_parallel_src_rank(self): + """Global rank corresponding to the first local rank in the TP group.""" + global_rank = dist.get_rank() + local_rank = self.get_tensor_model_parallel_rank() + return global_rank - local_rank + + def get_data_parallel_group_ranks(self): + return self.data_parallel_global_ranks + + def get_sequence_data_parallel_group(self): + if self.sequence_and_data_parallel_group is not None: + return self.sequence_and_data_parallel_group + return self.get_data_parallel_group() + + def get_sequence_data_parallel_world_size(self): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_world_size() + return self.get_data_parallel_world_size() + + def get_sequence_data_parallel_rank(self): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_rank() + return self.get_data_parallel_rank() + + # ---- end compatibility aliases ---- + def is_initialized(self): """Check if parallel state has been initialized""" return self.data_parallel_group is not None diff --git a/tests/unit/utils/test_parallel_state_deepspeed.py b/tests/unit/utils/test_parallel_state_deepspeed.py index d77793716382..4356d2872001 100644 --- a/tests/unit/utils/test_parallel_state_deepspeed.py +++ b/tests/unit/utils/test_parallel_state_deepspeed.py @@ -6,8 +6,8 @@ Integration tests for using ParallelState as mpu in deepspeed.initialize() Tests the full workflow: -1. Initialize parallel_state_deepspeed with parallel configurations -2. Pass it as mpu parameter to deepspeed.initialize() +1. Initialize ParallelState with parallel configurations +2. Pass the ParallelState instance as mpu parameter to deepspeed.initialize() 3. Verify DeepSpeed Engine correctly uses the parallel state """ @@ -18,180 +18,124 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader +DTYPE = torch.float + class TestParallelStateAsMPU(DistributedTest): - """Test parallel_state_deepspeed as mpu parameter in deepspeed.initialize()""" + """Test ParallelState instance as mpu parameter in deepspeed.initialize()""" - world_size = 8 + world_size = 4 def _get_base_config(self): - """Get base DeepSpeed config""" - return {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + return {"train_micro_batch_size_per_gpu": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} - def _verify_mpu_integration(self, engine, mpu, expected_tp=1, expected_pp=1, expected_sp=1): - """Verify mpu is correctly integrated in engine""" - # 1. Engine holds mpu reference + def _verify_mpu_integration(self, engine, mpu, expected_tp=1, expected_pp=1): assert engine.mpu == mpu - - # 2. Parallel configuration is correct assert mpu.get_tensor_model_parallel_world_size() == expected_tp assert mpu.get_pipeline_model_parallel_world_size() == expected_pp - # 3. Data parallel world size is correctly calculated world_size = dist.get_world_size() - expected_dp = world_size // (expected_tp * expected_pp * expected_sp) + expected_dp = world_size // (expected_tp * expected_pp) assert mpu.get_data_parallel_world_size() == expected_dp - - # 4. Config uses mpu for world_size calculation - assert engine.config.world_size == expected_dp - - return expected_dp + assert engine._config.world_size == expected_dp + + def _train_steps(self, engine, steps=3): + data_loader = random_dataloader(model=engine, + total_samples=10, + hidden_dim=16, + device=engine.device, + dtype=DTYPE) + for i, batch in enumerate(data_loader): + if i >= steps: + break + loss = engine(batch[0], batch[1]) + assert loss is not None + engine.backward(loss) + engine.step() def test_basic_mpu_usage(self): - """Test basic mpu parameter usage with TP and PP""" + """Test basic TP with ParallelState instance as mpu""" from deepspeed.utils import parallel_state_deepspeed as ps - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_basic") - state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + state.initialize_model_parallel(tensor_model_parallel_size=2) config = self._get_base_config() model = SimpleModel(hidden_dim=16) - # Pass parallel_state module as mpu (the module provides compatibility layer) - with ps.set_current_parallel_state("test_basic"): - engine, optimizer, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) - - # Verify integration - with ps.set_current_parallel_state("test_basic"): - self._verify_mpu_integration(engine, ps, expected_tp=2, expected_pp=2) + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=state, + model_parameters=model.parameters()) - # Verify optimizer is created + self._verify_mpu_integration(engine, state, expected_tp=2) assert optimizer is not None - - # Test training for 5 batches - data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) - for i, batch in enumerate(data_loader): - if i >= 5: - break - loss = engine(batch[0], batch[1]) - assert loss is not None - engine.backward(loss) - engine.step() + self._train_steps(engine) def test_config_driven_mpu(self): - """Test mpu initialized from config with sequence_parallel_size""" + """Test mpu initialized from config with tensor_model_parallel_size""" from deepspeed.utils import parallel_state_deepspeed as ps - config = { - "train_batch_size": 8, - "sequence_parallel_size": 2, - "order": "tp-sp-dp-pp", # Need to specify order when using sp - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001 - } - } + parallel_config = { + "tensor_parallel": { + "autotp_size": 2 + }, } - # Initialize from config - ps.initialize_parallel_state_from_config(config, name="config_driven_test") + state = ps.initialize_parallel_state_from_config(parallel_config, name="config_driven_test") + engine_config = self._get_base_config() model = SimpleModel(hidden_dim=16) + engine, _, _, _ = deepspeed.initialize(model=model, + config=engine_config, + mpu=state, + model_parameters=model.parameters()) - # Set current instance - ps.set_current_parallel_state("config_driven_test") - - engine, _, _, _ = deepspeed.initialize(model=model, config=config, mpu=ps, model_parameters=model.parameters()) - - # Verify SP group is created - with ps.set_current_parallel_state("config_driven_test"): - sp_world_size = ps.get_sequence_parallel_world_size() - assert sp_world_size == 2 - - # Verify integration - with ps.set_current_parallel_state("config_driven_test"): - self._verify_mpu_integration(engine, ps, expected_sp=2) - - # Test training for 5 batches - data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) - for i, batch in enumerate(data_loader): - if i >= 5: - break - loss = engine(batch[0], batch[1]) - engine.backward(loss) - engine.step() + self._verify_mpu_integration(engine, state, expected_tp=2) + self._train_steps(engine) def test_multi_instance_mpu(self): """Test multiple named instances as mpu (Actor-Critic scenario)""" from deepspeed.utils import parallel_state_deepspeed as ps - # Initialize Actor with TP=2 actor_state = ps.get_parallel_state_instance("actor") actor_state.initialize_model_parallel(tensor_model_parallel_size=2) - # Initialize Critic with TP=1 (no parallelism) critic_state = ps.get_parallel_state_instance("critic") critic_state.initialize_model_parallel(tensor_model_parallel_size=1) config = self._get_base_config() - # Create Actor engine actor_model = SimpleModel(hidden_dim=16) - with ps.set_current_parallel_state("actor"): - actor_engine, _, _, _ = deepspeed.initialize(model=actor_model, - config=config, - mpu=ps, - model_parameters=actor_model.parameters()) + actor_engine, _, _, _ = deepspeed.initialize(model=actor_model, + config=config, + mpu=actor_state, + model_parameters=actor_model.parameters()) - # Create Critic engine critic_model = SimpleModel(hidden_dim=16) - with ps.set_current_parallel_state("critic"): - critic_engine, _, _, _ = deepspeed.initialize(model=critic_model, - config=config, - mpu=ps, - model_parameters=critic_model.parameters()) - - # Verify Actor uses TP=2 - with ps.set_current_parallel_state("actor"): - assert ps.get_tensor_model_parallel_world_size() == 2 - assert actor_engine.mpu == ps - - # Verify Critic uses TP=1 - with ps.set_current_parallel_state("critic"): - assert ps.get_tensor_model_parallel_world_size() == 1 - assert critic_engine.mpu == ps - - # Test training for 5 batches on both engines - actor_loader = random_dataloader(model=actor_engine.module, total_samples=20, hidden_dim=16, device=actor_engine.device) - critic_loader = random_dataloader(model=critic_engine.module, total_samples=20, hidden_dim=16, device=critic_engine.device) - for i, (actor_batch, critic_batch) in enumerate(zip(actor_loader, critic_loader)): - if i >= 5: - break - actor_loss = actor_engine(actor_batch[0], actor_batch[1]) - assert actor_loss is not None - actor_engine.backward(actor_loss) - actor_engine.step() + critic_engine, _, _, _ = deepspeed.initialize(model=critic_model, + config=config, + mpu=critic_state, + model_parameters=critic_model.parameters()) + + assert actor_state.get_tensor_model_parallel_world_size() == 2 + assert actor_engine.mpu == actor_state + + assert critic_state.get_tensor_model_parallel_world_size() == 1 + assert critic_engine.mpu == critic_state - critic_loss = critic_engine(critic_batch[0], critic_batch[1]) - assert critic_loss is not None - critic_engine.backward(critic_loss) - critic_engine.step() + self._train_steps(actor_engine) + self._train_steps(critic_engine) def test_mpu_with_zero_stage1(self): """Test mpu integration with ZeRO Stage 1""" from deepspeed.utils import parallel_state_deepspeed as ps - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_zero") state.initialize_model_parallel(tensor_model_parallel_size=2) config = { - "train_batch_size": 8, + "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 1 }, @@ -204,96 +148,57 @@ def test_mpu_with_zero_stage1(self): } model = SimpleModel(hidden_dim=16) + engine, optimizer, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=state, + model_parameters=model.parameters()) - with ps.set_current_parallel_state("test_zero"): - engine, optimizer, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) - - # Verify ZeRO optimizer is created assert optimizer is not None from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer assert isinstance(optimizer, DeepSpeedZeroOptimizer) - # Verify mpu integration - with ps.set_current_parallel_state("test_zero"): - self._verify_mpu_integration(engine, ps, expected_tp=2) - - # Verify optimizer uses correct DP group - assert optimizer.mpu == ps - - # Test training for 5 batches - data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) - for i, batch in enumerate(data_loader): - if i >= 5: - break - loss = engine(batch[0], batch[1]) - engine.backward(loss) - engine.step() + self._verify_mpu_integration(engine, state, expected_tp=2) + self._train_steps(engine) def test_deepspeed_config_uses_mpu(self): """Test DeepSpeedConfig correctly uses mpu for world_size calculation""" from deepspeed.utils import parallel_state_deepspeed as ps from deepspeed.runtime.config import DeepSpeedConfig - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_config") - state.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) + state.initialize_model_parallel(tensor_model_parallel_size=2) config_dict = self._get_base_config() + ds_config = DeepSpeedConfig(config_dict, mpu=state) - # Create DeepSpeedConfig with mpu - with ps.set_current_parallel_state("test_config"): - ds_config = DeepSpeedConfig(config_dict, mpu=ps) - - # Verify world_size calculation uses mpu - expected_dp = dist.get_world_size() // (2 * 2) + expected_dp = dist.get_world_size() // 2 assert ds_config.world_size == expected_dp - - # Verify it matches mpu's calculation - with ps.set_current_parallel_state("test_config"): - assert ds_config.world_size == ps.get_data_parallel_world_size() + assert ds_config.world_size == state.get_data_parallel_world_size() def test_mpu_without_parallelism(self): """Test mpu with all parallelism dimensions = 1 (no parallelism)""" from deepspeed.utils import parallel_state_deepspeed as ps - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_no_parallel") state.initialize_model_parallel() config = self._get_base_config() model = SimpleModel(hidden_dim=16) + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=state, + model_parameters=model.parameters()) - with ps.set_current_parallel_state("test_no_parallel"): - engine, _, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) + assert state.get_tensor_model_parallel_world_size() == 1 + assert state.get_pipeline_model_parallel_world_size() == 1 + assert state.get_data_parallel_world_size() == dist.get_world_size() - # Verify all dimensions are 1 - with ps.set_current_parallel_state("test_no_parallel"): - assert ps.get_tensor_model_parallel_world_size() == 1 - assert ps.get_pipeline_model_parallel_world_size() == 1 - - # DP should equal world_size - assert ps.get_data_parallel_world_size() == dist.get_world_size() - - # Test training for 5 batches - data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) - for i, batch in enumerate(data_loader): - if i >= 5: - break - loss = engine(batch[0], batch[1]) - engine.backward(loss) - engine.step() + self._train_steps(engine) def test_mpu_with_different_orders(self): - """Test mpu with different parallel dimension orders""" + """Test mpu with custom parallel dimension order""" from deepspeed.utils import parallel_state_deepspeed as ps - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_order") state.initialize_model_parallel(tensor_model_parallel_size=2, expert_model_parallel_size=2, @@ -301,21 +206,17 @@ def test_mpu_with_different_orders(self): config = self._get_base_config() model = SimpleModel(hidden_dim=16) + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=state, + model_parameters=model.parameters()) - with ps.set_current_parallel_state("test_order"): - engine, _, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) + assert state.get_tensor_model_parallel_world_size() == 2 + assert state.get_expert_model_parallel_world_size() == 2 - # Verify parallel configuration - with ps.set_current_parallel_state("test_order"): - assert ps.get_tensor_model_parallel_world_size() == 2 - assert ps.get_expert_model_parallel_world_size() == 2 - - # Verify DP world_size: world_size / (tp * ep) - expected_dp = dist.get_world_size() // (2 * 2) - assert ps.get_data_parallel_world_size() == expected_dp + # EP does not reduce the regular DP world size; DP = world_size / (TP * PP) + expected_dp = dist.get_world_size() // 2 + assert state.get_data_parallel_world_size() == expected_dp class TestParallelStateConfigPriority(DistributedTest): @@ -324,71 +225,38 @@ class TestParallelStateConfigPriority(DistributedTest): world_size = 4 def test_param_overrides_config(self): - """Function parameter should override config value""" + """Function parameter should override nested config value""" from deepspeed.utils import parallel_state_deepspeed as ps config = { - "train_batch_size": 4, - "sequence_parallel_size": 2, # Config says 2 - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001 - } - } + "tensor_parallel": { + "autotp_size": 2 + }, } - # Override with param: sp=1 - ps.initialize_parallel_state_from_config( + state = ps.initialize_parallel_state_from_config( config, name="param_override_test", - sequence_parallel_size=1 # Parameter overrides config + tensor_model_parallel_size=1, ) - model = SimpleModel(hidden_dim=16) - - with ps.set_current_parallel_state("param_override_test"): - engine, _, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) - - # With sp=1, SP group should not have special effect - assert engine is not None - assert engine.mpu == ps + assert state.get_tensor_model_parallel_world_size() == 1 + assert state.get_data_parallel_world_size() == dist.get_world_size() def test_config_overrides_default(self): - """Config value should override default value""" + """Nested config value (tensor_parallel.autotp_size) should override default""" from deepspeed.utils import parallel_state_deepspeed as ps config = { - "train_batch_size": 4, - "sequence_parallel_size": 2, # Override default (1) - "order": "tp-sp-dp-pp", # Need to specify order when using sp - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001 - } - } + "tensor_parallel": { + "autotp_size": 2 + }, } - # Don't pass sequence_parallel_size parameter - ps.initialize_parallel_state_from_config(config, name="config_override_test") - - model = SimpleModel(hidden_dim=16) - - with ps.set_current_parallel_state("config_override_test"): - engine, _, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) + state = ps.initialize_parallel_state_from_config(config, name="config_override_test") - # Verify SP is configured from config - # Since sp_size = 2, SP group should be initialized - with ps.set_current_parallel_state("config_override_test"): - sp_world_size = ps.get_sequence_parallel_world_size() - assert sp_world_size == 2 + assert state.get_tensor_model_parallel_world_size() == 2 + assert state.get_data_parallel_world_size() == dist.get_world_size() // 2 class TestParallelStateValidation(DistributedTest): @@ -400,58 +268,48 @@ def test_context_parallel_not_supported(self): """Test that CP > 1 raises NotImplementedError""" from deepspeed.utils import parallel_state_deepspeed as ps - # CP > 1 should raise error via initialize_parallel_state_from_config with pytest.raises(NotImplementedError, match="does not support context_parallel_size"): - ps.initialize_parallel_state_from_config({"context_parallel_size": 2}, name="cp_test") + ps.initialize_parallel_state_from_config({}, name="cp_test", context_parallel_size=2) def test_hierarchical_cp_not_supported(self): """Test that hierarchical CP raises NotImplementedError""" from deepspeed.utils import parallel_state_deepspeed as ps with pytest.raises(NotImplementedError, match="does not support hierarchical_context_parallel_sizes"): - ps.initialize_parallel_state_from_config({"hierarchical_context_parallel_sizes": [2, 2]}, name="hcp_test") + ps.initialize_parallel_state_from_config({}, name="hcp_test", hierarchical_context_parallel_sizes=[2, 2]) class TestAllToAllGroupsWithMPU(DistributedTest): """Test All-to-All groups initialization with mpu""" - world_size = 8 + world_size = 4 def test_all_to_all_groups_with_mpu(self): """Test All-to-All groups work with mpu in initialize""" from deepspeed.utils import parallel_state_deepspeed as ps - # Use named instance to avoid test interference state = ps.get_parallel_state_instance("test_all_to_all") state.initialize_model_parallel() - config = {"train_batch_size": 8, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} + config = {"train_micro_batch_size_per_gpu": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} model = SimpleModel(hidden_dim=16) + engine, _, _, _ = deepspeed.initialize(model=model, + config=config, + mpu=state, + model_parameters=model.parameters()) - with ps.set_current_parallel_state("test_all_to_all"): - engine, _, _, _ = deepspeed.initialize(model=model, - config=config, - mpu=ps, - model_parameters=model.parameters()) - - # Initialize All-to-All groups - with ps.set_current_parallel_state("test_all_to_all"): - all_to_all_groups = ps.initialize_all_to_all_groups() - - # Verify groups are created + all_to_all_groups = state.initialize_all_to_all_groups() assert isinstance(all_to_all_groups, dict) assert len(all_to_all_groups) > 0 - # Test backward compatibility interface - with ps.set_current_parallel_state("test_all_to_all"): - compat_groups = ps._get_local_all_to_all_group() - assert compat_groups == all_to_all_groups - - # Test training for 5 batches - data_loader = random_dataloader(model=engine.module, total_samples=20, hidden_dim=16, device=engine.device) + data_loader = random_dataloader(model=engine, + total_samples=10, + hidden_dim=16, + device=engine.device, + dtype=DTYPE) for i, batch in enumerate(data_loader): - if i >= 5: + if i >= 3: break loss = engine(batch[0], batch[1]) assert loss is not None From e74fe736b406c30501f36d904492b0c4f176620f Mon Sep 17 00:00:00 2001 From: yunqing Date: Wed, 4 Mar 2026 09:57:27 +0800 Subject: [PATCH 20/23] refactor: unify sequence-data parallel API naming - parallel_state.py: rename get_sequence_and_data_parallel_{group, world_size,rank} to drop redundant 'and', making them consistent with groups.py hasattr checks; add get_model_parallel_{world_size, rank} backward-compat methods for ZeRO optimizer under SP scenarios - parallel_state_deepspeed.py: rename corresponding module-level functions to match the updated ParallelState method names; update docstrings to document groups.py compatible interface contract Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 59 ++++++++++++--------- deepspeed/utils/parallel_state_deepspeed.py | 22 ++++---- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 5d213c369787..1569227f1f0f 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -885,11 +885,16 @@ def get_sequence_parallel_group(self, check_initialized=True): assert self.sequence_parallel_group is not None, "sequence parallel group is not initialized" return self.sequence_parallel_group - def get_sequence_and_data_parallel_group(self, check_initialized=True): - """Get the sequence and data parallel group the caller rank belongs to.""" - if check_initialized: - assert self.sequence_and_data_parallel_group is not None, "sequence and data parallel group is not initialized" - return self.sequence_and_data_parallel_group + def get_sequence_data_parallel_group(self, check_initialized=True): + """Get the sequence and data parallel group the caller rank belongs to. + + DeepSpeed groups.py compatible interface: groups._get_sequence_data_parallel_group() + checks hasattr(mpu, 'get_sequence_data_parallel_group') to route ZeRO optimizer + to the correct all-reduce group spanning both SP and DP dimensions. + """ + if self.sequence_and_data_parallel_group is not None: + return self.sequence_and_data_parallel_group + return self.get_data_parallel_group() def get_embedding_group(self, check_initialized=True): """Get the embedding group the caller rank belongs to.""" @@ -987,7 +992,7 @@ def get_sequence_parallel_rank(self): return self.get_sequence_parallel_group().rank() return 0 - def get_sequence_and_data_parallel_world_size(self): + def get_sequence_data_parallel_world_size(self): """Return world size for the sequence and data parallel group.""" if dist.is_available() and dist.is_initialized(): if self.sequence_and_data_parallel_group is not None: @@ -995,7 +1000,7 @@ def get_sequence_and_data_parallel_world_size(self): return self.get_data_parallel_world_size() return 0 - def get_sequence_and_data_parallel_rank(self): + def get_sequence_data_parallel_rank(self): """Return caller's rank in the sequence and data parallel group.""" if dist.is_available() and dist.is_initialized(): if self.sequence_and_data_parallel_group is not None: @@ -1003,13 +1008,34 @@ def get_sequence_and_data_parallel_rank(self): return self.get_data_parallel_rank() return 0 - # ---- DeepSpeed mpu compatibility aliases ---- - # groups.py and engine.py call these methods on the mpu object. + # ============================================================================ + # Backward Compatibility Methods for DeepSpeed ZeRO + # ============================================================================ def get_model_parallel_world_size(self): + """Return world size for the model parallel group. + + Backward compatibility method for DeepSpeed ZeRO optimizer. + In SP scenarios, model_parallel (TP) size is always 1 since SP cannot coexist with TP. + In non-SP scenarios, model_parallel refers to tensor parallel (TP). + """ + if self.sequence_parallel_group is not None: + # SP is enabled, model_parallel (TP) size must be 1 + return 1 + # No SP, return TP size return self.get_tensor_model_parallel_world_size() def get_model_parallel_rank(self): + """Return caller's rank for the model parallel group. + + Backward compatibility method for DeepSpeed ZeRO optimizer. + In SP scenarios, model_parallel (TP) rank is always 0 since SP cannot coexist with TP. + In non-SP scenarios, model_parallel refers to tensor parallel (TP). + """ + if self.sequence_parallel_group is not None: + # SP is enabled, model_parallel (TP) rank must be 0 + return 0 + # No SP, return TP rank return self.get_tensor_model_parallel_rank() def get_tensor_model_parallel_src_rank(self): @@ -1021,21 +1047,6 @@ def get_tensor_model_parallel_src_rank(self): def get_data_parallel_group_ranks(self): return self.data_parallel_global_ranks - def get_sequence_data_parallel_group(self): - if self.sequence_and_data_parallel_group is not None: - return self.sequence_and_data_parallel_group - return self.get_data_parallel_group() - - def get_sequence_data_parallel_world_size(self): - if self.sequence_and_data_parallel_group is not None: - return self.get_sequence_and_data_parallel_world_size() - return self.get_data_parallel_world_size() - - def get_sequence_data_parallel_rank(self): - if self.sequence_and_data_parallel_group is not None: - return self.get_sequence_and_data_parallel_rank() - return self.get_data_parallel_rank() - # ---- end compatibility aliases ---- def is_initialized(self): diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index eb768fd83815..e54fe0db75ba 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -439,37 +439,41 @@ def get_sequence_parallel_rank(name: Optional[str] = None): return get_parallel_state(name).get_sequence_parallel_rank() -def get_sequence_and_data_parallel_group(name: Optional[str] = None): +def get_sequence_data_parallel_group(name: Optional[str] = None): """Get the sequence and data parallel group the caller rank belongs to. Args: name: Optional name of the parallel state instance. If None, uses current active instance. - DeepSpeed-compatible interface. + DeepSpeed groups.py compatible interface: groups._get_sequence_data_parallel_group() + checks hasattr(mpu, 'get_sequence_data_parallel_group') to route ZeRO optimizer + to the correct all-reduce group spanning both SP and DP dimensions. """ - return get_parallel_state(name).get_sequence_and_data_parallel_group() + return get_parallel_state(name).get_sequence_data_parallel_group() -def get_sequence_and_data_parallel_world_size(name: Optional[str] = None): +def get_sequence_data_parallel_world_size(name: Optional[str] = None): """Return world size for the sequence and data parallel group. Args: name: Optional name of the parallel state instance. If None, uses current active instance. - DeepSpeed-compatible interface. + DeepSpeed groups.py compatible interface: groups._get_sequence_data_parallel_world_size() + checks hasattr(mpu, 'get_sequence_data_parallel_world_size') to route ZeRO optimizer. """ - return get_parallel_state(name).get_sequence_and_data_parallel_world_size() + return get_parallel_state(name).get_sequence_data_parallel_world_size() -def get_sequence_and_data_parallel_rank(name: Optional[str] = None): +def get_sequence_data_parallel_rank(name: Optional[str] = None): """Return caller's rank in the sequence and data parallel group. Args: name: Optional name of the parallel state instance. If None, uses current active instance. - DeepSpeed-compatible interface. + DeepSpeed groups.py compatible interface: groups._get_sequence_data_parallel_rank() + checks hasattr(mpu, 'get_sequence_data_parallel_rank') to route ZeRO optimizer. """ - return get_parallel_state(name).get_sequence_and_data_parallel_rank() + return get_parallel_state(name).get_sequence_data_parallel_rank() # ============================================================================ From 445b71e788646d31718863a77460463915077d5f Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Wed, 4 Mar 2026 11:17:08 +0800 Subject: [PATCH 21/23] Rename parallel_state_deepspeed to parallel_state_wrappers Naming a utility module after xxx_deepspeed doesn't make a lot of sense. Rename the module to parallel_state_wrappers to reflect that it essentially wraps the ParallelState class for easier creation and accessing. Signed-off-by: Junjie Mao --- ...eepspeed.py => parallel_state_wrappers.py} | 6 ++--- docs/_pages/training.md | 8 +++--- docs/code-docs/source/initialize.rst | 4 +-- ...te_deepspeed.py => test_parallel_state.py} | 26 +++++++++---------- 4 files changed, 22 insertions(+), 22 deletions(-) rename deepspeed/utils/{parallel_state_deepspeed.py => parallel_state_wrappers.py} (99%) rename tests/unit/utils/{test_parallel_state_deepspeed.py => test_parallel_state.py} (93%) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_wrappers.py similarity index 99% rename from deepspeed/utils/parallel_state_deepspeed.py rename to deepspeed/utils/parallel_state_wrappers.py index e54fe0db75ba..3d3d4b9d3f34 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_wrappers.py @@ -33,11 +33,11 @@ Usage: # Basic usage (single global instance): - from parallel_state_deepspeed import get_data_parallel_group + from parallel_state_wrappers import get_data_parallel_group dp_group = get_data_parallel_group() # Multi-instance usage (for RL scenarios): - from parallel_state_deepspeed import ( + from parallel_state_wrappers import ( get_parallel_state_instance, set_current_parallel_state, get_data_parallel_group, @@ -896,7 +896,7 @@ def get_value(param_value, config_key, default_value): "hierarchical_context_parallel_sizes", "expert_model_parallel_size", "num_distributed_optimizer_instances", "expert_tensor_parallel_size", "distributed_timeout_minutes", "order", "create_gloo_process_groups" } - + for key, value in init_kwargs.items(): # Skip unsupported parameters if key not in supported_params: diff --git a/docs/_pages/training.md b/docs/_pages/training.md index bdae8b563807..9dee60cda1db 100644 --- a/docs/_pages/training.md +++ b/docs/_pages/training.md @@ -256,7 +256,7 @@ context parallelism (CP), and expert parallelism (EP). You can initialize the parallel state either explicitly or from a DeepSpeed config: ```python -from deepspeed.utils import parallel_state_deepspeed as ps +from deepspeed.utils import parallel_state_wrappers as ps # Option 1: Initialize from config dict (also works with DeepSpeedConfig objects) config_dict = { @@ -295,7 +295,7 @@ Config keys support dot-separated paths for nested dictionaries. For example, the top level or `"tensor_parallel.autotp_size"` in a nested config. ```python -from deepspeed.utils import parallel_state_deepspeed as ps +from deepspeed.utils import parallel_state_wrappers as ps # Override specific parameters while reading others from config parallel_state = ps.initialize_parallel_state_from_config( @@ -311,7 +311,7 @@ In reinforcement learning scenarios where multiple models (e.g., actor and criti require different parallelism configurations, you can create named instances: ```python -from deepspeed.utils import parallel_state_deepspeed as ps +from deepspeed.utils import parallel_state_wrappers as ps # Create separate parallel state instances actor_ps = ps.initialize_parallel_state_from_config( @@ -333,7 +333,7 @@ with ps.set_current_parallel_state("critic"): #### Compatibility with Existing Code -The module-level functions in `parallel_state_deepspeed` (such as +The module-level functions in `parallel_state_wrappers` (such as `get_data_parallel_group()`, `get_tensor_model_parallel_world_size()`, etc.) operate on the current active `ParallelState` instance, preserving backward compatibility with code written against the previous `groups.py` API. diff --git a/docs/code-docs/source/initialize.rst b/docs/code-docs/source/initialize.rst index 172376043229..df043555d9d5 100644 --- a/docs/code-docs/source/initialize.rst +++ b/docs/code-docs/source/initialize.rst @@ -59,7 +59,7 @@ Example usage: .. code-block:: python - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps config_dict = { "train_micro_batch_size_per_gpu": 1, @@ -75,7 +75,7 @@ Example usage: mpu=parallel_state, ) -.. autofunction:: deepspeed.utils.parallel_state_deepspeed.initialize_parallel_state_from_config +.. autofunction:: deepspeed.utils.parallel_state_wrappers.initialize_parallel_state_from_config .. autoclass:: deepspeed.utils.parallel_state.ParallelState :members: initialize_model_parallel, is_initialized, get_tensor_model_parallel_group, get_data_parallel_group, get_pipeline_model_parallel_group, get_sequence_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_data_parallel_world_size, get_data_parallel_rank, get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_rank diff --git a/tests/unit/utils/test_parallel_state_deepspeed.py b/tests/unit/utils/test_parallel_state.py similarity index 93% rename from tests/unit/utils/test_parallel_state_deepspeed.py rename to tests/unit/utils/test_parallel_state.py index 4356d2872001..7a9edd8b6db9 100644 --- a/tests/unit/utils/test_parallel_state_deepspeed.py +++ b/tests/unit/utils/test_parallel_state.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team # DeepSpeed Team """ @@ -55,7 +55,7 @@ def _train_steps(self, engine, steps=3): def test_basic_mpu_usage(self): """Test basic TP with ParallelState instance as mpu""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps state = ps.get_parallel_state_instance("test_basic") state.initialize_model_parallel(tensor_model_parallel_size=2) @@ -74,7 +74,7 @@ def test_basic_mpu_usage(self): def test_config_driven_mpu(self): """Test mpu initialized from config with tensor_model_parallel_size""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps parallel_config = { "tensor_parallel": { @@ -96,7 +96,7 @@ def test_config_driven_mpu(self): def test_multi_instance_mpu(self): """Test multiple named instances as mpu (Actor-Critic scenario)""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps actor_state = ps.get_parallel_state_instance("actor") actor_state.initialize_model_parallel(tensor_model_parallel_size=2) @@ -129,7 +129,7 @@ def test_multi_instance_mpu(self): def test_mpu_with_zero_stage1(self): """Test mpu integration with ZeRO Stage 1""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps state = ps.get_parallel_state_instance("test_zero") state.initialize_model_parallel(tensor_model_parallel_size=2) @@ -162,7 +162,7 @@ def test_mpu_with_zero_stage1(self): def test_deepspeed_config_uses_mpu(self): """Test DeepSpeedConfig correctly uses mpu for world_size calculation""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps from deepspeed.runtime.config import DeepSpeedConfig state = ps.get_parallel_state_instance("test_config") @@ -177,7 +177,7 @@ def test_deepspeed_config_uses_mpu(self): def test_mpu_without_parallelism(self): """Test mpu with all parallelism dimensions = 1 (no parallelism)""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps state = ps.get_parallel_state_instance("test_no_parallel") state.initialize_model_parallel() @@ -197,7 +197,7 @@ def test_mpu_without_parallelism(self): def test_mpu_with_different_orders(self): """Test mpu with custom parallel dimension order""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps state = ps.get_parallel_state_instance("test_order") state.initialize_model_parallel(tensor_model_parallel_size=2, @@ -226,7 +226,7 @@ class TestParallelStateConfigPriority(DistributedTest): def test_param_overrides_config(self): """Function parameter should override nested config value""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps config = { "tensor_parallel": { @@ -245,7 +245,7 @@ def test_param_overrides_config(self): def test_config_overrides_default(self): """Nested config value (tensor_parallel.autotp_size) should override default""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps config = { "tensor_parallel": { @@ -266,14 +266,14 @@ class TestParallelStateValidation(DistributedTest): def test_context_parallel_not_supported(self): """Test that CP > 1 raises NotImplementedError""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps with pytest.raises(NotImplementedError, match="does not support context_parallel_size"): ps.initialize_parallel_state_from_config({}, name="cp_test", context_parallel_size=2) def test_hierarchical_cp_not_supported(self): """Test that hierarchical CP raises NotImplementedError""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps with pytest.raises(NotImplementedError, match="does not support hierarchical_context_parallel_sizes"): ps.initialize_parallel_state_from_config({}, name="hcp_test", hierarchical_context_parallel_sizes=[2, 2]) @@ -286,7 +286,7 @@ class TestAllToAllGroupsWithMPU(DistributedTest): def test_all_to_all_groups_with_mpu(self): """Test All-to-All groups work with mpu in initialize""" - from deepspeed.utils import parallel_state_deepspeed as ps + from deepspeed.utils import parallel_state_wrappers as ps state = ps.get_parallel_state_instance("test_all_to_all") state.initialize_model_parallel() From da15a8a05ff765d5ae2e43f5a873b01e8b94f04d Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Wed, 4 Mar 2026 16:05:19 +0800 Subject: [PATCH 22/23] parallel_state: Take parallelism sizes from existing parameters only When extracting parallelism sizes from the configuration JSON, only check officially-defined parameters. Not all parallelism sizes are configurable via JSON today, but whether and how such parameters should be added is a separate topic from this PR which is focused on unifying process group management. Signed-off-by: Junjie Mao --- deepspeed/utils/parallel_state_wrappers.py | 51 ++++++++-------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/deepspeed/utils/parallel_state_wrappers.py b/deepspeed/utils/parallel_state_wrappers.py index 3d3d4b9d3f34..c493cb4a7760 100644 --- a/deepspeed/utils/parallel_state_wrappers.py +++ b/deepspeed/utils/parallel_state_wrappers.py @@ -829,46 +829,29 @@ def get_value(param_value, config_key, default_value): if param_value is not None: return param_value - candidates = config_key if isinstance(config_key, (list, tuple)) else [config_key] - for key in candidates: - found, value = _resolve_nested_key(config_dict, key) + if config_key is not None: + found, value = _resolve_nested_key(config_dict, config_key) if found: return value return default_value init_kwargs = { - "tensor_model_parallel_size": - get_value(tensor_model_parallel_size, - ["tensor_model_parallel_size", "tensor_parallel.autotp_size"], 1), - "pipeline_model_parallel_size": - get_value(pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), - "virtual_pipeline_model_parallel_size": - get_value(virtual_pipeline_model_parallel_size, "virtual_pipeline_model_parallel_size", None), - "pipeline_model_parallel_comm_backend": - get_value(pipeline_model_parallel_comm_backend, "pipeline_model_parallel_comm_backend", None), - "context_parallel_size": - get_value(context_parallel_size, "context_parallel_size", 1), - "sequence_parallel_size": - get_value(sequence_parallel_size, "sequence_parallel_size", 1), - "hierarchical_context_parallel_sizes": - get_value(hierarchical_context_parallel_sizes, "hierarchical_context_parallel_sizes", None), - "expert_model_parallel_size": - get_value(expert_model_parallel_size, "expert_model_parallel_size", 1), - "num_distributed_optimizer_instances": - get_value(num_distributed_optimizer_instances, "num_distributed_optimizer_instances", 1), - "expert_tensor_parallel_size": - get_value(expert_tensor_parallel_size, "expert_tensor_parallel_size", None), - "nccl_communicator_config_path": - get_value(nccl_communicator_config_path, "nccl_communicator_config_path", None), - "distributed_timeout_minutes": - get_value(distributed_timeout_minutes, "distributed_timeout_minutes", 30), - "order": - get_value(order, "order", "tp-ep-dp-pp"), - "create_gloo_process_groups": - get_value(create_gloo_process_groups, "create_gloo_process_groups", False), - "high_priority_stream_groups": - get_value(high_priority_stream_groups, "high_priority_stream_groups", None), + "tensor_model_parallel_size": get_value(tensor_model_parallel_size, "tensor_parallel.autotp_size", 1), + "pipeline_model_parallel_size": get_value(pipeline_model_parallel_size, None, 1), + "virtual_pipeline_model_parallel_size": get_value(virtual_pipeline_model_parallel_size, None, None), + "pipeline_model_parallel_comm_backend": get_value(pipeline_model_parallel_comm_backend, None, None), + "context_parallel_size": get_value(context_parallel_size, None, 1), + "sequence_parallel_size": get_value(sequence_parallel_size, None, 1), + "hierarchical_context_parallel_sizes": get_value(hierarchical_context_parallel_sizes, None, None), + "expert_model_parallel_size": get_value(expert_model_parallel_size, None, 1), + "num_distributed_optimizer_instances": get_value(num_distributed_optimizer_instances, None, 1), + "expert_tensor_parallel_size": get_value(expert_tensor_parallel_size, None, None), + "nccl_communicator_config_path": get_value(nccl_communicator_config_path, None, None), + "distributed_timeout_minutes": get_value(distributed_timeout_minutes, None, 30), + "order": get_value(order, None, "tp-ep-dp-pp"), + "create_gloo_process_groups": get_value(create_gloo_process_groups, None, False), + "high_priority_stream_groups": get_value(high_priority_stream_groups, None, None), } # Validate context_parallel_size From b1e16c43ec46366281af2d01f4ac79c4e164992b Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Wed, 4 Mar 2026 17:48:11 +0800 Subject: [PATCH 23/23] tests: Test ZeRO + Ulysses SP training using ParallelState Signed-off-by: Yuqing Li Signed-off-by: Junjie Mao --- tests/unit/utils/test_parallel_state.py | 158 +++++++++++++++++++++++- 1 file changed, 157 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_parallel_state.py b/tests/unit/utils/test_parallel_state.py index 7a9edd8b6db9..f3b6e48b901b 100644 --- a/tests/unit/utils/test_parallel_state.py +++ b/tests/unit/utils/test_parallel_state.py @@ -15,8 +15,9 @@ import torch import deepspeed import deepspeed.comm as dist -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel, random_dataloader +from unit.util import torch_assert_close DTYPE = torch.float @@ -315,3 +316,158 @@ def test_all_to_all_groups_with_mpu(self): assert loss is not None engine.backward(loss) engine.step() + + +@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) +@pytest.mark.parametrize("sp_size", [2]) +class TestUlyssesSPWithParallelState(DistributedTest): + world_size = 4 + + def test_ulysses_sp_parallel_state(self, zero_stage, sp_size): + """Compare loss using mpu=ParallelState to test parallel_state's MPU interface.""" + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from deepspeed.utils.parallel_state_wrappers import get_parallel_state_instance + from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter + from deepspeed.runtime.utils import move_to_device + from deepspeed.accelerator import get_accelerator + + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + sequence_parallel_size = sp_size + micro_batch_size = 1 + num_iterations = 10 + + seed = 42 + torch.manual_seed(seed) + get_accelerator().manual_seed_all(seed) + + base_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + } + } + dtype = preferred_dtype() + if dtype == torch.bfloat16: + base_config["bf16"] = {"enabled": True} + elif dtype == torch.float16: + base_config["fp16"] = {"enabled": True, "loss_scale": 1.0} + + def collate_fn(batch): + input_ids, position_ids = batch[0] + return dict(input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0)) + + input_ids = torch.randint(1, 100, (num_iterations * 2, 6)) + position_ids = torch.arange(6).unsqueeze(0).expand(num_iterations * 2, -1) + ds = torch.utils.data.TensorDataset(input_ids, position_ids) + dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn, shuffle=False) + batches_full = list(dl) + + hf_model_config = AutoConfig.from_pretrained(model_name_or_path) + core_attn_implementation = "sdpa" + core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation] + + # No-SP model + + model_no_sp = AutoModelForCausalLM.from_pretrained(model_name_or_path) + model_no_sp, _, _, _ = deepspeed.initialize(config=base_config, + model=model_no_sp, + model_parameters=model_no_sp.parameters(), + mpu=None) + + losses_no_sp = [] + for i in range(num_iterations): + batch = move_to_device(batches_full[i], model_no_sp.device) + loss = model_no_sp(**batch).loss + model_no_sp.backward(loss) + model_no_sp.step() + losses_no_sp.append(loss.detach().cpu()) + + # SP model + # + # UlyssesSPAttentionHF.register_with_transformers() creates and returns an SP-specific parallel state object. + # Register explicitly here before UlyssesSPAttentionHF is adapted to the generic ParallelState. + + instance_name = "sp_test_psm" + ps = get_parallel_state_instance(instance_name) + ps.initialize_model_parallel(sequence_parallel_size=sequence_parallel_size, order="sp-dp") + # Get SP group info from the initialized instance + sp_group = ps.get_sequence_parallel_group() + sp_world_size = ps.get_sequence_parallel_world_size() + sp_rank = ps.get_sequence_parallel_rank() + + uattn_sp = UlyssesSPAttentionHF(attn=core_attn_function, + batch_size=micro_batch_size, + attn_head_count=hf_model_config.num_attention_heads, + attn_head_size=getattr( + hf_model_config, "head_dim", + hf_model_config.hidden_size // hf_model_config.num_attention_heads), + kv_head_count=hf_model_config.num_key_value_heads, + num_hidden_layers=hf_model_config.num_hidden_layers, + process_group=sp_group, + seq_length_is_variable=True, + local_seq_length=None, + global_seq_length=None) + + def uattn_sp_wrapper(module, query, key, value, attention_mask, *args, **kwargs): + return uattn_sp(module, query, key, value, None, *args, **kwargs) + + for key in list(ALL_ATTENTION_FUNCTIONS.keys()): + ALL_ATTENTION_FUNCTIONS[key] = uattn_sp_wrapper + + config_sp = dict(base_config) + config_sp["sequence_parallel_size"] = sequence_parallel_size + config_sp["gradient_accumulation_steps"] = 2 + + model_sp = AutoModelForCausalLM.from_pretrained(model_name_or_path) + + # Pass ps instance as mpu + model_sp, _, _, _ = deepspeed.initialize(config=config_sp, + model=model_sp, + model_parameters=model_sp.parameters(), + mpu=ps) + + ds_sp = torch.utils.data.TensorDataset(input_ids, position_ids) + dl_sp = torch.utils.data.DataLoader(ds_sp, batch_size=micro_batch_size, collate_fn=collate_fn, shuffle=False) + dl_sp_sharded = UlyssesSPDataLoaderAdapter(dl_sp, + sp_rank=sp_rank, + sp_group=sp_group, + sp_world_size=sp_world_size, + device=model_sp.device) + + losses_sp = [] + loss_accum = [] + + for i, batch_sp in enumerate(dl_sp_sharded): + if i >= num_iterations * 2: + break + batch_sp = move_to_device(batch_sp, model_sp.device) + outputs_sp = model_sp(**batch_sp) + shift_labels = batch_sp["shift_labels"] + loss_sp_local = model_sp.module.loss_function(logits=outputs_sp.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=model_sp.module.config.vocab_size) + losses_per_rank = torch.distributed.nn.functional.all_gather(loss_sp_local, group=sp_group) + good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum(losses_per_rank[r] * good_tokens_per_rank[r] for r in range(sp_world_size)) + loss_sp = total_loss / sum(good_tokens_per_rank) + model_sp.backward(loss_sp) + model_sp.step() + loss_accum.append(loss_sp.detach().cpu()) + if len(loss_accum) == 2: + avg_loss = torch.stack(loss_accum).mean() + losses_sp.append(avg_loss) + loss_accum = [] + + for i in range(num_iterations): + torch_assert_close(losses_no_sp[i], losses_sp[i], rtol=1.6e-02, atol=1e-03)