diff --git a/blogs/huggingface-tp/README.md b/blogs/huggingface-tp/README.md
index 44469f1818c3..ceb95417ae5f 100644
--- a/blogs/huggingface-tp/README.md
+++ b/blogs/huggingface-tp/README.md
@@ -230,13 +230,11 @@ Furthermore, if users are not using transformers library, you can replace the ``
# Ongoing Work
- **Optimization**: Communication/Activation optimization.
-- **Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing,
-
+- **Usability**: Support the [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple the AutoTP parser, and expand model testing.
+ - [UPDATE] We now support [custom partitioning](https://deepspeed.readthedocs.io/en/latest/training.html#custom-layer-specs) in the same spirit as HF's partitioning plan, and will build Transformers TP plan support on top of that ([PR](http://github.com/deepspeedai/DeepSpeed/pull/7806)).
Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending.
-
Welcome bug reports, enhancement, and additional model training examples.
-
# Contributors
This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft.
diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py
index 10b71eef7a5c..0d53a172e64e 100755
--- a/deepspeed/__init__.py
+++ b/deepspeed/__init__.py
@@ -48,6 +48,8 @@
from .pipe import PipelineModule
from .git_version_info import version, git_hash, git_branch
+from .runtime.tensor_parallel.init_utils import (load_ds_config, merge_tp_model_init_into_config,
+ record_tp_model_init_args)
def _parse_version(version_str):
@@ -159,17 +161,6 @@ def initialize(args=None,
if config is None and config_params is not None:
config = config_params
- mesh_device = None
- if mesh_param:
- logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
- mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
- #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
- elif config is not None:
- if "sequence_parallel_size" in config and "data_parallel_size" in config:
- logger.info(f"config to Initialize mesh device: {config}")
- mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
- ("data_parallel", "sequence_parallel"))
-
# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
@@ -184,6 +175,26 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
+
+ if not isinstance(config, dict):
+ config = load_ds_config(config)
+
+ mesh_device = None
+ if mesh_param:
+ logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
+ mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
+ #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
+ else:
+ if "sequence_parallel_size" in config and "data_parallel_size" in config:
+ logger.info(f"config to Initialize mesh device: {config}")
+ mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
+ ("data_parallel", "sequence_parallel"))
+
+ merge_tp_model_init_into_config(config, mpu, mesh_param, dist)
+
+ autotp_size = config.get("tensor_parallel", {}).get("autotp_size", 0)
+ if autotp_size and autotp_size > 0:
+ set_autotp_mode(training=True)
if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
set_optimizer_flags(config_class, model)
@@ -379,7 +390,49 @@ def init_inference(model, config=None, **kwargs):
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
"""
- Initialize the model for tensor parallelism.
+ Record tensor-parallel initialization arguments for training.
+
+ Note (compatibility and initialization behavior):
+ AutoTP sharding is applied during ``deepspeed.initialize(...)``. This
+ function exists for backward compatibility and only records TP arguments so
+ they can be validated and merged with the DeepSpeed config at initialization.
+ When you use both (i.e., calling ``set_autotp_mode(training=True)`` and
+ ``deepspeed.tp_model_init`` while also passing the config to
+ ``deepspeed.initialize``), DeepSpeed merges the settings at initialization.
+ Conflicting settings raise an error. The table below summarizes the behavior
+ across input combinations.
+
+ Inputs:
+ - TPI: tp_model_init was called? (Y/N)
+ - TPG: tp_model_init provided tp_group? (Y/N)
+ - CFG: tensor_parallel in DeepSpeed config? (Y/N)
+ - MPU: mpu passed to deepspeed.initialize()? (Y/N)
+
+ | TPI | TPG | CFG | MPU | Outcome | Notes |
+ |-----|-----|-----|-----|----------------------------------------|-------|
+ | N | N | N | N | Error | No TP intent; nothing to initialize |
+ | N | N | N | Y | No AutoTP | mpu may be used for other MP, but TP not enabled |
+ | N | N | Y | N | Init AutoTP from config | Use config; need TP group via config-driven init |
+ | N | N | Y | Y | Init AutoTP from config | mpu used to build TP group |
+ | Y | N | N | N | Error | No TP group source |
+ | Y | N | N | Y | Init AutoTP from tp_model_init | Use recorded args + mpu for TP group |
+ | Y | N | Y | N | Init AutoTP from config | Fill missing from TPI; error on mismatches; need TP group source |
+ | Y | N | Y | Y | Init AutoTP from config | Fill missing from TPI; error on mismatches |
+ | Y | Y | N | N | Init AutoTP from tp_model_init | Use recorded tp_group; config absent |
+ | Y | Y | N | Y | Error | tp_group + mpu conflict |
+ | Y | Y | Y | N | Init AutoTP from config | Error on mismatches; use tp_group from TPI; reject mpu |
+ | Y | Y | Y | Y | Error | tp_group + mpu conflict |
+
+ Field-level merge rules when both tp_model_init and config exist:
+ - Canonical source: config
+ - Allowed: fill missing config fields from tp_model_init
+ - Error on mismatch: autotp_size, dtype, tp_group size or identity
+
+ Extra checks:
+ - If tp_group is provided, reject mpu.
+ - If tp_group is not provided, require mpu (or another TP group source).
+ - If tensor_parallel is absent and only tp_model_init was called, require
+ a TP group source (direct tp_group or mpu).
Args:
model (torch.nn.Module): The model to be initialized.
@@ -387,23 +440,15 @@ def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
dtype (torch.dtype): The data type to be used for the model.
Returns:
- torch.nn.Module: The initialized model with tensor parallelism.
+ torch.nn.Module: The original model (no sharding applied here).
"""
- # avoid re-entry
if hasattr(model, 'ds_autotp_parsed'):
- logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
- return
-
- set_autotp_mode(training=True)
+ logger.warning("ds_autotp_parsed' attribute already exists in the model; tp_model_init is now record-only.")
- from deepspeed.runtime.tensor_parallel import TpTrainingManager
- # The expected usage here is for it to be invoked by transformers package.
+ tp_group = kwargs.get("tp_group")
+ record_tp_model_init_args(tp_size=tp_size, dtype=dtype, tp_group=tp_group, dist_module=dist)
- #TODO: We should provide a custom TP mapping solution without using autoTP
- #as modifying the autoTP logic may be more difficult for users compared to configuring it
-
- model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module
-
- setattr(model, 'ds_autotp_parsed', True)
+ # Keep AutoTP training mode active for backward compatibility.
+ set_autotp_mode(training=True)
return model
diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py
index 9fc2f979a04b..2299ef6e7c3a 100755
--- a/deepspeed/module_inject/__init__.py
+++ b/deepspeed/module_inject/__init__.py
@@ -6,5 +6,6 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
-from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
+from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode, SubParamLinearLayer, SubParamLinearAllreduce
from .policy import DSPolicy
+from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType, AutoTPPresets, merge_autotp_configs
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 82cd9042071e..121e3938444a 100755
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -17,6 +17,7 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode
+from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType
def move(tensor, device, copy=True):
@@ -199,7 +200,8 @@ def __init__(self,
state_dict,
linear_layer_setting,
orig_layer_impl,
- keep_module_on_host=False):
+ keep_module_on_host=False,
+ partition_config: Optional[AutoTPConfig] = None):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
@@ -211,6 +213,7 @@ def __init__(self,
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
+ self.partition_config = partition_config
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
def in_module_list(module, module_list):
@@ -353,6 +356,11 @@ def _replace(self, child, name, conv_linear_layer):
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
+
+ # If partition_config is provided, use the new configurable API
+ if self.partition_config is not None:
+ return self._replace_with_config(child, name)
+
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
@@ -396,6 +404,87 @@ def _replace(self, child, name, conv_linear_layer):
return LinearLayer(child, self.mp_group, name=name)
+ def _replace_with_config(self, child, name):
+ """
+ Replace layer using the new configurable AutoTP API.
+
+ This method uses TPLayerSpec to determine how to partition the layer.
+ """
+ if getattr(child, "replaced", False) == True:
+ return child
+
+ # Build the full parameter name for pattern matching
+ param_name = name + ".weight" if not name.endswith(".weight") else name
+
+ # Find matching spec
+ model_type = self._get_model_type()
+ spec = self.partition_config.find_matching_spec(param_name, model_type)
+
+ if spec is None:
+ # No matching spec found
+ if self.partition_config.strict_mode:
+ raise ValueError(f"No matching spec for {param_name}")
+ # Default: column parallel for Linear layers
+ spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)
+
+ setattr(child, "replaced", True)
+
+ if spec.partition_type == PartitionType.SKIP:
+ return child
+
+ if spec.partition_type == PartitionType.ROW:
+ return self._create_row_parallel_layer(child, spec, name)
+ else:
+ return self._create_column_parallel_layer(child, spec, name)
+
+ def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
+ """Create row-parallel layer (AllReduce after forward)."""
+ # Check for lm_head / embed_out
+ if name == "lm_head" or name == 'embed_out':
+ return LmHeadLinearAllreduce(module, self.mp_group)
+
+ if spec.shape is not None:
+ return SubParamLinearAllreduce(
+ module,
+ self.mp_group,
+ shape=spec.shape,
+ partition_dim=spec.get_partition_dim(),
+ name=name,
+ )
+ return LinearAllreduce(module, self.mp_group, name=name)
+
+ def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
+ """Create column-parallel layer (AllReduce in backward)."""
+ if spec.shape is not None:
+ return SubParamLinearLayer(
+ module,
+ self.mp_group,
+ shape=spec.shape,
+ partition_dim=spec.get_partition_dim(),
+ name=name,
+ )
+ return LinearLayer(module, self.mp_group, name=name)
+
+ def _get_model_type(self) -> Optional[str]:
+ """Extract model type from module config or class name."""
+ config = getattr(self.module, "config", None)
+ if config is not None:
+ model_type = getattr(config, "model_type", None)
+ if model_type:
+ return str(model_type).lower()
+ module_str = str(type(self.module))
+ # Try to extract model type from class name (e.g., "LlamaDecoderLayer" -> "llama")
+ patterns = [
+ r"(\w+)DecoderLayer",
+ r"(\w+)Block",
+ r"(\w+)Layer",
+ ]
+ for pattern in patterns:
+ match = re.search(pattern, module_str)
+ if match:
+ return match.group(1).lower()
+ return None
+
def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
diff --git a/deepspeed/module_inject/autotp_config.py b/deepspeed/module_inject/autotp_config.py
new file mode 100644
index 000000000000..4bafea806829
--- /dev/null
+++ b/deepspeed/module_inject/autotp_config.py
@@ -0,0 +1,569 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Configurable AutoTP API
+
+This module provides a unified specification for tensor parallel layer partitioning.
+The design is inspired by Universal Checkpointing's SubparamShape and provides
+a single, well-defined format that users can easily understand, customize, and extend.
+"""
+
+import re
+from dataclasses import dataclass, field
+from typing import List, Tuple, Union, Optional
+from enum import Enum
+from deepspeed.utils.logging import warning_once
+
+
+class PartitionType(Enum):
+ """How the layer should be partitioned for tensor parallelism."""
+ COLUMN = "column" # Partition output dim, AllReduce in backward
+ ROW = "row" # Partition input dim, AllReduce in forward
+ SKIP = "skip" # Do not partition this layer
+
+
+@dataclass
+class TPLayerSpec:
+ """
+ Unified specification for tensor parallel layer partitioning.
+
+ This is inspired by Universal Checkpointing's SubparamShape but extended
+ for AutoTP's needs (forward/backward communication patterns).
+
+ The `shape` parameter supports at most 1-level nesting at the partition dimension:
+ - (3, -1) -> 3 equal-size sub-params
+ - ((q, k, v), -1) -> 3 unequal-size sub-params (1-level nesting)
+
+ Examples:
+ # Simple row-parallel layer (e.g., o_proj, down_proj)
+ TPLayerSpec(
+ patterns=[".*\\.o_proj$", ".*\\.down_proj$"],
+ partition_type=PartitionType.ROW,
+ )
+
+ # Simple column-parallel layer (e.g., q_proj, k_proj, v_proj)
+ TPLayerSpec(
+ patterns=[".*\\.[qkv]_proj$"],
+ partition_type=PartitionType.COLUMN,
+ )
+
+ # Fused QKV - GLM style [Q, K, V] concatenated on dim 0
+ TPLayerSpec(
+ patterns=[".*\\.query_key_value\\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(3, -1), # 3 equal sub-params, -1 = infer
+ partition_dim=0,
+ )
+
+ # Fused QKV - Bloom style [q1,k1,v1,q2,k2,v2,...]
+ TPLayerSpec(
+ patterns=[".*\\.query_key_value\\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ # No reshape needed, just split along dim 0
+ )
+
+ # GQA with different Q/K/V sizes (1-level nesting)
+ TPLayerSpec(
+ patterns=[".*\\.qkv_proj\\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=((q_size, k_size, v_size), -1), # Unequal sub-params
+ partition_dim=0,
+ )
+
+ # Chunked MLP (gate_up_proj)
+ TPLayerSpec(
+ patterns=[".*\\.gate_up_proj\\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(2, -1), # [gate, up] packed
+ partition_dim=0,
+ )
+
+ # MoE FFN with expert dimension
+ TPLayerSpec(
+ patterns=[".*\\.experts\\..*\\.w1\\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(num_experts, -1, hidden_in), # View as 3D
+ partition_dim=1, # Partition the hidden_out dimension
+ )
+
+ # Skip layer (e.g., MoE gate)
+ TPLayerSpec(
+ patterns=[".*\\.gate$", ".*\\.router$"],
+ partition_type=PartitionType.SKIP,
+ )
+ """
+
+ # Layer identification - regex patterns to match parameter names
+ patterns: List[str]
+
+ # Partition type determines communication pattern
+ partition_type: PartitionType = PartitionType.COLUMN
+
+ # Optional: logical shape for partitioning
+ # - Use -1 for dimensions that should be inferred
+ # - Use tuple of ints at partition_dim for unequal sub-params (1-level nesting only)
+ # Examples:
+ # (3, -1) -> 3 equal sub-params
+ # ((4096, 1024, 1024), -1) -> 3 unequal sub-params (GQA)
+ # (n_experts, -1, hidden) -> MoE reshape
+ shape: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None
+
+ # Which dimension to partition (after optional reshape)
+ # Default: 0 for COLUMN, 1 for ROW (standard 2D weight matrix)
+ partition_dim: Optional[int] = None
+
+ # Optional: model type constraint (only apply for specific models)
+ model_types: Optional[List[str]] = None
+
+ def __post_init__(self):
+ if isinstance(self.partition_type, str):
+ self.partition_type = PartitionType(self.partition_type.lower())
+ if self.shape is not None:
+ self.shape = self._normalize_shape(self.shape)
+ self._validate_shape_format()
+
+ @staticmethod
+ def _normalize_shape(shape):
+ if isinstance(shape, list):
+ return tuple(TPLayerSpec._normalize_shape(item) for item in shape)
+ if isinstance(shape, tuple):
+ return tuple(TPLayerSpec._normalize_shape(item) if isinstance(item, list) else item for item in shape)
+ return shape
+
+ def _validate_shape_format(self):
+ if not isinstance(self.shape, tuple):
+ raise ValueError("AutoTP shape must be a tuple of ints or a tuple at partition_dim.")
+ partition_dim = self.get_partition_dim()
+ if partition_dim < 0 or partition_dim >= len(self.shape):
+ raise ValueError(
+ f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(self.shape)}.")
+ nested_tuple_seen = False
+ for idx, dim in enumerate(self.shape):
+ if isinstance(dim, tuple):
+ if idx != partition_dim:
+ raise ValueError(
+ f"AutoTP shape nested tuple only allowed at partition_dim={partition_dim}, got at {idx}.")
+ if nested_tuple_seen:
+ raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.")
+ nested_tuple_seen = True
+ if len(dim) == 0:
+ raise ValueError("AutoTP shape nested tuple cannot be empty.")
+ for val in dim:
+ if isinstance(val, tuple):
+ raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.")
+ if not isinstance(val, int) or val <= 0:
+ raise ValueError("AutoTP nested sub-parameter sizes must be positive integers.")
+ elif isinstance(dim, int):
+ if dim == 0 or dim < -1:
+ raise ValueError("AutoTP shape dimensions must be positive integers or -1.")
+ else:
+ raise ValueError("AutoTP shape must contain only integers or a tuple at partition_dim.")
+
+ def get_partition_dim(self) -> int:
+ """Get effective partition dimension."""
+ if self.partition_dim is not None:
+ return self.partition_dim
+ # Default based on partition type for 2D weight matrices
+ return 0 if self.partition_type == PartitionType.COLUMN else 1
+
+ def has_unequal_sub_params(self) -> bool:
+ """Check if this spec has unequal sub-parameters (nested tuple at partition_dim)."""
+ if self.shape is None:
+ return False
+ dim = self.get_partition_dim()
+ if dim >= len(self.shape):
+ return False
+ return isinstance(self.shape[dim], tuple)
+
+ def get_sub_param_sizes(self) -> Optional[Tuple[int, ...]]:
+ """Get sub-parameter sizes if using unequal sub-params."""
+ if not self.has_unequal_sub_params():
+ return None
+ return self.shape[self.get_partition_dim()]
+
+ def get_num_sub_params(self) -> Optional[int]:
+ """Get the number of sub-parameters."""
+ if self.shape is None:
+ return None
+ dim = self.get_partition_dim()
+ if dim >= len(self.shape):
+ return None
+ if isinstance(self.shape[dim], tuple):
+ return len(self.shape[dim])
+ elif isinstance(self.shape[dim], int) and self.shape[dim] > 0:
+ return self.shape[dim]
+ return None
+
+ def matches(self, param_name: str, model_type: Optional[str] = None) -> bool:
+ """Check if this spec matches the given parameter."""
+ # Check model type constraint
+ if self.model_types:
+ if model_type is None:
+ return False
+ model_type_norm = str(model_type).lower()
+ model_types_norm = [str(mt).lower() for mt in self.model_types]
+ if model_type_norm not in model_types_norm:
+ return False
+ # Check pattern match
+ return any(re.match(pattern, param_name) for pattern in self.patterns)
+
+
+@dataclass
+class AutoTPConfig:
+ """
+ Configuration for Automatic Tensor Parallelism.
+
+ Example usage:
+ config = AutoTPConfig(
+ tp_size=4,
+ layer_specs=[
+ # Row-parallel layers (AllReduce after forward)
+ TPLayerSpec(
+ patterns=[".*\\.o_proj", ".*\\.down_proj"],
+ partition_type=PartitionType.ROW,
+ ),
+ # Column-parallel layers
+ TPLayerSpec(
+ patterns=[".*\\.[qkv]_proj", ".*\\.up_proj", ".*\\.gate_proj"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ # Skip MoE gates
+ TPLayerSpec(
+ patterns=[".*\\.gate$"],
+ partition_type=PartitionType.SKIP,
+ ),
+ ],
+ )
+ """
+
+ tp_size: int = 1
+
+ # Unified layer specifications
+ layer_specs: List[TPLayerSpec] = field(default_factory=list)
+
+ # Embedding configuration
+ embedding_partition_dim: int = 1 # Usually partition vocab dim
+
+ # LM head configuration
+ lm_head_patterns: List[str] = field(default_factory=lambda: ["lm_head", "embed_out"])
+
+ # Behavior flags
+ use_default_specs: bool = True # Merge with built-in specs
+ strict_mode: bool = False # Fail if unmatched Linear layers found
+
+ def find_matching_spec(self, param_name: str, model_type: Optional[str] = None) -> Optional[TPLayerSpec]:
+ """Find the first matching spec for a parameter."""
+ matches = [spec for spec in self.layer_specs if spec.matches(param_name, model_type)]
+ if not matches:
+ return None
+ if len(matches) > 1:
+ matched_patterns = [spec.patterns for spec in matches]
+ warning_once(f"AutoTPConfig: parameter {param_name} matched multiple layer_specs {matched_patterns}; "
+ "using the first match.")
+ return matches[0]
+
+ @classmethod
+ def from_dict(cls, config_dict: dict) -> "AutoTPConfig":
+ """Create config from dictionary (JSON config)."""
+ layer_specs = []
+ for spec_dict in config_dict.get("layer_specs", []):
+ # Convert partition_type string to enum
+ partition_type_str = spec_dict.get("partition_type", "column")
+ if isinstance(partition_type_str, str):
+ partition_type = PartitionType(partition_type_str.lower())
+ else:
+ partition_type = partition_type_str
+
+ # Convert shape from list to tuple if necessary
+ shape = spec_dict.get("shape")
+ if shape is not None:
+ shape = cls._convert_shape(shape)
+
+ layer_specs.append(
+ TPLayerSpec(
+ patterns=spec_dict.get("patterns", []),
+ partition_type=partition_type,
+ shape=shape,
+ partition_dim=spec_dict.get("partition_dim"),
+ model_types=spec_dict.get("model_types"),
+ ))
+
+ return cls(
+ tp_size=config_dict.get("tp_size", 1),
+ layer_specs=layer_specs,
+ embedding_partition_dim=config_dict.get("embedding_partition_dim", 1),
+ lm_head_patterns=config_dict.get("lm_head_patterns", ["lm_head", "embed_out"]),
+ use_default_specs=config_dict.get("use_default_specs", True),
+ strict_mode=config_dict.get("strict_mode", False),
+ )
+
+ @staticmethod
+ def _convert_shape(shape):
+ """Convert shape from list to tuple, handling nested structures."""
+ if isinstance(shape, list):
+ return tuple(AutoTPConfig._convert_shape(item) if isinstance(item, list) else item for item in shape)
+ return shape
+
+
+class AutoTPPresets:
+ """Built-in presets for common model architectures."""
+
+ @staticmethod
+ def llama() -> AutoTPConfig:
+ """LLaMA-style models (separate Q, K, V projections)."""
+ return AutoTPConfig(layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ ], )
+
+ @staticmethod
+ def llama_gqa(num_heads: int, num_kv_heads: int, head_dim: int) -> AutoTPConfig:
+ """LLaMA with Grouped Query Attention (fused QKV variant)."""
+ q_size = num_heads * head_dim
+ kv_size = num_kv_heads * head_dim
+ return AutoTPConfig(
+ layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ # Fused QKV with unequal sizes (GQA)
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.qkv_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=((q_size, kv_size, kv_size), -1), # 1-level nesting
+ partition_dim=0,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ ], )
+
+ @staticmethod
+ def bloom() -> AutoTPConfig:
+ """BLOOM-style models (fused QKV with interleaved heads)."""
+ return AutoTPConfig(
+ layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attention\.dense\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.self_attention\.query_key_value\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ # Bloom style: [q1,k1,v1,q2,k2,v2,...] - no reshape needed
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ ], )
+
+ @staticmethod
+ def chatglm() -> AutoTPConfig:
+ """ChatGLM-style models (GLM-style fused QKV)."""
+ return AutoTPConfig(
+ layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attention\.dense\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.self_attention\.query_key_value\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(3, -1), # [Q, K, V] concatenated
+ partition_dim=0,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(2, -1), # [gate, up] packed
+ partition_dim=0,
+ ),
+ ], )
+
+ @staticmethod
+ def mixtral() -> AutoTPConfig:
+ """Mixtral MoE model."""
+ return AutoTPConfig(
+ layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ # MoE experts
+ TPLayerSpec(
+ patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w2\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w[13]\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ # Skip MoE gate
+ TPLayerSpec(
+ patterns=[r".*\.block_sparse_moe\.gate\.weight$"],
+ partition_type=PartitionType.SKIP,
+ ),
+ ], )
+
+ @staticmethod
+ def deepseek_v2() -> AutoTPConfig:
+ """DeepSeek-V2 with MLA (Multi-head Latent Attention)."""
+ return AutoTPConfig(
+ layer_specs=[
+ # Standard attention output
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ # MLA uses compressed KV, skip low-rank projections
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.(q_a_proj|kv_a_proj_with_mqa)\.weight$"],
+ partition_type=PartitionType.SKIP,
+ ),
+ # Q/K/V projections from latent
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.(q_b_proj|kv_b_proj)\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ # MoE experts
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.experts\.\d+\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.experts\.\d+\.(up|gate)_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ # Skip MoE gate
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.gate\.weight$"],
+ partition_type=PartitionType.SKIP,
+ ),
+ # Shared expert
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.shared_experts\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.shared_experts\.(up|gate)_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ ], )
+
+ @staticmethod
+ def qwen2() -> AutoTPConfig:
+ """Qwen2 model."""
+ return AutoTPConfig(layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ ),
+ ], )
+
+ @staticmethod
+ def phi3() -> AutoTPConfig:
+ """Phi3 model with fused QKV and chunked MLP."""
+ return AutoTPConfig(
+ layer_specs=[
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.o_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ # Phi3 has fused qkv_proj
+ TPLayerSpec(
+ patterns=[r".*\.self_attn\.qkv_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(3, -1), # [Q, K, V] concatenated
+ partition_dim=0,
+ ),
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.down_proj\.weight$"],
+ partition_type=PartitionType.ROW,
+ ),
+ # Phi3 has gate_up_proj fused
+ TPLayerSpec(
+ patterns=[r".*\.mlp\.gate_up_proj\.weight$"],
+ partition_type=PartitionType.COLUMN,
+ shape=(2, -1), # [gate, up] packed
+ partition_dim=0,
+ ),
+ ], )
+
+ @staticmethod
+ def get_preset(model_type: str) -> Optional[AutoTPConfig]:
+ """Get a preset configuration by model type name."""
+ presets = {
+ "llama": AutoTPPresets.llama,
+ "bloom": AutoTPPresets.bloom,
+ "chatglm": AutoTPPresets.chatglm,
+ "mixtral": AutoTPPresets.mixtral,
+ "deepseek_v2": AutoTPPresets.deepseek_v2,
+ "qwen2": AutoTPPresets.qwen2,
+ "phi3": AutoTPPresets.phi3,
+ }
+ preset_fn = presets.get(model_type.lower())
+ if preset_fn:
+ return preset_fn()
+ return None
+
+
+def merge_autotp_configs(base: AutoTPConfig, override: AutoTPConfig) -> AutoTPConfig:
+ """Merge two AutoTP configs, with override taking precedence."""
+ # Combine layer specs - override specs come first (higher priority)
+ merged_specs = list(override.layer_specs) + list(base.layer_specs)
+
+ return AutoTPConfig(
+ tp_size=override.tp_size if override.tp_size > 1 else base.tp_size,
+ layer_specs=merged_specs,
+ embedding_partition_dim=override.embedding_partition_dim,
+ lm_head_patterns=override.lm_head_patterns or base.lm_head_patterns,
+ use_default_specs=override.use_default_specs,
+ strict_mode=override.strict_mode,
+ )
diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py
index a9faac71361a..5f55d2e78e26 100644
--- a/deepspeed/module_inject/layers.py
+++ b/deepspeed/module_inject/layers.py
@@ -20,7 +20,8 @@
__all__ = [
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
- "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
+ "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer",
+ "SubParamLinearLayer", "SubParamLinearAllreduce"
]
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
@@ -801,6 +802,356 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
return super().forward(positions + self.offset)
+def _shape_prod(values):
+ result = 1
+ for val in values:
+ result *= val
+ return result
+
+
+def _normalize_shape_spec(shape):
+ if isinstance(shape, list):
+ return tuple(_normalize_shape_spec(item) for item in shape)
+ if isinstance(shape, tuple):
+ return tuple(_normalize_shape_spec(item) if isinstance(item, list) else item for item in shape)
+ return shape
+
+
+def _infer_subparam_logical_shapes(weight_shape, shape, partition_dim, name=None):
+ shape = _normalize_shape_spec(shape)
+ if not isinstance(shape, tuple):
+ raise ValueError("AutoTP shape must be a tuple for sub-parameter partitioning.")
+ if partition_dim < 0 or partition_dim >= len(shape):
+ raise ValueError(f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(shape)}.")
+
+ layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer"
+ partition_elem = shape[partition_dim]
+ subparam_sizes = None
+ num_subparams = None
+
+ if isinstance(partition_elem, tuple):
+ if len(partition_elem) == 0:
+ raise ValueError(f"{layer_label} sub-parameter size tuple cannot be empty.")
+ if any(isinstance(val, tuple) for val in partition_elem):
+ raise ValueError(f"{layer_label} supports only 1-level nesting at partition_dim.")
+ if any((not isinstance(val, int)) or val <= 0 for val in partition_elem):
+ raise ValueError(f"{layer_label} sub-parameter sizes must be positive integers.")
+ subparam_sizes = tuple(int(val) for val in partition_elem)
+ partition_dim_size = sum(subparam_sizes)
+ elif isinstance(partition_elem, int):
+ if partition_elem == -1:
+ partition_dim_size = None
+ elif partition_elem > 0:
+ num_subparams = partition_elem
+ partition_dim_size = None
+ else:
+ raise ValueError(f"{layer_label} partition_dim spec must be positive integer or -1.")
+ else:
+ raise ValueError(f"{layer_label} partition_dim spec must be int or tuple.")
+
+ logical_dims = []
+ for idx, dim in enumerate(shape):
+ if idx == partition_dim:
+ logical_dims.append(partition_dim_size)
+ continue
+ if isinstance(dim, tuple):
+ raise ValueError(f"{layer_label} nested tuple only allowed at partition_dim={partition_dim}.")
+ if isinstance(dim, int):
+ if dim == -1:
+ logical_dims.append(None)
+ elif dim > 0:
+ logical_dims.append(dim)
+ else:
+ raise ValueError(f"{layer_label} shape dimensions must be positive integers or -1.")
+ else:
+ raise ValueError(f"{layer_label} shape dimensions must be integers.")
+
+ total_numel = _shape_prod(weight_shape)
+ known_product = _shape_prod([dim for dim in logical_dims if dim is not None])
+ unknown_indices = [idx for idx, dim in enumerate(logical_dims) if dim is None]
+
+ if len(unknown_indices) == 0:
+ if known_product != total_numel:
+ raise ValueError(f"{layer_label} shape product {known_product} != weight numel {total_numel}.")
+ elif len(unknown_indices) == 1:
+ inferred = total_numel // known_product
+ if inferred * known_product != total_numel:
+ raise ValueError(f"{layer_label} cannot infer shape for weight with numel {total_numel}.")
+ logical_dims[unknown_indices[0]] = inferred
+ else:
+ if len(shape) == len(weight_shape):
+ for idx in unknown_indices:
+ logical_dims[idx] = weight_shape[idx]
+ if _shape_prod(logical_dims) != total_numel:
+ raise ValueError(
+ f"{layer_label} shape product {_shape_prod(logical_dims)} != weight numel {total_numel}.")
+ else:
+ raise ValueError(f"{layer_label} shape has multiple inferred dims and is ambiguous for weight.")
+
+ logical_shape = tuple(logical_dims)
+ if logical_shape[-1] != weight_shape[-1]:
+ raise ValueError(
+ f"{layer_label} shape last dim {logical_shape[-1]} must match weight input dim {weight_shape[-1]}.")
+
+ output_shape = logical_shape[:-1]
+ if len(output_shape) == 0:
+ raise ValueError(f"{layer_label} shape must include at least one output dimension.")
+ if _shape_prod(output_shape) != weight_shape[0]:
+ raise ValueError(
+ f"{layer_label} output shape product {_shape_prod(output_shape)} != weight output dim {weight_shape[0]}.")
+
+ partition_dim_size = logical_shape[partition_dim]
+ if partition_dim_size is None or partition_dim_size <= 0:
+ raise ValueError(f"{layer_label} partition_dim size must be a positive integer.")
+
+ if num_subparams is not None:
+ if partition_dim_size % num_subparams != 0:
+ raise ValueError(
+ f"{layer_label} partition_dim size {partition_dim_size} not divisible by sub-param count {num_subparams}."
+ )
+ subparam_sizes = tuple([partition_dim_size // num_subparams] * num_subparams)
+
+ if subparam_sizes is not None and sum(subparam_sizes) != partition_dim_size:
+ raise ValueError(
+ f"{layer_label} sub-parameter sizes sum {sum(subparam_sizes)} != partition_dim size {partition_dim_size}.")
+
+ bias_partition_dim = partition_dim if partition_dim < len(output_shape) else None
+ return logical_shape, output_shape, subparam_sizes, bias_partition_dim
+
+
+def _partition_logical_tensor(tensor, partition_dim, tp_world_size, tp_index, name=None, subparam_sizes=None):
+ if tp_world_size == 1:
+ return tensor
+ layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer"
+ if subparam_sizes:
+ for size in subparam_sizes:
+ if size % tp_world_size != 0:
+ raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.")
+ sub_params = torch.split(tensor, subparam_sizes, dim=partition_dim)
+ partitioned_sub_params = [torch.chunk(sp, tp_world_size, dim=partition_dim)[tp_index] for sp in sub_params]
+ return torch.cat(partitioned_sub_params, dim=partition_dim)
+ if tensor.shape[partition_dim] % tp_world_size != 0:
+ raise ValueError(
+ f"{layer_label} partition_dim size {tensor.shape[partition_dim]} not divisible by tp_size {tp_world_size}."
+ )
+ return torch.chunk(tensor, tp_world_size, dim=partition_dim)[tp_index]
+
+
+def _all_gather_along_dim(tensor, partition_dim, mp_group, tp_world_size):
+ if mp_group is None or tp_world_size == 1:
+ return tensor
+ perm = [partition_dim] + [idx for idx in range(tensor.dim()) if idx != partition_dim]
+ inv_perm = [0] * len(perm)
+ for idx, dim in enumerate(perm):
+ inv_perm[dim] = idx
+ tensor_perm = tensor.permute(perm).contiguous()
+ output = torch.empty((tp_world_size * tensor_perm.shape[0], *tensor_perm.shape[1:]),
+ dtype=tensor.dtype,
+ device=tensor.device)
+ dist.all_gather_into_tensor(output, tensor_perm, group=mp_group)
+ return output.permute(inv_perm).contiguous()
+
+
+def _gather_logical_tensor(tensor,
+ logical_shape,
+ partition_dim,
+ mp_group,
+ tp_world_size,
+ name=None,
+ subparam_sizes=None):
+ if mp_group is None or tp_world_size == 1:
+ return tensor.reshape(logical_shape)
+ layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer"
+ if logical_shape[partition_dim] % tp_world_size != 0:
+ raise ValueError(
+ f"{layer_label} partition_dim size {logical_shape[partition_dim]} not divisible by tp_size {tp_world_size}."
+ )
+ partitioned_shape = list(logical_shape)
+ partitioned_shape[partition_dim] = logical_shape[partition_dim] // tp_world_size
+ tensor_view = tensor.reshape(partitioned_shape)
+
+ if subparam_sizes:
+ for size in subparam_sizes:
+ if size % tp_world_size != 0:
+ raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.")
+ partitioned_sizes = [size // tp_world_size for size in subparam_sizes]
+ sub_params = torch.split(tensor_view, partitioned_sizes, dim=partition_dim)
+ gathered_sub_params = [_all_gather_along_dim(sp, partition_dim, mp_group, tp_world_size) for sp in sub_params]
+ return torch.cat(gathered_sub_params, dim=partition_dim)
+ return _all_gather_along_dim(tensor_view, partition_dim, mp_group, tp_world_size)
+
+
+class SubParamLinearLayer(TensorParallel_Layer):
+ """
+ Column-parallel linear layer with sub-parameter support.
+
+ Handles cases where weights contain multiple logical sub-parameters
+ that need to be partitioned separately (e.g., fused QKV, chunked MLP, GQA).
+
+ The `shape` parameter controls how the weight is viewed and partitioned:
+ - (3, -1) with partition_dim=0: 3 equal sub-params, partition each at dim 0
+ - ((q, k, v), -1) with partition_dim=0: 3 unequal sub-params (1-level nesting)
+ """
+
+ def __init__(self, module, mp_group, shape, partition_dim=0, **kwargs):
+ super(SubParamLinearLayer, self).__init__(mp_group, **kwargs)
+ self.weight = module.weight
+ self.bias = module.bias
+ self.shape = shape
+ self.partition_dim = partition_dim
+
+ self._orig_weight_shape = tuple(module.weight.shape)
+ self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None
+ (self._logical_shape, self._output_shape, self._subparam_sizes,
+ self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape,
+ self.partition_dim, self.name)
+ if self.bias is not None and self.bias.numel() != _shape_prod(self._output_shape):
+ raise ValueError(f"AutoTP layer '{self.name}' bias size {self.bias.numel()} does not match output shape "
+ f"{self._output_shape}.")
+
+ self._tp_partition([self.weight, self.bias])
+ self.support_training = True
+ self.config_tp_params(self.weight)
+ if self.bias is not None:
+ self.config_tp_params(self.bias)
+
+ def forward(self, input):
+ if getattr(self, 'mp_group', None) is not None:
+ input = ColumnParallel.apply(self.mp_group, input)
+ output = torch.matmul(input, self.weight.transpose(-1, -2))
+ if self.bias is not None:
+ output = add_bias(output, self.bias)
+ return output
+
+ @torch.no_grad()
+ def gather_params(self, params_list):
+ """Gather partitioned parameters back to full size."""
+ for idx, param in enumerate(params_list):
+ if param is None:
+ continue
+ params_list[idx].data_partition = param.data
+ if idx == 0:
+ full_view = _gather_logical_tensor(param,
+ self._logical_shape,
+ self.partition_dim,
+ self.mp_group,
+ self.tp_world_size,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[idx].data = full_view.reshape(self._orig_weight_shape)
+ else:
+ if self._bias_partition_dim is None:
+ params_list[idx].data = param.data
+ else:
+ full_bias_view = _gather_logical_tensor(param,
+ self._output_shape,
+ self._bias_partition_dim,
+ self.mp_group,
+ self.tp_world_size,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[idx].data = full_bias_view.reshape(self._orig_bias_shape)
+
+ @torch.no_grad()
+ def _tp_partition(self, params_list):
+ weight = params_list[0]
+ if weight is None:
+ return
+
+ weight_view = weight.reshape(self._logical_shape)
+ partitioned_view = _partition_logical_tensor(weight_view,
+ self.partition_dim,
+ self.tp_world_size,
+ self.tp_index,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach()
+
+ if params_list[1] is not None:
+ if self._bias_partition_dim is None:
+ params_list[1].data = self.move(params_list[1]).detach()
+ else:
+ bias_view = params_list[1].reshape(self._output_shape)
+ bias_partitioned = _partition_logical_tensor(bias_view,
+ self._bias_partition_dim,
+ self.tp_world_size,
+ self.tp_index,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[1].data = self.move(bias_partitioned.reshape(-1)).detach()
+
+
+class SubParamLinearAllreduce(TensorParallel_Layer):
+ """
+ Row-parallel linear layer with sub-parameter support (AllReduce after forward).
+
+ Handles cases where weights contain multiple logical sub-parameters
+ that need to be partitioned separately.
+ """
+
+ def __init__(self, module, mp_group, shape, partition_dim=1, **kwargs):
+ super(SubParamLinearAllreduce, self).__init__(mp_group, **kwargs)
+ self.weight = module.weight
+ self.bias = module.bias
+ self.shape = shape
+ self.partition_dim = partition_dim
+
+ self._orig_weight_shape = tuple(module.weight.shape)
+ self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None
+ (self._logical_shape, self._output_shape, self._subparam_sizes,
+ self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape,
+ self.partition_dim, self.name)
+
+ self._tp_partition([self.weight, self.bias])
+ self.support_training = True
+ self.config_tp_params(self.weight)
+ if self.bias is not None:
+ self.config_requires_grad(self.bias)
+
+ def forward(self, input):
+ output = torch.matmul(input, self.weight.transpose(-1, -2))
+ output = RowParallel.apply(self.mp_group, output, not self.is_training_mode())
+ if self.bias is not None:
+ output = add_bias(output, self.bias)
+ return output
+
+ @torch.no_grad()
+ def gather_params(self, params_list):
+ """Gather partitioned parameters back to full size."""
+ for idx, param in enumerate(params_list):
+ if param is None or idx > 0:
+ # don't gather bias for row parallel
+ return
+ params_list[idx].data_partition = param.data
+ full_view = _gather_logical_tensor(param,
+ self._logical_shape,
+ self.partition_dim,
+ self.mp_group,
+ self.tp_world_size,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[idx].data = full_view.reshape(self._orig_weight_shape)
+
+ @torch.no_grad()
+ def _tp_partition(self, params_list):
+ weight = params_list[0]
+ if weight is None:
+ return
+
+ weight_view = weight.reshape(self._logical_shape)
+ partitioned_view = _partition_logical_tensor(weight_view,
+ self.partition_dim,
+ self.tp_world_size,
+ self.tp_index,
+ name=self.name,
+ subparam_sizes=self._subparam_sizes)
+ params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach()
+
+ # Bias is not partitioned for row parallel (it's applied after all-reduce)
+ if params_list[1] is not None:
+ params_list[1].data = self.move(params_list[1]).detach()
+
+
class RMSNormalize(nn.Module):
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index 26752cfa4fec..263369fc0484 100644
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -273,9 +273,20 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False,
def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
+ # Get the configurable partition_config if available
+ partition_config = None
+ if hasattr(config, 'get_partition_config_object'):
+ partition_config = config.get_partition_config_object()
+
# 1. Create AutoTP object
- _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
- config.keep_module_on_host)
+ _autotp = AutoTP(module,
+ all_reduce_linears,
+ prefix,
+ state_dict,
+ linear_layer_setting,
+ orig_layer_impl,
+ config.keep_module_on_host,
+ partition_config=partition_config)
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py
index ded262edcf61..f1dbaae43ec9 100644
--- a/deepspeed/module_inject/tp_shard.py
+++ b/deepspeed/module_inject/tp_shard.py
@@ -4,7 +4,12 @@
# DeepSpeed Team
from deepspeed import comm as dist
-global num_kv_heads
+
+# Defaults for optional TP globals. These can be overridden by setters.
+num_kv_heads = None
+num_attention_heads = None
+n_embd = None
+tp_grain_size = 1
def set_num_kv_heads(num):
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 5f8585d65d76..410d1ad6c46c 100755
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -499,6 +499,7 @@ def _optimized_linear_offload_setup(self):
def _configure_tensor_parallel(self, model, tp_config):
self._configure_tensor_parallel_states(model)
configure_tensor_parallel_runtime(tp_config)
+ self._apply_autotp_partitioning(model, tp_config)
def _configure_tensor_parallel_states(self, model):
"""
@@ -564,6 +565,55 @@ def broadcast_and_check(args, bcast_rank, bcast_group):
prepend=True,
with_kwargs=True)
+ def _apply_autotp_partitioning(self, model, tp_config):
+ if getattr(model, "ds_autotp_parsed", False):
+ return
+ if get_accelerator().is_available() and self.local_rank >= 0:
+ get_accelerator().set_device(self.local_rank)
+
+ tp_size = self.autotp_size()
+ if tp_config.tensor_parallel.tp_size not in (1, tp_size):
+ raise ValueError(f"tensor_parallel.tp.tp_size ({tp_config.tensor_parallel.tp_size}) "
+ f"does not match tensor_parallel.autotp_size ({tp_size}).")
+ tp_config.tensor_parallel.tp_size = tp_size
+ if tp_config.tensor_parallel.tp_group is None:
+ tp_config.tensor_parallel.tp_group = groups.get_tensor_model_parallel_group()
+
+ from deepspeed.module_inject.auto_tp import AutoTP
+
+ partition_config = None
+ if hasattr(tp_config, "get_partition_config_object"):
+ partition_config = tp_config.get_partition_config_object()
+
+ if partition_config is not None:
+ autotp = AutoTP(module=model,
+ all_reduce_linears=(),
+ prefix="",
+ state_dict=None,
+ linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding),
+ orig_layer_impl=None,
+ keep_module_on_host=tp_config.keep_module_on_host,
+ partition_config=partition_config)
+ autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group)
+ autotp.update_linear_policies()
+ autotp._replace_module(model)
+ setattr(model, "ds_autotp_parsed", True)
+ return
+
+ if tp_size <= 1:
+ setattr(model, "ds_autotp_parsed", True)
+ return
+
+ model_config = getattr(model, "config", None)
+ from deepspeed.module_inject import replace_transformer_layer
+
+ parser_dict = AutoTP.tp_parser(model)
+ for client_module, injection_policy in parser_dict:
+ tp_config.injection_policy_tuple = injection_policy
+ replace_transformer_layer(client_module, model, None, tp_config, model_config)
+
+ setattr(model, "ds_autotp_parsed", True)
+
def __del__(self):
try:
self.destroy()
diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py
index 957984e9f8b3..326277f6eb5c 100644
--- a/deepspeed/runtime/tensor_parallel/config.py
+++ b/deepspeed/runtime/tensor_parallel/config.py
@@ -7,7 +7,7 @@
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
import torch
from pydantic import Field
-from typing import Optional
+from typing import Optional, Dict, Any
class AUTOTP_MODE(Enum):
@@ -57,6 +57,41 @@ class TPTrainingConfig(DeepSpeedConfigModel):
"""
injection_policy_tuple: Optional[tuple] = None
+
+ # New configurable AutoTP settings
+ partition_config: Optional[Dict[str, Any]] = None
+ """
+ Configuration for the new configurable AutoTP API.
+ Allows users to specify custom layer partitioning rules via TPLayerSpec.
+
+ Example:
+ "partition_config": {
+ "use_default_specs": false,
+ "layer_specs": [
+ {
+ "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
+ "partition_type": "row"
+ },
+ {
+ "patterns": [".*\\.[qkv]_proj\\.weight$"],
+ "partition_type": "column"
+ },
+ {
+ "patterns": [".*\\.gate_up_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [2, -1],
+ "partition_dim": 0
+ }
+ ]
+ }
+ """
+
+ preset_model: Optional[str] = None
+ """
+ Use a built-in preset for common model architectures.
+ Available presets: "llama", "bloom", "chatglm", "mixtral", "deepseek_v2", "qwen2", "phi3"
+ """
+
#The following parameters are required by autoTP parser.
########################################
keep_module_on_host: bool = False
@@ -74,8 +109,35 @@ class TPTrainingConfig(DeepSpeedConfigModel):
linear layers as a tuple:
`(attention_output projection, transformer output projection)`
"""
+
########################################
+ def get_partition_config_object(self):
+ """
+ Get the AutoTPConfig object from the configuration.
+ Returns None if no custom config is specified.
+ """
+ from deepspeed.module_inject.autotp_config import AutoTPConfig, AutoTPPresets, merge_autotp_configs
+
+ config = None
+
+ # First check for preset
+ if self.preset_model:
+ config = AutoTPPresets.get_preset(self.preset_model)
+
+ # Then check for custom config
+ if self.partition_config:
+ custom_config = AutoTPConfig.from_dict(self.partition_config)
+ if config and custom_config.use_default_specs:
+ config = merge_autotp_configs(config, custom_config)
+ else:
+ config = custom_config
+
+ if config:
+ config.tp_size = self.autotp_size
+
+ return config
+
def get_tensor_parallel_config(ds_config):
diff --git a/deepspeed/runtime/tensor_parallel/init_utils.py b/deepspeed/runtime/tensor_parallel/init_utils.py
new file mode 100644
index 000000000000..95dacb984cbe
--- /dev/null
+++ b/deepspeed/runtime/tensor_parallel/init_utils.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import base64
+import os
+from typing import Optional, Union
+
+import hjson
+import torch
+
+from deepspeed.runtime.config_utils import dict_raise_error_on_duplicate_keys
+
+_TP_MODEL_INIT_ARGS = None
+
+
+def load_ds_config(config: Union[str, dict]) -> dict:
+ if isinstance(config, dict):
+ return config
+ if isinstance(config, str):
+ if os.path.exists(config):
+ return hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
+ try:
+ config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
+ return hjson.loads(config_decoded)
+ except (UnicodeDecodeError, AttributeError, ValueError) as exc:
+ raise ValueError(
+ f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. "
+ f"Received: {config}") from exc
+ raise ValueError(f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. "
+ f"Received: {config}")
+
+
+def record_tp_model_init_args(tp_size, dtype, tp_group, dist_module):
+ global _TP_MODEL_INIT_ARGS
+ new_args = {
+ "tp_size": tp_size,
+ "dtype": dtype,
+ "tp_group": tp_group,
+ }
+
+ if _TP_MODEL_INIT_ARGS is None:
+ _TP_MODEL_INIT_ARGS = new_args
+ return
+
+ if _TP_MODEL_INIT_ARGS["tp_size"] != tp_size or _TP_MODEL_INIT_ARGS["dtype"] != dtype:
+ raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.")
+
+ existing_group = _TP_MODEL_INIT_ARGS.get("tp_group")
+ if existing_group is None and tp_group is None:
+ return
+ if (existing_group is None) != (tp_group is None):
+ raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.")
+
+ existing_group_size = tp_group_world_size(existing_group, dist_module)
+ new_group_size = tp_group_world_size(tp_group, dist_module)
+ if existing_group_size != new_group_size:
+ raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.")
+
+
+def tp_group_world_size(tp_group, dist_module):
+ if tp_group is None or dist_module is None:
+ return None
+ return dist_module.get_world_size(group=tp_group)
+
+
+def infer_config_dtype(config_dict: dict) -> Optional[torch.dtype]:
+ bf16_config = config_dict.get("bf16", {})
+ if isinstance(bf16_config, dict) and bf16_config.get("enabled", False):
+ return torch.bfloat16
+ fp16_config = config_dict.get("fp16", {})
+ if isinstance(fp16_config, dict) and fp16_config.get("enabled", False):
+ return torch.float16
+ return None
+
+
+def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_module):
+ if _TP_MODEL_INIT_ARGS is None:
+ return
+
+ tp_size = _TP_MODEL_INIT_ARGS["tp_size"]
+ dtype = _TP_MODEL_INIT_ARGS["dtype"]
+ tp_group = _TP_MODEL_INIT_ARGS["tp_group"]
+
+ if tp_group is not None and mpu is not None:
+ raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.")
+ if tp_group is None and mpu is None and mesh_param is None:
+ raise ValueError("tp_model_init did not provide tp_group; deepspeed.initialize requires mpu or mesh_param.")
+
+ tp_section = config_dict.get("tensor_parallel")
+ if tp_section is None:
+ tp_section = {}
+ config_dict["tensor_parallel"] = tp_section
+
+ config_autotp_size = tp_section.get("autotp_size")
+ if config_autotp_size is not None and config_autotp_size != tp_size:
+ raise ValueError(
+ f"Conflicting tensor_parallel.autotp_size in config ({config_autotp_size}) and tp_model_init ({tp_size}).")
+
+ if config_autotp_size is None:
+ tp_section["autotp_size"] = tp_size
+
+ tp_config = tp_section.get("tp") or {}
+ if not isinstance(tp_config, dict):
+ raise ValueError("tensor_parallel.tp must be a dict when provided.")
+
+ config_tp_size = tp_config.get("tp_size")
+ if config_tp_size is not None and config_tp_size != tp_size:
+ raise ValueError(
+ f"Conflicting tensor_parallel.tp.tp_size in config ({config_tp_size}) and tp_model_init ({tp_size}).")
+ if config_tp_size is None:
+ tp_config["tp_size"] = tp_size
+
+ if tp_group is not None:
+ config_tp_group = tp_config.get("tp_group")
+ if config_tp_group is not None and config_tp_group is not tp_group:
+ raise ValueError("Conflicting tensor_parallel.tp.tp_group in config and tp_model_init.")
+ tp_config["tp_group"] = tp_group
+
+ tp_group_size = tp_group_world_size(tp_group, dist_module)
+ if tp_group_size is not None and tp_group_size != tp_size:
+ raise ValueError(f"tp_model_init tp_size ({tp_size}) does not match tp_group size ({tp_group_size}).")
+
+ tp_section["tp"] = tp_config
+
+ config_dtype = infer_config_dtype(config_dict)
+ if config_dtype is not None and config_dtype != dtype:
+ raise ValueError(f"Conflicting dtype: config uses {config_dtype} but tp_model_init requested {dtype}.")
+
+ tp_dtype = tp_section.get("dtype")
+ if tp_dtype is not None:
+ if isinstance(tp_dtype, str):
+ tp_dtype_map = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32,
+ }
+ tp_dtype_value = tp_dtype_map.get(tp_dtype.lower())
+ else:
+ tp_dtype_value = tp_dtype
+ if tp_dtype_value is not None and tp_dtype_value != dtype:
+ raise ValueError(f"Conflicting tensor_parallel.dtype in config ({tp_dtype}) and tp_model_init ({dtype}).")
diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md
index 5ac33f3e4447..d5344d3b2320 100755
--- a/docs/_pages/config-json.md
+++ b/docs/_pages/config-json.md
@@ -730,6 +730,96 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
| -------------------------------------------------------------------------------------------------------------- | ------- |
| Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. | `true` |
+### Tensor Parallel (AutoTP)
+Configure AutoTP tensor parallelism for training via the DeepSpeed config and hybrid TP + ZeRO. AutoTP supports ZeRO stages 0, 1, and 2 (stage 3 is not supported). `deepspeed.tp_model_init()` remains supported for backward compatibility but is not required when `tensor_parallel` is set in the config.
+```json
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "preset_model": "llama",
+ "tp_overlap_comm": false,
+ "partition_config": {
+ "use_default_specs": false,
+ "layer_specs": [
+ {
+ "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
+ "partition_type": "row"
+ }
+ ]
+ }
+ }
+```
+**tensor_parallel**: [dictionary]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------ | ------- |
+| Enable AutoTP tensor parallelism and configure preset or custom partitioning rules. | `{}` |
+
+***autotp_size***: [integer]
+
+| Description | Default |
+| --------------------------------------------------------------------------- | ------- |
+| Tensor-parallel degree. Set to `0` to disable AutoTP. | `0` |
+
+***preset_model***: [string]
+
+| Description | Default |
+| ----------------------------------------------------------------------------------------------------- | ------- |
+| Built-in model presets: `llama`, `bloom`, `chatglm`, `mixtral`, `deepseek_v2`, `qwen2`, `phi3`. | `null` |
+
+***tp_overlap_comm***: [boolean]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------- | ------- |
+| Overlap tensor-parallel allreduce communication with computation (training only). | `false` |
+
+***partition_config***: [dictionary]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------------------------------------------- | ------- |
+| Custom AutoTP layer partitioning rules. Use with or without `preset_model` to customize sharding patterns. | `null` |
+
+***use_default_specs***: [boolean]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------------------- | ------- |
+| Merge custom `layer_specs` with preset defaults when `preset_model` is set; otherwise use only custom specs. | `true` |
+
+***layer_specs***: [list]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------------- | ------- |
+| Ordered list of pattern rules that define how to partition matching parameters. | `[]` |
+
+***patterns***: [list of strings]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------------- | ------- |
+| Regex patterns to match parameter names for this partition rule. | `[]` |
+
+***partition_type***: [string]
+
+| Description | Default |
+| ---------------------------------------------------------------------------- | ------- |
+| Partition type for matching parameters: `row`, `column`, or `skip`. | `column` |
+
+***shape***: [list]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------------- | ------- |
+| Optional sub-parameter shape for fused weights before TP partitioning (e.g., `[2, -1]`). | `null` |
+
+***partition_dim***: [integer]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------------- | ------- |
+| Dimension to split when `shape` is provided (e.g., `0` for fused QKV or gate/up). | `null` |
+
+***model_types***: [list of strings]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------------- | ------- |
+| Optional model type filters (from `model.config.model_type`) for shared configs. | `null` |
+
***ignore_unused_parameters***: [boolean]
| Description | Default |
diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md
index f18b823e2490..1433df080ee3 100755
--- a/docs/_tutorials/automatic-tensor-parallelism.md
+++ b/docs/_tutorials/automatic-tensor-parallelism.md
@@ -3,6 +3,8 @@ title: "Automatic Tensor Parallelism for HuggingFace Models"
tags: inference
---
+> **Note:** This tutorial covers AutoTP for **inference**. For **training** with tensor parallelism and ZeRO optimization, see [AutoTP Training API](autotp-training).
+
# Contents
* [Introduction](#introduction)
* [Example Script](#example-script)
diff --git a/docs/_tutorials/autotp-training.md b/docs/_tutorials/autotp-training.md
new file mode 100644
index 000000000000..9eb51d3e0900
--- /dev/null
+++ b/docs/_tutorials/autotp-training.md
@@ -0,0 +1,188 @@
+---
+title: "AutoTP Training API"
+tags: training tensor-parallelism
+---
+
+# AutoTP Training API
+
+This tutorial covers the **AutoTP Training API** for combining tensor parallelism with ZeRO optimization during training. For inference-only tensor parallelism, see [Automatic Tensor Parallelism for HuggingFace Models](automatic-tensor-parallelism).
+
+## Contents
+- [Introduction](#introduction)
+- [Quick Start](#quick-start)
+- [Custom Layer Specifications](#custom-layer-specifications)
+- [Limitations](#limitations)
+
+## Introduction
+
+The AutoTP Training API enables hybrid parallelism by combining:
+- **Tensor Parallelism (TP)**: Split model weights across GPUs within a node
+- **Data Parallelism (DP)**: Replicate model across GPU groups
+- **ZeRO Optimization**: Memory-efficient optimizer states (Stage 0, 1, or 2)
+
+Tensor parallelism (TP) splits the computations and parameters of large layers
+across multiple GPUs so each rank holds only a shard of the weight matrix. This
+is an efficient way to train large-scale transformer models by reducing per-GPU
+memory pressure while keeping the layer math distributed across the TP group.
+
+
+## Quick Start
+
+### Basic Usage
+
+AutoTP training can be enabled entirely through the DeepSpeed config. When
+`tensor_parallel` is set in the config, `deepspeed.initialize(...)` applies
+AutoTP sharding during engine initialization, so the training loop itself does
+not change.
+
+```python
+import torch
+import deepspeed
+
+# 1. Create your model
+model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
+
+# 2. Define the DeepSpeed config with tensor_parallel settings
+ds_config = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {"stage": 2},
+ "bf16": {"enabled": True},
+ "tensor_parallel": {"autotp_size": 4},
+}
+
+# 3. Initialize DeepSpeed with AutoTP + ZeRO
+engine, optimizer, _, _ = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ config=ds_config,
+ mpu=mpu # Model parallel unit (optional if you provide tp_group elsewhere)
+)
+
+# 4. Train as usual
+for batch in dataloader:
+ outputs = engine(input_ids=batch["input_ids"], labels=batch["labels"])
+ engine.backward(outputs.loss)
+ engine.step()
+```
+
+Compatibility note: For backward compatibility, you can still call
+`set_autotp_mode(training=True)` and `deepspeed.tp_model_init(...)`, but they
+are not required when the DeepSpeed config provides the necessary
+`tensor_parallel` settings.
+
+### Preset-based Sharding
+
+If your model matches a built-in preset, set `tensor_parallel.preset_model` in the DeepSpeed config:
+
+```json
+{
+ "train_batch_size": 8,
+ "train_micro_batch_size_per_gpu": 1,
+ "bf16": { "enabled": true },
+ "zero_optimization": { "stage": 2 },
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "preset_model": "llama"
+ }
+}
+```
+
+For the list of available presets, see [supported models](../code-docs/training#autotp-supported-models).
+
+
+
+## Custom Patterns
+
+If you are training a custom model, define regex-based patterns and partition rules in `tensor_parallel.partition_config`:
+
+```json
+{
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "partition_config": {
+ "use_default_specs": false,
+ "layer_specs": [
+ {
+ "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
+ "partition_type": "row"
+ },
+ {
+ "patterns": [".*\\.[qkv]_proj\\.weight$"],
+ "partition_type": "column"
+ },
+ {
+ "patterns": [".*\\.gate_up_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [2, -1],
+ "partition_dim": 0
+ }
+ ]
+ }
+ }
+}
+```
+
+## Custom Layer Specifications
+
+For models not covered by presets, define custom layer specs:
+
+```json
+{
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "partition_config": {
+ "use_default_specs": false,
+ "layer_specs": [
+ {
+ "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
+ "partition_type": "row"
+ },
+ {
+ "patterns": [".*\\.[qkv]_proj\\.weight$"],
+ "partition_type": "column"
+ },
+ {
+ "patterns": [".*\\.gate_up_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [2, -1],
+ "partition_dim": 0
+ }
+ ]
+ }
+ }
+}
+```
+
+### Fused Layers with Unequal Sub-parameters (GQA)
+
+For Grouped Query Attention with different Q/K/V sizes:
+
+```json
+{
+ "tensor_parallel": {
+ "partition_config": {
+ "layer_specs": [
+ {
+ "patterns": [".*\\.qkv_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [[q_size, kv_size, kv_size], -1],
+ "partition_dim": 0
+ }
+ ]
+ }
+ }
+}
+```
+
+## Limitations
+
+1. **ZeRO Stage 3 not supported**: AutoTP currently only works with ZeRO stages 0, 1, and 2.
+
+2. **TP size must divide model dimensions**: The tensor parallel size must evenly divide the attention head count and hidden dimensions.
+
+
+## See Also
+
+- [Automatic Tensor Parallelism for Inference](automatic-tensor-parallelism)
+- [ZeRO Optimization](zero)
+- [DeepSpeed Configuration](config-json)
diff --git a/docs/code-docs/source/_static/autotp-subparams-gate-up.png b/docs/code-docs/source/_static/autotp-subparams-gate-up.png
new file mode 100644
index 000000000000..116cedf92667
Binary files /dev/null and b/docs/code-docs/source/_static/autotp-subparams-gate-up.png differ
diff --git a/docs/code-docs/source/_static/autotp-subparams-gqa.png b/docs/code-docs/source/_static/autotp-subparams-gqa.png
new file mode 100644
index 000000000000..15a1045a1dbf
Binary files /dev/null and b/docs/code-docs/source/_static/autotp-subparams-gqa.png differ
diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst
index 1974db70c226..9b01b2f85315 100644
--- a/docs/code-docs/source/training.rst
+++ b/docs/code-docs/source/training.rst
@@ -257,3 +257,222 @@ Besides the use of multiple DeepSpeedEngines, the above differs from typical usa
You can call ``loss.backward()`` once for the shared loss.
**Note:** Previously, you had to call ``_backward_epilogue`` on each model engine after ``loss.backward()``. However, starting from v0.18.3, DeepSpeed automatically handles this internally, so you no longer need to call ``_backward_epilogue`` manually.
+
+
+Automatic Tensor Parallel Training
+----------------------------------
+DeepSpeed supports **Automatic Tensor Parallel (AutoTP) training** for sharding
+model weights across GPUs while remaining compatible with ZeRO and standard
+training workflows. This training API is different from the inference-only
+tensor parallel API exposed by ``deepspeed.init_inference``.
+
+Tensor parallelism (TP) splits the computations and parameters of large layers
+across multiple GPUs so each rank holds only a shard of the weight matrix. This
+is an efficient way to train large-scale transformer models by reducing per-GPU
+memory pressure while keeping the layer math distributed across the TP group.
+
+AutoTP training is enabled by setting ``tensor_parallel`` in the DeepSpeed
+config and passing it to ``deepspeed.initialize``. DeepSpeed applies AutoTP
+sharding during engine initialization; calling ``deepspeed.tp_model_init``, which we previously used to initialize AutoTP, is now optional.
+See :ref:`autotp-training-init-details` for more details.
+
+.. code-block:: python
+
+ import deepspeed
+
+ ds_config = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {"stage": 2},
+ "tensor_parallel": {"autotp_size": 4},
+ }
+
+ engine, optimizer, _, _ = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ config=ds_config,
+ mpu=mpu, # optional: TP/DP process groups
+ )
+
+.. note::
+ AutoTP training supports ZeRO stages 0, 1, and 2. ZeRO Stage 3 is not supported.
+
+.. _autotp-training-init-details:
+
+Initialization behavior
+~~~~~~~~~~~~~~~~~~~~~~~
+
+AutoTP previously required calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` before ``deepspeed.initialize``. Now we can include all the necessary configurations in the DeepSpeed config.
+We still support the traditional initialization path for backward compatibility.
+When you use both (i.e. calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` and passing the config to ``deepspeed.initialize``), we will merge the settings at initialization. When we have conflicting settings, we will error out.
+
+Parameter partitioning
+~~~~~~~~~~~~~~~~~~~~~~
+TP sharding needs to know which parameter tensors should be partitioned and
+along which dimensions. AutoTP provides three ways to balance ready-to-use
+defaults with customizability:
+
+* **Heuristics**: automatic sharding based on parameter names and model rules.
+* **Preset**: choose a built-in model family via ``preset_model``.
+* **Custom specs**: define regex patterns and partition rules via ``partition_config``.
+
+Heuristic rules
+^^^^^^^^^^^^^^^
+Heuristics use parameter names and model-specific rules to decide how to shard
+layers. If you are training a supported model (see
+:ref:`autotp-supported-models`), the heuristic rules automatically shard the
+model, so you only need to add ``autotp_size``.
+
+.. code-block:: json
+
+ {
+ ...
+ "tensor_parallel": {
+ "autotp_size": 4
+ },
+ "zero_optimization": {
+ ...
+ },
+ ...
+ }
+
+Preset-based partitioning
+^^^^^^^^^^^^^^^^^^^^^^^^^
+You can explicitly specify the model family with ``preset_model``:
+
+.. code-block:: json
+
+ {
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "preset_model": "llama"
+ }
+ }
+
+See :ref:`autotp-supported-models` for the supported preset names and the
+implementation in `AutoTPPresets `_.
+If you add a new model family, you can easily add a new preset by defining
+patterns like the existing presets, and we welcome PRs for those additions.
+
+Custom layer specs
+^^^^^^^^^^^^^^^^^^
+If you are training a custom model, you can use ``partition_config`` to specify
+custom regex-based patterns and partition settings.
+
+.. code-block:: json
+
+ {
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "partition_config": {
+ "use_default_specs": false,
+ "layer_specs": [
+ {
+ "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
+ "partition_type": "row"
+ },
+ {
+ "patterns": [".*\\.[qkv]_proj\\.weight$"],
+ "partition_type": "column"
+ },
+ {
+ "patterns": [".*\\.gate_up_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [2, -1],
+ "partition_dim": 0
+ }
+ ]
+ }
+ }
+ }
+
+You can also set ``use_default_specs`` to ``true`` to merge your custom
+patterns on top of the preset (when ``preset_model`` is provided).
+
+For fused or packed weights (for example QKV or gate/up projections), the
+``shape`` and ``partition_dim`` options control sub-parameter partitioning.
+Sub-parameter partitioning lets AutoTP split a single weight tensor into
+logical chunks before applying tensor-parallel sharding. For example, the
+``gate_up_proj`` weight can be viewed as two packed matrices (gate and up) by
+setting ``shape`` to ``[2, -1]`` and ``partition_dim`` to ``0``; AutoTP then
+partitions each chunk consistently across tensor-parallel ranks.
+
+.. image:: /_static/autotp-subparams-gate-up.png
+ :alt: AutoTP sub-parameter partitioning
+
+Another example is GQA-style fused QKV weights. The tensor can contain unequal
+Q/K/V segments stacked along the output dimension. For example, set ``shape``
+to the explicit sizes (for example ``[(q_size, kv_size, kv_size), -1]``) and
+``partition_dim`` to ``0`` so AutoTP splits the Q, K, and V regions first, then
+shards each region across tensor-parallel ranks.
+
+.. code-block:: json
+
+ {
+ "patterns": [".*\\.qkv_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [[q_size, kv_size, kv_size], -1],
+ "partition_dim": 0
+ }
+
+.. image:: /_static/autotp-subparams-gqa.png
+ :alt: AutoTP sub-parameter partitioning
+
+
+Model-type filtering for shared configs
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Use ``model_types`` when you want a single config to work across multiple model
+families but apply different specs. This is useful in shared training scripts
+or when patterns overlap across architectures.
+
+.. code-block:: json
+
+ {
+ "tensor_parallel": {
+ "autotp_size": 4,
+ "partition_config": {
+ "layer_specs": [
+ {
+ "patterns": [".*\\.qkv_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [[q_size, kv_size, kv_size], -1],
+ "partition_dim": 0,
+ "model_types": ["llama"]
+ },
+ {
+ "patterns": [".*\\.qkv_proj\\.weight$"],
+ "partition_type": "column",
+ "shape": [3, -1],
+ "partition_dim": 0,
+ "model_types": ["qwen2"]
+ }
+ ]
+ }
+ }
+ }
+
+
+.. _autotp-supported-models:
+
+Supported models
+~~~~~~~~~~~~~~~~
+The following model families are supported by built-in AutoTP presets:
+
+- ``llama``
+- ``bloom``
+- ``chatglm``
+- ``mixtral``
+- ``deepseek_v2``
+- ``qwen2``
+- ``phi3``
+
+Preset definitions live in `AutoTPPresets `_.
+If you add a new model family, you can easily add a new preset by defining
+patterns like the existing presets, and we welcome PRs for those additions.
+
+These strings are the values accepted by ``preset_model`` and are matched
+against the model type in ``model.config.model_type`` (case-insensitive). When
+``preset_model`` is not set, AutoTP uses the legacy automatic sharding rules
+unless you provide a custom ``partition_config``.
+These presets are also useful when you want to extend the default patterns:
+set ``use_default_specs`` to ``true`` in ``partition_config`` to merge your custom
+specs on top of the selected preset.
diff --git a/tests/unit/model_parallelism/test_autotp_custom_patterns.py b/tests/unit/model_parallelism/test_autotp_custom_patterns.py
new file mode 100644
index 000000000000..d19d8b595eeb
--- /dev/null
+++ b/tests/unit/model_parallelism/test_autotp_custom_patterns.py
@@ -0,0 +1,302 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed.comm as dist
+import deepspeed
+from copy import deepcopy
+from torch import nn
+
+from unit.common import DistributedTest, preferred_dtype
+from deepspeed.accelerator import get_accelerator
+from deepspeed.utils import groups
+from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer)
+from deepspeed.module_inject.autotp_config import AutoTPConfig
+from deepspeed.module_inject.auto_tp import AutoTP
+
+
+def skip_on_device():
+ if get_accelerator().device_name() == 'xpu':
+ pytest.skip("XPU requires a higher version for test")
+
+
+class SequentialLinearModel(torch.nn.Module):
+
+ def __init__(self, hidden_dim, nlayers=1):
+ super(SequentialLinearModel, self).__init__()
+ self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)])
+
+ def forward(self, x):
+ for layer in self.linears:
+ x = layer(x)
+ return x
+
+
+def init_tp_engine(tp_size, partition_config=None):
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-6
+ }
+ },
+ "tensor_parallel": {
+ "autotp_size": tp_size,
+ },
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ if partition_config is not None:
+ config_dict["tensor_parallel"]["partition_config"] = partition_config
+ else:
+ config_dict["tensor_parallel"]["partition_config"] = {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": "skip",
+ }],
+ }
+ if preferred_dtype() is torch.float16:
+ config_dict["fp16"] = {"enabled": True}
+ elif preferred_dtype() is torch.bfloat16:
+ config_dict["bf16"] = {"enabled": True}
+
+ model = SequentialLinearModel(hidden_dim=8)
+ deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+
+
+def apply_autotp_with_partition_config(model, tp_size, partition_config):
+ groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size)
+ autotp_config = AutoTPConfig.from_dict(partition_config)
+ autotp = AutoTP(module=model,
+ all_reduce_linears=[],
+ prefix="",
+ state_dict=None,
+ linear_layer_setting=None,
+ orig_layer_impl=None,
+ keep_module_on_host=False,
+ partition_config=autotp_config)
+ autotp.set_tensor_parallel_config(tp_size, groups.get_tensor_model_parallel_group())
+ autotp.update_linear_policies()
+ autotp._replace_module(model)
+ return model
+
+
+def gather_subparam_output(output, subparam_sizes, mp_group):
+ tp_world_size = dist.get_world_size(group=mp_group)
+ local_sizes = [size // tp_world_size for size in subparam_sizes]
+ output_chunks = torch.split(output, local_sizes, dim=-1)
+ gathered_chunks = []
+ for chunk in output_chunks:
+ chunk = chunk.contiguous()
+ gathered = [torch.empty_like(chunk) for _ in range(tp_world_size)]
+ dist.all_gather(gathered, chunk, group=mp_group)
+ gathered_chunks.append(torch.cat(gathered, dim=-1))
+ return torch.cat(gathered_chunks, dim=-1)
+
+
+class TestAutoTPCustomPatterns(DistributedTest):
+ world_size = 2
+ reuse_dist_env = False
+
+ def test_custom_pattern_replacement(self):
+ skip_on_device()
+ partition_config = {
+ "use_default_specs":
+ False,
+ "layer_specs": [
+ {
+ "patterns": [".*linears\\.0\\.weight$"],
+ "partition_type": "row",
+ },
+ {
+ "patterns": [".*linears\\.1\\.weight$"],
+ "partition_type": "column",
+ },
+ {
+ "patterns": [".*linears\\.2\\.weight$"],
+ "partition_type": "skip",
+ },
+ ],
+ }
+ model = SequentialLinearModel(hidden_dim=16, nlayers=3)
+ model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config)
+
+ assert isinstance(model.linears[0], LinearAllreduce)
+ assert isinstance(model.linears[1], LinearLayer)
+ assert isinstance(model.linears[2], nn.Linear)
+
+ def test_custom_patterns_applied_via_config(self):
+ skip_on_device()
+ partition_config = {
+ "use_default_specs":
+ False,
+ "layer_specs": [
+ {
+ "patterns": [".*linears\\.0\\.weight$"],
+ "partition_type": "row",
+ },
+ {
+ "patterns": [".*linears\\.1\\.weight$"],
+ "partition_type": "column",
+ },
+ {
+ "patterns": [".*linears\\.2\\.weight$"],
+ "partition_type": "skip",
+ },
+ ],
+ }
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-6
+ }
+ },
+ "tensor_parallel": {
+ "autotp_size": 2,
+ "partition_config": partition_config,
+ },
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ if preferred_dtype() is torch.float16:
+ config_dict["fp16"] = {"enabled": True}
+ elif preferred_dtype() is torch.bfloat16:
+ config_dict["bf16"] = {"enabled": True}
+
+ model = SequentialLinearModel(hidden_dim=16, nlayers=3)
+ engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+ assert isinstance(engine.module.linears[0], LinearAllreduce)
+ assert isinstance(engine.module.linears[1], LinearLayer)
+ assert isinstance(engine.module.linears[2], nn.Linear)
+
+ def test_first_match_precedence(self):
+ skip_on_device()
+ partition_config = {
+ "use_default_specs":
+ False,
+ "layer_specs": [
+ {
+ "patterns": [".*linears\\.0\\.weight$"],
+ "partition_type": "skip",
+ },
+ {
+ "patterns": [".*linears\\.0\\.weight$"],
+ "partition_type": "column",
+ },
+ ],
+ }
+ model = SequentialLinearModel(hidden_dim=16, nlayers=1)
+ model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config)
+
+ assert isinstance(model.linears[0], nn.Linear)
+
+
+def test_invalid_custom_shape_rejected():
+ bad_config = {
+ "layer_specs": [{
+ "patterns": [".*"],
+ "partition_type": "column",
+ "shape": [2, [1, 1]],
+ }]
+ }
+ with pytest.raises(ValueError, match="nested tuple only allowed at partition_dim"):
+ AutoTPConfig.from_dict(bad_config)
+
+
+class TestAutoTPFusedWeights(DistributedTest):
+ world_size = 2
+ reuse_dist_env = False
+
+ def test_gate_up_fused_weight_partition(self):
+ skip_on_device()
+ init_tp_engine(tp_size=2)
+
+ hidden_dim = 8
+ torch.manual_seed(42)
+ linear = nn.Linear(hidden_dim,
+ hidden_dim * 2,
+ bias=True,
+ dtype=preferred_dtype(),
+ device=get_accelerator().current_device())
+ full_weight = deepcopy(linear.weight.data)
+ full_bias = deepcopy(linear.bias.data)
+
+ layer = SubParamLinearLayer(deepcopy(linear),
+ groups.get_tensor_model_parallel_group(),
+ shape=(2, -1),
+ partition_dim=0,
+ name="mlp.gate_up_proj")
+ assert layer._subparam_sizes == (hidden_dim, hidden_dim)
+ assert layer.weight.shape == (hidden_dim, hidden_dim)
+
+ layer.gather_params([layer.weight, layer.bias])
+ torch.testing.assert_close(layer.weight.data, full_weight)
+ torch.testing.assert_close(layer.bias.data, full_bias)
+
+ def test_gqa_uneven_qkv_fused_weight_partition(self):
+ skip_on_device()
+ init_tp_engine(tp_size=2)
+
+ hidden_dim = 8
+ q_size, k_size, v_size = 8, 4, 4
+ torch.manual_seed(123)
+ linear = nn.Linear(hidden_dim,
+ q_size + k_size + v_size,
+ bias=True,
+ dtype=preferred_dtype(),
+ device=get_accelerator().current_device())
+ full_weight = deepcopy(linear.weight.data)
+ full_bias = deepcopy(linear.bias.data)
+
+ layer = SubParamLinearLayer(deepcopy(linear),
+ groups.get_tensor_model_parallel_group(),
+ shape=((q_size, k_size, v_size), -1),
+ partition_dim=0,
+ name="self_attn.qkv_proj")
+ assert layer._subparam_sizes == (q_size, k_size, v_size)
+ assert layer.weight.shape == ((q_size + k_size + v_size) // 2, hidden_dim)
+
+ layer.gather_params([layer.weight, layer.bias])
+ torch.testing.assert_close(layer.weight.data, full_weight)
+ torch.testing.assert_close(layer.bias.data, full_bias)
+
+ def test_gqa_uneven_qkv_fused_forward(self):
+ skip_on_device()
+ groups._init_tp_mesh_device(tensor_model_parallel_size=2)
+
+ hidden_dim = 8
+ q_size, k_size, v_size = 8, 4, 4
+ torch.manual_seed(321)
+ linear = nn.Linear(hidden_dim,
+ q_size + k_size + v_size,
+ bias=True,
+ dtype=preferred_dtype(),
+ device=get_accelerator().current_device())
+ layer = SubParamLinearLayer(deepcopy(linear),
+ groups.get_tensor_model_parallel_group(),
+ shape=((q_size, k_size, v_size), -1),
+ partition_dim=0,
+ name="self_attn.qkv_proj")
+
+ torch.manual_seed(42)
+ inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device())
+ full_output = linear(inputs)
+ tp_output = layer(inputs)
+
+ gathered_output = gather_subparam_output(tp_output, (q_size, k_size, v_size),
+ groups.get_tensor_model_parallel_group())
+ atol = 1e-3
+ rtol = 2e-2
+ if preferred_dtype() is torch.float32:
+ atol = 1e-5
+ rtol = 1e-5
+ torch.testing.assert_close(gathered_output, full_output, atol=atol, rtol=rtol)
diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py
index 9d2b04211520..baaca247c229 100644
--- a/tests/unit/model_parallelism/test_autotp_training.py
+++ b/tests/unit/model_parallelism/test_autotp_training.py
@@ -16,7 +16,7 @@
from deepspeed.utils import groups
from contextlib import contextmanager
from torch import nn
-from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode
+from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode, is_autotp_training_mode
from unit.checkpoint.common import compare_lr_scheduler_states, compare_optimizer_states
import os
from deepspeed.runtime.utils import is_model_parallel_parameter
@@ -27,6 +27,39 @@ def skip_on_device():
pytest.skip("XPU requires a higher version for test")
+def reset_tp_model_init_state():
+ deepspeed._TP_MODEL_INIT_ARGS = None
+ set_autotp_mode(training=False)
+
+
+class DummyMPU:
+
+ def __init__(self, tp_world_size=1):
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.tp_world_size = tp_world_size
+ self.dp_group = dist.get_world_group()
+ self.tp_group = dist.get_world_group()
+
+ def get_model_parallel_rank(self):
+ return self.rank % self.tp_world_size
+
+ def get_model_parallel_world_size(self):
+ return self.tp_world_size
+
+ def get_data_parallel_rank(self):
+ return self.rank // self.tp_world_size
+
+ def get_data_parallel_world_size(self):
+ return self.world_size // self.tp_world_size
+
+ def get_data_parallel_group(self):
+ return self.dp_group
+
+ def get_model_parallel_group(self):
+ return self.tp_group
+
+
class SequentialLinearModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
@@ -69,13 +102,19 @@ class TestTpParallelStates(DistributedTest):
def test(self, tp_size: int):
skip_on_device()
- set_autotp_mode(training=True)
dp_size = 4 / tp_size
hidden_dim = 128
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"tensor_parallel": {
- "autotp_size": tp_size
+ "autotp_size": tp_size,
+ "partition_config": {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": "skip",
+ }],
+ }
},
"zero_optimization": {
"stage": 0
@@ -87,6 +126,104 @@ def test(self, tp_size: int):
assert groups.get_data_parallel_world_size() == dp_size
+class TestTpModelInitCompatibility(DistributedTest):
+ world_size = 4
+ reuse_dist_env = False
+
+ def test_tp_model_init_merges_config(self):
+ skip_on_device()
+ reset_tp_model_init_state()
+ model = SimpleModel(hidden_dim=8)
+ deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ engine, _, _, _ = deepspeed.initialize(model=model,
+ model_parameters=model.parameters(),
+ config=config_dict,
+ mpu=DummyMPU())
+ assert engine.autotp_size() == 1
+ assert is_autotp_training_mode()
+
+ def test_tp_model_init_config_autotp_size_mismatch(self):
+ skip_on_device()
+ reset_tp_model_init_state()
+ model = SimpleModel(hidden_dim=8)
+ deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "tensor_parallel": {
+ "autotp_size": 2,
+ },
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ with pytest.raises(ValueError, match="tensor_parallel.autotp_size"):
+ deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())
+
+ def test_tp_model_init_requires_mpu_or_mesh_param(self):
+ skip_on_device()
+ reset_tp_model_init_state()
+ model = SimpleModel(hidden_dim=8)
+ deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ with pytest.raises(ValueError, match="requires mpu or mesh_param"):
+ deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+
+ def test_tp_model_init_tp_group_rejects_mpu(self):
+ skip_on_device()
+ reset_tp_model_init_state()
+ model = SimpleModel(hidden_dim=8)
+ tp_group = dist.new_group(ranks=[0])
+ deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype(), tp_group=tp_group)
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ with pytest.raises(ValueError, match="tp_model_init provided tp_group"):
+ deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())
+
+ def test_tp_model_init_dtype_mismatch(self):
+ skip_on_device()
+ reset_tp_model_init_state()
+ model = SimpleModel(hidden_dim=8)
+ deepspeed.tp_model_init(model, tp_size=1, dtype=torch.float16)
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "bf16": {
+ "enabled": True,
+ },
+ "zero_optimization": {
+ "stage": 0,
+ }
+ }
+ with pytest.raises(ValueError, match="Conflicting dtype"):
+ deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())
+
+ @pytest.mark.sequential
+ @pytest.mark.parametrize("tp_size", [2, 4])
+ @pytest.mark.parametrize("tp_overlap_comm", [True, False])
+ def test_tp_model_init_row_parallel(self, tp_size: int, tp_overlap_comm: bool):
+ run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False, use_tp_model_init=True)
+
+ @pytest.mark.sequential
+ @pytest.mark.parametrize("tp_size", [2, 4])
+ @pytest.mark.parametrize("tp_overlap_comm", [True, False])
+ def test_tp_model_init_column_parallel(self, tp_size: int, tp_overlap_comm: bool):
+ run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True, use_tp_model_init=True)
+
+
@pytest.mark.parametrize("tp_size", [2, 4])
class TestTpDataloaderCorrectness(DistributedTest):
world_size = 4
@@ -95,7 +232,6 @@ class TestTpDataloaderCorrectness(DistributedTest):
def test(self, tp_size: int):
skip_on_device()
hidden_dim = 128
- set_autotp_mode(training=True)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
@@ -106,7 +242,14 @@ def test(self, tp_size: int):
}
},
"tensor_parallel": {
- "autotp_size": tp_size
+ "autotp_size": tp_size,
+ "partition_config": {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": "skip",
+ }],
+ }
},
"zero_optimization": {
"stage": 0,
@@ -164,137 +307,116 @@ def process_linear_layer(hidden_dim, input):
return torch_linear, torch_out
-@pytest.mark.sequential
-@pytest.mark.parametrize("tp_size", [2, 4])
-@pytest.mark.parametrize("tp_overlap_comm", [True, False])
-class TestTpLayerFwdBwd(DistributedTest):
- world_size = 4
- reuse_dist_env = False
-
- def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
- skip_on_device()
- hidden_dim = 128
- batch_size_per_device = 1
- set_autotp_mode(training=True)
- config_dict = {
- "train_micro_batch_size_per_gpu": 1,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 1e-6
- }
- },
- "tensor_parallel": {
- "autotp_size": tp_size,
- "tp_overlap_comm": tp_overlap_comm
- },
- "zero_optimization": {
- "stage": 0,
+def run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel, use_tp_model_init=False):
+ skip_on_device()
+ hidden_dim = 128
+ batch_size_per_device = 1
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "steps_per_print": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-6
}
+ },
+ "tensor_parallel": {
+ "autotp_size": tp_size,
+ "tp_overlap_comm": tp_overlap_comm
+ },
+ "zero_optimization": {
+ "stage": 0,
}
- if preferred_dtype() is torch.float16:
- config_dict["fp16"] = {"enabled": True}
- elif preferred_dtype() is torch.bfloat16:
- config_dict["bf16"] = {"enabled": True}
- model = SequentialLinearModel(hidden_dim=hidden_dim)
+ }
+ partition_type = "column" if column_parallel else "row"
+ config_dict["tensor_parallel"]["partition_config"] = {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": partition_type,
+ }],
+ }
+ if preferred_dtype() is torch.float16:
+ config_dict["fp16"] = {"enabled": True}
+ elif preferred_dtype() is torch.bfloat16:
+ config_dict["bf16"] = {"enabled": True}
+
+ model = SequentialLinearModel(hidden_dim=hidden_dim)
+ if use_tp_model_init:
+ reset_tp_model_init_state()
+ deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype())
+ mpu = DummyMPU(tp_world_size=tp_size)
+ model, _, _, _ = deepspeed.initialize(model=model,
+ model_parameters=model.parameters(),
+ config=config_dict,
+ mpu=mpu)
+ else:
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
- input = torch.randn(batch_size_per_device,
- hidden_dim,
- dtype=preferred_dtype(),
- requires_grad=True,
- device=get_accelerator().current_device())
- dist.broadcast(input,
- groups.get_tensor_model_parallel_src_rank(),
- group=groups.get_tensor_model_parallel_group())
-
- torch_linear, torch_out = process_linear_layer(hidden_dim, input)
- linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
-
- input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()]
- out = linear(input_.to(get_accelerator().current_device()))
+ input = torch.randn(batch_size_per_device,
+ hidden_dim,
+ dtype=preferred_dtype(),
+ requires_grad=True,
+ device=get_accelerator().current_device())
+ dist.broadcast(input, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group())
+
+ # Note: correctness checks below use standalone TP wrappers and do not
+ # rely on the model's AutoTP-partitioned parameters.
+ torch_linear, torch_out = process_linear_layer(hidden_dim, input)
+ if column_parallel:
+ linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
+ out = linear(input.to(get_accelerator().current_device()))
loss = out.sum()
loss.backward()
- torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
- torch_bias_grad = torch_linear.bias.grad
- # Use assert_close with rtol for proper floating-point comparisons
+ cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()]
+ torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
+ torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
+
torch.testing.assert_close(linear.bias.grad,
torch_bias_grad.to(get_accelerator().current_device()),
atol=1e-3,
rtol=1e-3)
- # The gradient of the weight is not the same as the torch_linear.weight.grad
torch.testing.assert_close(linear.weight.grad,
torch_grad.to(get_accelerator().current_device()),
atol=1e-3,
rtol=1e-3)
- torch.testing.assert_close(out, torch_out.to(get_accelerator().current_device()), atol=1e-2, rtol=1e-2)
-
- def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
- skip_on_device()
- hidden_dim = 128
- batch_size_per_device = 1
- set_autotp_mode(training=True)
- config_dict = {
- "train_micro_batch_size_per_gpu": 1,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 1e-6
- }
- },
- "tensor_parallel": {
- "autotp_size": tp_size,
- "tp_overlap_comm": tp_overlap_comm
- },
- "zero_optimization": {
- "stage": 0,
- }
- }
- if preferred_dtype() is torch.float16:
- config_dict["fp16"] = {"enabled": True}
- elif preferred_dtype() is torch.bfloat16:
- config_dict["bf16"] = {"enabled": True}
-
- model = SequentialLinearModel(hidden_dim=hidden_dim)
- model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
- input = torch.randn(batch_size_per_device,
- hidden_dim,
- dtype=preferred_dtype(),
- requires_grad=True,
- device=get_accelerator().current_device())
- dist.broadcast(input,
- groups.get_tensor_model_parallel_src_rank(),
- group=groups.get_tensor_model_parallel_group())
-
- torch_linear, torch_out = process_linear_layer(hidden_dim, input)
-
- linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
-
- out = linear(input.to(get_accelerator().current_device()))
+ torch.testing.assert_close(cur_device_out.to(get_accelerator().current_device()).contiguous(),
+ out.contiguous(),
+ atol=1e-2,
+ rtol=1e-2)
+ else:
+ linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
+ input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()]
+ out = linear(input_.to(get_accelerator().current_device()))
loss = out.sum()
loss.backward()
- cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()]
- torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
-
- torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
- # Use assert_close with rtol for proper floating-point comparisons
+ torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
+ torch_bias_grad = torch_linear.bias.grad
torch.testing.assert_close(linear.bias.grad,
torch_bias_grad.to(get_accelerator().current_device()),
atol=1e-3,
rtol=1e-3)
-
torch.testing.assert_close(linear.weight.grad,
torch_grad.to(get_accelerator().current_device()),
atol=1e-3,
rtol=1e-3)
- torch.testing.assert_close(cur_device_out.to(get_accelerator().current_device()).contiguous(),
- out.contiguous(),
- atol=1e-2,
- rtol=1e-2)
+ torch.testing.assert_close(out, torch_out.to(get_accelerator().current_device()), atol=1e-2, rtol=1e-2)
+
+
+@pytest.mark.sequential
+@pytest.mark.parametrize("tp_size", [2, 4])
+@pytest.mark.parametrize("tp_overlap_comm", [True, False])
+class TestTpLayerFwdBwd(DistributedTest):
+ world_size = 4
+ reuse_dist_env = False
+
+ def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
+ run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False)
+
+ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
+ run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True)
# @pytest.mark.sequential
@@ -307,7 +429,6 @@ def test(self, layer_type):
skip_on_device()
tp_size = 4
hidden_dim = 128
- set_autotp_mode(training=True)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
@@ -317,7 +438,14 @@ def test(self, layer_type):
}
},
"tensor_parallel": {
- "autotp_size": tp_size
+ "autotp_size": tp_size,
+ "partition_config": {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": "skip",
+ }],
+ }
},
"zero_optimization": {
"stage": 0,
@@ -382,6 +510,15 @@ def test(self, layer_type):
def dummy_init_engine(config):
# This is a dummy initialization function for the DeepSpeed engine.
# We only need to use the config to initialize the distributed settings for the test.
+ # Add default partition_config for simple test models if not provided
+ if "tensor_parallel" in config and "partition_config" not in config["tensor_parallel"]:
+ config["tensor_parallel"]["partition_config"] = {
+ "use_default_specs": False,
+ "layer_specs": [{
+ "patterns": [".*\\.weight$"],
+ "partition_type": "skip",
+ }],
+ }
model = SequentialLinearModel(hidden_dim=8)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config)
@@ -412,7 +549,6 @@ class TestSave(DistributedTest):
def test_save_original_weight(self, tp_size: int, zero_stage: int):
skip_on_device()
hidden_dim = 64
- set_autotp_mode(training=True)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
@@ -471,7 +607,6 @@ def compare_state_dicts(state_dict1, state_dict2):
def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
skip_on_device()
hidden_dim = 64
- set_autotp_mode(training=True)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
@@ -544,7 +679,6 @@ class TestTpGradNorm(DistributedTest):
def test(self, tp_size: int, zero_stage: int):
skip_on_device()
hidden_dim = 64
- set_autotp_mode(training=True)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,