From 8e7023ba0125296b46e03d36313cfe3dca7cbe22 Mon Sep 17 00:00:00 2001 From: "hongliang.yuan" Date: Mon, 12 Jan 2026 13:41:43 +0800 Subject: [PATCH] initial stage_indices in get_held_layers of deepseek --- colossalai/shardformer/policies/deepseek.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 9baf068aec9f..a499cfdc0d87 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -332,6 +332,7 @@ def get_held_layers(self) -> List[Module]: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) stage_indices = stage_manager.get_stage_index(layers_per_stage) + stage_manager.stage_indices = stage_indices if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: