Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def _replace_with_config(self, child, name):
# No matching spec found
if self.partition_config.strict_mode:
raise ValueError(f"No matching spec for {param_name}")
# Default: column parallel for Linear layers
spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)
# With partition_config, rely only on explicit specs and skip unmatched layers.
return child

setattr(child, "replaced", True)

Expand All @@ -439,6 +439,8 @@ def _replace_with_config(self, child, name):

def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create row-parallel layer (AllReduce after forward)."""
if self.conv_linear_layer:
return Conv_LinearALlreduce(module, self.mp_group, name=name)
# Check for lm_head / embed_out
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(module, self.mp_group)
Expand All @@ -455,6 +457,12 @@ def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):

def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create column-parallel layer (AllReduce in backward)."""
if self.conv_linear_layer:
return conv_LinearLayer(module, self.mp_group, name=name)
# Only use fused-QKV heuristics when no partition_config is provided.
elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size):
# Check and handle fused qkv for TP
return fused_LinearLayer(module, self.mp_group, fused_module=self.module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fix exposed by a test? i.e. a model with conv linear layer or fused qkv weight.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! I added a test to validate we use the layers for new custom patterns when a partition is given.

if spec.shape is not None:
return SubParamLinearLayer(
module,
Expand Down Expand Up @@ -488,6 +496,7 @@ def _get_model_type(self) -> Optional[str]:
def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return

mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
Expand Down Expand Up @@ -551,7 +560,30 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
continue
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:

# When using partition_config (custom patterns/presets), use pattern-based routing
# instead of linear_policies. This keeps all pattern logic centralized here.
if self.partition_config is not None:
full_name = prev_name + '.' + name if prev_name else name
if isinstance(child, nn.Embedding):
# Check if embedding matches any pattern
param_name = full_name + ".weight"
model_type = self._get_model_type()
spec = self.partition_config.find_matching_spec(param_name, model_type)
if spec is not None and spec.partition_type != PartitionType.SKIP:
new_child = self._slice_embedding(child, full_name, False)
if new_child is not None:
setattr(r_module, name, new_child)
# If no pattern matched or skip, leave embedding unchanged
elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2:
new_child = self._replace_with_config(child, full_name)
if new_child is not None:
setattr(r_module, name, new_child)
else:
self.update_mp_params(child)
self._replace_module(child, name, class_name)
# Traditional path: use linear_policies for type-based routing
elif child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/tensor_parallel/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_mod
if tp_group is not None and mpu is not None:
raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.")
if tp_group is None and mpu is None and mesh_param is None:
raise ValueError("tp_model_init did not provide tp_group; deepspeed.initialize requires mpu or mesh_param.")
# Auto-create TP groups for compatibility with HF Trainer (mpu is not passed).
from deepspeed.utils import groups
groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size)

tp_section = config_dict.get("tensor_parallel")
if tp_section is None:
Expand Down
206 changes: 199 additions & 7 deletions tests/unit/model_parallelism/test_autotp_custom_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from unit.common import DistributedTest, preferred_dtype
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import groups
from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer)
from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer, fused_LinearLayer)
from deepspeed.module_inject.autotp_config import AutoTPConfig
from deepspeed.module_inject.auto_tp import AutoTP

Expand All @@ -35,6 +35,49 @@ def forward(self, x):
return x


class CustomLinearModule(torch.nn.Module):

def __init__(self, hidden_dim):
super(CustomLinearModule, self).__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
self.bias = torch.nn.Parameter(torch.empty(hidden_dim))
torch.nn.init.uniform_(self.weight, -0.02, 0.02)
torch.nn.init.uniform_(self.bias, -0.02, 0.02)

def forward(self, x):
return torch.matmul(x, self.weight.transpose(-1, -2)) + self.bias


class CustomLinearModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(CustomLinearModel, self).__init__()
self.custom = CustomLinearModule(hidden_dim)

def forward(self, x):
return self.custom(x)


class QKVLinearModule(torch.nn.Module):

def __init__(self, hidden_dim):
super(QKVLinearModule, self).__init__()
self.qkv_proj = torch.nn.Linear(hidden_dim, hidden_dim * 3)

def forward(self, x):
return self.qkv_proj(x)


class QKVLinearModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(QKVLinearModel, self).__init__()
self.self_attn = QKVLinearModule(hidden_dim)

def forward(self, x):
return self.self_attn(x)


def init_tp_engine(tp_size, partition_config=None):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
Expand Down Expand Up @@ -100,6 +143,15 @@ def gather_subparam_output(output, subparam_sizes, mp_group):
return torch.cat(gathered_chunks, dim=-1)


def assert_close_for_preferred_dtype(actual, expected):
atol = 1e-3
rtol = 2e-2
if preferred_dtype() is torch.float32:
atol = 1e-5
rtol = 1e-5
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)


class TestAutoTPCustomPatterns(DistributedTest):
world_size = 2
reuse_dist_env = False
Expand Down Expand Up @@ -178,6 +230,151 @@ def test_custom_patterns_applied_via_config(self):
assert isinstance(engine.module.linears[1], LinearLayer)
assert isinstance(engine.module.linears[2], nn.Linear)

def test_use_default_specs_false_skips_unmatched_layers(self):
skip_on_device()
# Verify unmatched layers remain unsharded when defaults are disabled.
partition_config = {
"use_default_specs":
False,
"layer_specs": [
{
"patterns": [".*linears\\.0\\.weight$"],
"partition_type": "row",
},
{
"patterns": [".*linears\\.1\\.weight$"],
"partition_type": "column",
},
],
}
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"tensor_parallel": {
"autotp_size": 2,
"partition_config": partition_config,
},
"zero_optimization": {
"stage": 0,
}
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = SequentialLinearModel(hidden_dim=16, nlayers=3)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert isinstance(engine.module.linears[0], LinearAllreduce)
assert isinstance(engine.module.linears[1], LinearLayer)
assert isinstance(engine.module.linears[2], nn.Linear)

def test_custom_module_replacement_with_patterns(self):
skip_on_device()
# Verify custom linear-like modules are partitioned via patterns.
partition_config = {
"use_default_specs": False,
"layer_specs": [
{
"patterns": [".*custom\\.weight$"],
"partition_type": "column",
},
],
}
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"tensor_parallel": {
"autotp_size": 2,
"partition_config": partition_config,
},
"zero_optimization": {
"stage": 0,
}
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = CustomLinearModel(hidden_dim=16)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert isinstance(engine.module.custom, LinearLayer)

def test_custom_pattern_disables_fused_qkv_heuristic(self):
skip_on_device()
# Use a qkv_proj name that would trigger the fused-QKV heuristic, then
# verify custom patterns override that path and preserve correctness.
torch.manual_seed(1234)
hidden_dim = 16
qkv_sizes = (hidden_dim, hidden_dim, hidden_dim)
partition_config = {
"use_default_specs":
False,
"layer_specs": [
{
"patterns": [".*self_attn\\.qkv_proj\\.weight$"],
"partition_type": "column",
"shape": [list(qkv_sizes), -1],
"partition_dim": 0,
},
],
}
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"tensor_parallel": {
"autotp_size": 2,
"partition_config": partition_config,
},
"zero_optimization": {
"stage": 0,
}
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = QKVLinearModel(hidden_dim=hidden_dim)
baseline = deepcopy(model).to(get_accelerator().current_device(), dtype=preferred_dtype())
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
qkv_layer = engine.module.self_attn.qkv_proj
# Custom pattern should force SubParamLinearLayer (shape-based path),
# and avoid the legacy fused-QKV heuristic despite the qkv_proj name.
assert isinstance(qkv_layer, SubParamLinearLayer)
assert not isinstance(qkv_layer, fused_LinearLayer)

assert qkv_layer.partition_dim == 0
assert qkv_layer._subparam_sizes == qkv_sizes
assert qkv_layer._orig_weight_shape == (hidden_dim * 3, hidden_dim)

qkv_layer.gather_params([qkv_layer.weight, qkv_layer.bias])
torch.testing.assert_close(qkv_layer.weight, baseline.self_attn.qkv_proj.weight)
if qkv_layer.bias is not None:
torch.testing.assert_close(qkv_layer.bias, baseline.self_attn.qkv_proj.bias)

torch.manual_seed(4321)
inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device())
full_output = baseline(inputs)
tp_output = engine.module(inputs)
assert_close_for_preferred_dtype(tp_output, full_output)

def test_first_match_precedence(self):
skip_on_device()
partition_config = {
Expand Down Expand Up @@ -294,9 +491,4 @@ def test_gqa_uneven_qkv_fused_forward(self):

gathered_output = gather_subparam_output(tp_output, (q_size, k_size, v_size),
groups.get_tensor_model_parallel_group())
atol = 1e-3
rtol = 2e-2
if preferred_dtype() is torch.float32:
atol = 1e-5
rtol = 1e-5
torch.testing.assert_close(gathered_output, full_output, atol=atol, rtol=rtol)
assert_close_for_preferred_dtype(gathered_output, full_output)
21 changes: 17 additions & 4 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,32 @@ def test_tp_model_init_config_autotp_size_mismatch(self):
with pytest.raises(ValueError, match="tensor_parallel.autotp_size"):
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())

def test_tp_model_init_requires_mpu_or_mesh_param(self):
def test_tp_model_init_autocreates_tp_group(self):
skip_on_device()
reset_tp_model_init_state()
# Verify tp_model_init creates TP groups when no mpu is provided.
model = SimpleModel(hidden_dim=8)
deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
tp_size = 2
deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype())
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"tensor_parallel": {
"partition_config": {
"use_default_specs": False,
"layer_specs": [{
"patterns": [".*\\.weight$"],
"partition_type": "skip",
}],
}
},
"zero_optimization": {
"stage": 0,
}
}
with pytest.raises(ValueError, match="requires mpu or mesh_param"):
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert engine.autotp_size() == tp_size
assert groups.get_tensor_model_parallel_world_size() == tp_size
assert groups.get_data_parallel_world_size() == dist.get_world_size() // tp_size

def test_tp_model_init_tp_group_rejects_mpu(self):
skip_on_device()
Expand Down