Skip to content

Introduce Megatron-style parallel state management#7726

Open
eternalNight wants to merge 24 commits intodeepspeedai:masterfrom
openanolis:eternalNight/unify_process_group_management
Open

Introduce Megatron-style parallel state management#7726
eternalNight wants to merge 24 commits intodeepspeedai:masterfrom
openanolis:eternalNight/unify_process_group_management

Conversation

@eternalNight
Copy link
Contributor

@eternalNight eternalNight commented Dec 15, 2025

Summary

This PR, authored by @hahaha3210 and @Daydreamer-Li, introduces ParallelState, a class that manages process groups for an arbitrary combination of parallel strategies including TP, EP, PP and DP.

As is discussed in #7680, the primary approach is borrowing the process group creation logic from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py but encapsulating the states (i.e., process groups, ranks and world sizes) into a class. This design enables the coexistence of multiple, independent parallelism configurations within a single process, which is particularly valuable in scenarios involving multiple models, such as in reinforcement learning (RL) workflows. Objects of ParallelState can be created prior to calls to deepspeed.initialize so that process groups are available to custom modules, such as UlyssesSPAttentionHF, at an early stage.

Compatibility of ParallelState and current process group management facilities (including deepspeed.runtime.sequence_parallel.parallel_state_sp and deepspeed.utils.groups) is tested by test_mpu.py.

Usage

Basic usage (single global instance passed to deepspeed.initialize):

    from deepspeed import DeepSpeedConfig
    ds_config = DeepSpeedConfig("config.json")

    parallel_state = ps.initialize_parallel_state_from_config(config_dict)
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        config=config_dict,
        mpu=parallel_state,
    )

The ParallelState class is compatible with Megatron's parallel_state module and thus can also be passed as mpu to existing implementations of parallelism techniques.

Opens

  • Support for Ulysses SP is yet to be added. Dimensions for Ulysses SP can now be set by passing sequence_parallel_size to initialize_parallel_state_from_config.
  • Support creating a ParallelState from a config object rather than specifying different parallel dimensions explicitly. initialize_parallel_state_from_config now accepts a config argument which can be either a dict or a config object.
  • Are wrappers in parallel_state_deepspeed.py necessary? If so, is there a better way to implement more concisely its APIs sharing similar code patterns? parallel_state_deepspeed.py is preserved and renamed as parallel_state_wrappers.py to provide helper methods (create from config, wrap with context, etc.) around ParallelState for easier use.
  • Are GLOO process groups necessary for DeepSpeed? If not, we can strip them from the draft. GLOO process groups are not created by default as deepspeed.comm.new_group does not accept a backend argument yet. The related code is kept so that we can easily bring it back when necessary.
  • Tweaking NCCL options require ProcessGroupNCCL.options from torch.distributed, and that is not provided by deepspeed.comm today. Should we introduce that to deepspeed.comm, or make the format-checking script allowing that specific use of torch.distributed? Leave it as future work.

Future work

  1. Configuration of process group options (only ProcessGroupNCCL.Options as of pytorch 2.9).
  2. Add config items for parallelism dimensions not covered in json today.
  3. Lazy creation of process groups. The Metatron-style manager creates all process groups for all potential combinations of parallelism in one shot. That won't impact much on GPU memory usage as recent NCCL allocates GPU memory on the first collective operation, not on communicator creation. But still it forks a significant amount of threads (~2000 on 8 GPUs) and may hit container PID limits. Possible optimizations include (1) not creating pgs for parallelism of size 1 at all, assuming they'll never be used; (2) creating pgs lazily (i.e., on first getter), assuming all ranks will call the same getter in the same order; (3) merge pgs among the same ranks, at the cost of synchronizing ops on them. That will be looked into in a follow-up PR.
  4. While unit tests in this PR works properly, there're still code referring to specific mpu/parallel_state modules explicitly and thus not friendly to explicitly-passed mpu.

@sfc-gh-truwase
Copy link
Collaborator

@stas00 @tohtana @delock FYI

@delock
Copy link
Collaborator

delock commented Dec 16, 2025

I like the idea of puting parallel dimension in a single place rather than relying on user reading deeply into document to figure out how to turn on each parallelism dimension. I also agree with the open that the class can be created from a config object. Does it make sense to have the config in config.json file, or a seperate config file is more flexible?

@eternalNight
Copy link
Contributor Author

I like the idea of puting parallel dimension in a single place rather than relying on user reading deeply into document to figure out how to turn on each parallelism dimension. I also agree with the open that the class can be created from a config object. Does it make sense to have the config in config.json file, or a seperate config file is more flexible?

My rough idea is to reuse the current config.json which already provides dimensions of various parallel techniques. Moving parallelism-related configs to a separate file is such a huge change that can impose a big obstable to users who try to upgrade.

@eternalNight eternalNight force-pushed the eternalNight/unify_process_group_management branch 2 times, most recently from 383eeb1 to fa34116 Compare January 26, 2026 11:14
@eternalNight eternalNight marked this pull request as ready for review January 26, 2026 11:15
@eternalNight eternalNight self-assigned this Jan 27, 2026
@tohtana
Copy link
Collaborator

tohtana commented Jan 28, 2026

Hi @eternalNight, this is amazing!
Do you have a plan to add a document and unit tests? They will help users learn how to use this new style of parallelization configs.

@eternalNight
Copy link
Contributor Author

Hi @eternalNight, this is amazing! Do you have a plan to add a document and unit tests? They will help users learn how to use this new style of parallelization configs.

Good point & will do.

@Daydreamer-Li Daydreamer-Li force-pushed the eternalNight/unify_process_group_management branch from 45bfe99 to 39cd316 Compare February 11, 2026 03:56
@sfc-gh-truwase
Copy link
Collaborator

@eternalNight, I am curious how this is going? Thanks

@eternalNight
Copy link
Contributor Author

@eternalNight, I am curious how this is going? Thanks

We're just back from the lunar new year holidays and are polishing the documentations. Will update another version this week probably.

@delock
Copy link
Collaborator

delock commented Mar 4, 2026

@eternalNight Can you sign-off to fix DCO error? Thanks!

@eternalNight
Copy link
Contributor Author

@eternalNight Can you sign-off to fix DCO error? Thanks!

All commits are signed on this branch. The DCO errors are due to the mismatch between author info in the commit metadata (which are less official due to local configs) and the signed-off-by (which uses the official names and mail addresses).

言枢 and others added 13 commits March 4, 2026 16:28
Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Add comprehensive config.json support for parallelism configuration with
smart priority handling and context parallel validation.

Key features:
- Support all parallelism dimensions via config.json
- Config priority: config file > function params > defaults
- Conflict detection with warning logs
- Context parallel validation (CP must be 1)
- Backward compatible with existing code

Changes:
- Add 14 optional parameters to initialize_parallel_state_from_config()
- Implement 3-tier priority system with conflict detection
- Add CP validation: raise NotImplementedError if CP > 1
- Update default order from "tp-cp-ep-dp-pp" to "tp-ep-dp-pp"
- Add detailed docstrings and usage examples

This allows users to configure all parallel dimensions in config.json
instead of reading documentation and manually calling initialize_model_parallel.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
- Extend RankGenerator to include SP dimension and enforce TP/PP/EP compatibility

- Initialize sequence parallel and sequence+data parallel process groups in ParallelState.initialize_model_parallel

- Add sequence-parallel accessor stubs in parallel_state_deepspeed for future unified SP interfaces

Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
Remove Chinese inline comment from the example config.json
docstring to comply with DeepSpeed community coding standards.

This ensures all comments and documentation are in English only.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
The deepspeed.comm.new_group() wrapper only accepts 'ranks' parameter,
but _create_group() needs to pass additional parameters like timeout,
backend, pg_options, etc. to support advanced process group configuration.

This fix uses torch.distributed.new_group() directly to support all
parameters while still using deepspeed.comm for other operations.

Fixes TypeError: new_group() got an unexpected keyword argument 'timeout'

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
- Include sequence_parallel_size in model_size calculation
- Fix SP group count: num_sequence_parallel_groups = data_parallel_size
- Use consecutive rank grouping for SP (not RankGenerator)
- SP uses different parallelism model than TP/PP/CP/EP

Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
Updated _create_group() to use deepspeed.comm.new_group() which currently
only supports 'ranks' parameter. Other parameters (timeout, backend,
pg_options, etc.) are commented out and documented in TODO comments.

For non-nccl backends, the function returns None with a warning, as these
are not yet supported by the deepspeed.comm interface.

These parameters will be enabled once DeepSpeed's comm interface is enhanced
to support them.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
Migrate _get_local_all_to_all_group functionality from groups.py to the
new parallel_state architecture to support ZeRO++ quantized gradients.

Changes in parallel_state.py:
- Add all_to_all_groups and all_to_all_initialized to ParallelState class
- Implement initialize_all_to_all_groups() method to create local and global
  All-to-All groups based on node topology
- Implement get_all_to_all_groups() method to retrieve initialized groups

Changes in parallel_state_deepspeed.py:
- Add initialize_all_to_all_groups() wrapper function
- Add get_all_to_all_groups() wrapper function
- Add _get_local_all_to_all_group() for backward compatibility with groups.py

Benefits:
- Supports multi-instance scenarios (e.g., RL with actor/critic models)
- Consistent with the new parallel_state architecture
- Maintains backward compatibility with existing groups.py interface
- Enables future config-based initialization of All-to-All groups

Note: This does not remove the implementation from groups.py yet to maintain
backward compatibility during the transition period.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
DeepSpeed's comm interface does not support gloo backend,
so set create_gloo_process_groups default to False.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
- Replace manual consecutive rank grouping with RankGenerator.get_ranks('sp')
- Remove redundant world_size validation logic (handled by RankGenerator)
- Reduce SP group creation code from 41 lines to 26 lines
- Maintain same SP group topology: consecutive ranks [0,1], [2,3] for sp_size=2
- Fix code style issues: remove unused import, update warning message

This change unifies process group creation by leveraging RankGenerator's
orthogonal parallelism algorithm, which naturally produces consecutive
rank grouping when order='sp-dp'.

Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
1. Change create_gloo_process_groups from true to false
   - Aligns with default value change in previous commit
   - DeepSpeed comm interface does not support gloo backend

2. Correct Sequence Parallel usage description
   - SP is included in model_size calculation (tp * pp * cp * sp)
   - SP can be used together with TP/PP/EP
   - Number of SP groups equals data_parallel_size
   - SP uses consecutive rank grouping (not orthogonal like TP/PP/CP/EP)

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
Remove is_torch_min_version function that is never called in the codebase.
This reduces code complexity and removes unnecessary dependencies.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
言枢 and others added 10 commits March 4, 2026 16:28
Remove the 'parallelism' config block concept and directly read from
existing DeepSpeed config fields. This avoids adding new top-level
config structure which requires changes throughout the codebase.

Changes:
- Remove 'parallelism' nested config block from examples
- Read 'sequence_parallel_size' directly from top-level config
- Change priority: function params > config values > defaults (was: config > params > defaults)
- Update create_gloo_process_groups default from True to False
- Simplify documentation to reflect current implementation

This makes the config-based initialization fully backward compatible
without requiring any new config schema validation or parsing logic.

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
The test_mpu.py script is used to verify the equivalence between
existing process group management facilities and the proposed, unified
ParallelState. It is meant to be an temporary helper and will not be
useful after we switch existing implementations to the new interfaces.
Thus remove it from the current PR.

The test is still available at
https://gist.github.com/eternalNight/b76c72216b4be84832b615b76465396f.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
…ig and add integration tests

Add a supported_params whitelist to prevent unsupported parameters
(nccl_communicator_config_path, high_priority_stream_groups) from being
passed to initialize_model_parallel. Also add comprehensive integration
tests for ParallelState as mpu with 5-batch training loops.

Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
…and usability

- Support nested dict config keys via dot-separated paths (e.g.
  "tensor_parallel.autotp_size") so autotp tp_size can be resolved
  from config automatically
- Allow config_key to be a list of candidates tried in order
- Remove unused param_name argument from get_value helper
- Return the ParallelState instance so callers can use it as mpu
  directly, e.g. ps = initialize_parallel_state_from_config(config)

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
- Add "Built-in Parallel State Management" section to training.md
  covering basic usage, config-based initialization, multi-instance
  support for RL scenarios, and backward compatibility
- Add "Parallel State Initialization" section to initialize.rst with
  API references for initialize_parallel_state_from_config and
  ParallelState class

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
Add DeepSpeed mpu compatibility aliases to ParallelState class:
- get_model_parallel_world_size/rank: alias for tensor model parallel
- get_tensor_model_parallel_src_rank: compute first global rank in TP group
- get_data_parallel_group_ranks: expose DP global ranks
- get_sequence_data_parallel_group/world_size/rank: fall back to DP group
  when sequence parallelism is not initialized, fixing the assertion error
  'sequence and data parallel group is not initialized'

Refactor test_parallel_state_deepspeed.py for CI compatibility:
- Reduce world_size from 8 to 4 to match upstream CI hardware
- Pass ParallelState instance directly as mpu instead of using the
  parallel_state_deepspeed module with set_current_parallel_state
- Replace non-standard config keys (sequence_parallel_size, order) with
  standard ones (tensor_parallel.autotp_size)
- Use train_micro_batch_size_per_gpu instead of train_batch_size
- Extract common training loop into _train_steps helper method

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
- parallel_state.py: rename get_sequence_and_data_parallel_{group,
  world_size,rank} to drop redundant 'and', making them consistent
  with groups.py hasattr checks; add get_model_parallel_{world_size,
  rank} backward-compat methods for ZeRO optimizer under SP scenarios

- parallel_state_deepspeed.py: rename corresponding module-level
  functions to match the updated ParallelState method names; update
  docstrings to document groups.py compatible interface contract

Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
Naming a utility module after xxx_deepspeed doesn't make a lot of sense.
Rename the module to parallel_state_wrappers to reflect that it
essentially wraps the ParallelState class for easier creation and
accessing.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
When extracting parallelism sizes from the configuration JSON, only
check officially-defined parameters.

Not all parallelism sizes are configurable via JSON today, but whether
and how such parameters should be added is a separate topic from this PR
which is focused on unifying process group management.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Signed-off-by: Yuqing Li <lyq491672@alibaba-inc.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
@eternalNight eternalNight force-pushed the eternalNight/unify_process_group_management branch from 9e9416a to b1e16c4 Compare March 4, 2026 10:11
@eternalNight
Copy link
Contributor Author

@sfc-gh-truwase @delock @tohtana @stas00 Here's the final refreshment and rebase of the ParallelState class and helpers around it, with unit tests and documentation on usage. Please kindly review.

Future works are listed in the PR description.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants