diff --git a/blogs/huggingface-tp/README.md b/blogs/huggingface-tp/README.md index 44469f1818c3..ceb95417ae5f 100644 --- a/blogs/huggingface-tp/README.md +++ b/blogs/huggingface-tp/README.md @@ -230,13 +230,11 @@ Furthermore, if users are not using transformers library, you can replace the `` # Ongoing Work - **Optimization**: Communication/Activation optimization. -- **Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing, - +- **Usability**: Support the [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple the AutoTP parser, and expand model testing. + - [UPDATE] We now support [custom partitioning](https://deepspeed.readthedocs.io/en/latest/training.html#custom-layer-specs) in the same spirit as HF's partitioning plan, and will build Transformers TP plan support on top of that ([PR](http://github.com/deepspeedai/DeepSpeed/pull/7806)). Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending. - Welcome bug reports, enhancement, and additional model training examples. - # Contributors This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft. diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 10b71eef7a5c..0d53a172e64e 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -48,6 +48,8 @@ from .pipe import PipelineModule from .git_version_info import version, git_hash, git_branch +from .runtime.tensor_parallel.init_utils import (load_ds_config, merge_tp_model_init_into_config, + record_tp_model_init_args) def _parse_version(version_str): @@ -159,17 +161,6 @@ def initialize(args=None, if config is None and config_params is not None: config = config_params - mesh_device = None - if mesh_param: - logger.info(f"mesh_param to Initialize mesh device: {mesh_param}") - mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel")) - #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device - elif config is not None: - if "sequence_parallel_size" in config and "data_parallel_size" in config: - logger.info(f"config to Initialize mesh device: {config}") - mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \ - ("data_parallel", "sequence_parallel")) - # Check for deepscale_config for backwards compat if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") @@ -184,6 +175,26 @@ def initialize(args=None, assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call" config = args.deepspeed_config assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file" + + if not isinstance(config, dict): + config = load_ds_config(config) + + mesh_device = None + if mesh_param: + logger.info(f"mesh_param to Initialize mesh device: {mesh_param}") + mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel")) + #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device + else: + if "sequence_parallel_size" in config and "data_parallel_size" in config: + logger.info(f"config to Initialize mesh device: {config}") + mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \ + ("data_parallel", "sequence_parallel")) + + merge_tp_model_init_into_config(config, mpu, mesh_param, dist) + + autotp_size = config.get("tensor_parallel", {}).get("autotp_size", 0) + if autotp_size and autotp_size > 0: + set_autotp_mode(training=True) if not isinstance(model, PipelineModule): config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device) set_optimizer_flags(config_class, model) @@ -379,7 +390,49 @@ def init_inference(model, config=None, **kwargs): def tp_model_init(model, tp_size, dtype, config=None, **kwargs): """ - Initialize the model for tensor parallelism. + Record tensor-parallel initialization arguments for training. + + Note (compatibility and initialization behavior): + AutoTP sharding is applied during ``deepspeed.initialize(...)``. This + function exists for backward compatibility and only records TP arguments so + they can be validated and merged with the DeepSpeed config at initialization. + When you use both (i.e., calling ``set_autotp_mode(training=True)`` and + ``deepspeed.tp_model_init`` while also passing the config to + ``deepspeed.initialize``), DeepSpeed merges the settings at initialization. + Conflicting settings raise an error. The table below summarizes the behavior + across input combinations. + + Inputs: + - TPI: tp_model_init was called? (Y/N) + - TPG: tp_model_init provided tp_group? (Y/N) + - CFG: tensor_parallel in DeepSpeed config? (Y/N) + - MPU: mpu passed to deepspeed.initialize()? (Y/N) + + | TPI | TPG | CFG | MPU | Outcome | Notes | + |-----|-----|-----|-----|----------------------------------------|-------| + | N | N | N | N | Error | No TP intent; nothing to initialize | + | N | N | N | Y | No AutoTP | mpu may be used for other MP, but TP not enabled | + | N | N | Y | N | Init AutoTP from config | Use config; need TP group via config-driven init | + | N | N | Y | Y | Init AutoTP from config | mpu used to build TP group | + | Y | N | N | N | Error | No TP group source | + | Y | N | N | Y | Init AutoTP from tp_model_init | Use recorded args + mpu for TP group | + | Y | N | Y | N | Init AutoTP from config | Fill missing from TPI; error on mismatches; need TP group source | + | Y | N | Y | Y | Init AutoTP from config | Fill missing from TPI; error on mismatches | + | Y | Y | N | N | Init AutoTP from tp_model_init | Use recorded tp_group; config absent | + | Y | Y | N | Y | Error | tp_group + mpu conflict | + | Y | Y | Y | N | Init AutoTP from config | Error on mismatches; use tp_group from TPI; reject mpu | + | Y | Y | Y | Y | Error | tp_group + mpu conflict | + + Field-level merge rules when both tp_model_init and config exist: + - Canonical source: config + - Allowed: fill missing config fields from tp_model_init + - Error on mismatch: autotp_size, dtype, tp_group size or identity + + Extra checks: + - If tp_group is provided, reject mpu. + - If tp_group is not provided, require mpu (or another TP group source). + - If tensor_parallel is absent and only tp_model_init was called, require + a TP group source (direct tp_group or mpu). Args: model (torch.nn.Module): The model to be initialized. @@ -387,23 +440,15 @@ def tp_model_init(model, tp_size, dtype, config=None, **kwargs): dtype (torch.dtype): The data type to be used for the model. Returns: - torch.nn.Module: The initialized model with tensor parallelism. + torch.nn.Module: The original model (no sharding applied here). """ - # avoid re-entry if hasattr(model, 'ds_autotp_parsed'): - logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.") - return - - set_autotp_mode(training=True) + logger.warning("ds_autotp_parsed' attribute already exists in the model; tp_model_init is now record-only.") - from deepspeed.runtime.tensor_parallel import TpTrainingManager - # The expected usage here is for it to be invoked by transformers package. + tp_group = kwargs.get("tp_group") + record_tp_model_init_args(tp_size=tp_size, dtype=dtype, tp_group=tp_group, dist_module=dist) - #TODO: We should provide a custom TP mapping solution without using autoTP - #as modifying the autoTP logic may be more difficult for users compared to configuring it - - model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module - - setattr(model, 'ds_autotp_parsed', True) + # Keep AutoTP training mode active for backward compatibility. + set_autotp_mode(training=True) return model diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py index 9fc2f979a04b..2299ef6e7c3a 100755 --- a/deepspeed/module_inject/__init__.py +++ b/deepspeed/module_inject/__init__.py @@ -6,5 +6,6 @@ from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection from .module_quantize import quantize_transformer_layer from .replace_policy import HFBertLayerPolicy -from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode +from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode, SubParamLinearLayer, SubParamLinearAllreduce from .policy import DSPolicy +from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType, AutoTPPresets, merge_autotp_configs diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 82cd9042071e..121e3938444a 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,6 +17,7 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list from deepspeed.utils import groups from deepspeed.module_inject.layers import is_autotp_training_mode +from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType def move(tensor, device, copy=True): @@ -199,7 +200,8 @@ def __init__(self, state_dict, linear_layer_setting, orig_layer_impl, - keep_module_on_host=False): + keep_module_on_host=False, + partition_config: Optional[AutoTPConfig] = None): self.module = module self.all_reduce_linears = all_reduce_linears self.prefix = prefix @@ -211,6 +213,7 @@ def __init__(self, self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False + self.partition_config = partition_config TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host) def in_module_list(module, module_list): @@ -353,6 +356,11 @@ def _replace(self, child, name, conv_linear_layer): weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) + + # If partition_config is provided, use the new configurable API + if self.partition_config is not None: + return self._replace_with_config(child, name) + # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): @@ -396,6 +404,87 @@ def _replace(self, child, name, conv_linear_layer): return LinearLayer(child, self.mp_group, name=name) + def _replace_with_config(self, child, name): + """ + Replace layer using the new configurable AutoTP API. + + This method uses TPLayerSpec to determine how to partition the layer. + """ + if getattr(child, "replaced", False) == True: + return child + + # Build the full parameter name for pattern matching + param_name = name + ".weight" if not name.endswith(".weight") else name + + # Find matching spec + model_type = self._get_model_type() + spec = self.partition_config.find_matching_spec(param_name, model_type) + + if spec is None: + # No matching spec found + if self.partition_config.strict_mode: + raise ValueError(f"No matching spec for {param_name}") + # Default: column parallel for Linear layers + spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN) + + setattr(child, "replaced", True) + + if spec.partition_type == PartitionType.SKIP: + return child + + if spec.partition_type == PartitionType.ROW: + return self._create_row_parallel_layer(child, spec, name) + else: + return self._create_column_parallel_layer(child, spec, name) + + def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str): + """Create row-parallel layer (AllReduce after forward).""" + # Check for lm_head / embed_out + if name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce(module, self.mp_group) + + if spec.shape is not None: + return SubParamLinearAllreduce( + module, + self.mp_group, + shape=spec.shape, + partition_dim=spec.get_partition_dim(), + name=name, + ) + return LinearAllreduce(module, self.mp_group, name=name) + + def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str): + """Create column-parallel layer (AllReduce in backward).""" + if spec.shape is not None: + return SubParamLinearLayer( + module, + self.mp_group, + shape=spec.shape, + partition_dim=spec.get_partition_dim(), + name=name, + ) + return LinearLayer(module, self.mp_group, name=name) + + def _get_model_type(self) -> Optional[str]: + """Extract model type from module config or class name.""" + config = getattr(self.module, "config", None) + if config is not None: + model_type = getattr(config, "model_type", None) + if model_type: + return str(model_type).lower() + module_str = str(type(self.module)) + # Try to extract model type from class name (e.g., "LlamaDecoderLayer" -> "llama") + patterns = [ + r"(\w+)DecoderLayer", + r"(\w+)Block", + r"(\w+)Layer", + ] + for pattern in patterns: + match = re.search(pattern, module_str) + if match: + return match.group(1).lower() + return None + def _slice_embedding(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: return diff --git a/deepspeed/module_inject/autotp_config.py b/deepspeed/module_inject/autotp_config.py new file mode 100644 index 000000000000..4bafea806829 --- /dev/null +++ b/deepspeed/module_inject/autotp_config.py @@ -0,0 +1,569 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Configurable AutoTP API + +This module provides a unified specification for tensor parallel layer partitioning. +The design is inspired by Universal Checkpointing's SubparamShape and provides +a single, well-defined format that users can easily understand, customize, and extend. +""" + +import re +from dataclasses import dataclass, field +from typing import List, Tuple, Union, Optional +from enum import Enum +from deepspeed.utils.logging import warning_once + + +class PartitionType(Enum): + """How the layer should be partitioned for tensor parallelism.""" + COLUMN = "column" # Partition output dim, AllReduce in backward + ROW = "row" # Partition input dim, AllReduce in forward + SKIP = "skip" # Do not partition this layer + + +@dataclass +class TPLayerSpec: + """ + Unified specification for tensor parallel layer partitioning. + + This is inspired by Universal Checkpointing's SubparamShape but extended + for AutoTP's needs (forward/backward communication patterns). + + The `shape` parameter supports at most 1-level nesting at the partition dimension: + - (3, -1) -> 3 equal-size sub-params + - ((q, k, v), -1) -> 3 unequal-size sub-params (1-level nesting) + + Examples: + # Simple row-parallel layer (e.g., o_proj, down_proj) + TPLayerSpec( + patterns=[".*\\.o_proj$", ".*\\.down_proj$"], + partition_type=PartitionType.ROW, + ) + + # Simple column-parallel layer (e.g., q_proj, k_proj, v_proj) + TPLayerSpec( + patterns=[".*\\.[qkv]_proj$"], + partition_type=PartitionType.COLUMN, + ) + + # Fused QKV - GLM style [Q, K, V] concatenated on dim 0 + TPLayerSpec( + patterns=[".*\\.query_key_value\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # 3 equal sub-params, -1 = infer + partition_dim=0, + ) + + # Fused QKV - Bloom style [q1,k1,v1,q2,k2,v2,...] + TPLayerSpec( + patterns=[".*\\.query_key_value\\.weight$"], + partition_type=PartitionType.COLUMN, + # No reshape needed, just split along dim 0 + ) + + # GQA with different Q/K/V sizes (1-level nesting) + TPLayerSpec( + patterns=[".*\\.qkv_proj\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=((q_size, k_size, v_size), -1), # Unequal sub-params + partition_dim=0, + ) + + # Chunked MLP (gate_up_proj) + TPLayerSpec( + patterns=[".*\\.gate_up_proj\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ) + + # MoE FFN with expert dimension + TPLayerSpec( + patterns=[".*\\.experts\\..*\\.w1\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(num_experts, -1, hidden_in), # View as 3D + partition_dim=1, # Partition the hidden_out dimension + ) + + # Skip layer (e.g., MoE gate) + TPLayerSpec( + patterns=[".*\\.gate$", ".*\\.router$"], + partition_type=PartitionType.SKIP, + ) + """ + + # Layer identification - regex patterns to match parameter names + patterns: List[str] + + # Partition type determines communication pattern + partition_type: PartitionType = PartitionType.COLUMN + + # Optional: logical shape for partitioning + # - Use -1 for dimensions that should be inferred + # - Use tuple of ints at partition_dim for unequal sub-params (1-level nesting only) + # Examples: + # (3, -1) -> 3 equal sub-params + # ((4096, 1024, 1024), -1) -> 3 unequal sub-params (GQA) + # (n_experts, -1, hidden) -> MoE reshape + shape: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None + + # Which dimension to partition (after optional reshape) + # Default: 0 for COLUMN, 1 for ROW (standard 2D weight matrix) + partition_dim: Optional[int] = None + + # Optional: model type constraint (only apply for specific models) + model_types: Optional[List[str]] = None + + def __post_init__(self): + if isinstance(self.partition_type, str): + self.partition_type = PartitionType(self.partition_type.lower()) + if self.shape is not None: + self.shape = self._normalize_shape(self.shape) + self._validate_shape_format() + + @staticmethod + def _normalize_shape(shape): + if isinstance(shape, list): + return tuple(TPLayerSpec._normalize_shape(item) for item in shape) + if isinstance(shape, tuple): + return tuple(TPLayerSpec._normalize_shape(item) if isinstance(item, list) else item for item in shape) + return shape + + def _validate_shape_format(self): + if not isinstance(self.shape, tuple): + raise ValueError("AutoTP shape must be a tuple of ints or a tuple at partition_dim.") + partition_dim = self.get_partition_dim() + if partition_dim < 0 or partition_dim >= len(self.shape): + raise ValueError( + f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(self.shape)}.") + nested_tuple_seen = False + for idx, dim in enumerate(self.shape): + if isinstance(dim, tuple): + if idx != partition_dim: + raise ValueError( + f"AutoTP shape nested tuple only allowed at partition_dim={partition_dim}, got at {idx}.") + if nested_tuple_seen: + raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.") + nested_tuple_seen = True + if len(dim) == 0: + raise ValueError("AutoTP shape nested tuple cannot be empty.") + for val in dim: + if isinstance(val, tuple): + raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.") + if not isinstance(val, int) or val <= 0: + raise ValueError("AutoTP nested sub-parameter sizes must be positive integers.") + elif isinstance(dim, int): + if dim == 0 or dim < -1: + raise ValueError("AutoTP shape dimensions must be positive integers or -1.") + else: + raise ValueError("AutoTP shape must contain only integers or a tuple at partition_dim.") + + def get_partition_dim(self) -> int: + """Get effective partition dimension.""" + if self.partition_dim is not None: + return self.partition_dim + # Default based on partition type for 2D weight matrices + return 0 if self.partition_type == PartitionType.COLUMN else 1 + + def has_unequal_sub_params(self) -> bool: + """Check if this spec has unequal sub-parameters (nested tuple at partition_dim).""" + if self.shape is None: + return False + dim = self.get_partition_dim() + if dim >= len(self.shape): + return False + return isinstance(self.shape[dim], tuple) + + def get_sub_param_sizes(self) -> Optional[Tuple[int, ...]]: + """Get sub-parameter sizes if using unequal sub-params.""" + if not self.has_unequal_sub_params(): + return None + return self.shape[self.get_partition_dim()] + + def get_num_sub_params(self) -> Optional[int]: + """Get the number of sub-parameters.""" + if self.shape is None: + return None + dim = self.get_partition_dim() + if dim >= len(self.shape): + return None + if isinstance(self.shape[dim], tuple): + return len(self.shape[dim]) + elif isinstance(self.shape[dim], int) and self.shape[dim] > 0: + return self.shape[dim] + return None + + def matches(self, param_name: str, model_type: Optional[str] = None) -> bool: + """Check if this spec matches the given parameter.""" + # Check model type constraint + if self.model_types: + if model_type is None: + return False + model_type_norm = str(model_type).lower() + model_types_norm = [str(mt).lower() for mt in self.model_types] + if model_type_norm not in model_types_norm: + return False + # Check pattern match + return any(re.match(pattern, param_name) for pattern in self.patterns) + + +@dataclass +class AutoTPConfig: + """ + Configuration for Automatic Tensor Parallelism. + + Example usage: + config = AutoTPConfig( + tp_size=4, + layer_specs=[ + # Row-parallel layers (AllReduce after forward) + TPLayerSpec( + patterns=[".*\\.o_proj", ".*\\.down_proj"], + partition_type=PartitionType.ROW, + ), + # Column-parallel layers + TPLayerSpec( + patterns=[".*\\.[qkv]_proj", ".*\\.up_proj", ".*\\.gate_proj"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gates + TPLayerSpec( + patterns=[".*\\.gate$"], + partition_type=PartitionType.SKIP, + ), + ], + ) + """ + + tp_size: int = 1 + + # Unified layer specifications + layer_specs: List[TPLayerSpec] = field(default_factory=list) + + # Embedding configuration + embedding_partition_dim: int = 1 # Usually partition vocab dim + + # LM head configuration + lm_head_patterns: List[str] = field(default_factory=lambda: ["lm_head", "embed_out"]) + + # Behavior flags + use_default_specs: bool = True # Merge with built-in specs + strict_mode: bool = False # Fail if unmatched Linear layers found + + def find_matching_spec(self, param_name: str, model_type: Optional[str] = None) -> Optional[TPLayerSpec]: + """Find the first matching spec for a parameter.""" + matches = [spec for spec in self.layer_specs if spec.matches(param_name, model_type)] + if not matches: + return None + if len(matches) > 1: + matched_patterns = [spec.patterns for spec in matches] + warning_once(f"AutoTPConfig: parameter {param_name} matched multiple layer_specs {matched_patterns}; " + "using the first match.") + return matches[0] + + @classmethod + def from_dict(cls, config_dict: dict) -> "AutoTPConfig": + """Create config from dictionary (JSON config).""" + layer_specs = [] + for spec_dict in config_dict.get("layer_specs", []): + # Convert partition_type string to enum + partition_type_str = spec_dict.get("partition_type", "column") + if isinstance(partition_type_str, str): + partition_type = PartitionType(partition_type_str.lower()) + else: + partition_type = partition_type_str + + # Convert shape from list to tuple if necessary + shape = spec_dict.get("shape") + if shape is not None: + shape = cls._convert_shape(shape) + + layer_specs.append( + TPLayerSpec( + patterns=spec_dict.get("patterns", []), + partition_type=partition_type, + shape=shape, + partition_dim=spec_dict.get("partition_dim"), + model_types=spec_dict.get("model_types"), + )) + + return cls( + tp_size=config_dict.get("tp_size", 1), + layer_specs=layer_specs, + embedding_partition_dim=config_dict.get("embedding_partition_dim", 1), + lm_head_patterns=config_dict.get("lm_head_patterns", ["lm_head", "embed_out"]), + use_default_specs=config_dict.get("use_default_specs", True), + strict_mode=config_dict.get("strict_mode", False), + ) + + @staticmethod + def _convert_shape(shape): + """Convert shape from list to tuple, handling nested structures.""" + if isinstance(shape, list): + return tuple(AutoTPConfig._convert_shape(item) if isinstance(item, list) else item for item in shape) + return shape + + +class AutoTPPresets: + """Built-in presets for common model architectures.""" + + @staticmethod + def llama() -> AutoTPConfig: + """LLaMA-style models (separate Q, K, V projections).""" + return AutoTPConfig(layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def llama_gqa(num_heads: int, num_kv_heads: int, head_dim: int) -> AutoTPConfig: + """LLaMA with Grouped Query Attention (fused QKV variant).""" + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Fused QKV with unequal sizes (GQA) + TPLayerSpec( + patterns=[r".*\.self_attn\.qkv_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=((q_size, kv_size, kv_size), -1), # 1-level nesting + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def bloom() -> AutoTPConfig: + """BLOOM-style models (fused QKV with interleaved heads).""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attention\.dense\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attention\.query_key_value\.weight$"], + partition_type=PartitionType.COLUMN, + # Bloom style: [q1,k1,v1,q2,k2,v2,...] - no reshape needed + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def chatglm() -> AutoTPConfig: + """ChatGLM-style models (GLM-style fused QKV).""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attention\.dense\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attention\.query_key_value\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # [Q, K, V] concatenated + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ), + ], ) + + @staticmethod + def mixtral() -> AutoTPConfig: + """Mixtral MoE model.""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # MoE experts + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w2\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w[13]\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gate + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.gate\.weight$"], + partition_type=PartitionType.SKIP, + ), + ], ) + + @staticmethod + def deepseek_v2() -> AutoTPConfig: + """DeepSeek-V2 with MLA (Multi-head Latent Attention).""" + return AutoTPConfig( + layer_specs=[ + # Standard attention output + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # MLA uses compressed KV, skip low-rank projections + TPLayerSpec( + patterns=[r".*\.self_attn\.(q_a_proj|kv_a_proj_with_mqa)\.weight$"], + partition_type=PartitionType.SKIP, + ), + # Q/K/V projections from latent + TPLayerSpec( + patterns=[r".*\.self_attn\.(q_b_proj|kv_b_proj)\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # MoE experts + TPLayerSpec( + patterns=[r".*\.mlp\.experts\.\d+\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.experts\.\d+\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gate + TPLayerSpec( + patterns=[r".*\.mlp\.gate\.weight$"], + partition_type=PartitionType.SKIP, + ), + # Shared expert + TPLayerSpec( + patterns=[r".*\.mlp\.shared_experts\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.shared_experts\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def qwen2() -> AutoTPConfig: + """Qwen2 model.""" + return AutoTPConfig(layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def phi3() -> AutoTPConfig: + """Phi3 model with fused QKV and chunked MLP.""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Phi3 has fused qkv_proj + TPLayerSpec( + patterns=[r".*\.self_attn\.qkv_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # [Q, K, V] concatenated + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Phi3 has gate_up_proj fused + TPLayerSpec( + patterns=[r".*\.mlp\.gate_up_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ), + ], ) + + @staticmethod + def get_preset(model_type: str) -> Optional[AutoTPConfig]: + """Get a preset configuration by model type name.""" + presets = { + "llama": AutoTPPresets.llama, + "bloom": AutoTPPresets.bloom, + "chatglm": AutoTPPresets.chatglm, + "mixtral": AutoTPPresets.mixtral, + "deepseek_v2": AutoTPPresets.deepseek_v2, + "qwen2": AutoTPPresets.qwen2, + "phi3": AutoTPPresets.phi3, + } + preset_fn = presets.get(model_type.lower()) + if preset_fn: + return preset_fn() + return None + + +def merge_autotp_configs(base: AutoTPConfig, override: AutoTPConfig) -> AutoTPConfig: + """Merge two AutoTP configs, with override taking precedence.""" + # Combine layer specs - override specs come first (higher priority) + merged_specs = list(override.layer_specs) + list(base.layer_specs) + + return AutoTPConfig( + tp_size=override.tp_size if override.tp_size > 1 else base.tp_size, + layer_specs=merged_specs, + embedding_partition_dim=override.embedding_partition_dim, + lm_head_patterns=override.lm_head_patterns or base.lm_head_patterns, + use_default_specs=override.use_default_specs, + strict_mode=override.strict_mode, + ) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index a9faac71361a..5f55d2e78e26 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -20,7 +20,8 @@ __all__ = [ "TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce", - "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer" + "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer", + "SubParamLinearLayer", "SubParamLinearAllreduce" ] DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE @@ -801,6 +802,356 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int return super().forward(positions + self.offset) +def _shape_prod(values): + result = 1 + for val in values: + result *= val + return result + + +def _normalize_shape_spec(shape): + if isinstance(shape, list): + return tuple(_normalize_shape_spec(item) for item in shape) + if isinstance(shape, tuple): + return tuple(_normalize_shape_spec(item) if isinstance(item, list) else item for item in shape) + return shape + + +def _infer_subparam_logical_shapes(weight_shape, shape, partition_dim, name=None): + shape = _normalize_shape_spec(shape) + if not isinstance(shape, tuple): + raise ValueError("AutoTP shape must be a tuple for sub-parameter partitioning.") + if partition_dim < 0 or partition_dim >= len(shape): + raise ValueError(f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(shape)}.") + + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + partition_elem = shape[partition_dim] + subparam_sizes = None + num_subparams = None + + if isinstance(partition_elem, tuple): + if len(partition_elem) == 0: + raise ValueError(f"{layer_label} sub-parameter size tuple cannot be empty.") + if any(isinstance(val, tuple) for val in partition_elem): + raise ValueError(f"{layer_label} supports only 1-level nesting at partition_dim.") + if any((not isinstance(val, int)) or val <= 0 for val in partition_elem): + raise ValueError(f"{layer_label} sub-parameter sizes must be positive integers.") + subparam_sizes = tuple(int(val) for val in partition_elem) + partition_dim_size = sum(subparam_sizes) + elif isinstance(partition_elem, int): + if partition_elem == -1: + partition_dim_size = None + elif partition_elem > 0: + num_subparams = partition_elem + partition_dim_size = None + else: + raise ValueError(f"{layer_label} partition_dim spec must be positive integer or -1.") + else: + raise ValueError(f"{layer_label} partition_dim spec must be int or tuple.") + + logical_dims = [] + for idx, dim in enumerate(shape): + if idx == partition_dim: + logical_dims.append(partition_dim_size) + continue + if isinstance(dim, tuple): + raise ValueError(f"{layer_label} nested tuple only allowed at partition_dim={partition_dim}.") + if isinstance(dim, int): + if dim == -1: + logical_dims.append(None) + elif dim > 0: + logical_dims.append(dim) + else: + raise ValueError(f"{layer_label} shape dimensions must be positive integers or -1.") + else: + raise ValueError(f"{layer_label} shape dimensions must be integers.") + + total_numel = _shape_prod(weight_shape) + known_product = _shape_prod([dim for dim in logical_dims if dim is not None]) + unknown_indices = [idx for idx, dim in enumerate(logical_dims) if dim is None] + + if len(unknown_indices) == 0: + if known_product != total_numel: + raise ValueError(f"{layer_label} shape product {known_product} != weight numel {total_numel}.") + elif len(unknown_indices) == 1: + inferred = total_numel // known_product + if inferred * known_product != total_numel: + raise ValueError(f"{layer_label} cannot infer shape for weight with numel {total_numel}.") + logical_dims[unknown_indices[0]] = inferred + else: + if len(shape) == len(weight_shape): + for idx in unknown_indices: + logical_dims[idx] = weight_shape[idx] + if _shape_prod(logical_dims) != total_numel: + raise ValueError( + f"{layer_label} shape product {_shape_prod(logical_dims)} != weight numel {total_numel}.") + else: + raise ValueError(f"{layer_label} shape has multiple inferred dims and is ambiguous for weight.") + + logical_shape = tuple(logical_dims) + if logical_shape[-1] != weight_shape[-1]: + raise ValueError( + f"{layer_label} shape last dim {logical_shape[-1]} must match weight input dim {weight_shape[-1]}.") + + output_shape = logical_shape[:-1] + if len(output_shape) == 0: + raise ValueError(f"{layer_label} shape must include at least one output dimension.") + if _shape_prod(output_shape) != weight_shape[0]: + raise ValueError( + f"{layer_label} output shape product {_shape_prod(output_shape)} != weight output dim {weight_shape[0]}.") + + partition_dim_size = logical_shape[partition_dim] + if partition_dim_size is None or partition_dim_size <= 0: + raise ValueError(f"{layer_label} partition_dim size must be a positive integer.") + + if num_subparams is not None: + if partition_dim_size % num_subparams != 0: + raise ValueError( + f"{layer_label} partition_dim size {partition_dim_size} not divisible by sub-param count {num_subparams}." + ) + subparam_sizes = tuple([partition_dim_size // num_subparams] * num_subparams) + + if subparam_sizes is not None and sum(subparam_sizes) != partition_dim_size: + raise ValueError( + f"{layer_label} sub-parameter sizes sum {sum(subparam_sizes)} != partition_dim size {partition_dim_size}.") + + bias_partition_dim = partition_dim if partition_dim < len(output_shape) else None + return logical_shape, output_shape, subparam_sizes, bias_partition_dim + + +def _partition_logical_tensor(tensor, partition_dim, tp_world_size, tp_index, name=None, subparam_sizes=None): + if tp_world_size == 1: + return tensor + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + if subparam_sizes: + for size in subparam_sizes: + if size % tp_world_size != 0: + raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.") + sub_params = torch.split(tensor, subparam_sizes, dim=partition_dim) + partitioned_sub_params = [torch.chunk(sp, tp_world_size, dim=partition_dim)[tp_index] for sp in sub_params] + return torch.cat(partitioned_sub_params, dim=partition_dim) + if tensor.shape[partition_dim] % tp_world_size != 0: + raise ValueError( + f"{layer_label} partition_dim size {tensor.shape[partition_dim]} not divisible by tp_size {tp_world_size}." + ) + return torch.chunk(tensor, tp_world_size, dim=partition_dim)[tp_index] + + +def _all_gather_along_dim(tensor, partition_dim, mp_group, tp_world_size): + if mp_group is None or tp_world_size == 1: + return tensor + perm = [partition_dim] + [idx for idx in range(tensor.dim()) if idx != partition_dim] + inv_perm = [0] * len(perm) + for idx, dim in enumerate(perm): + inv_perm[dim] = idx + tensor_perm = tensor.permute(perm).contiguous() + output = torch.empty((tp_world_size * tensor_perm.shape[0], *tensor_perm.shape[1:]), + dtype=tensor.dtype, + device=tensor.device) + dist.all_gather_into_tensor(output, tensor_perm, group=mp_group) + return output.permute(inv_perm).contiguous() + + +def _gather_logical_tensor(tensor, + logical_shape, + partition_dim, + mp_group, + tp_world_size, + name=None, + subparam_sizes=None): + if mp_group is None or tp_world_size == 1: + return tensor.reshape(logical_shape) + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + if logical_shape[partition_dim] % tp_world_size != 0: + raise ValueError( + f"{layer_label} partition_dim size {logical_shape[partition_dim]} not divisible by tp_size {tp_world_size}." + ) + partitioned_shape = list(logical_shape) + partitioned_shape[partition_dim] = logical_shape[partition_dim] // tp_world_size + tensor_view = tensor.reshape(partitioned_shape) + + if subparam_sizes: + for size in subparam_sizes: + if size % tp_world_size != 0: + raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.") + partitioned_sizes = [size // tp_world_size for size in subparam_sizes] + sub_params = torch.split(tensor_view, partitioned_sizes, dim=partition_dim) + gathered_sub_params = [_all_gather_along_dim(sp, partition_dim, mp_group, tp_world_size) for sp in sub_params] + return torch.cat(gathered_sub_params, dim=partition_dim) + return _all_gather_along_dim(tensor_view, partition_dim, mp_group, tp_world_size) + + +class SubParamLinearLayer(TensorParallel_Layer): + """ + Column-parallel linear layer with sub-parameter support. + + Handles cases where weights contain multiple logical sub-parameters + that need to be partitioned separately (e.g., fused QKV, chunked MLP, GQA). + + The `shape` parameter controls how the weight is viewed and partitioned: + - (3, -1) with partition_dim=0: 3 equal sub-params, partition each at dim 0 + - ((q, k, v), -1) with partition_dim=0: 3 unequal sub-params (1-level nesting) + """ + + def __init__(self, module, mp_group, shape, partition_dim=0, **kwargs): + super(SubParamLinearLayer, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + self.shape = shape + self.partition_dim = partition_dim + + self._orig_weight_shape = tuple(module.weight.shape) + self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None + (self._logical_shape, self._output_shape, self._subparam_sizes, + self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape, + self.partition_dim, self.name) + if self.bias is not None and self.bias.numel() != _shape_prod(self._output_shape): + raise ValueError(f"AutoTP layer '{self.name}' bias size {self.bias.numel()} does not match output shape " + f"{self._output_shape}.") + + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + + def forward(self, input): + if getattr(self, 'mp_group', None) is not None: + input = ColumnParallel.apply(self.mp_group, input) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output = add_bias(output, self.bias) + return output + + @torch.no_grad() + def gather_params(self, params_list): + """Gather partitioned parameters back to full size.""" + for idx, param in enumerate(params_list): + if param is None: + continue + params_list[idx].data_partition = param.data + if idx == 0: + full_view = _gather_logical_tensor(param, + self._logical_shape, + self.partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_view.reshape(self._orig_weight_shape) + else: + if self._bias_partition_dim is None: + params_list[idx].data = param.data + else: + full_bias_view = _gather_logical_tensor(param, + self._output_shape, + self._bias_partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_bias_view.reshape(self._orig_bias_shape) + + @torch.no_grad() + def _tp_partition(self, params_list): + weight = params_list[0] + if weight is None: + return + + weight_view = weight.reshape(self._logical_shape) + partitioned_view = _partition_logical_tensor(weight_view, + self.partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach() + + if params_list[1] is not None: + if self._bias_partition_dim is None: + params_list[1].data = self.move(params_list[1]).detach() + else: + bias_view = params_list[1].reshape(self._output_shape) + bias_partitioned = _partition_logical_tensor(bias_view, + self._bias_partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[1].data = self.move(bias_partitioned.reshape(-1)).detach() + + +class SubParamLinearAllreduce(TensorParallel_Layer): + """ + Row-parallel linear layer with sub-parameter support (AllReduce after forward). + + Handles cases where weights contain multiple logical sub-parameters + that need to be partitioned separately. + """ + + def __init__(self, module, mp_group, shape, partition_dim=1, **kwargs): + super(SubParamLinearAllreduce, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + self.shape = shape + self.partition_dim = partition_dim + + self._orig_weight_shape = tuple(module.weight.shape) + self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None + (self._logical_shape, self._output_shape, self._subparam_sizes, + self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape, + self.partition_dim, self.name) + + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_requires_grad(self.bias) + + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + output = RowParallel.apply(self.mp_group, output, not self.is_training_mode()) + if self.bias is not None: + output = add_bias(output, self.bias) + return output + + @torch.no_grad() + def gather_params(self, params_list): + """Gather partitioned parameters back to full size.""" + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't gather bias for row parallel + return + params_list[idx].data_partition = param.data + full_view = _gather_logical_tensor(param, + self._logical_shape, + self.partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_view.reshape(self._orig_weight_shape) + + @torch.no_grad() + def _tp_partition(self, params_list): + weight = params_list[0] + if weight is None: + return + + weight_view = weight.reshape(self._logical_shape) + partitioned_view = _partition_logical_tensor(weight_view, + self.partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach() + + # Bias is not partitioned for row parallel (it's applied after all-reduce) + if params_list[1] is not None: + params_list[1].data = self.move(params_list[1]).detach() + + class RMSNormalize(nn.Module): def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 26752cfa4fec..263369fc0484 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -273,9 +273,20 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) + # Get the configurable partition_config if available + partition_config = None + if hasattr(config, 'get_partition_config_object'): + partition_config = config.get_partition_config_object() + # 1. Create AutoTP object - _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl, - config.keep_module_on_host) + _autotp = AutoTP(module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + config.keep_module_on_host, + partition_config=partition_config) # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index ded262edcf61..f1dbaae43ec9 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -4,7 +4,12 @@ # DeepSpeed Team from deepspeed import comm as dist -global num_kv_heads + +# Defaults for optional TP globals. These can be overridden by setters. +num_kv_heads = None +num_attention_heads = None +n_embd = None +tp_grain_size = 1 def set_num_kv_heads(num): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5f8585d65d76..410d1ad6c46c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -499,6 +499,7 @@ def _optimized_linear_offload_setup(self): def _configure_tensor_parallel(self, model, tp_config): self._configure_tensor_parallel_states(model) configure_tensor_parallel_runtime(tp_config) + self._apply_autotp_partitioning(model, tp_config) def _configure_tensor_parallel_states(self, model): """ @@ -564,6 +565,55 @@ def broadcast_and_check(args, bcast_rank, bcast_group): prepend=True, with_kwargs=True) + def _apply_autotp_partitioning(self, model, tp_config): + if getattr(model, "ds_autotp_parsed", False): + return + if get_accelerator().is_available() and self.local_rank >= 0: + get_accelerator().set_device(self.local_rank) + + tp_size = self.autotp_size() + if tp_config.tensor_parallel.tp_size not in (1, tp_size): + raise ValueError(f"tensor_parallel.tp.tp_size ({tp_config.tensor_parallel.tp_size}) " + f"does not match tensor_parallel.autotp_size ({tp_size}).") + tp_config.tensor_parallel.tp_size = tp_size + if tp_config.tensor_parallel.tp_group is None: + tp_config.tensor_parallel.tp_group = groups.get_tensor_model_parallel_group() + + from deepspeed.module_inject.auto_tp import AutoTP + + partition_config = None + if hasattr(tp_config, "get_partition_config_object"): + partition_config = tp_config.get_partition_config_object() + + if partition_config is not None: + autotp = AutoTP(module=model, + all_reduce_linears=(), + prefix="", + state_dict=None, + linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding), + orig_layer_impl=None, + keep_module_on_host=tp_config.keep_module_on_host, + partition_config=partition_config) + autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group) + autotp.update_linear_policies() + autotp._replace_module(model) + setattr(model, "ds_autotp_parsed", True) + return + + if tp_size <= 1: + setattr(model, "ds_autotp_parsed", True) + return + + model_config = getattr(model, "config", None) + from deepspeed.module_inject import replace_transformer_layer + + parser_dict = AutoTP.tp_parser(model) + for client_module, injection_policy in parser_dict: + tp_config.injection_policy_tuple = injection_policy + replace_transformer_layer(client_module, model, None, tp_config, model_config) + + setattr(model, "ds_autotp_parsed", True) + def __del__(self): try: self.destroy() diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py index 957984e9f8b3..326277f6eb5c 100644 --- a/deepspeed/runtime/tensor_parallel/config.py +++ b/deepspeed/runtime/tensor_parallel/config.py @@ -7,7 +7,7 @@ from deepspeed.runtime.config_utils import DeepSpeedConfigModel import torch from pydantic import Field -from typing import Optional +from typing import Optional, Dict, Any class AUTOTP_MODE(Enum): @@ -57,6 +57,41 @@ class TPTrainingConfig(DeepSpeedConfigModel): """ injection_policy_tuple: Optional[tuple] = None + + # New configurable AutoTP settings + partition_config: Optional[Dict[str, Any]] = None + """ + Configuration for the new configurable AutoTP API. + Allows users to specify custom layer partitioning rules via TPLayerSpec. + + Example: + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + """ + + preset_model: Optional[str] = None + """ + Use a built-in preset for common model architectures. + Available presets: "llama", "bloom", "chatglm", "mixtral", "deepseek_v2", "qwen2", "phi3" + """ + #The following parameters are required by autoTP parser. ######################################## keep_module_on_host: bool = False @@ -74,8 +109,35 @@ class TPTrainingConfig(DeepSpeedConfigModel): linear layers as a tuple: `(attention_output projection, transformer output projection)` """ + ######################################## + def get_partition_config_object(self): + """ + Get the AutoTPConfig object from the configuration. + Returns None if no custom config is specified. + """ + from deepspeed.module_inject.autotp_config import AutoTPConfig, AutoTPPresets, merge_autotp_configs + + config = None + + # First check for preset + if self.preset_model: + config = AutoTPPresets.get_preset(self.preset_model) + + # Then check for custom config + if self.partition_config: + custom_config = AutoTPConfig.from_dict(self.partition_config) + if config and custom_config.use_default_specs: + config = merge_autotp_configs(config, custom_config) + else: + config = custom_config + + if config: + config.tp_size = self.autotp_size + + return config + def get_tensor_parallel_config(ds_config): diff --git a/deepspeed/runtime/tensor_parallel/init_utils.py b/deepspeed/runtime/tensor_parallel/init_utils.py new file mode 100644 index 000000000000..95dacb984cbe --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/init_utils.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import base64 +import os +from typing import Optional, Union + +import hjson +import torch + +from deepspeed.runtime.config_utils import dict_raise_error_on_duplicate_keys + +_TP_MODEL_INIT_ARGS = None + + +def load_ds_config(config: Union[str, dict]) -> dict: + if isinstance(config, dict): + return config + if isinstance(config, str): + if os.path.exists(config): + return hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys) + try: + config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') + return hjson.loads(config_decoded) + except (UnicodeDecodeError, AttributeError, ValueError) as exc: + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " + f"Received: {config}") from exc + raise ValueError(f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " + f"Received: {config}") + + +def record_tp_model_init_args(tp_size, dtype, tp_group, dist_module): + global _TP_MODEL_INIT_ARGS + new_args = { + "tp_size": tp_size, + "dtype": dtype, + "tp_group": tp_group, + } + + if _TP_MODEL_INIT_ARGS is None: + _TP_MODEL_INIT_ARGS = new_args + return + + if _TP_MODEL_INIT_ARGS["tp_size"] != tp_size or _TP_MODEL_INIT_ARGS["dtype"] != dtype: + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + existing_group = _TP_MODEL_INIT_ARGS.get("tp_group") + if existing_group is None and tp_group is None: + return + if (existing_group is None) != (tp_group is None): + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + existing_group_size = tp_group_world_size(existing_group, dist_module) + new_group_size = tp_group_world_size(tp_group, dist_module) + if existing_group_size != new_group_size: + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + +def tp_group_world_size(tp_group, dist_module): + if tp_group is None or dist_module is None: + return None + return dist_module.get_world_size(group=tp_group) + + +def infer_config_dtype(config_dict: dict) -> Optional[torch.dtype]: + bf16_config = config_dict.get("bf16", {}) + if isinstance(bf16_config, dict) and bf16_config.get("enabled", False): + return torch.bfloat16 + fp16_config = config_dict.get("fp16", {}) + if isinstance(fp16_config, dict) and fp16_config.get("enabled", False): + return torch.float16 + return None + + +def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_module): + if _TP_MODEL_INIT_ARGS is None: + return + + tp_size = _TP_MODEL_INIT_ARGS["tp_size"] + dtype = _TP_MODEL_INIT_ARGS["dtype"] + tp_group = _TP_MODEL_INIT_ARGS["tp_group"] + + if tp_group is not None and mpu is not None: + raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.") + if tp_group is None and mpu is None and mesh_param is None: + raise ValueError("tp_model_init did not provide tp_group; deepspeed.initialize requires mpu or mesh_param.") + + tp_section = config_dict.get("tensor_parallel") + if tp_section is None: + tp_section = {} + config_dict["tensor_parallel"] = tp_section + + config_autotp_size = tp_section.get("autotp_size") + if config_autotp_size is not None and config_autotp_size != tp_size: + raise ValueError( + f"Conflicting tensor_parallel.autotp_size in config ({config_autotp_size}) and tp_model_init ({tp_size}).") + + if config_autotp_size is None: + tp_section["autotp_size"] = tp_size + + tp_config = tp_section.get("tp") or {} + if not isinstance(tp_config, dict): + raise ValueError("tensor_parallel.tp must be a dict when provided.") + + config_tp_size = tp_config.get("tp_size") + if config_tp_size is not None and config_tp_size != tp_size: + raise ValueError( + f"Conflicting tensor_parallel.tp.tp_size in config ({config_tp_size}) and tp_model_init ({tp_size}).") + if config_tp_size is None: + tp_config["tp_size"] = tp_size + + if tp_group is not None: + config_tp_group = tp_config.get("tp_group") + if config_tp_group is not None and config_tp_group is not tp_group: + raise ValueError("Conflicting tensor_parallel.tp.tp_group in config and tp_model_init.") + tp_config["tp_group"] = tp_group + + tp_group_size = tp_group_world_size(tp_group, dist_module) + if tp_group_size is not None and tp_group_size != tp_size: + raise ValueError(f"tp_model_init tp_size ({tp_size}) does not match tp_group size ({tp_group_size}).") + + tp_section["tp"] = tp_config + + config_dtype = infer_config_dtype(config_dict) + if config_dtype is not None and config_dtype != dtype: + raise ValueError(f"Conflicting dtype: config uses {config_dtype} but tp_model_init requested {dtype}.") + + tp_dtype = tp_section.get("dtype") + if tp_dtype is not None: + if isinstance(tp_dtype, str): + tp_dtype_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + tp_dtype_value = tp_dtype_map.get(tp_dtype.lower()) + else: + tp_dtype_value = tp_dtype + if tp_dtype_value is not None and tp_dtype_value != dtype: + raise ValueError(f"Conflicting tensor_parallel.dtype in config ({tp_dtype}) and tp_model_init ({dtype}).") diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 5ac33f3e4447..d5344d3b2320 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -730,6 +730,96 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s | -------------------------------------------------------------------------------------------------------------- | ------- | | Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. | `true` | +### Tensor Parallel (AutoTP) +Configure AutoTP tensor parallelism for training via the DeepSpeed config and hybrid TP + ZeRO. AutoTP supports ZeRO stages 0, 1, and 2 (stage 3 is not supported). `deepspeed.tp_model_init()` remains supported for backward compatibility but is not required when `tensor_parallel` is set in the config. +```json + "tensor_parallel": { + "autotp_size": 4, + "preset_model": "llama", + "tp_overlap_comm": false, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + } + ] + } + } +``` +**tensor_parallel**: [dictionary] + +| Description | Default | +| ------------------------------------------------------------------------------------------ | ------- | +| Enable AutoTP tensor parallelism and configure preset or custom partitioning rules. | `{}` | + +***autotp_size***: [integer] + +| Description | Default | +| --------------------------------------------------------------------------- | ------- | +| Tensor-parallel degree. Set to `0` to disable AutoTP. | `0` | + +***preset_model***: [string] + +| Description | Default | +| ----------------------------------------------------------------------------------------------------- | ------- | +| Built-in model presets: `llama`, `bloom`, `chatglm`, `mixtral`, `deepseek_v2`, `qwen2`, `phi3`. | `null` | + +***tp_overlap_comm***: [boolean] + +| Description | Default | +| -------------------------------------------------------------------------------------------------------- | ------- | +| Overlap tensor-parallel allreduce communication with computation (training only). | `false` | + +***partition_config***: [dictionary] + +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Custom AutoTP layer partitioning rules. Use with or without `preset_model` to customize sharding patterns. | `null` | + +***use_default_specs***: [boolean] + +| Description | Default | +| -------------------------------------------------------------------------------------------------------------------- | ------- | +| Merge custom `layer_specs` with preset defaults when `preset_model` is set; otherwise use only custom specs. | `true` | + +***layer_specs***: [list] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Ordered list of pattern rules that define how to partition matching parameters. | `[]` | + +***patterns***: [list of strings] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Regex patterns to match parameter names for this partition rule. | `[]` | + +***partition_type***: [string] + +| Description | Default | +| ---------------------------------------------------------------------------- | ------- | +| Partition type for matching parameters: `row`, `column`, or `skip`. | `column` | + +***shape***: [list] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Optional sub-parameter shape for fused weights before TP partitioning (e.g., `[2, -1]`). | `null` | + +***partition_dim***: [integer] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Dimension to split when `shape` is provided (e.g., `0` for fused QKV or gate/up). | `null` | + +***model_types***: [list of strings] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Optional model type filters (from `model.config.model_type`) for shared configs. | `null` | + ***ignore_unused_parameters***: [boolean] | Description | Default | diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md index f18b823e2490..1433df080ee3 100755 --- a/docs/_tutorials/automatic-tensor-parallelism.md +++ b/docs/_tutorials/automatic-tensor-parallelism.md @@ -3,6 +3,8 @@ title: "Automatic Tensor Parallelism for HuggingFace Models" tags: inference --- +> **Note:** This tutorial covers AutoTP for **inference**. For **training** with tensor parallelism and ZeRO optimization, see [AutoTP Training API](autotp-training). + # Contents * [Introduction](#introduction) * [Example Script](#example-script) diff --git a/docs/_tutorials/autotp-training.md b/docs/_tutorials/autotp-training.md new file mode 100644 index 000000000000..9eb51d3e0900 --- /dev/null +++ b/docs/_tutorials/autotp-training.md @@ -0,0 +1,188 @@ +--- +title: "AutoTP Training API" +tags: training tensor-parallelism +--- + +# AutoTP Training API + +This tutorial covers the **AutoTP Training API** for combining tensor parallelism with ZeRO optimization during training. For inference-only tensor parallelism, see [Automatic Tensor Parallelism for HuggingFace Models](automatic-tensor-parallelism). + +## Contents +- [Introduction](#introduction) +- [Quick Start](#quick-start) +- [Custom Layer Specifications](#custom-layer-specifications) +- [Limitations](#limitations) + +## Introduction + +The AutoTP Training API enables hybrid parallelism by combining: +- **Tensor Parallelism (TP)**: Split model weights across GPUs within a node +- **Data Parallelism (DP)**: Replicate model across GPU groups +- **ZeRO Optimization**: Memory-efficient optimizer states (Stage 0, 1, or 2) + +Tensor parallelism (TP) splits the computations and parameters of large layers +across multiple GPUs so each rank holds only a shard of the weight matrix. This +is an efficient way to train large-scale transformer models by reducing per-GPU +memory pressure while keeping the layer math distributed across the TP group. + + +## Quick Start + +### Basic Usage + +AutoTP training can be enabled entirely through the DeepSpeed config. When +`tensor_parallel` is set in the config, `deepspeed.initialize(...)` applies +AutoTP sharding during engine initialization, so the training loop itself does +not change. + +```python +import torch +import deepspeed + +# 1. Create your model +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") + +# 2. Define the DeepSpeed config with tensor_parallel settings +ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": {"stage": 2}, + "bf16": {"enabled": True}, + "tensor_parallel": {"autotp_size": 4}, +} + +# 3. Initialize DeepSpeed with AutoTP + ZeRO +engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + mpu=mpu # Model parallel unit (optional if you provide tp_group elsewhere) +) + +# 4. Train as usual +for batch in dataloader: + outputs = engine(input_ids=batch["input_ids"], labels=batch["labels"]) + engine.backward(outputs.loss) + engine.step() +``` + +Compatibility note: For backward compatibility, you can still call +`set_autotp_mode(training=True)` and `deepspeed.tp_model_init(...)`, but they +are not required when the DeepSpeed config provides the necessary +`tensor_parallel` settings. + +### Preset-based Sharding + +If your model matches a built-in preset, set `tensor_parallel.preset_model` in the DeepSpeed config: + +```json +{ + "train_batch_size": 8, + "train_micro_batch_size_per_gpu": 1, + "bf16": { "enabled": true }, + "zero_optimization": { "stage": 2 }, + "tensor_parallel": { + "autotp_size": 4, + "preset_model": "llama" + } +} +``` + +For the list of available presets, see [supported models](../code-docs/training#autotp-supported-models). + + + +## Custom Patterns + +If you are training a custom model, define regex-based patterns and partition rules in `tensor_parallel.partition_config`: + +```json +{ + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +## Custom Layer Specifications + +For models not covered by presets, define custom layer specs: + +```json +{ + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +### Fused Layers with Unequal Sub-parameters (GQA) + +For Grouped Query Attention with different Q/K/V sizes: + +```json +{ + "tensor_parallel": { + "partition_config": { + "layer_specs": [ + { + "patterns": [".*\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [[q_size, kv_size, kv_size], -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +## Limitations + +1. **ZeRO Stage 3 not supported**: AutoTP currently only works with ZeRO stages 0, 1, and 2. + +2. **TP size must divide model dimensions**: The tensor parallel size must evenly divide the attention head count and hidden dimensions. + + +## See Also + +- [Automatic Tensor Parallelism for Inference](automatic-tensor-parallelism) +- [ZeRO Optimization](zero) +- [DeepSpeed Configuration](config-json) diff --git a/docs/code-docs/source/_static/autotp-subparams-gate-up.png b/docs/code-docs/source/_static/autotp-subparams-gate-up.png new file mode 100644 index 000000000000..116cedf92667 Binary files /dev/null and b/docs/code-docs/source/_static/autotp-subparams-gate-up.png differ diff --git a/docs/code-docs/source/_static/autotp-subparams-gqa.png b/docs/code-docs/source/_static/autotp-subparams-gqa.png new file mode 100644 index 000000000000..15a1045a1dbf Binary files /dev/null and b/docs/code-docs/source/_static/autotp-subparams-gqa.png differ diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 1974db70c226..9b01b2f85315 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -257,3 +257,222 @@ Besides the use of multiple DeepSpeedEngines, the above differs from typical usa You can call ``loss.backward()`` once for the shared loss. **Note:** Previously, you had to call ``_backward_epilogue`` on each model engine after ``loss.backward()``. However, starting from v0.18.3, DeepSpeed automatically handles this internally, so you no longer need to call ``_backward_epilogue`` manually. + + +Automatic Tensor Parallel Training +---------------------------------- +DeepSpeed supports **Automatic Tensor Parallel (AutoTP) training** for sharding +model weights across GPUs while remaining compatible with ZeRO and standard +training workflows. This training API is different from the inference-only +tensor parallel API exposed by ``deepspeed.init_inference``. + +Tensor parallelism (TP) splits the computations and parameters of large layers +across multiple GPUs so each rank holds only a shard of the weight matrix. This +is an efficient way to train large-scale transformer models by reducing per-GPU +memory pressure while keeping the layer math distributed across the TP group. + +AutoTP training is enabled by setting ``tensor_parallel`` in the DeepSpeed +config and passing it to ``deepspeed.initialize``. DeepSpeed applies AutoTP +sharding during engine initialization; calling ``deepspeed.tp_model_init``, which we previously used to initialize AutoTP, is now optional. +See :ref:`autotp-training-init-details` for more details. + +.. code-block:: python + + import deepspeed + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": {"stage": 2}, + "tensor_parallel": {"autotp_size": 4}, + } + + engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + mpu=mpu, # optional: TP/DP process groups + ) + +.. note:: + AutoTP training supports ZeRO stages 0, 1, and 2. ZeRO Stage 3 is not supported. + +.. _autotp-training-init-details: + +Initialization behavior +~~~~~~~~~~~~~~~~~~~~~~~ + +AutoTP previously required calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` before ``deepspeed.initialize``. Now we can include all the necessary configurations in the DeepSpeed config. +We still support the traditional initialization path for backward compatibility. +When you use both (i.e. calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` and passing the config to ``deepspeed.initialize``), we will merge the settings at initialization. When we have conflicting settings, we will error out. + +Parameter partitioning +~~~~~~~~~~~~~~~~~~~~~~ +TP sharding needs to know which parameter tensors should be partitioned and +along which dimensions. AutoTP provides three ways to balance ready-to-use +defaults with customizability: + +* **Heuristics**: automatic sharding based on parameter names and model rules. +* **Preset**: choose a built-in model family via ``preset_model``. +* **Custom specs**: define regex patterns and partition rules via ``partition_config``. + +Heuristic rules +^^^^^^^^^^^^^^^ +Heuristics use parameter names and model-specific rules to decide how to shard +layers. If you are training a supported model (see +:ref:`autotp-supported-models`), the heuristic rules automatically shard the +model, so you only need to add ``autotp_size``. + +.. code-block:: json + + { + ... + "tensor_parallel": { + "autotp_size": 4 + }, + "zero_optimization": { + ... + }, + ... + } + +Preset-based partitioning +^^^^^^^^^^^^^^^^^^^^^^^^^ +You can explicitly specify the model family with ``preset_model``: + +.. code-block:: json + + { + "tensor_parallel": { + "autotp_size": 4, + "preset_model": "llama" + } + } + +See :ref:`autotp-supported-models` for the supported preset names and the +implementation in `AutoTPPresets `_. +If you add a new model family, you can easily add a new preset by defining +patterns like the existing presets, and we welcome PRs for those additions. + +Custom layer specs +^^^^^^^^^^^^^^^^^^ +If you are training a custom model, you can use ``partition_config`` to specify +custom regex-based patterns and partition settings. + +.. code-block:: json + + { + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + } + } + +You can also set ``use_default_specs`` to ``true`` to merge your custom +patterns on top of the preset (when ``preset_model`` is provided). + +For fused or packed weights (for example QKV or gate/up projections), the +``shape`` and ``partition_dim`` options control sub-parameter partitioning. +Sub-parameter partitioning lets AutoTP split a single weight tensor into +logical chunks before applying tensor-parallel sharding. For example, the +``gate_up_proj`` weight can be viewed as two packed matrices (gate and up) by +setting ``shape`` to ``[2, -1]`` and ``partition_dim`` to ``0``; AutoTP then +partitions each chunk consistently across tensor-parallel ranks. + +.. image:: /_static/autotp-subparams-gate-up.png + :alt: AutoTP sub-parameter partitioning + +Another example is GQA-style fused QKV weights. The tensor can contain unequal +Q/K/V segments stacked along the output dimension. For example, set ``shape`` +to the explicit sizes (for example ``[(q_size, kv_size, kv_size), -1]``) and +``partition_dim`` to ``0`` so AutoTP splits the Q, K, and V regions first, then +shards each region across tensor-parallel ranks. + +.. code-block:: json + + { + "patterns": [".*\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [[q_size, kv_size, kv_size], -1], + "partition_dim": 0 + } + +.. image:: /_static/autotp-subparams-gqa.png + :alt: AutoTP sub-parameter partitioning + + +Model-type filtering for shared configs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Use ``model_types`` when you want a single config to work across multiple model +families but apply different specs. This is useful in shared training scripts +or when patterns overlap across architectures. + +.. code-block:: json + + { + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "layer_specs": [ + { + "patterns": [".*\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [[q_size, kv_size, kv_size], -1], + "partition_dim": 0, + "model_types": ["llama"] + }, + { + "patterns": [".*\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [3, -1], + "partition_dim": 0, + "model_types": ["qwen2"] + } + ] + } + } + } + + +.. _autotp-supported-models: + +Supported models +~~~~~~~~~~~~~~~~ +The following model families are supported by built-in AutoTP presets: + +- ``llama`` +- ``bloom`` +- ``chatglm`` +- ``mixtral`` +- ``deepseek_v2`` +- ``qwen2`` +- ``phi3`` + +Preset definitions live in `AutoTPPresets `_. +If you add a new model family, you can easily add a new preset by defining +patterns like the existing presets, and we welcome PRs for those additions. + +These strings are the values accepted by ``preset_model`` and are matched +against the model type in ``model.config.model_type`` (case-insensitive). When +``preset_model`` is not set, AutoTP uses the legacy automatic sharding rules +unless you provide a custom ``partition_config``. +These presets are also useful when you want to extend the default patterns: +set ``use_default_specs`` to ``true`` in ``partition_config`` to merge your custom +specs on top of the selected preset. diff --git a/tests/unit/model_parallelism/test_autotp_custom_patterns.py b/tests/unit/model_parallelism/test_autotp_custom_patterns.py new file mode 100644 index 000000000000..d19d8b595eeb --- /dev/null +++ b/tests/unit/model_parallelism/test_autotp_custom_patterns.py @@ -0,0 +1,302 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed.comm as dist +import deepspeed +from copy import deepcopy +from torch import nn + +from unit.common import DistributedTest, preferred_dtype +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import groups +from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer) +from deepspeed.module_inject.autotp_config import AutoTPConfig +from deepspeed.module_inject.auto_tp import AutoTP + + +def skip_on_device(): + if get_accelerator().device_name() == 'xpu': + pytest.skip("XPU requires a higher version for test") + + +class SequentialLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim, nlayers=1): + super(SequentialLinearModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) + + def forward(self, x): + for layer in self.linears: + x = layer(x) + return x + + +def init_tp_engine(tp_size, partition_config=None): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size, + }, + "zero_optimization": { + "stage": 0, + } + } + if partition_config is not None: + config_dict["tensor_parallel"]["partition_config"] = partition_config + else: + config_dict["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=8) + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + +def apply_autotp_with_partition_config(model, tp_size, partition_config): + groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size) + autotp_config = AutoTPConfig.from_dict(partition_config) + autotp = AutoTP(module=model, + all_reduce_linears=[], + prefix="", + state_dict=None, + linear_layer_setting=None, + orig_layer_impl=None, + keep_module_on_host=False, + partition_config=autotp_config) + autotp.set_tensor_parallel_config(tp_size, groups.get_tensor_model_parallel_group()) + autotp.update_linear_policies() + autotp._replace_module(model) + return model + + +def gather_subparam_output(output, subparam_sizes, mp_group): + tp_world_size = dist.get_world_size(group=mp_group) + local_sizes = [size // tp_world_size for size in subparam_sizes] + output_chunks = torch.split(output, local_sizes, dim=-1) + gathered_chunks = [] + for chunk in output_chunks: + chunk = chunk.contiguous() + gathered = [torch.empty_like(chunk) for _ in range(tp_world_size)] + dist.all_gather(gathered, chunk, group=mp_group) + gathered_chunks.append(torch.cat(gathered, dim=-1)) + return torch.cat(gathered_chunks, dim=-1) + + +class TestAutoTPCustomPatterns(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def test_custom_pattern_replacement(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*linears\\.1\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*linears\\.2\\.weight$"], + "partition_type": "skip", + }, + ], + } + model = SequentialLinearModel(hidden_dim=16, nlayers=3) + model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config) + + assert isinstance(model.linears[0], LinearAllreduce) + assert isinstance(model.linears[1], LinearLayer) + assert isinstance(model.linears[2], nn.Linear) + + def test_custom_patterns_applied_via_config(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*linears\\.1\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*linears\\.2\\.weight$"], + "partition_type": "skip", + }, + ], + } + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": 2, + "partition_config": partition_config, + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=16, nlayers=3) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert isinstance(engine.module.linears[0], LinearAllreduce) + assert isinstance(engine.module.linears[1], LinearLayer) + assert isinstance(engine.module.linears[2], nn.Linear) + + def test_first_match_precedence(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "skip", + }, + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "column", + }, + ], + } + model = SequentialLinearModel(hidden_dim=16, nlayers=1) + model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config) + + assert isinstance(model.linears[0], nn.Linear) + + +def test_invalid_custom_shape_rejected(): + bad_config = { + "layer_specs": [{ + "patterns": [".*"], + "partition_type": "column", + "shape": [2, [1, 1]], + }] + } + with pytest.raises(ValueError, match="nested tuple only allowed at partition_dim"): + AutoTPConfig.from_dict(bad_config) + + +class TestAutoTPFusedWeights(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def test_gate_up_fused_weight_partition(self): + skip_on_device() + init_tp_engine(tp_size=2) + + hidden_dim = 8 + torch.manual_seed(42) + linear = nn.Linear(hidden_dim, + hidden_dim * 2, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + full_weight = deepcopy(linear.weight.data) + full_bias = deepcopy(linear.bias.data) + + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=(2, -1), + partition_dim=0, + name="mlp.gate_up_proj") + assert layer._subparam_sizes == (hidden_dim, hidden_dim) + assert layer.weight.shape == (hidden_dim, hidden_dim) + + layer.gather_params([layer.weight, layer.bias]) + torch.testing.assert_close(layer.weight.data, full_weight) + torch.testing.assert_close(layer.bias.data, full_bias) + + def test_gqa_uneven_qkv_fused_weight_partition(self): + skip_on_device() + init_tp_engine(tp_size=2) + + hidden_dim = 8 + q_size, k_size, v_size = 8, 4, 4 + torch.manual_seed(123) + linear = nn.Linear(hidden_dim, + q_size + k_size + v_size, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + full_weight = deepcopy(linear.weight.data) + full_bias = deepcopy(linear.bias.data) + + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=((q_size, k_size, v_size), -1), + partition_dim=0, + name="self_attn.qkv_proj") + assert layer._subparam_sizes == (q_size, k_size, v_size) + assert layer.weight.shape == ((q_size + k_size + v_size) // 2, hidden_dim) + + layer.gather_params([layer.weight, layer.bias]) + torch.testing.assert_close(layer.weight.data, full_weight) + torch.testing.assert_close(layer.bias.data, full_bias) + + def test_gqa_uneven_qkv_fused_forward(self): + skip_on_device() + groups._init_tp_mesh_device(tensor_model_parallel_size=2) + + hidden_dim = 8 + q_size, k_size, v_size = 8, 4, 4 + torch.manual_seed(321) + linear = nn.Linear(hidden_dim, + q_size + k_size + v_size, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=((q_size, k_size, v_size), -1), + partition_dim=0, + name="self_attn.qkv_proj") + + torch.manual_seed(42) + inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device()) + full_output = linear(inputs) + tp_output = layer(inputs) + + gathered_output = gather_subparam_output(tp_output, (q_size, k_size, v_size), + groups.get_tensor_model_parallel_group()) + atol = 1e-3 + rtol = 2e-2 + if preferred_dtype() is torch.float32: + atol = 1e-5 + rtol = 1e-5 + torch.testing.assert_close(gathered_output, full_output, atol=atol, rtol=rtol) diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index 9d2b04211520..baaca247c229 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -16,7 +16,7 @@ from deepspeed.utils import groups from contextlib import contextmanager from torch import nn -from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode +from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode, is_autotp_training_mode from unit.checkpoint.common import compare_lr_scheduler_states, compare_optimizer_states import os from deepspeed.runtime.utils import is_model_parallel_parameter @@ -27,6 +27,39 @@ def skip_on_device(): pytest.skip("XPU requires a higher version for test") +def reset_tp_model_init_state(): + deepspeed._TP_MODEL_INIT_ARGS = None + set_autotp_mode(training=False) + + +class DummyMPU: + + def __init__(self, tp_world_size=1): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.tp_world_size = tp_world_size + self.dp_group = dist.get_world_group() + self.tp_group = dist.get_world_group() + + def get_model_parallel_rank(self): + return self.rank % self.tp_world_size + + def get_model_parallel_world_size(self): + return self.tp_world_size + + def get_data_parallel_rank(self): + return self.rank // self.tp_world_size + + def get_data_parallel_world_size(self): + return self.world_size // self.tp_world_size + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + class SequentialLinearModel(torch.nn.Module): def __init__(self, hidden_dim, empty_grad=False, nlayers=1): @@ -69,13 +102,19 @@ class TestTpParallelStates(DistributedTest): def test(self, tp_size: int): skip_on_device() - set_autotp_mode(training=True) dp_size = 4 / tp_size hidden_dim = 128 config_dict = { "train_micro_batch_size_per_gpu": 1, "tensor_parallel": { - "autotp_size": tp_size + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } }, "zero_optimization": { "stage": 0 @@ -87,6 +126,104 @@ def test(self, tp_size: int): assert groups.get_data_parallel_world_size() == dp_size +class TestTpModelInitCompatibility(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_tp_model_init_merges_config(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 0, + } + } + engine, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=DummyMPU()) + assert engine.autotp_size() == 1 + assert is_autotp_training_mode() + + def test_tp_model_init_config_autotp_size_mismatch(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2, + }, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="tensor_parallel.autotp_size"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + def test_tp_model_init_requires_mpu_or_mesh_param(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="requires mpu or mesh_param"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + def test_tp_model_init_tp_group_rejects_mpu(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + tp_group = dist.new_group(ranks=[0]) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype(), tp_group=tp_group) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="tp_model_init provided tp_group"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + def test_tp_model_init_dtype_mismatch(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=torch.float16) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "bf16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="Conflicting dtype"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + @pytest.mark.sequential + @pytest.mark.parametrize("tp_size", [2, 4]) + @pytest.mark.parametrize("tp_overlap_comm", [True, False]) + def test_tp_model_init_row_parallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False, use_tp_model_init=True) + + @pytest.mark.sequential + @pytest.mark.parametrize("tp_size", [2, 4]) + @pytest.mark.parametrize("tp_overlap_comm", [True, False]) + def test_tp_model_init_column_parallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True, use_tp_model_init=True) + + @pytest.mark.parametrize("tp_size", [2, 4]) class TestTpDataloaderCorrectness(DistributedTest): world_size = 4 @@ -95,7 +232,6 @@ class TestTpDataloaderCorrectness(DistributedTest): def test(self, tp_size: int): skip_on_device() hidden_dim = 128 - set_autotp_mode(training=True) config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -106,7 +242,14 @@ def test(self, tp_size: int): } }, "tensor_parallel": { - "autotp_size": tp_size + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } }, "zero_optimization": { "stage": 0, @@ -164,137 +307,116 @@ def process_linear_layer(hidden_dim, input): return torch_linear, torch_out -@pytest.mark.sequential -@pytest.mark.parametrize("tp_size", [2, 4]) -@pytest.mark.parametrize("tp_overlap_comm", [True, False]) -class TestTpLayerFwdBwd(DistributedTest): - world_size = 4 - reuse_dist_env = False - - def testRowParallel(self, tp_size: int, tp_overlap_comm: bool): - skip_on_device() - hidden_dim = 128 - batch_size_per_device = 1 - set_autotp_mode(training=True) - config_dict = { - "train_micro_batch_size_per_gpu": 1, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-6 - } - }, - "tensor_parallel": { - "autotp_size": tp_size, - "tp_overlap_comm": tp_overlap_comm - }, - "zero_optimization": { - "stage": 0, +def run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel, use_tp_model_init=False): + skip_on_device() + hidden_dim = 128 + batch_size_per_device = 1 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 } + }, + "tensor_parallel": { + "autotp_size": tp_size, + "tp_overlap_comm": tp_overlap_comm + }, + "zero_optimization": { + "stage": 0, } - if preferred_dtype() is torch.float16: - config_dict["fp16"] = {"enabled": True} - elif preferred_dtype() is torch.bfloat16: - config_dict["bf16"] = {"enabled": True} - model = SequentialLinearModel(hidden_dim=hidden_dim) + } + partition_type = "column" if column_parallel else "row" + config_dict["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": partition_type, + }], + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=hidden_dim) + if use_tp_model_init: + reset_tp_model_init_state() + deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype()) + mpu = DummyMPU(tp_world_size=tp_size) + model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=mpu) + else: model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) - input = torch.randn(batch_size_per_device, - hidden_dim, - dtype=preferred_dtype(), - requires_grad=True, - device=get_accelerator().current_device()) - dist.broadcast(input, - groups.get_tensor_model_parallel_src_rank(), - group=groups.get_tensor_model_parallel_group()) - - torch_linear, torch_out = process_linear_layer(hidden_dim, input) - linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) - - input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] - out = linear(input_.to(get_accelerator().current_device())) + input = torch.randn(batch_size_per_device, + hidden_dim, + dtype=preferred_dtype(), + requires_grad=True, + device=get_accelerator().current_device()) + dist.broadcast(input, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + + # Note: correctness checks below use standalone TP wrappers and do not + # rely on the model's AutoTP-partitioned parameters. + torch_linear, torch_out = process_linear_layer(hidden_dim, input) + if column_parallel: + linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + out = linear(input.to(get_accelerator().current_device())) loss = out.sum() loss.backward() - torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()] - torch_bias_grad = torch_linear.bias.grad - # Use assert_close with rtol for proper floating-point comparisons + cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] + torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] + torch.testing.assert_close(linear.bias.grad, torch_bias_grad.to(get_accelerator().current_device()), atol=1e-3, rtol=1e-3) - # The gradient of the weight is not the same as the torch_linear.weight.grad torch.testing.assert_close(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3, rtol=1e-3) - torch.testing.assert_close(out, torch_out.to(get_accelerator().current_device()), atol=1e-2, rtol=1e-2) - - def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool): - skip_on_device() - hidden_dim = 128 - batch_size_per_device = 1 - set_autotp_mode(training=True) - config_dict = { - "train_micro_batch_size_per_gpu": 1, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-6 - } - }, - "tensor_parallel": { - "autotp_size": tp_size, - "tp_overlap_comm": tp_overlap_comm - }, - "zero_optimization": { - "stage": 0, - } - } - if preferred_dtype() is torch.float16: - config_dict["fp16"] = {"enabled": True} - elif preferred_dtype() is torch.bfloat16: - config_dict["bf16"] = {"enabled": True} - - model = SequentialLinearModel(hidden_dim=hidden_dim) - model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) - input = torch.randn(batch_size_per_device, - hidden_dim, - dtype=preferred_dtype(), - requires_grad=True, - device=get_accelerator().current_device()) - dist.broadcast(input, - groups.get_tensor_model_parallel_src_rank(), - group=groups.get_tensor_model_parallel_group()) - - torch_linear, torch_out = process_linear_layer(hidden_dim, input) - - linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) - - out = linear(input.to(get_accelerator().current_device())) + torch.testing.assert_close(cur_device_out.to(get_accelerator().current_device()).contiguous(), + out.contiguous(), + atol=1e-2, + rtol=1e-2) + else: + linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + out = linear(input_.to(get_accelerator().current_device())) loss = out.sum() loss.backward() - cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] - torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] - - torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] - # Use assert_close with rtol for proper floating-point comparisons + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()] + torch_bias_grad = torch_linear.bias.grad torch.testing.assert_close(linear.bias.grad, torch_bias_grad.to(get_accelerator().current_device()), atol=1e-3, rtol=1e-3) - torch.testing.assert_close(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3, rtol=1e-3) - torch.testing.assert_close(cur_device_out.to(get_accelerator().current_device()).contiguous(), - out.contiguous(), - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close(out, torch_out.to(get_accelerator().current_device()), atol=1e-2, rtol=1e-2) + + +@pytest.mark.sequential +@pytest.mark.parametrize("tp_size", [2, 4]) +@pytest.mark.parametrize("tp_overlap_comm", [True, False]) +class TestTpLayerFwdBwd(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def testRowParallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False) + + def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True) # @pytest.mark.sequential @@ -307,7 +429,6 @@ def test(self, layer_type): skip_on_device() tp_size = 4 hidden_dim = 128 - set_autotp_mode(training=True) config_dict = { "train_micro_batch_size_per_gpu": 1, "optimizer": { @@ -317,7 +438,14 @@ def test(self, layer_type): } }, "tensor_parallel": { - "autotp_size": tp_size + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } }, "zero_optimization": { "stage": 0, @@ -382,6 +510,15 @@ def test(self, layer_type): def dummy_init_engine(config): # This is a dummy initialization function for the DeepSpeed engine. # We only need to use the config to initialize the distributed settings for the test. + # Add default partition_config for simple test models if not provided + if "tensor_parallel" in config and "partition_config" not in config["tensor_parallel"]: + config["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } model = SequentialLinearModel(hidden_dim=8) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) @@ -412,7 +549,6 @@ class TestSave(DistributedTest): def test_save_original_weight(self, tp_size: int, zero_stage: int): skip_on_device() hidden_dim = 64 - set_autotp_mode(training=True) config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -471,7 +607,6 @@ def compare_state_dicts(state_dict1, state_dict2): def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): skip_on_device() hidden_dim = 64 - set_autotp_mode(training=True) config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -544,7 +679,6 @@ class TestTpGradNorm(DistributedTest): def test(self, tp_size: int, zero_stage: int): skip_on_device() hidden_dim = 64 - set_autotp_mode(training=True) config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1,