diff --git a/cookbook/megatron/npu/tp_lora_npu.py b/cookbook/megatron/npu/tp_lora_npu.py new file mode 100644 index 00000000..21abd028 --- /dev/null +++ b/cookbook/megatron/npu/tp_lora_npu.py @@ -0,0 +1,74 @@ +import os + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +# Build a device mesh for the verified NPU LoRA smoke. +MODEL_ID = os.environ.get('TWINKLE_LOCAL_MODEL_DIR', 'ms://Qwen/Qwen3-4B') +DATASET_PATH = os.environ.get( + 'TWINKLE_LOCAL_DATASET_PATH', + 'ms://swift/self-cognition', +) +MAX_STEPS = int(os.environ.get('TWINKLE_MAX_STEPS', '10')) +TRAIN_SAMPLES = int(os.environ.get('TWINKLE_TRAIN_SAMPLE_LIMIT', '160')) +BATCH_SIZE = int(os.environ.get('TWINKLE_BATCH_SIZE', '16')) + +# 8 cards: dp=2, tp=2, pp=2 +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +def build_dataloader() -> DataLoader: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_SAMPLES))) + dataset.set_template('Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + return DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + +def train(): + dataloader = build_dataloader() + + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + + # Keep the scheduler compatible with the shortened smoke run. + lr_decay_steps = max(MAX_STEPS, 2) + model.set_lr_scheduler( + scheduler_cls='default', + lr_warmup_steps=1, + lr_decay_steps=lr_decay_steps, + ) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + 'LoRA NPU smoke config: ' + f'model_id={MODEL_ID}, dataset={DATASET_PATH}, batch_size={BATCH_SIZE}, ' + f'train_samples={TRAIN_SAMPLES}, max_steps={MAX_STEPS}' + ) + logger.info(f'dataloader_steps={len(dataloader)}') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + metric = model.calculate_metric(is_training=True) + logger.info(f'step={step} metric={metric}') + if step + 1 >= MAX_STEPS: + break + + model.save('last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/cookbook/megatron/npu/tp_lora_npu.sh b/cookbook/megatron/npu/tp_lora_npu.sh new file mode 100755 index 00000000..418f3ff3 --- /dev/null +++ b/cookbook/megatron/npu/tp_lora_npu.sh @@ -0,0 +1,4 @@ +MEGATRON_LM_PATH=${MEGATRON_LM_PATH:-/path/to/Megatron-LM} +ASCEND_RT_VISIBLE_DEVICES=${ASCEND_RT_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \ +PYTHONPATH="${MEGATRON_LM_PATH}:${PYTHONPATH:-}" \ +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_lora_npu.py diff --git a/cookbook/megatron/npu/tp_moe_lora_npu.py b/cookbook/megatron/npu/tp_moe_lora_npu.py new file mode 100644 index 00000000..aa7af11c --- /dev/null +++ b/cookbook/megatron/npu/tp_moe_lora_npu.py @@ -0,0 +1,95 @@ +import os + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +# Build a device mesh for the verified NPU MoE LoRA smoke. +# Expert LoRA currently only supports ETP=1, so we keep TP at 1 here. +MODEL_ID = os.environ.get( + 'TWINKLE_LOCAL_MODEL_DIR', + 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507', +) +DATASET_PATH = os.environ.get( + 'TWINKLE_LOCAL_DATASET_PATH', + 'ms://swift/self-cognition', +) +MAX_STEPS = int(os.environ.get('TWINKLE_MAX_STEPS', '10')) +TRAIN_SAMPLES = int(os.environ.get('TWINKLE_TRAIN_SAMPLE_LIMIT', '80')) +BATCH_SIZE = int(os.environ.get('TWINKLE_BATCH_SIZE', '8')) +DP_SIZE = int(os.environ.get('TWINKLE_DP_SIZE', '8')) +TP_SIZE = int(os.environ.get('TWINKLE_TP_SIZE', '1')) +EP_SIZE = int(os.environ.get('TWINKLE_EP_SIZE', '2')) +PP_SIZE = int(os.environ.get('TWINKLE_PP_SIZE', '1')) +CP_SIZE = int(os.environ.get('TWINKLE_CP_SIZE', '1')) +LR = float(os.environ.get('TWINKLE_LR', '1e-4')) + +# 8 cards: dp=8, tp=1, ep=2, pp=1, cp=1 +device_mesh = DeviceMesh.from_sizes( + dp_size=DP_SIZE, + tp_size=TP_SIZE, + pp_size=PP_SIZE, + cp_size=CP_SIZE, + ep_size=EP_SIZE, +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +def build_dataloader() -> DataLoader: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_SAMPLES))) + dataset.set_template('Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + return DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + +def _to_loss_value(outputs) -> float: + loss = outputs['loss'] if isinstance(outputs, dict) else outputs.loss + return float(loss.detach().cpu()) if hasattr(loss, 'detach') else float(loss) + + +def train(): + dataloader = build_dataloader() + + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=LR) + + # Keep the scheduler compatible with the shortened smoke run. + lr_decay_steps = max(MAX_STEPS, 2) + model.set_lr_scheduler( + scheduler_cls='default', + lr_warmup_steps=1, + lr_decay_steps=lr_decay_steps, + ) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + 'MoE LoRA NPU smoke config: ' + f'model_id={MODEL_ID}, dataset={DATASET_PATH}, batch_size={BATCH_SIZE}, ' + f'train_samples={TRAIN_SAMPLES}, max_steps={MAX_STEPS}, ' + f'dp={DP_SIZE}, tp={TP_SIZE}, ep={EP_SIZE}, pp={PP_SIZE}, cp={CP_SIZE}' + ) + logger.info(f'dataloader_steps={len(dataloader)}') + + for step, batch in enumerate(dataloader): + outputs = model.forward_backward(inputs=batch) + model.clip_grad_and_step() + logger.info(f'step={step} loss={_to_loss_value(outputs)}') + if step + 1 >= MAX_STEPS: + break + + model.save('last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/cookbook/megatron/npu/tp_moe_lora_npu.sh b/cookbook/megatron/npu/tp_moe_lora_npu.sh new file mode 100755 index 00000000..944ae67b --- /dev/null +++ b/cookbook/megatron/npu/tp_moe_lora_npu.sh @@ -0,0 +1,4 @@ +MEGATRON_LM_PATH=${MEGATRON_LM_PATH:-/path/to/Megatron-LM} +ASCEND_RT_VISIBLE_DEVICES=${ASCEND_RT_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \ +PYTHONPATH="${MEGATRON_LM_PATH}:${PYTHONPATH:-}" \ +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_moe_lora_npu.py diff --git a/docs/source_en/Usage Guide/NPU-Support.md b/docs/source_en/Usage Guide/NPU-Support.md index e2b5e6da..d5e52e56 100644 --- a/docs/source_en/Usage Guide/NPU-Support.md +++ b/docs/source_en/Usage Guide/NPU-Support.md @@ -18,6 +18,7 @@ Before getting started, please ensure your system meets the following requiremen - torch and torch_npu versions **must be exactly the same** (e.g., both 2.7.1) - Python 3.11 is recommended for best compatibility - CANN toolkit requires approximately 10GB+ disk space +- If you need to use the **Megatron backend** (TP/PP/EP parallelism), you also need to install MindSpeed and prepare Megatron-LM source code. See the "[Megatron Training Environment Setup](#4-megatron-training-environment-setup-optional)" section below ## Supported Hardware @@ -75,7 +76,36 @@ pip install vllm-ascend==0.11.0rc3 - Ensure CANN environment is activated before installation: `source /usr/local/Ascend/ascend-toolkit/set_env.sh` - Recommended versions are vLLM 0.11.0 and vLLM-Ascend 0.11.0rc3 -### 4. Verify Installation +### 4. Megatron Training Environment Setup (Optional) + +If you need to use the Megatron backend for advanced parallel training such as TP/PP/EP, the following additional environment setup is required. This step is not needed if you only use DP/FSDP parallelism. + +#### Install MindSpeed + +MindSpeed is a required acceleration library for running Megatron on Ascend NPU, providing operator adaptation and distributed communication optimization. + +**Installation**: Refer to the [MindSpeed Official Repository](https://gitcode.com/Ascend/MindSpeed) for installation instructions. + +#### Clone Megatron-LM Source Code + +Megatron training requires Megatron-LM source code: + +```bash +git clone https://github.com/NVIDIA/Megatron-LM.git -b core_r0.12.0 +``` + +#### Configure PYTHONPATH + +Before running Megatron training scripts, you need to add both Twinkle source code and Megatron-LM source code to `PYTHONPATH`: + +```bash +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${MEGATRON_LM_PATH}:${PYTHONPATH} +``` + +> **Tip**: `cookbook/megatron/tp.sh` and `cookbook/megatron/tp_moe.sh` already include automatic PYTHONPATH configuration. You can use these scripts directly to launch training without manual setup. Default paths can be overridden via the `TWINKLE_SRC_PATH` and `MEGATRON_LM_PATH` environment variables. + +### 5. Verify Installation Create test script `verify_npu.py`: @@ -155,6 +185,54 @@ python cookbook/grpo/lora_npu.py - ✅ Optional TorchSampler or vLLMSampler - ✅ Complete RL training workflow +### Megatron MoE LoRA Fine-tuning + +Verified 8-card TP+EP LoRA training example: + +**Example Path**: [cookbook/megatron/npu/tp_moe_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_moe_lora_npu.py) + +**Run Method**: +```bash +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${MEGATRON_LM_PATH}:${PYTHONPATH} + +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_moe_lora_npu.py +``` + +**Notes**: +- Current expert LoRA only supports `ETP=1` +- This example uses the verified topology: `DP=8, TP=1, EP=2, PP=1, CP=1` +- If you raise `TP` to `2` together with `EP=2`, the framework will reject it explicitly + +**Example Features**: +- ✅ MoE + LoRA fine-tuning +- ✅ Megatron backend (DP=8, TP=1, EP=2) +- ✅ 10-step continuous loss printing + checkpoint saving + +### Megatron LoRA Fine-tuning + +Verified 8-card TP+PP LoRA fine-tuning example: + +**Example Path**: [cookbook/megatron/npu/tp_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_lora_npu.py) + +**Run Method**: +```bash +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export TWINKLE_SRC_PATH=/path/to/twinkle/src +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${TWINKLE_SRC_PATH}:${MEGATRON_LM_PATH}:${PYTHONPATH} + +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_lora_npu.py +``` + +**Example Features**: +- ✅ LoRA fine-tuning (r=8, target_modules=all-linear) +- ✅ Megatron backend (DP=2, TP=2, PP=2) +- ✅ 10-step continuous metric printing + checkpoint saving + +**Note**: MoE models do not currently support LoRA fine-tuning (Expert LoRA is not available when ETP>1). + ### More Examples Check the `cookbook/remote/tinker/ascend/` directory for remote training server-side configuration. @@ -167,15 +245,15 @@ Twinkle currently supports the following **verified** parallelization strategies |---------|------|---------|---------| | DP (Data Parallel) | Data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) | | FSDP (Fully Sharded Data Parallel) | Fully sharded data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) | -| TP (Tensor Parallel) | Tensor parallelism (Megatron) | 🚧 | To be verified | -| PP (Pipeline Parallel) | Pipeline parallelism (Megatron) | 🚧 | To be verified | -| CP (Context Parallel) | Context parallelism | 🚧 | To be verified | -| EP (Expert Parallel) | Expert parallelism (MoE) | 🚧 | To be verified | +| TP (Tensor Parallel) | Tensor parallelism (Megatron) | ✅ | Verified (see cookbook/megatron/npu/) | +| PP (Pipeline Parallel) | Pipeline parallelism (Megatron) | ✅ | Verified (see cookbook/megatron/npu/) | +| CP (Context Parallel) | Context parallelism | ❌ | Not supported for now | +| EP (Expert Parallel) | Expert parallelism (MoE) | ✅ | Verified (see cookbook/megatron/npu/tp_moe_lora_npu.py) | **Legend**: - ✅ Verified: Has actual running example code - 🚧 To be verified: Theoretically supported but no NPU verification example yet -- ❌ Not supported: Not available in current version +- ❌ Not supported for now: the current implementation path does not support it, so keep it disabled on NPU Megatron ### DP + FSDP Example @@ -193,7 +271,29 @@ device_mesh = DeviceMesh( ) ``` -**Note**: Megatron backend (TP/PP/EP) support on NPU is under development, with no available examples yet. If you need these advanced parallelization strategies, please verify in GPU environment first or follow project updates. +### Megatron TP + PP Example + +The following configuration is from `cookbook/megatron/npu/tp_lora_npu.py`, verified in an actual 8-card NPU environment: + +```python +from twinkle import DeviceMesh + +# 8 cards: dp=2, tp=2, pp=2 +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) +``` + +### Megatron TP + EP Example (MoE Model) + +The following configuration is from `cookbook/megatron/npu/tp_moe_lora_npu.py`, verified in an actual 8-card NPU environment: + +```python +from twinkle import DeviceMesh + +# 8 cards: dp=8, tp=1, ep=2, pp=1, cp=1 +device_mesh = DeviceMesh.from_sizes(dp_size=8, tp_size=1, pp_size=1, cp_size=1, ep_size=2) +``` + +**Note**: Context Parallel (CP) is not supported yet on NPU Megatron. Please keep `cp_size=1`. ## Common Issues @@ -223,6 +323,54 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl - Refer to [Ascend Community Version Compatibility Table](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0015.html) - Install corresponding CANN toolkit version +### 3. Megatron Training Reports ModuleNotFoundError: No module named 'megatron' + +**Problem**: Running Megatron training scripts reports that the `megatron` module cannot be found. + +**Solution**: +- Confirm that Megatron-LM source code has been cloned and its path is added to `PYTHONPATH` +- Confirm that `TWINKLE_SRC_PATH` points to Twinkle's `src` directory +- Refer to the PYTHONPATH configuration in `cookbook/megatron/tp.sh` + +```bash +export PYTHONPATH=/path/to/twinkle/src:/path/to/Megatron-LM:${PYTHONPATH} +``` + +### 4. NPU Cards Occupied Causing Training Failure + +**Problem**: Training fails with HCCL communication timeout or device unavailable errors after launch. + +**Solution**: +- First use `npu-smi info` to check which cards are occupied by other processes +- Set `ASCEND_RT_VISIBLE_DEVICES` to specify only available cards +- Ensure `torchrun --nproc_per_node` count matches the number of cards in `ASCEND_RT_VISIBLE_DEVICES` + +```bash +# Check card usage +npu-smi info + +# Assuming cards 0,1,2,3 are free and 4,5,6,7 are occupied +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +torchrun --nproc_per_node=4 your_script.py +``` + +### 5. torchrun Using Wrong Python Environment + +**Problem**: Multi-card training reports many missing package errors (e.g., `ModuleNotFoundError: No module named 'datasets'`), but `pip list` locally shows these packages. + +**Solution**: +- Check if `which torchrun` points to the current Conda environment +- If it points to the system Python, activate the correct environment first + +```bash +# Check torchrun source +which torchrun + +# Ensure it comes from the current Conda environment +conda activate your_env +which torchrun # Should point to a path within the conda environment +``` + ## Feature Support Status Feature support matrix based on actual code verification: @@ -236,10 +384,13 @@ Feature support matrix based on actual code verification: | Ray Distributed | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available | | TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available | | vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available | -| Full Fine-tuning | ✅ | 🚧 | - | Theoretically supported, to be verified | | QLoRA | ✅ | ❌ | - | Quantization operators not yet supported | | DPO | ✅ | 🚧 | - | Theoretically supported, to be verified | -| Megatron TP/PP | ✅ | 🚧 | - | To be adapted and verified | +| Megatron TP/PP | ✅ | ✅ | cookbook/megatron/npu/tp_lora_npu.py | Verified (dp=2, tp=2, pp=2) | +| Megatron EP (MoE) | ✅ | ✅ | cookbook/megatron/npu/tp_moe_lora_npu.py | Verified (dp=8, tp=1, ep=2) | +| Megatron MoE LoRA (ETP=1) | ✅ | ✅ | cookbook/megatron/npu/tp_moe_lora_npu.py | Verified (dp=8, tp=1, ep=2) | +| Megatron LoRA | ✅ | ✅ | cookbook/megatron/npu/tp_lora_npu.py | Verified (dp=2, tp=2, pp=2) | +| MoE + LoRA + ETP > 1 | ✅ | ❌ | - | Expert LoRA not supported when ETP>1 | | Flash Attention | ✅ | ⚠️ | - | Some operators not supported | **Legend**: @@ -269,6 +420,14 @@ Twinkle provides the following verified NPU training examples: - Supports Reference Model - Optional TorchSampler or vLLMSampler +### Megatron Training +- **8-card LoRA Fine-tuning**: [cookbook/megatron/npu/tp_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_lora_npu.py) + - LoRA fine-tuning, DP=2, TP=2, PP=2 + - Supports all-linear target modules +- **8-card MoE LoRA Fine-tuning**: [cookbook/megatron/npu/tp_moe_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_moe_lora_npu.py) + - MoE LoRA fine-tuning, DP=8, TP=1, EP=2 + - Expert LoRA currently requires ETP=1 + ### Remote Training (Tinker Protocol) - **Server Configuration**: [cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend) - Provides HTTP API interface @@ -284,6 +443,14 @@ python cookbook/sft/lora_npu.py # GRPO training export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python cookbook/grpo/lora_npu.py + +# Megatron LoRA training (use sh script directly) +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +bash cookbook/megatron/npu/tp_lora_npu.sh + +# Megatron MoE LoRA training +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +bash cookbook/megatron/npu/tp_moe_lora_npu.sh ``` ## Reference Resources diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" index 3241dbf5..52c922f9 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" @@ -18,6 +18,7 @@ - torch 和 torch_npu 版本**必须完全一致**(例如都为 2.7.1) - 推荐使用 Python 3.11 以获得最佳兼容性 - CANN 工具包需要约 10GB+ 磁盘空间 +- 如果需要使用 **Megatron 后端**(TP/PP/EP 并行),还需额外安装 MindSpeed 并准备 Megatron-LM 源码,详见下方「[Megatron 训练环境准备](#4-megatron-训练环境准备可选)」章节 ## 支持的硬件 @@ -75,7 +76,34 @@ pip install vllm-ascend==0.11.0rc3 - 安装前确保已激活 CANN 环境:`source /usr/local/Ascend/ascend-toolkit/set_env.sh` - 推荐使用的版本为 vLLM 0.11.0 和 vLLM-Ascend 0.11.0rc3 -### 4. 验证安装 +### 4. Megatron 训练环境准备(可选) + +如果需要使用 Megatron 后端进行 TP/PP/EP 等高级并行训练,需要额外准备以下环境。仅使用 DP/FSDP 并行时无需此步骤。 + +#### 安装 MindSpeed + +MindSpeed 是昇腾 NPU 上运行 Megatron 的必要加速库,提供算子适配和分布式通信优化。 + +**安装方式**:参考 [MindSpeed 官方仓库](https://gitcode.com/Ascend/MindSpeed) 的安装说明。 + +#### 克隆 Megatron-LM 源码 + +Megatron 训练需要 Megatron-LM 源码: + +```bash +git clone https://github.com/NVIDIA/Megatron-LM.git -b core_r0.12.0 +``` + +#### 配置 PYTHONPATH + +运行 Megatron 训练脚本前,需要将 Twinkle 源码和 Megatron-LM 源码同时加入 `PYTHONPATH`: + +```bash +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${MEGATRON_LM_PATH}:${PYTHONPATH} +``` + +### 5. 验证安装 创建测试脚本 `verify_npu.py`: @@ -155,6 +183,53 @@ python cookbook/grpo/lora_npu.py - ✅ 可选 TorchSampler 或 vLLMSampler - ✅ 完整的 RL 训练流程 +### Megatron MoE LoRA 微调 + +已验证的 8 卡 TP+EP LoRA 训练示例: + +**示例路径**:[cookbook/megatron/npu/tp_moe_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_moe_lora_npu.py) + +**运行方式**: +```bash +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${MEGATRON_LM_PATH}:${PYTHONPATH} + +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_moe_lora_npu.py +``` + +**说明**: +- 当前 expert LoRA 仅支持 `ETP=1` +- 这份示例使用已验证拓扑:`DP=8, TP=1, EP=2, PP=1, CP=1` +- 如果把 `TP` 提到 `2` 再配 `EP=2`,框架会明确拒绝 + +**示例特性**: +- ✅ MoE + LoRA 微调 +- ✅ Megatron 后端(DP=8, TP=1, EP=2) +- ✅ 10 步 loss 连续打印 + checkpoint 保存 + +### Megatron LoRA 微调 + +已验证的 8 卡 TP+PP LoRA 微调示例: + +**示例路径**:[cookbook/megatron/npu/tp_lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/megatron/npu/tp_lora_npu.py) + +**运行方式**: +```bash +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export MEGATRON_LM_PATH=/path/to/Megatron-LM +export PYTHONPATH=${MEGATRON_LM_PATH}:${PYTHONPATH} + +# 运行训练 +torchrun --nproc_per_node=8 cookbook/megatron/npu/tp_lora_npu.py +``` + +**示例特性**: +- ✅ LoRA 微调(r=8, target_modules=all-linear) +- ✅ Megatron 后端(DP=2, TP=2, PP=2) +- ✅ 10 步 metric 连续打印 + checkpoint 保存 + + ### 更多示例 查看 `cookbook/remote/tinker/ascend/` 目录了解远程训练服务端配置。 @@ -167,15 +242,15 @@ Twinkle 在 NPU 上目前支持以下**经过验证**的并行策略: |---------|------|---------|---------| | DP (Data Parallel) | 数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) | | FSDP (Fully Sharded Data Parallel) | 完全分片数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) | -| TP (Tensor Parallel) | 张量并行(Megatron) | 🚧 | 待验证 | -| PP (Pipeline Parallel) | 流水线并行(Megatron) | 🚧 | 待验证 | -| CP (Context Parallel) | 上下文并行 | 🚧 | 待验证 | -| EP (Expert Parallel) | 专家并行(MoE) | 🚧 | 待验证 | +| TP (Tensor Parallel) | 张量并行(Megatron) | ✅ | 已验证(见 cookbook/megatron/npu/) | +| PP (Pipeline Parallel) | 流水线并行(Megatron) | ✅ | 已验证(见 cookbook/megatron/npu/) | +| CP (Context Parallel) | 上下文并行 | ❌ | 暂不支持 | +| EP (Expert Parallel) | 专家并行(MoE) | ✅ | 已验证(见 cookbook/megatron/npu/tp_moe_lora_npu.py) | **图例说明**: - ✅ 已验证:有实际运行示例代码 - 🚧 待验证:理论上支持但暂无 NPU 验证示例 -- ❌ 不支持:当前版本不可用 +- ❌ 暂不支持:当前实现路径明确不支持,NPU Megatron 不要开启 ### DP + FSDP 示例 @@ -193,7 +268,29 @@ device_mesh = DeviceMesh( ) ``` -**注意**:Megatron 后端(TP/PP/EP)在 NPU 上的支持正在开发中,暂无可用示例。如需使用这些高级并行策略,请先在 GPU 环境下验证,或关注项目更新。 +### Megatron TP + PP 示例(Dense LoRA) + +以下配置来自 `cookbook/megatron/npu/tp_lora_npu.py`,在实际 8 卡 NPU 环境中验证通过: + +```python +from twinkle import DeviceMesh + +# 8 卡:dp=2, tp=2, pp=2 +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) +``` + +### Megatron TP + EP 示例(MoE LoRA) + +以下配置来自 `cookbook/megatron/npu/tp_moe_lora_npu.py`,在实际 8 卡 NPU 环境中验证通过: + +```python +from twinkle import DeviceMesh + +# 8 卡:dp=8, tp=1, ep=2, pp=1, cp=1 +device_mesh = DeviceMesh.from_sizes(dp_size=8, tp_size=1, pp_size=1, cp_size=1, ep_size=2) +``` + +**注意**:Context Parallel(CP)在 NPU Megatron 上暂不支持,建议保持 `cp_size=1`。 ## 常见问题 @@ -223,6 +320,18 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl - 参考[昇腾社区版本配套表](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0015.html) - 安装对应版本的 CANN 工具包 +### 3. Megatron 训练报 ModuleNotFoundError: No module named 'megatron' + +**问题**:运行 Megatron 训练脚本时报找不到 `megatron` 模块。 + +**解决方案**: +- 确认已克隆 Megatron-LM 源码,并将其路径加入 `PYTHONPATH` +- 参考 `cookbook/megatron/tp.sh` 中的 PYTHONPATH 配置 + +```bash +export PYTHONPATH=/path/to/Megatron-LM:${PYTHONPATH} +``` + ## 功能支持情况 基于实际代码验证的功能支持矩阵: @@ -236,10 +345,13 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl | Ray 分布式 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 | | TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 | | vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 | -| 全量微调 | ✅ | 🚧 | - | 理论支持,待验证 | | QLoRA | ✅ | ❌ | - | 量化算子暂不支持 | | DPO | ✅ | 🚧 | - | 理论支持,待验证 | -| Megatron TP/PP | ✅ | 🚧 | - | 待适配和验证 | +| Megatron TP/PP | ✅ | ✅ | cookbook/megatron/npu/tp_lora_npu.py | 已验证(dp=2, tp=2, pp=2) | +| Megatron EP(MoE) | ✅ | ✅ | cookbook/megatron/npu/tp_moe_lora_npu.py | 已验证(dp=8, tp=1, ep=2) | +| Megatron LoRA | ✅ | ✅ | cookbook/megatron/npu/tp_lora_npu.py | 已验证(dp=2, tp=2, pp=2) | +| Megatron MoE LoRA(ETP=1) | ✅ | ✅ | cookbook/megatron/npu/tp_moe_lora_npu.py | 已验证(dp=8, tp=1, ep=2) | +| MoE + LoRA + ETP>1 | ✅ | ❌ | - | Expert LoRA 在 ETP>1 时不支持 | | Flash Attention | ✅ | ⚠️ | - | 部分算子不支持 | **图例说明**: diff --git a/src/twinkle/model/megatron/__init__.py b/src/twinkle/model/megatron/__init__.py index 0f462566..117c6156 100644 --- a/src/twinkle/model/megatron/__init__.py +++ b/src/twinkle/model/megatron/__init__.py @@ -6,8 +6,14 @@ # Follow the same LazyModule approach as `twinkle.model`: only import when those symbols are actually accessed. from typing import TYPE_CHECKING +from twinkle import Platform from twinkle.utils.import_utils import _LazyModule +if Platform.device_prefix() == 'npu': + # MindSpeed needs to patch `torch.compile`/TE symbols before any `megatron.core` + # module binds them by value. Keeping this import early is the smallest reliable hook. + import mindspeed.megatron_adaptor # noqa: F401 + if TYPE_CHECKING: from .megatron import MegatronModel, MegatronStrategy from .multi_lora_megatron import MultiLoraMegatronModel diff --git a/src/twinkle/model/megatron/_mindspeed_args.py b/src/twinkle/model/megatron/_mindspeed_args.py new file mode 100644 index 00000000..43382317 --- /dev/null +++ b/src/twinkle/model/megatron/_mindspeed_args.py @@ -0,0 +1,148 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Private MindSpeed runtime args helpers for NPU Megatron runs.""" + +import argparse +import json +import torch +from typing import Any, Dict + +from .utils import convert_hf_config + + +def sanitize_mindspeed_values(values: Dict[str, Any]) -> Dict[str, Any]: + return {key: value for key, value in values.items() if isinstance(key, str) and key.isidentifier()} + + +def _resolve_optimization_level(values: Dict[str, Any]) -> int: + if any(( + bool(values.get('multi_latent_attention')), + bool(values.get('multi_head_latent_attention')), + values.get('q_lora_rank') is not None, + values.get('num_experts', 0) > 0, + bool(values.get('moe_grouped_gemm')), + bool(values.get('moe_fb_overlap')), + bool(values.get('moe_alltoall_overlap_comm')), + bool(values.get('moe_allgather_overlap_comm')), + bool(values.get('balanced_moe_experts')), + values.get('schedules_method') == 'dualpipev', + )): + return 2 + return 0 + + +def _update_sanitized(values: Dict[str, Any], section: Dict[str, Any]) -> None: + values.update(sanitize_mindspeed_values(section)) + + +def _build_fixed_runtime_defaults() -> Dict[str, Any]: + # Fixed MindSpeed / TE runtime defaults. + return { + 'transformer_impl': 'transformer_engine', + 'fp8': None, + 'optimizer_selection': 'fused_adamw', + 'shape_order': 'SBH', + 'use_ascend_mc2': False, + 'enable_gloo_process_groups': True, + 'disable_gloo_group': False, + } + + +def _build_topology_and_shape_defaults(args: Any, values: Dict[str, Any], rope_scaling: Dict[str, + Any]) -> Dict[str, Any]: + # Core topology and transformer shape. + return { + 'tensor_model_parallel_size': args.tp_size, + 'pipeline_model_parallel_size': args.pp_size, + 'context_parallel_size': args.cp_size, + 'expert_model_parallel_size': args.ep_size, + 'expert_tensor_parallel_size': args.etp_size, + 'virtual_pipeline_model_parallel_size': args.vpp_size, + 'sequence_parallel': bool(args.sequence_parallel), + 'num_layers': int(args.num_layers), + 'hidden_size': int(args.hidden_size), + 'num_attention_heads': int(args.num_attention_heads), + 'num_query_groups': int(args.num_query_groups or args.num_attention_heads), + 'ffn_hidden_size': int(args.ffn_hidden_size), + 'mtp_num_layers': int(args.mtp_num_layers or 0), + 'bf16': args.params_dtype == torch.bfloat16, + 'fp16': args.params_dtype == torch.float16, + 'position_embedding_type': values.get('position_embedding_type', 'rope'), + 'rope_scaling_type': rope_scaling.get('rope_type') or rope_scaling.get('type'), + 'yarn_scaling_factor': rope_scaling.get('factor'), + 'rope_scaling_mscale': rope_scaling.get('mscale'), + 'rope_scaling_mscale_all_dim': rope_scaling.get('mscale_all_dim'), + } + + +def _build_moe_runtime_defaults(values: Dict[str, Any], args: Any, num_experts: int) -> Dict[str, Any]: + # MoE runtime knobs. + return { + 'num_experts': num_experts, + 'num_moe_experts': num_experts or None, + 'moe_grouped_gemm': bool(values.get('moe_grouped_gemm', False) or num_experts > 0), + 'moe_token_dispatcher_type': values.get('moe_token_dispatcher_type') + or ('alltoall' if num_experts > 0 else None), + 'moe_router_topk': int(values.get('moe_router_topk', args.num_experts_per_tok) or 2), + } + + +def _build_mla_runtime_defaults(values: Dict[str, Any], q_lora_rank: Any, multi_latent_attention: bool, + qk_layernorm: bool, args: Any) -> Dict[str, Any]: + # MLA / DeepSeek-style attention knobs. + return { + 'multi_latent_attention': multi_latent_attention, + 'multi_head_latent_attention': multi_latent_attention, + 'q_lora_rank': q_lora_rank, + 'kv_lora_rank': values.get('kv_lora_rank'), + 'qk_layernorm': qk_layernorm, + 'use_qk_norm': qk_layernorm, + 'qk_nope_head_dim': values.get('qk_head_dim', values.get('qk_nope_head_dim')), + 'qk_rope_head_dim': values.get('qk_pos_emb_head_dim', values.get('qk_rope_head_dim')), + 'v_head_dim': values.get('v_head_dim', args.kv_channels), + } + + +def build_mindspeed_namespace(args: Any, defaults: Dict[str, Any]) -> argparse.Namespace: + """Build MindSpeed runtime args namespace from Twinkle args. + + If there are fields with the same name, the one at the lowest level will be overwritten. + + Merges three layers in order of precedence (later layers override earlier ones): + 1. MindSpeed defaults (~100+ fields from register_args) - lowest priority + 2. HF config (layers, heads, MoE params via convert_hf_config()) - medium priority + 3. Twinkle args (tp/pp/cp/ep, dtype) - highest priority, overrides all + + Args: + args: TwinkleMegatronArgs instance. + defaults: MindSpeed default values from args_utils.get_mindspeed_args(). + + Returns: + Merged MindSpeed runtime arguments as Namespace. + """ + if getattr(args, 'fp8', None): + raise RuntimeError('MindSpeed NPU TE bootstrap does not support FP8.') + + values = sanitize_mindspeed_values(defaults.copy()) + hf_config = getattr(args, 'hf_config', None) + if hf_config is not None: + values.update(sanitize_mindspeed_values(convert_hf_config(hf_config))) + + rope_scaling = args.rope_scaling if isinstance(args.rope_scaling, dict) else {} + num_experts = int(getattr(args, 'num_experts', 0) or values.get('num_experts', 0) or 0) + q_lora_rank = values.get('q_lora_rank', getattr(args, 'q_lora_rank', None)) + multi_latent_attention = bool( + getattr(args, 'multi_latent_attention', False) or values.get('multi_latent_attention', False) + or values.get('multi_head_latent_attention', False) or q_lora_rank is not None) + qk_layernorm = bool(getattr(args, 'qk_layernorm', False) or values.get('qk_layernorm', False)) + + _update_sanitized(values, _build_fixed_runtime_defaults()) + _update_sanitized(values, _build_topology_and_shape_defaults(args, values, rope_scaling)) + _update_sanitized(values, _build_moe_runtime_defaults(values, args, num_experts)) + _update_sanitized(values, + _build_mla_runtime_defaults(values, q_lora_rank, multi_latent_attention, qk_layernorm, args)) + values['optimization_level'] = _resolve_optimization_level(values) + return argparse.Namespace(**sanitize_mindspeed_values(values)) + + +def get_mindspeed_signature(namespace: argparse.Namespace) -> str: + return json.dumps(sanitize_mindspeed_values(vars(namespace).copy()), sort_keys=True, default=str) diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index eacc10db..070fe750 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -6,12 +6,97 @@ from types import SimpleNamespace from typing import Any, Dict, List, Literal, Optional -from twinkle import DeviceMesh +from twinkle import DeviceMesh, Platform, get_logger from twinkle.utils import exists from .utils import convert_hf_config # Global args storage _GLOBAL_ARGS: Optional['TwinkleMegatronArgs'] = None +logger = get_logger() + + +def _normalize_word_embedding_allreduce_call(*call_args, **call_kwargs): + """Normalize Megatron's private word-embedding helper call. + + Megatron Core has changed the helper signature across releases: + - 0.12.1: (model, config) + - 0.16.1: (model, config, embd_group, pp_group) + - future releases may add more positional/keyword args. + + We keep the semantics stable and only normalize the known pieces. + """ + model = call_kwargs.pop('model', call_args[0] if call_args else None) + config = call_kwargs.pop('config', call_args[1] if len(call_args) > 1 else None) + if model is None or config is None: + raise TypeError('word-embedding finalize helper requires at least model and config arguments.') + + embd_group = call_kwargs.pop('embd_group', call_args[2] if len(call_args) > 2 else None) + pp_group = call_kwargs.pop('pp_group', call_args[3] if len(call_args) > 3 else None) + return model, config, embd_group, pp_group, call_kwargs + + +def _allreduce_word_embedding_grads_allow_none(*call_args, **call_kwargs): + """None-safe drop-in for Megatron's private embedding all-reduce helper. + + This wrapper intentionally accepts arbitrary positional/keyword arguments so + it can survive Megatron helper signature drift across versions. + """ + from megatron.core import parallel_state + from megatron.core.distributed.finalize_model_grads import (_get_main_grad_attr, _reshard_if_dtensor, + _unshard_if_dtensor, get_attr_wrapped_model) + + model, config, embd_group, pp_group, _ = _normalize_word_embedding_allreduce_call(*call_args, **call_kwargs) + if embd_group is None: + embd_group = parallel_state.get_embedding_group() + if pp_group is None: + pp_group = parallel_state.get_pipeline_model_parallel_group() + + def _get_main_grad_attr_compat(weight, ddp_config): + try: + helper_params = inspect.signature(_get_main_grad_attr).parameters + except (TypeError, ValueError): + helper_params = None + + if helper_params is not None and len(helper_params) <= 1: + return _get_main_grad_attr(weight) + return _get_main_grad_attr(weight, ddp_config.use_custom_fsdp) + + if parallel_state.is_rank_in_embedding_group( + ignore_virtual=True) and torch.distributed.get_world_size(embd_group) > 1: + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + model_module = model[0] + elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): + model_module = model[-1] + else: + model_module = model[0] + + ddp_config = model_module.ddp_config + model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) + + if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0): + weight = model_module.shared_embedding_or_output_weight() + if weight is None: + logger.warning_once( + 'Megatron LoRA finalize skipped shared embedding/output weight all-reduce ' + 'because the tied weight is missing on this pipeline stage.', + hash_id='megatron_lora_skip_embedding_allreduce_missing_weight', + ) + return + + grad_attr = _get_main_grad_attr_compat(weight, ddp_config) + orig_grad = getattr(weight, grad_attr, None) + grad = _unshard_if_dtensor(orig_grad) + if grad is None: + logger.warning_once( + 'Megatron LoRA finalize skipped shared embedding/output weight all-reduce ' + 'because the tied weight has no grad. This is expected when LoRA freezes ' + 'the base embedding/output weight.', + hash_id='megatron_lora_skip_embedding_allreduce_none_grad', + ) + return + + torch.distributed.all_reduce(grad, group=embd_group) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) def get_args() -> 'TwinkleMegatronArgs': @@ -231,6 +316,17 @@ def ep_size(self) -> int: @property def expert_tensor_parallel_size(self) -> int: + if not exists('megatron_core>=0.13'): + # megatron_core<0.13 does not have a separate ETP config. For expert ColumnParallelLinear, + # the internal path still uses the dense TP group, and parameter sharding is determined by tp_size. + # etp_size has no practical effect here. + # Force alignment here to avoid a mismatch where GPTBridge shards by etp_size while + # the parameters were built according to tp_size. + tp = self.device_mesh.tp_world_size or 1 + if self.device_mesh.etp_size is not None and self.device_mesh.etp_world_size != tp: + logger.warning(f'etp_size={self.device_mesh.etp_world_size} is ignored on ' + f'megatron_core<0.13; expert TP is tied to tp_size={tp}') + return tp return self.device_mesh.etp_world_size @property @@ -333,10 +429,11 @@ def from_hf_config( model_type = getattr(hf_config, 'model_type', 'qwen2') - # Detect multimodal model from the registered MegatronModelMeta - from .model.register import get_megatron_model_meta - model_meta = get_megatron_model_meta(model_type) - is_multimodal = model_meta.is_multimodal if model_meta is not None else False + # Detect multimodal models without importing the Megatron registry. + # The registry import chain can pull in megatron.core, which must stay + # behind the MindSpeed bootstrap on NPU. + from .model.constant import MLLMModelType + is_multimodal = model_type in {value for key, value in vars(MLLMModelType).items() if not key.startswith('_')} # Determine QKV bias if hasattr(text_config, 'attention_bias'): @@ -470,27 +567,46 @@ def create_model(self, ) -> List[nn.Module]: # Recompute all layers for maximum memory savings recompute_num_layers = num_layers // self.pp_size - # Create finalize_model_grads function for DP gradient synchronization - # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config. - # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models. + # Custom finalize_model_grads for LoRA, registered via TransformerConfig. + # Fixes two issues with Megatron's native finalize_model_grads: + # + # 1. Bare models (single-rank / no-op wrap) only carry ddp_config but lack + # finish_grad_sync(), so we gate on real DDP capability instead. + # + # 2. In multi-rank LoRA + PP, native _allreduce_word_embedding_grads assumes + # shared embedding/output weight always has a grad. LoRA freezes the base + # weight so grad is None -> all_reduce(None) crashes. We monkey-patch that + # one helper to skip None grads, reusing the rest of native finalize via + # try/finally to avoid forking the entire module. from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads def finalize_model_grads_for_lora(model, *args, **kwargs): + import importlib + from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from megatron.core.distributed.finalize_model_grads import (_get_main_grad_attr, _reshard_if_dtensor, + _unshard_if_dtensor, get_attr_wrapped_model) from peft import PeftModel as _PeftModel - # Check if model is DDP-wrapped (has ddp_config) - # Need to unwrap PeftModel to check the underlying model + # Unwrap PeftModel -> LoraModel -> real model to check DDP capability. def _get_base_model(m): if isinstance(m, _PeftModel): return _get_base_model(m.base_model.model) return m + # Fix 1: check real DDP capability, not just ddp_config presence. base_model = _get_base_model(model[0]) - if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'ddp_config'): - # Use native implementation for DDP models - return _native_finalize_model_grads(model, *args, **kwargs) - + if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'finish_grad_sync'): + # Fix 2: temporarily swap in the None-safe embedding allreduce. + finalize_model_grads_mod = importlib.import_module('megatron.core.distributed.finalize_model_grads') + orig_allreduce_word_embedding_grads = finalize_model_grads_mod._allreduce_word_embedding_grads + finalize_model_grads_mod._allreduce_word_embedding_grads = _allreduce_word_embedding_grads_allow_none + try: + return _native_finalize_model_grads(model, *args, **kwargs) + finally: + finalize_model_grads_mod._allreduce_word_embedding_grads = orig_allreduce_word_embedding_grads + + # Bare model (single-rank / no-op wrap): no DDP sync, skip. return # MoE configuration @@ -565,6 +681,7 @@ def _get_base_model(m): bias_activation_fusion = use_swiglu and not has_bias if 'moe_token_dispatcher_type' not in moe_kwargs: moe_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' + is_npu = Platform.device_prefix() == 'npu' config = TransformerConfig( num_layers=num_layers, hidden_size=mg_config_dict['hidden_size'], @@ -595,11 +712,15 @@ def _get_base_model(m): hidden_dropout=0.0, attention_dropout=0.0, # Performance optimizations - masked_softmax_fusion=True, # Fused attention softmax + # NPU fallback: the current environment does not provide the TBE-backed + # fused softmax kernel that MindSpeed's NPU path selects by default. + # Keep the GPU fast path unchanged, but fall back to unfused softmax on NPU + # so attention can run without a hard dependency on `tbe`. + masked_softmax_fusion=not is_npu, bias_dropout_fusion=True, # Fused bias + dropout apply_rope_fusion=True, # Fused RoPE application attention_softmax_in_fp32=True, # Numerical stability - attention_backend=AttnBackend.flash, # FlashAttention for speed + attention_backend=AttnBackend.flash, # Activation recomputation for memory efficiency recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4b37973c..84c81bcd 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -257,7 +257,11 @@ def _create_megatron_model( **kwargs, ) -> List[nn.Module]: from .args import get_args + from .mindspeed_bootstrap import bootstrap_mindspeed_for_npu args = get_args() + # Convert the args into a MindSpeed-compatible form so Megatron and + # MindSpeed can share the same runtime arguments. + bootstrap_mindspeed_for_npu(args) self.initialize(**kwargs) model = args.create_model() @@ -265,7 +269,6 @@ def _create_megatron_model( bridge = self._bridge for _model in model: bridge.load_weights(_model, args.model_dir) - if dist.is_initialized(): dist.barrier() @@ -492,9 +495,14 @@ def forward_step_func(data_iterator, model): logps = None if labels is not None and mpu.is_pipeline_last_stage(): loss_mask = (labels != -100).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 + # Avoid bool advanced indexing here. On NPU this lowers to + # aclnnNonzeroV2 inside AdvancedIndex and can crash during + # end-to-end training; torch.where preserves the same masking + # semantics without going through that path. + masked_labels = torch.where(loss_mask, labels, torch.zeros_like(labels)) + output_tensor.div_(temperature) + logps = selective_log_softmax(output_tensor, masked_labels) if cp_size > 1: logps = self._postprocess_tensor_cp(logps) @@ -788,23 +796,39 @@ def _create_megatron_optimizer(self, **kwargs): # Build optimizer config lr = kwargs.pop('lr', 1e-4) + use_gloo_process_groups = kwargs.pop('use_gloo_process_groups', True) use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', False) - - opt_config = OptimizerConfig( - optimizer='adam', - lr=lr, - min_lr=kwargs.get('min_lr', 0.0), - weight_decay=kwargs.get('weight_decay', 0.01), - adam_beta1=kwargs.get('adam_beta1', 0.9), - adam_beta2=kwargs.get('adam_beta2', 0.999), - adam_eps=kwargs.get('adam_eps', 1e-8), - clip_grad=kwargs.get('clip_grad', 1.0), - bf16=kwargs.get('bf16', True), - use_distributed_optimizer=use_distributed_optimizer, - overlap_param_gather=kwargs.get('overlap_param_gather', False), - log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), - **kwargs, - ) + # Some Megatron-LM versions (e.g. 0.12.1) only accept + # overlap_param_gather_with_optimizer_step here. + # overlap_param_gather still exists on ddp_config / distributed + # optimizer paths, but passing it directly into OptimizerConfig + # raises TypeError on this branch. + config_sig = inspect.signature(OptimizerConfig).parameters + config_kwargs = { + 'optimizer': 'adam', + 'lr': lr, + 'min_lr': kwargs.get('min_lr', 0.0), + 'weight_decay': kwargs.get('weight_decay', 0.01), + 'adam_beta1': kwargs.get('adam_beta1', 0.9), + 'adam_beta2': kwargs.get('adam_beta2', 0.999), + 'adam_eps': kwargs.get('adam_eps', 1e-8), + 'clip_grad': kwargs.get('clip_grad', 1.0), + 'bf16': kwargs.get('bf16', True), + 'use_distributed_optimizer': use_distributed_optimizer, + 'log_num_zeros_in_grad': kwargs.get('log_num_zeros_in_grad', False), + } + # Keep the old knob only if this Megatron version still exposes it. + # Some branches wire it through ddp_config instead of OptimizerConfig. + if 'overlap_param_gather' in config_sig: + config_kwargs['overlap_param_gather'] = kwargs.get('overlap_param_gather', False) + if 'overlap_param_gather_with_optimizer_step' in config_sig: + config_kwargs['overlap_param_gather_with_optimizer_step'] = kwargs.get( + 'overlap_param_gather_with_optimizer_step', kwargs.get('overlap_param_gather', False)) + for key, value in kwargs.items(): + if key in config_sig and key not in config_kwargs: + config_kwargs[key] = value + + opt_config = OptimizerConfig(**config_kwargs) # Ensure each model chunk has ddp_config attached (required by Megatron optimizer) from megatron.core.distributed import DistributedDataParallelConfig @@ -814,6 +838,7 @@ def _create_megatron_optimizer(self, **kwargs): optimizer = get_megatron_optimizer( config=opt_config, model_chunks=model_chunks, + use_gloo_process_groups=use_gloo_process_groups, ) return optimizer @@ -1508,11 +1533,39 @@ def initialize(self, **kwargs) -> None: if self._initialized: return + import torch.distributed as dist from megatron.core import parallel_state from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from .args import get_args self._try_init_process_group() + # `self._try_init_process_group()` only initializes torch.distributed when + # Platform.get_world_size() > 1. In single-card local runs that means no + # default process group exists yet, but Megatron still expects one before + # parallel_state.initialize_model_parallel(). Keep a rank=0/world_size=1 + # local PG here so single-card smoke and local training can reach the same + # Megatron initialization path as multi-card jobs. + if Platform.device_prefix() == 'npu' and dist.is_initialized(): + # The default NPU process group carries a `device_id` binding. + # Clear it first so Gloo subgroups do not inherit it and fail. + default_pg = dist.distributed_c10d._get_default_group() + if getattr(default_pg, 'bound_device_id', None) is not None: + default_pg.bound_device_id = None + if not dist.is_initialized(): + from twinkle import find_free_port + + backend = Platform.device_backend() + init_kwargs = { + 'backend': backend, + 'init_method': f'tcp://127.0.0.1:{find_free_port()}', + 'rank': 0, + 'world_size': 1, + } + if backend == 'nccl': + init_kwargs['device_id'] = torch.device(Platform.get_local_device()) + # Do not bind `device_id` on the default NPU process group, + # otherwise later Gloo subgroups will inherit it. + dist.init_process_group(**init_kwargs) args = get_args() init_kwargs = { 'tensor_model_parallel_size': args.tensor_model_parallel_size, @@ -1521,6 +1574,10 @@ def initialize(self, **kwargs) -> None: 'virtual_pipeline_model_parallel_size': args.virtual_pipeline_model_parallel_size, 'expert_model_parallel_size': args.expert_model_parallel_size, } + if Platform.device_prefix() == 'npu': + # Enable auxiliary Gloo groups on NPU and keep them separate from + # the main HCCL groups. + init_kwargs['create_gloo_process_groups'] = True if args.order: init_kwargs['order'] = args.order diff --git a/src/twinkle/model/megatron/mindspeed_bootstrap.py b/src/twinkle/model/megatron/mindspeed_bootstrap.py new file mode 100644 index 00000000..bc7f80ef --- /dev/null +++ b/src/twinkle/model/megatron/mindspeed_bootstrap.py @@ -0,0 +1,81 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MindSpeed bootstrap helpers for NPU Megatron runs.""" + +import importlib +import inspect +from argparse import Namespace +from typing import Any, Dict, Optional + +from twinkle import Platform +from ._mindspeed_args import build_mindspeed_namespace, get_mindspeed_signature, sanitize_mindspeed_values + +_DEFAULT_MINDSPEED_VALUES: Optional[Dict[str, Any]] = None +_RUNTIME_MINDSPEED_ARGS: Optional[Namespace] = None +_LAST_REPATCH_SIGNATURE: Optional[str] = None + + +def _get_mindspeed_defaults(args_utils) -> Dict[str, Any]: + global _DEFAULT_MINDSPEED_VALUES + + if _DEFAULT_MINDSPEED_VALUES is None: + defaults = args_utils.get_mindspeed_args(get_defaults=True) + _DEFAULT_MINDSPEED_VALUES = sanitize_mindspeed_values(vars(defaults).copy()) + return _DEFAULT_MINDSPEED_VALUES + + +def _install_full_args_provider(args_utils) -> None: + if getattr(args_utils, '_TWINKLE_RUNTIME_PROVIDER_INSTALLED', False): + return + + def get_full_args(): + if _RUNTIME_MINDSPEED_ARGS is None: + raise RuntimeError('MindSpeed runtime args are not initialized before bootstrap.') + return _RUNTIME_MINDSPEED_ARGS + + args_utils.get_full_args = get_full_args + args_utils._TWINKLE_RUNTIME_PROVIDER_INSTALLED = True + + +def _set_runtime_args(args_utils, runtime_args: Namespace) -> None: + global _RUNTIME_MINDSPEED_ARGS + + _RUNTIME_MINDSPEED_ARGS = runtime_args + args_utils._MINDSPEED_ARGS = runtime_args + + +def _import_mindspeed_adaptor(args_utils): + patch_utils = importlib.import_module('mindspeed.patch_utils') + if not hasattr(patch_utils, 'inspect'): + patch_utils.inspect = inspect + return importlib.import_module('mindspeed.megatron_adaptor') + + +def bootstrap_mindspeed_for_npu(args: Any) -> Optional[Dict[str, Any]]: + global _LAST_REPATCH_SIGNATURE + + if Platform.device_prefix() != 'npu': + return None + + try: + args_utils = importlib.import_module('mindspeed.args_utils') + except ModuleNotFoundError as exc: + raise RuntimeError('MindSpeed is required for Twinkle NPU Megatron runs. ' + 'Please install MindSpeed in the current environment.') from exc + # Fetch MindSpeed defaults here, then merge them with Twinkle args to + # build the final MindSpeed runtime args. + runtime_args = build_mindspeed_namespace(args, _get_mindspeed_defaults(args_utils)) + # Replace get_full_args in mindspeed.args_utils so it returns the runtime + # args constructed by Twinkle. + _install_full_args_provider(args_utils) + # Store the constructed runtime args in mindspeed.args_utils so later + # MindSpeed modules can consume them. + _set_runtime_args(args_utils, runtime_args) + + signature = get_mindspeed_signature(runtime_args) + adaptor = _import_mindspeed_adaptor(args_utils) + if signature != _LAST_REPATCH_SIGNATURE: + if _LAST_REPATCH_SIGNATURE is not None: + adaptor.repatch(vars(runtime_args).copy()) + _LAST_REPATCH_SIGNATURE = signature + + return vars(runtime_args).copy() diff --git a/src/twinkle/model/megatron/model/__init__.py b/src/twinkle/model/megatron/model/__init__.py index c61acef9..71e44abc 100644 --- a/src/twinkle/model/megatron/model/__init__.py +++ b/src/twinkle/model/megatron/model/__init__.py @@ -1,4 +1,4 @@ -from . import gpts, mm_gpts from .constant import MegatronModelType from .gpt_bridge import GPTBridge -from .register import MegatronModelLoader, MegatronModelMeta, get_megatron_model_meta, register_megatron_model +from .register import (MegatronModelLoader, MegatronModelMeta, ensure_megatron_model_registry, get_megatron_model_meta, + register_megatron_model) diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py index 07dfd82a..8382401e 100644 --- a/src/twinkle/model/megatron/model/register.py +++ b/src/twinkle/model/megatron/model/register.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import importlib import torch.nn as nn from dataclasses import dataclass from typing import List, Optional, Type @@ -6,6 +7,7 @@ from .constant import MLLMMegatronModelType MEGATRON_MODEL_MAPPING = {} +_MODELS_REGISTERED = False @dataclass @@ -86,8 +88,18 @@ def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: _MODEL_META_MAPPING = None +def ensure_megatron_model_registry() -> None: + global _MODELS_REGISTERED + if _MODELS_REGISTERED: + return + importlib.import_module(f'{__package__}.gpts') + importlib.import_module(f'{__package__}.mm_gpts') + _MODELS_REGISTERED = True + + def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: global _MODEL_META_MAPPING + ensure_megatron_model_registry() if _MODEL_META_MAPPING is None: _MODEL_META_MAPPING = {} for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items(): diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 0cdeb81d..b43f7930 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -42,6 +42,25 @@ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None): import torch.distributed as dist output_objects = [object] if device_mesh is not None and device_mesh.data_world_size > 1: + # 1. NPU/HCCL object collectives are brittle here. Megatron already creates + # an equivalent Gloo DP group on NPU, so use that backend for Python + # object gather to keep the metric path identical to GPU's semantics + # without relying on HCCL object support. + # 2. We previously left this path on the default backend and saw metric + # gathering hang in `dist.all_gather_object(...)` on 8-card NPU smoke, + # with pystack pointing at the object-collective call chain. Switching + # to the Megatron-created Gloo DP group is what unblocked the metric path. + # 3. If CP is enabled, Megatron builds a separate DP-Gloo group that includes + # the context-parallel dimension; using the plain DP group would pick the + # wrong rank set for metric aggregation. + if Platform.device_prefix() == 'npu': + try: + from megatron.core import parallel_state as mpu + + process_group = mpu.get_data_parallel_group_gloo( + with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) + except Exception: + pass group_size = dist.get_world_size(group=process_group) output_objects = [None for _ in range(group_size)] dist.all_gather_object(output_objects, object, group=process_group)