Skip to content

Commit 45bfe99

Browse files
yunqingcursoragent
andcommitted
fix: filter unsupported params in initialize_parallel_state_from_config and add integration tests
Add a supported_params whitelist to prevent unsupported parameters (nccl_communicator_config_path, high_priority_stream_groups) from being passed to initialize_model_parallel. Also add comprehensive integration tests for ParallelState as mpu with 5-batch training loops. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fa34116 commit 45bfe99

2 files changed

Lines changed: 473 additions & 4 deletions

File tree

deepspeed/utils/parallel_state_deepspeed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,14 +894,24 @@ def get_value(param_name, param_value, config_key, default_value):
894894

895895
# Remove None values for optional parameters (except those that can be None)
896896
# Keep None for: virtual_pipeline_model_parallel_size, pipeline_model_parallel_comm_backend,
897-
# hierarchical_context_parallel_sizes, expert_tensor_parallel_size, nccl_communicator_config_path,
898-
# high_priority_stream_groups
897+
# hierarchical_context_parallel_sizes, expert_tensor_parallel_size
898+
# Note: nccl_communicator_config_path and high_priority_stream_groups are not supported by initialize_model_parallel
899899
filtered_kwargs = {}
900+
supported_params = {
901+
"tensor_model_parallel_size", "pipeline_model_parallel_size", "virtual_pipeline_model_parallel_size",
902+
"pipeline_model_parallel_comm_backend", "context_parallel_size", "sequence_parallel_size",
903+
"hierarchical_context_parallel_sizes", "expert_model_parallel_size", "num_distributed_optimizer_instances",
904+
"expert_tensor_parallel_size", "distributed_timeout_minutes", "order", "create_gloo_process_groups"
905+
}
906+
900907
for key, value in init_kwargs.items():
908+
# Skip unsupported parameters
909+
if key not in supported_params:
910+
continue
911+
# Keep None for parameters that can be None
901912
if value is not None or key in [
902913
"virtual_pipeline_model_parallel_size", "pipeline_model_parallel_comm_backend",
903-
"hierarchical_context_parallel_sizes", "expert_tensor_parallel_size", "nccl_communicator_config_path",
904-
"high_priority_stream_groups"
914+
"hierarchical_context_parallel_sizes", "expert_tensor_parallel_size"
905915
]:
906916
filtered_kwargs[key] = value
907917

0 commit comments

Comments
 (0)