Skip to content
97 changes: 71 additions & 26 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from .pipe import PipelineModule

from .git_version_info import version, git_hash, git_branch
from .runtime.tensor_parallel.init_utils import (load_ds_config, merge_tp_model_init_into_config,
record_tp_model_init_args)


def _parse_version(version_str):
Expand Down Expand Up @@ -159,17 +161,6 @@ def initialize(args=None,
if config is None and config_params is not None:
config = config_params

mesh_device = None
if mesh_param:
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
elif config is not None:
if "sequence_parallel_size" in config and "data_parallel_size" in config:
logger.info(f"config to Initialize mesh device: {config}")
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
("data_parallel", "sequence_parallel"))

# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
Expand All @@ -184,6 +175,26 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"

if not isinstance(config, dict):
config = load_ds_config(config)

mesh_device = None
if mesh_param:
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
else:
if "sequence_parallel_size" in config and "data_parallel_size" in config:
logger.info(f"config to Initialize mesh device: {config}")
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
("data_parallel", "sequence_parallel"))

merge_tp_model_init_into_config(config, mpu, mesh_param, dist)

autotp_size = config.get("tensor_parallel", {}).get("autotp_size", 0)
if autotp_size and autotp_size > 0:
set_autotp_mode(training=True)
if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
set_optimizer_flags(config_class, model)
Expand Down Expand Up @@ -379,31 +390,65 @@ def init_inference(model, config=None, **kwargs):

def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
"""
Initialize the model for tensor parallelism.
Record tensor-parallel initialization arguments for training.

Note (compatibility and initialization behavior):
AutoTP sharding is applied during ``deepspeed.initialize(...)``. This
function exists for backward compatibility and only records TP arguments so
they can be validated and merged with the DeepSpeed config at initialization.
When you use both (i.e., calling ``set_autotp_mode(training=True)`` and
``deepspeed.tp_model_init`` while also passing the config to
``deepspeed.initialize``), DeepSpeed merges the settings at initialization.
Conflicting settings raise an error. The table below summarizes the behavior
across input combinations.

Inputs:
- TPI: tp_model_init was called? (Y/N)
- TPG: tp_model_init provided tp_group? (Y/N)
- CFG: tensor_parallel in DeepSpeed config? (Y/N)
- MPU: mpu passed to deepspeed.initialize()? (Y/N)

| TPI | TPG | CFG | MPU | Outcome | Notes |
|-----|-----|-----|-----|----------------------------------------|-------|
| N | N | N | N | Error | No TP intent; nothing to initialize |
| N | N | N | Y | No AutoTP | mpu may be used for other MP, but TP not enabled |
| N | N | Y | N | Init AutoTP from config | Use config; need TP group via config-driven init |
| N | N | Y | Y | Init AutoTP from config | mpu used to build TP group |
| Y | N | N | N | Error | No TP group source |
| Y | N | N | Y | Init AutoTP from tp_model_init | Use recorded args + mpu for TP group |
| Y | N | Y | N | Init AutoTP from config | Fill missing from TPI; error on mismatches; need TP group source |
| Y | N | Y | Y | Init AutoTP from config | Fill missing from TPI; error on mismatches |
| Y | Y | N | N | Init AutoTP from tp_model_init | Use recorded tp_group; config absent |
| Y | Y | N | Y | Error | tp_group + mpu conflict |
| Y | Y | Y | N | Init AutoTP from config | Error on mismatches; use tp_group from TPI; reject mpu |
| Y | Y | Y | Y | Error | tp_group + mpu conflict |

Field-level merge rules when both tp_model_init and config exist:
- Canonical source: config
- Allowed: fill missing config fields from tp_model_init
- Error on mismatch: autotp_size, dtype, tp_group size or identity

Extra checks:
- If tp_group is provided, reject mpu.
- If tp_group is not provided, require mpu (or another TP group source).
- If tensor_parallel is absent and only tp_model_init was called, require
a TP group source (direct tp_group or mpu).

Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.

Returns:
torch.nn.Module: The initialized model with tensor parallelism.
torch.nn.Module: The original model (no sharding applied here).
"""
# avoid re-entry
if hasattr(model, 'ds_autotp_parsed'):
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
return

set_autotp_mode(training=True)
logger.warning("ds_autotp_parsed' attribute already exists in the model; tp_model_init is now record-only.")

from deepspeed.runtime.tensor_parallel import TpTrainingManager
# The expected usage here is for it to be invoked by transformers package.
tp_group = kwargs.get("tp_group")
record_tp_model_init_args(tp_size=tp_size, dtype=dtype, tp_group=tp_group, dist_module=dist)

#TODO: We should provide a custom TP mapping solution without using autoTP
#as modifying the autoTP logic may be more difficult for users compared to configuring it

model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module

setattr(model, 'ds_autotp_parsed', True)
# Keep AutoTP training mode active for backward compatibility.
set_autotp_mode(training=True)

return model
3 changes: 2 additions & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode, SubParamLinearLayer, SubParamLinearAllreduce
from .policy import DSPolicy
from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType, AutoTPPresets, merge_autotp_configs
91 changes: 90 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode
from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType


def move(tensor, device, copy=True):
Expand Down Expand Up @@ -199,7 +200,8 @@ def __init__(self,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
keep_module_on_host=False,
partition_config: Optional[AutoTPConfig] = None):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
Expand All @@ -211,6 +213,7 @@ def __init__(self,
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.partition_config = partition_config
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)

def in_module_list(module, module_list):
Expand Down Expand Up @@ -353,6 +356,11 @@ def _replace(self, child, name, conv_linear_layer):

weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

# If partition_config is provided, use the new configurable API
if self.partition_config is not None:
return self._replace_with_config(child, name)

# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
Expand Down Expand Up @@ -396,6 +404,87 @@ def _replace(self, child, name, conv_linear_layer):

return LinearLayer(child, self.mp_group, name=name)

def _replace_with_config(self, child, name):
"""
Replace layer using the new configurable AutoTP API.

This method uses TPLayerSpec to determine how to partition the layer.
"""
if getattr(child, "replaced", False) == True:
return child

# Build the full parameter name for pattern matching
param_name = name + ".weight" if not name.endswith(".weight") else name

# Find matching spec
model_type = self._get_model_type()
spec = self.partition_config.find_matching_spec(param_name, model_type)

if spec is None:
# No matching spec found
if self.partition_config.strict_mode:
raise ValueError(f"No matching spec for {param_name}")
# Default: column parallel for Linear layers
spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)

setattr(child, "replaced", True)

if spec.partition_type == PartitionType.SKIP:
return child

if spec.partition_type == PartitionType.ROW:
return self._create_row_parallel_layer(child, spec, name)
else:
return self._create_column_parallel_layer(child, spec, name)

def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create row-parallel layer (AllReduce after forward)."""
# Check for lm_head / embed_out
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(module, self.mp_group)

if spec.shape is not None:
return SubParamLinearAllreduce(
module,
self.mp_group,
shape=spec.shape,
partition_dim=spec.get_partition_dim(),
name=name,
)
return LinearAllreduce(module, self.mp_group, name=name)

def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create column-parallel layer (AllReduce in backward)."""
if spec.shape is not None:
return SubParamLinearLayer(
module,
self.mp_group,
shape=spec.shape,
partition_dim=spec.get_partition_dim(),
name=name,
)
return LinearLayer(module, self.mp_group, name=name)

def _get_model_type(self) -> Optional[str]:
"""Extract model type from module config or class name."""
config = getattr(self.module, "config", None)
if config is not None:
model_type = getattr(config, "model_type", None)
if model_type:
return str(model_type).lower()
module_str = str(type(self.module))
# Try to extract model type from class name (e.g., "LlamaDecoderLayer" -> "llama")
patterns = [
r"(\w+)DecoderLayer",
r"(\w+)Block",
r"(\w+)Layer",
]
for pattern in patterns:
match = re.search(pattern, module_str)
if match:
return match.group(1).lower()
return None

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
Expand Down
Loading