diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index fade5e455c..b91331c931 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,6 +6,6 @@ /miles/backends/sglang_utils/ @fzyzcjy @yueming-yuan @maocheng23 @yushengsu-thu /miles/ray/ @fzyzcjy @yueming-yuan @maocheng23 /miles/rollout/ @fzyzcjy @yueming-yuan @guapisolo -/miles/rollout/session/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 +/miles/rollout/session/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 @jybsuper /miles/router/ @fzyzcjy @yueming-yuan @guapisolo -/miles/utils/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 +/miles/utils/ @fzyzcjy @yueming-yuan @guapisolo @maocheng23 @jybsuper diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 6acec1451f..ab0caa21c2 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -146,6 +146,11 @@ jobs: ${{ inputs.custom_tag && format('--custom-tag {0}', inputs.custom_tag) || '' }} \ --push + - name: Point latest to current dev + if: github.event_name == 'schedule' || inputs.simulate_schedule == true + run: | + docker buildx imagetools create -t radixark/miles:latest radixark/miles:dev + - name: Prune old dev tags if: github.event_name == 'schedule' run: | @@ -193,3 +198,33 @@ jobs: echo " Failed to delete ${TAG} (HTTP ${HTTP_CODE})" fi done + + build-and-push-dev-glm: + needs: [build-and-push] + # Only rebuild dev-glm when the dev image was built (schedule, push to main, or dispatch with image_tag=dev) + if: needs.build-and-push.result == 'success' && (github.event_name == 'schedule' || inputs.simulate_schedule == true) + runs-on: self-hosted + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver-opts: | + image=moby/buildkit:latest + network=host + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push dev-glm + run: | + docker buildx build \ + -f docker/glm5/Dockerfile.dev-glm \ + -t radixark/miles:dev-glm \ + --push \ + . diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6cd1247deb..8c314f0af4 100755 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -92,6 +92,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -166,118 +167,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - unit-test: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-unit-test')) - runs-on: self-hosted - container: - image: radixark/miles:dev - options: > - --gpus all - --ipc=host - --shm-size=32g - --ulimit memlock=-1 - --ulimit stack=67108864 - --memory=0 - --memory-swap=0 - -v /mnt/nvme0n1/miles_ci:/data/miles_ci - -v /mnt/nvme0n1/miles_ci/models:/root/models - -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets - --privileged - --ulimit nofile=65535:65535 - -v /tmp:/tmp - strategy: - fail-fast: false - matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}] - defaults: - run: - working-directory: ${{ github.workspace }} - env: - GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - HF_TOKEN: ${{ secrets.HF_TOKEN }} - MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} - MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} - MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} - MILES_TEST_USE_INT4_ROLLOUT: ${{ matrix.info.use_int4_rollout || '0' }} - MILES_TEST_USE_BRIDGE: ${{ matrix.info.use_bridge || '0' }} - MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} - MILES_TEST_FEW_GPU: '0' - SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Cleanup Ray processes - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - pkill -9 -f gcs_server 2>/dev/null || true - pkill -9 -f 'ray-dashboard' 2>/dev/null || true - pkill -9 sglang 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - sleep 3 - - - - name: Resolve dependency refs - id: resolve-refs - shell: bash - env: - PR_BODY: ${{ github.event.pull_request.body || '' }} - INPUT_MEGATRON_PR: ${{ github.event.inputs.ci_megatron_pr || '' }} - INPUT_SGLANG_PR: ${{ github.event.inputs.ci_sglang_pr || '' }} - run: | - # Priority: workflow_dispatch input > PR description > default - MEGATRON_PR="${INPUT_MEGATRON_PR}" - SGLANG_PR="${INPUT_SGLANG_PR}" - - # Parse PR description for "ci-megatron-pr:" and "ci-sglang-pr:" - if [ -n "$PR_BODY" ]; then - PR_MEGATRON_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-megatron-pr:\s)\S+' || true) - PR_SGLANG_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-sglang-pr:\s)\S+' || true) - [ -z "$MEGATRON_PR" ] && [ -n "$PR_MEGATRON_PR" ] && MEGATRON_PR="$PR_MEGATRON_PR" - [ -z "$SGLANG_PR" ] && [ -n "$PR_SGLANG_PR" ] && SGLANG_PR="$PR_SGLANG_PR" - fi - - # Defaults - [ -z "$MEGATRON_PR" ] && MEGATRON_PR="miles-main" - [ -z "$SGLANG_PR" ] && SGLANG_PR="sglang-miles" - - # Convert "#N" PR syntax to git fetch ref: "pull/N/head" - resolve_fetch_ref() { - local ref="$1" - if [[ "$ref" =~ ^#([0-9]+)$ ]]; then - echo "pull/${BASH_REMATCH[1]}/head" - else - echo "$ref" - fi - } - MEGATRON_FETCH=$(resolve_fetch_ref "$MEGATRON_PR") - SGLANG_FETCH=$(resolve_fetch_ref "$SGLANG_PR") - - echo "ci_megatron_pr=$MEGATRON_FETCH" >> $GITHUB_OUTPUT - echo "ci_sglang_pr=$SGLANG_FETCH" >> $GITHUB_OUTPUT - echo "Resolved: megatron=$MEGATRON_PR -> fetch=$MEGATRON_FETCH, sglang=$SGLANG_PR -> fetch=$SGLANG_FETCH" - - - name: Install - shell: bash - env: - MEGATRON_PR: ${{ steps.resolve-refs.outputs.ci_megatron_pr }} - SGLANG_PR: ${{ steps.resolve-refs.outputs.ci_sglang_pr }} - run: | - cd /sgl-workspace/sglang && git reset --hard HEAD && git clean -fd && git fetch origin "$SGLANG_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages - cd /root/Megatron-LM && git reset --hard HEAD && git clean -fd && git fetch origin "$MEGATRON_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e . --no-deps --break-system-packages - cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - pip install pytest-asyncio --break-system-packages - - - - name: Execute - shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - e2e-test-sglang: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-sglang')) runs-on: self-hosted @@ -300,7 +189,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 1, "test_file": "e2e/sglang/test_chat_input_ids_equivalence.py"}, {"model_family": "qwen3", "num_gpus": 1, "test_file": "e2e/sglang/test_session_server_tool_call.py"}, {"model_family": "glm47", "num_gpus": 1, "test_file": "e2e/sglang/test_session_server_tool_call.py"}, {"model_family": "qwen3", "num_gpus": 1, "test_file": "e2e/sglang/test_tito_logprob_equivalence.py"}, {"model_family": "glm47", "num_gpus": 1, "test_file": "e2e/sglang/test_tito_logprob_equivalence.py"}] + info: [{"num_gpus": 1, "test_file": "e2e/sglang/test_chat_input_ids_equivalence.py"}, {"model_family": "qwen3", "num_gpus": 1, "test_file": "e2e/sglang/test_session_server_tool_call.py"}, {"model_family": "glm47", "num_gpus": 1, "test_file": "e2e/sglang/test_session_server_tool_call.py"}, {"model_family": "qwen3", "num_gpus": 1, "test_file": "e2e/sglang/test_tito_logprob_equivalence.py"}, {"model_family": "glm47", "num_gpus": 1, "test_file": "e2e/sglang/test_tito_logprob_equivalence.py"}, {"model_family": "qwen3_30b_a3b", "num_gpus": 1, "test_file": "e2e/sglang/test_r3_router_equivalence.py"}, {"model_family": "glm47_flash", "num_gpus": 1, "test_file": "e2e/sglang/test_r3_router_equivalence.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -316,6 +205,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -412,7 +302,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -428,6 +318,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -524,7 +415,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"name": "[FSDP] qwen3-4B-fsdp-true-on-policy", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"name": "[FSDP] qwen3-vl-4B-fsdp", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-distributed", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"name": "[FSDP] qwen3-0.6B-megatron-fsdp-align", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU", "num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -540,6 +431,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -636,7 +528,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -652,6 +544,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -764,6 +657,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -876,6 +770,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -988,6 +883,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -1100,6 +996,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -1212,6 +1109,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -1324,6 +1222,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository @@ -1375,7 +1274,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 4, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] + info: [{"name": "[FSDP] qwen3-4B-fsdp-true-on-policy", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"name": "[FSDP] qwen3-vl-4B-fsdp", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-distributed", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"name": "[FSDP] qwen3-0.6B-megatron-fsdp-align", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU", "num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] defaults: run: working-directory: ${{ github.workspace }} @@ -1391,6 +1290,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0db15cad65..23221d0249 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,10 +1,11 @@ <% set default_image = 'radixark/miles:dev' %> <% set fsdp_tests = [ - {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-4B-fsdp-true-on-policy', 'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-vl-4B-fsdp', 'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-fsdp-distributed', 'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-megatron-fsdp-align', 'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU', 'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 8}, ] %> <% set megatron_tests = [ @@ -27,7 +28,6 @@ <% set short_tests = [ {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 8}, {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 8}, - {'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config_mixed_offload.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config_mixed_offload_ft.py', 'num_gpus': 8}, @@ -67,12 +67,6 @@ {'test_file': 'utils/test_sglang_config.py', 'num_gpus': 0}, ], }, - 'unit-test': { - 'label': 'run-unit-test', - 'tests': [ - {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8} - ], - }, 'e2e-test-sglang': { 'label': 'run-ci-sglang', 'test_executor': 'pytest', @@ -82,6 +76,8 @@ {'test_file': 'e2e/sglang/test_session_server_tool_call.py', 'num_gpus': 1, 'model_family': 'glm47'}, {'test_file': 'e2e/sglang/test_tito_logprob_equivalence.py', 'num_gpus': 1, 'model_family': 'qwen3'}, {'test_file': 'e2e/sglang/test_tito_logprob_equivalence.py', 'num_gpus': 1, 'model_family': 'glm47'}, + {'test_file': 'e2e/sglang/test_r3_router_equivalence.py', 'num_gpus': 1, 'model_family': 'qwen3_30b_a3b'}, + {'test_file': 'e2e/sglang/test_r3_router_equivalence.py', 'num_gpus': 1, 'model_family': 'glm47_flash'}, ], }, 'e2e-test-short': { @@ -94,7 +90,7 @@ }, 'e2e-test-megatron': { 'label': 'run-ci-megatron', - 'tests': megatron_tests, + 'tests': megatron_tests + lora_tests, }, 'e2e-test-precision': { 'label': 'run-ci-precision', @@ -197,6 +193,7 @@ jobs: MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} MILES_TEST_FEW_GPU: '0' SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} + ROUTER_EQ_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} steps: - name: Checkout repository diff --git a/docker/Dockerfile b/docker/Dockerfile index 69e9ba1354..88d556b5c9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,10 +3,10 @@ # # 2. radixark/miles:dev-cu13-arm64 # build-arg:ENABLE_CUDA_13=1 \ -# build-arg:SGLANG_IMAGE_TAG=v0.5.9-cu130-arm64 \ +# build-arg:SGLANG_IMAGE_TAG=v0.5.10-cu130 \ # build-arg:WHEELS_TAG=cu130-aarch64 \ -ARG SGLANG_IMAGE_TAG=v0.5.9 +ARG SGLANG_IMAGE_TAG=v0.5.10 FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang # ======================================== Arguments ============================================= @@ -46,7 +46,7 @@ RUN mkdir -p /tmp/wheels && \ curl -sL "https://api.github.com/repos/${WHEELS_REPO}/releases/tags/${WHEELS_TAG}" \ | python3 -c "import sys, json, subprocess; \ [subprocess.run(['curl', '-fSL', '-o', '/tmp/wheels/' + a['name'], a['browser_download_url']], check=True) \ - for a in json.load(sys.stdin)['assets'] if a['name'].endswith('.whl')]" && \ + for a in json.load(sys.stdin)['assets'] if a['name'].endswith(('.whl', '.tar.gz'))]" && \ ls -lh /tmp/wheels/ # ====================================== Python dependencies ============================================ @@ -63,7 +63,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \ RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -RUN pip install flash-linear-attention==0.4.1 +RUN pip install flash-linear-attention==0.4.2 RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ @@ -83,12 +83,12 @@ RUN git clone https://github.com/${MEGATRON_REPO}.git --recursive -b ${MEGATRON_ RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@d64a639 --no-cache-dir --force-reinstall # RUN pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation RUN pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation -RUN pip install git+https://github.com/yushengsu-thu/Megatron-Bridge.git@merged-megatron-0.16.0rc0-miles --no-deps --no-build-isolation +RUN pip install git+https://github.com/radixark/Megatron-Bridge.git@bridge --no-deps --no-build-isolation RUN pip install megatron-energon --no-deps RUN pip install multi-storage-client --no-deps COPY requirements.txt /tmp/requirements.txt -RUN pip install -r /tmp/requirements.txt +RUN rm -rf /usr/lib/python3/dist-packages/jwt /usr/lib/python3/dist-packages/PyJWT* && pip install -r /tmp/requirements.txt # https://github.com/pytorch/pytorch/issues/168167 RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ @@ -125,4 +125,36 @@ RUN git clone https://github.com/radixark/miles.git /root/miles && \ # int4_qat RUN pip install /tmp/wheels/fake_int4_quant_cuda-*.whl +# ====================================== Install sgl-model-gateway ============================================ +# SGL_ROUTER_USE_WHEELS=0: +# Build from source https://github.com/radixark/sgl-router-for-miles +# SGL_ROUTER_USE_WHEELS=1 (default): +# Install the pre-built sgl-model-gateway wheel + +ARG SGL_ROUTER_USE_WHEELS=1 +ARG SGL_ROUTER_REPO=https://github.com/radixark/sgl-router-for-miles.git +ARG SGL_ROUTER_BRANCH=main + +RUN --mount=type=cache,target=/root/.cache/pip \ + set -eux; \ + if [ "${SGL_ROUTER_USE_WHEELS}" = "1" ]; then \ + pip install --force-reinstall /tmp/wheels/sglang_router-*.whl && \ + tar xzf /tmp/wheels/sgl-model-gateway-linux-*.tar.gz -C /usr/local/bin/ && \ + chmod +x /usr/local/bin/sgl-model-gateway; \ + elif [ "${SGL_ROUTER_USE_WHEELS}" = "0" ]; then \ + git clone --branch "${SGL_ROUTER_BRANCH}" --depth 1 "${SGL_ROUTER_REPO}" /build/sgl-model-gateway && \ + curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + export PATH="/root/.cargo/bin:${PATH}" && \ + python3 -m pip install maturin && \ + cd /build/sgl-model-gateway/bindings/python && \ + ulimit -n 65536 && \ + maturin build --release --features vendored-openssl --out /build/gateway_wheels && \ + cd /build/sgl-model-gateway && \ + cargo build --release --bin sgl-model-gateway --features vendored-openssl && \ + cp target/release/sgl-model-gateway /usr/local/bin/sgl-model-gateway && \ + chmod +x /usr/local/bin/sgl-model-gateway && \ + pip install --force-reinstall /build/gateway_wheels/sglang_router-*.whl && \ + rm -rf /root/.cargo /root/.rustup /build/sgl-model-gateway /build/gateway_wheels; \ + fi + RUN rm -rf /tmp/wheels diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index 016107b92c..bd34fd8ced 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,169 +1,156 @@ -#### Use the base image for ROCm 7 / gfx950 (MI355) - -# ===================================================================== -# Docker Image Version Information (Updated: Feb 5, 2026) -# ===================================================================== -# Base image: ROCm 7 with vllm pre-built for gfx950 -# Target GPU: MI355 (gfx950) -# -# Key Dependencies: -# - sglang: sglang-miles branch -# - sgl_kernel: built from selected sglang commit -# - Megatron-LM: radixark/Megatron-LM -# - TransformerEngine: commit 90c04bcdc3c109505b318f40a39680263af55edf -# - aiter: v0.1.10.post3 -# - Ray: 2.47.1 -# -# Patches: amd_patch/sglv0.5.7/ -# - megatron.patch -# - sglang.patch -# ===================================================================== - - -FROM rocm/sgl-dev:rocm7-vllm-20250904 +# 1. rlsys/miles:MI350-355-latest +# build-arg:SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x + +ARG SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x +FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang SHELL ["/bin/bash", "-ceuxo", "pipefail"] -ARG MAX_JOBS=128 -ARG SGLANG_REPO=sgl-project/sglang +# ======================================== Arguments ============================================= + ARG SGLANG_BRANCH=sglang-miles ARG SGLANG_COMMIT="" + ARG MEGATRON_REPO=radixark/Megatron-LM ARG MEGATRON_BRANCH=miles-main -ARG MEGATRON_COMMIT="" -ENV MAX_JOBS=${MAX_JOBS} -# Set environment variables for gfx950 -ENV GPU_ARCH=gfx950 -ENV PYTORCH_ROCM_ARCH=gfx950 -ENV GPU_ARCH_LIST=gfx950 -ENV AMDGPU_TARGET=gfx950 +ARG MILES_COMMIT=main +ARG GPU_ARCH=gfx950 +ARG MAX_JOBS=128 -########################################### -##############1. Install AITER############# -########################################### -WORKDIR /app +ARG AITER_REPO=https://github.com/ROCm/aiter.git +ARG AITER_COMMIT=v0.1.11.post1 -RUN pip uninstall -y aiter || true -RUN rm -rf aiter -RUN git clone https://github.com/ROCm/aiter.git \ - && cd aiter \ - && git checkout v0.1.10.post3 \ - && curl -fsSL https://patch-diff.githubusercontent.com/raw/ROCm/aiter/pull/2075.patch -o /tmp/aiter-pr2075.patch \ - && git apply --3way /tmp/aiter-pr2075.patch \ - && rm -f /tmp/aiter-pr2075.patch \ - && git submodule sync --recursive \ - && git submodule update --init --recursive \ - && GPU_ARCHS=gfx950 python setup.py develop -########################################### -########################################### -########################################### - - -########################################### -####2. Install TransformerEngine for gfx950 -########################################### -WORKDIR /app - -RUN rm -rf TransformerEngine -RUN git clone https://github.com/ROCm/TransformerEngine.git \ - && cd TransformerEngine \ - && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ - && git submodule update --init --recursive +ARG RCCL_TESTS_REPO=https://github.com/ROCm/rocm-systems.git +ARG RCCL_TESTS_BRANCH=develop +ARG RCCL_TESTS_PATH=projects/rccl-tests + +ARG TRANSFORMER_ENGINE_REPO=https://github.com/ROCm/TransformerEngine.git +ARG TRANSFORMER_ENGINE_BRANCH=v2.8_rocm + +# ======================================== Setup ============================================= +WORKDIR /root/ + +ENV MAX_JOBS=${MAX_JOBS} + +# Build configuration for MI350 / gfx950. +ENV GPU_ARCH=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV AMDGPU_TARGET=${GPU_ARCH} + +# Transformer Engine build knobs for the v2.8_rocm branch. ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_ROCM_ARCH=${GPU_ARCH} ENV NVTE_USE_HIPBLASLT=1 ENV NVTE_USE_ROCM=1 -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - -RUN cd TransformerEngine && pip install . -v -########################################### -########################################### -########################################### - - -######################################### -####3. Install Megatron-LM -######################################### -WORKDIR /app - -RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall - -RUN pip uninstall -y megatron-core || true -RUN rm -rf Megatron-LM -RUN git clone https://github.com/${MEGATRON_REPO}.git \ - && cd Megatron-LM \ - && git fetch origin ${MEGATRON_BRANCH} \ - && if [ -n "${MEGATRON_COMMIT}" ]; then \ - git checkout ${MEGATRON_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi \ - && pip install -e . -######################################### -######################################### -######################################### - - -######################################## -############ 4. Install mbridge######### -######################################## -RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps -######################################## -######################################## -######################################## - - -######################################## -######5. Install Ray#################### -######################################## -RUN pip uninstall ray -y || true -RUN pip install "ray[data,train,tune,serve]==2.47.1" -######################################## -######################################## -######################################## - - -######################################### -###6. Install torch_memory_saver######### -######################################### -RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@64a92e1d7fb822ea4af5579c8cebb162692c531c --no-cache-dir --force-reinstall -######################################### -######################################### - - -####################################### -####7. Install Apex for ROCm########### -####################################### -WORKDIR /app - -RUN pip uninstall -y apex || true -RUN rm -rf apex -RUN git clone https://github.com/ROCm/apex.git \ - && cd apex \ - && python setup.py install -####################################### -####################################### -####################################### - - -######################################## -###8. Install miles agent framework deps -######################################## -RUN pip install pydra_config==0.0.15 -RUN pip install together -RUN pip install google-generativeai -RUN pip install tensorboard -######################################## -######################################## -######################################## - - -######################################## -###9. Set performance environment vars## -######################################## +# Keep the core package enabled and skip the extra fused-attn kernel matrix rebuild. +ENV NVTE_FUSED_ATTN=0 +ENV CMAKE_PREFIX_PATH=/opt/rocm:/opt/rocm/hip:/usr/local:/usr + +# Patch Megatron's fused-kernel init for this toolchain. +COPY docker/amd_patch/latest/megatron.patch /tmp/amd_patch/megatron.patch +COPY requirements.txt /tmp/requirements.txt + +# ======================================== Apt dependencies ============================================= + +RUN apt update +# Install build tools and diagnostics utilities. +RUN apt install -y build-essential cmake dnsutils ethtool git nvtop rsync + +# Build rccl-tests diagnostics binaries. +RUN git clone --depth 1 --branch ${RCCL_TESTS_BRANCH} ${RCCL_TESTS_REPO} /tmp/rocm-systems && \ + make -C /tmp/rocm-systems/${RCCL_TESTS_PATH} -j$(nproc) \ + HIP_HOME=/opt/rocm \ + NCCL_HOME=/opt/rocm \ + GPU_TARGETS=${GPU_ARCH} && \ + cp /tmp/rocm-systems/${RCCL_TESTS_PATH}/build/*_perf /usr/local/bin/ && \ + rm -rf /tmp/rocm-systems + +# ====================================== Python dependencies ============================================ + +# Rebuild AITER at the version paired with SGLang. +RUN pip uninstall -y aiter || true +RUN pip install flydsl==0.0.1.dev95158637 psutil pybind11 +RUN cd /sgl-workspace/aiter && \ + git remote set-url origin ${AITER_REPO} && \ + git checkout ${AITER_COMMIT} && \ + git reset --hard ${AITER_COMMIT} && \ + git clean -fdx && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + # Temporary fixes for the current ROCm 7.2 image/toolchain combination. + sed -i '459 s/if.*:/if False:/' aiter/ops/triton/attention/pa_mqa_logits.py && \ + sed -i '/c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device)/i\ config = dict(config)' \ + aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py && \ + GPU_ARCHS=${GPU_ARCH} pip install -e . + +# Install Transformer Engine from the requested branch. +RUN pip uninstall -y transformer-engine transformer_engine transformer_engine_torch || true +RUN rm -rf /root/TransformerEngine && \ + git clone --recursive --branch ${TRANSFORMER_ENGINE_BRANCH} ${TRANSFORMER_ENGINE_REPO} /root/TransformerEngine && \ + cd /root/TransformerEngine && \ + pip install . --no-build-isolation -v + +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +RUN GPU_ARCHS=${GPU_ARCH} BUILD_TARGET=rocm MAX_JOBS=${MAX_JOBS} \ + pip -v install flash-attn==2.8.3 --no-build-isolation + +RUN pip install flash-linear-attention==0.4.2 + +RUN rm -rf /root/Megatron-LM && \ + git clone --recursive -b ${MEGATRON_BRANCH} https://github.com/${MEGATRON_REPO}.git /root/Megatron-LM && \ + cd /root/Megatron-LM && \ + git apply /tmp/amd_patch/megatron.patch && \ + pip install -e . + +RUN pip uninstall -y sgl_kernel sglang || true +RUN cd /sgl-workspace/sglang && \ + git reset --hard && \ + git clean -fdx && \ + git fetch origin ${SGLANG_BRANCH} && \ + if [ -n "${SGLANG_COMMIT}" ]; then \ + git checkout ${SGLANG_COMMIT}; \ + else \ + git checkout FETCH_HEAD; \ + fi && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + cd sgl-kernel && \ + rm -f pyproject.toml && \ + mv pyproject_rocm.toml pyproject.toml && \ + AMDGPU_TARGET=${GPU_ARCH} python setup_rocm.py install && \ + cd .. && \ + rm -rf python/pyproject.toml && \ + mv python/pyproject_other.toml python/pyproject.toml && \ + pip install -e "python[all_hip]" --no-deps + +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@d64a639 --no-cache-dir --force-reinstall +RUN pip install git+https://github.com/yushengsu-thu/Megatron-Bridge.git@merged-megatron-0.16.0rc0-miles --no-deps --no-build-isolation +RUN pip install megatron-energon --no-deps +RUN pip install multi-storage-client --no-deps + +RUN rm -rf /usr/lib/python3/dist-packages/jwt /usr/lib/python3/dist-packages/PyJWT* && \ + pip install -r /tmp/requirements.txt + +# Pin numpy 1.x for Megatron compatibility. +RUN pip install "numpy<2" + +# ====================================== Install main package ============================================ + +RUN git clone https://github.com/radixark/miles.git /root/miles && \ + cd /root/miles && \ + git checkout ${MILES_COMMIT} && \ + pip install -e . --no-deps + +# ====================================== Runtime knobs ============================================ + +# Runtime knobs consumed by the current SGLang/PyTorch stack. ENV HIP_FORCE_DEV_KERNARG=1 ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_USE_AITER=1 @@ -173,114 +160,11 @@ ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 ENV SGLANG_USE_ROCM700A=1 ENV NCCL_MIN_NCHANNELS=112 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -######################################## -######################################## -######################################## - -########################################### -##############Install SGLang############### -########################################### -WORKDIR /app - -# Install prerequisites -RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 - -# Clone SGLang -RUN pip uninstall -y sgl_kernel sglang || true -RUN rm -rf sglang -RUN git clone https://github.com/${SGLANG_REPO}.git \ - && cd sglang \ - && git fetch origin ${SGLANG_BRANCH} \ - && if [ -n "${SGLANG_COMMIT}" ]; then \ - git checkout ${SGLANG_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi - -# Build sgl-kernel for gfx950 -RUN cd sglang/sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && AMDGPU_TARGET=gfx950 python setup_rocm.py install - -# Install SGLang -RUN cd sglang \ - && rm -rf python/pyproject.toml \ - && mv python/pyproject_other.toml python/pyproject.toml \ - && pip install -e "python[all_hip]" - -# Test SGLang installation -RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" +RUN rm -rf /root/.cache/pip /root/TransformerEngine /tmp/amd_patch -RUN python -m pip cache purge -########################################### -########################################### -########################################### - - -########################################### -#### APPLY PATCHES (gfx950/MI355) ######### -########################################### - -# Copy patch from miles repo -COPY amd_patch/sglv0.5.7/megatron.patch /app/patch/megatron.patch -COPY amd_patch/sglv0.5.7/sglang.patch /app/patch/sglang.patch - -# Apply Megatron patches -RUN cd /app/Megatron-LM \ - && git apply --3way /app/patch/megatron.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "Patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi \ - && pip install -e . -v - -# Apply SGLang patch -RUN cd /app/sglang \ - && git apply --3way /app/patch/sglang.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi - -# Copy MOE configs for gfx950/MI355 -RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ - /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' 2>/dev/null | while read f; do \ - cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ - cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ -done - -########################################### -########################################### -########################################### - - -######################################## -#### Install additional packages######## -######################################## -RUN pip install sglang-router --force-reinstall -######################################## -######################################## -######################################## - - -######################################## -# Fix click/ray incompatibility with Python 3.10 -######################################## -RUN pip install click==8.2.1 -######################################## -######################################## -######################################## - - -WORKDIR /app +WORKDIR /root/ CMD ["/usr/bin/bash"] diff --git a/docker/amd_patch/latest/megatron.patch b/docker/amd_patch/latest/megatron.patch index f6efca346d..acd64149b7 100644 --- a/docker/amd_patch/latest/megatron.patch +++ b/docker/amd_patch/latest/megatron.patch @@ -1,5 +1,4 @@ diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 --- a/megatron/legacy/fused_kernels/__init__.py +++ b/megatron/legacy/fused_kernels/__init__.py @@ -3,6 +3,7 @@ @@ -10,42 +9,12 @@ index 87cceac3..ac686d74 100644 from torch.utils import cpp_extension -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" +@@ -15,6 +16,8 @@ def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') ++ if not torch.version.cuda: ++ return -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/latest/sglang.patch b/docker/amd_patch/latest/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/latest/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch b/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.0rc0/megatron.patch b/docker/amd_patch/sglv0.5.0rc0/megatron.patch deleted file mode 100644 index b129959aff..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/megatron.patch +++ /dev/null @@ -1,792 +0,0 @@ -diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py -index 41c21d93d..ef80f72d6 100644 ---- a/megatron/core/dist_checkpointing/strategies/common.py -+++ b/megatron/core/dist_checkpointing/strategies/common.py -@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): - msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu') - else: -- return torch.load(load_path, map_location='cpu') -+ return torch.load(load_path, map_location='cpu', weights_only=False) - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - if MultiStorageClientFeature.is_enabled(): -diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py -index 5a1ea308d..aa701237f 100644 ---- a/megatron/core/dist_checkpointing/strategies/torch.py -+++ b/megatron/core/dist_checkpointing/strategies/torch.py -@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - def _validate_global_shapes(self, metadata, sharded_tensors): - for sh_ten in sharded_tensors: - if sh_ten.key not in metadata.state_dict_metadata: -- raise KeyError( -- f"{sh_ten.key} from model not in state dict:" -- f" {sorted(metadata.state_dict_metadata.keys())}" -- ) -+ # raise KeyError( -+ # f"{sh_ten.key} from model not in state dict:" -+ # f" {sorted(metadata.state_dict_metadata.keys())}" -+ # ) -+ print(f"{sh_ten.key} from model not in state dict, will skip") -+ continue - loaded_shape = metadata.state_dict_metadata[sh_ten.key].size - expected_shape = self._expected_shape(sh_ten) - if loaded_shape != expected_shape: -@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - tensor_metadata = self.metadata.state_dict_metadata - metadata_with_sizes = [ - (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) -- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() -+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata - ] - try: - # Temporarily set sizes to expected shapes -@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): - planner=MCoreLoadPlanner( - shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, - allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, -+ allow_partial_load=True, - ), - ) - -diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b43..4451f2776 100644 ---- a/megatron/core/distributed/__init__.py -+++ b/megatron/core/distributed/__init__.py -@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads - from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel - from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel - from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig -+ -+# Backward compatibility patch for FSDP module reorganization -+import sys -+import importlib.util -+ -+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') -+if spec: -+ custom_fsdp = importlib.util.module_from_spec(spec) -+ spec.loader.exec_module(custom_fsdp) -+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp -+ if hasattr(custom_fsdp, 'MegatronFSDP'): -+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP -diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index acb93ef78..d239db4ab 100644 ---- a/megatron/core/extensions/transformer_engine.py -+++ b/megatron/core/extensions/transformer_engine.py -@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): - ) - - for param in self.parameters(): -+ setattr(param, "parallel_mode", parallel_mode) - if is_expert: - # Reduce the gradient on the expert_data_parallel group for expert linear layers - setattr(param, "allreduce", not self.expert_parallel) -@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): - - - if HAVE_TE and is_te_min_version("1.9.0.dev0"): -+ def ceil_div(x: int, y: int) -> int: -+ return (x + y - 1) // y -+ -+ class _FakeInt4QuantizationSTE(torch.autograd.Function): -+ @staticmethod -+ def forward(ctx, x, group_size): -+ m, n = x.shape -+ block_size_m, block_size_n = 1, group_size -+ -+ -+ m_padded = ceil_div(m, block_size_m) * block_size_m -+ n_padded = ceil_div(n, block_size_n) * block_size_n -+ -+ x_padded = torch.zeros( -+ (m_padded, n_padded), -+ dtype=x.dtype, device=x.device -+ ) -+ x_padded[:m, :n] = x -+ -+ x_view = x_padded.view( -+ m_padded // block_size_m, -+ block_size_m, -+ n_padded // block_size_n, -+ block_size_n -+ ) -+ -+ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) -+ q_max = 7 -+ x_scale = x_max / q_max -+ -+ x_scale = x_scale.clamp(min=1e-5) -+ -+ x_div = x_view / x_scale -+ x_round = torch.round(x_div) -+ -+ x_q_clamped = x_round.clamp(-q_max, q_max) -+ -+ x_dequant_view = x_q_clamped * x_scale -+ -+ x_dequant_full = x_dequant_view.view_as(x_padded) -+ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) -+ -+ return x_out -+ -+ @staticmethod -+ def backward(ctx, grad_output): -+ return grad_output, None -+ -+ def fake_int4_quantization_ste(x, group_size): -+ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) -+ -+ if hasattr(x, 'main_grad'): -+ x_out.main_grad = x.main_grad -+ -+ return x_out - - class TEGroupedLinear(te.pytorch.GroupedLinear): - """ -@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) -+ - out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - -@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - return out - return out, None - -+ def _get_weight_tensors(self): -+ """Get the weight tensors of the module.""" -+ weight_tensors = super()._get_weight_tensors() -+ -+ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": -+ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) -+ -+ weight_tensors = [ -+ fake_int4_quantization_ste(w, group_size) -+ for w in weight_tensors -+ ] -+ -+ return weight_tensors -+ - def _encode_extra_state(self, state): - # TE 2.0 changed the format of extra_state to be a byte tensor - if is_te_min_version("2.0.0"): -diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -index 1fd5dcfae..c9aeef1f0 100644 ---- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py -+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( - cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - -- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads -- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads -- mask = kv_off < head_num * stride_kv_nheads -- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] -- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] -- k = tl.load(KV_ptr + k_in_off, mask=mask) -- v = tl.load(KV_ptr + v_in_off, mask=mask) -+ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ k_off = ki_range * stride_kv_nheads + kj_range -+ if v_dim > 0: -+ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ v = tl.load(KV_ptr + v_off, mask=mask_v) -+ else: -+ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) -+ k = tl.load(KV_ptr + k_off, mask=mask_k) - -- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads -- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads -+ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads -+ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - -- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] -- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] -- tl.store(K_ptr + k_out_off, k, mask=mask) -- tl.store(V_ptr + v_out_off, v, mask=mask) -+ k_out_off = ki_range * stride_k_nheads + kj_range -+ tl.store(K_ptr + k_out_off, k, mask=mask_k) -+ if v_dim > 0: -+ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] -+ tl.store(V_ptr + v_out_off, v, mask=mask_v) - - EMB = K_POS_EMB + pid_m * stride_emb_seq - # x1 = t[..., 0::2], x2 = t[..., 1::2] -@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( - x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - -+ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ mask_x = x_range < head_num - x_left_off = ( -- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads -+ x_range * stride_k_nheads - + k_dim - + tl.arange(0, emb_dim // 2)[None, :] - ) - x_right_off = x_left_off + emb_dim // 2 -- tl.store(K_ptr + x_left_off, x_left, mask=mask) -- tl.store(K_ptr + x_right_off, x_right, mask=mask) -+ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) -+ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) - - - @triton.autotune( -@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( - else: - token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - -- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads -- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads -- mask = dkv_off < head_num * stride_dkv_nheads -- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] -- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] -- -- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads -- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads -- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] -- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -- dk = tl.load(dK_ptr + dk_in_off, mask=mask) -- dv = tl.load(dV_ptr + dv_in_off, mask=mask) -- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) -- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) -+ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ dk_out_off = ki_range * stride_dkv_nheads + kj_range -+ -+ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads -+ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads -+ dk_in_off = ki_range * stride_dk_nheads + kj_range -+ -+ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) -+ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) -+ -+ if v_dim > 0: -+ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -+ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) -+ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) - - if pid_head == 0: - x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): -- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads -- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim -+ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads -+ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads - mask = x_off < head_num * stride_dk_nheads - x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] - x_right_off = x_left_off + emb_dim // 2 -@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) - o_value = kv.new_empty(total_seqlen, nheads, v_dim) -+ k_dim_ceil = triton.next_power_of_2(k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_fwd_kv_kernel[grid]( -@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - emb_dim, - k_dim, -+ k_dim_ceil, - v_dim, - nheads, - batch_size, -@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) - d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) -+ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_bwd_kv_kernel[grid]( -@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - ctx.emb_dim, - ctx.k_dim, -+ k_dim_ceil, - ctx.v_dim, - nheads, - batch_size, -diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py -index 13d74aa52..060898a7a 100644 ---- a/megatron/core/models/common/language_module/language_module.py -+++ b/megatron/core/models/common/language_module/language_module.py -@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): - assert ( - column_parallel_linear is not None - ), "column_parallel_linear cannot be None when not using fused linear cross entropy." -- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) -+ # output -+ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} -+ output_layer_buffers = dict(column_parallel_linear.named_buffers()) -+ logits, _ = torch.func.functional_call( -+ column_parallel_linear, -+ {**output_layer_params, **output_layer_buffers}, -+ (hidden,), -+ col_linear_kwargs, -+ ) - - return self.compute_language_model_loss(labels, logits) - -diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index e21127b87..712793853 100755 ---- a/megatron/core/models/gpt/gpt_layer_specs.py -+++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_kitchen: bool = False, - use_te_activation_func: bool = False, - fallback_to_eager_attn: bool = False, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). - -@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( - mlp=mlp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - normalization=normalization, -+ post_self_attn_layernorm=post_self_attn_layernorm, -+ post_mlp_layernorm=post_mlp_layernorm, - ) - - -@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( - mlp: ModuleSpec, - sharded_state_dict_keys_map: Optional[dict] = None, - normalization: Optional[str] = None, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Helper function to get module spec for TransformerLayer""" - -@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( - input_layernorm=input_layernorm, - self_attention=attention, - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=pre_mlp_layernorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - ), - ) -diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index a1230568c..1fd52f65a 100644 ---- a/megatron/core/models/gpt/gpt_model.py -+++ b/megatron/core/models/gpt/gpt_model.py -@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, -+ mtp_kwargs: Optional[dict] = {}, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoder and finally into the post -@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, -+ mtp_kwargs=mtp_kwargs, - ) - - def _postprocess( -@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=None, - extra_block_kwargs=None, - inference_context=None, -+ mtp_kwargs={}, - ): - """Postprocesses decoder hidden states to generate logits or compute loss. - -@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() -- if mtp_in_postprocess: -+ -+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, -@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): - return hidden_states - - # Skip when mtp_num_layers is None or 0 -- if self.config.mtp_num_layers: -- mtp_labels = labels.clone() -+ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: -+ mtp_labels = mtp_kwargs['mtp_labels'].clone() -+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) -+ else: -+ # Otherwise, roll the loss_mask to keep up with the mtp_labels -+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( -@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ -- 'weight': output_weight, -+ 'weight': output_weight.detach() if output_weight else None, - 'runtime_gather_output': runtime_gather_output, - }, - ) -diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index 6e093f96f..eac21a3ea 100644 ---- a/megatron/core/optimizer/distrib_optimizer.py -+++ b/megatron/core/optimizer/distrib_optimizer.py -@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - # TE FusedAdam will not accumulate step for empty param groups, so we need to - # align the step across param groups. - param_group["step"] = int(step) -+ if "step" in param_group and param_group["step"] is None: -+ del param_group["step"] - - # Grad scaler state. - if self.grad_scaler: -@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - if key == 'padding': - tensors[key] = LocalNonpersistentObject(tensors[key]) - continue -+ if key == 'step': -+ continue - assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( - tensors[key].shape, - gbuf_local_start, -diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index a273002b9..4f821cfd5 100644 ---- a/megatron/core/parallel_state.py -+++ b/megatron/core/parallel_state.py -@@ -11,6 +11,7 @@ from typing import Callable, List, Optional - - import numpy as np - import torch -+import torch.distributed as dist - - from .utils import GlobalMemoryBuffer, is_torch_min_version - -diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index ac839c21f..f18309217 100644 ---- a/megatron/core/pipeline_parallel/p2p_communication.py -+++ b/megatron/core/pipeline_parallel/p2p_communication.py -@@ -26,22 +26,22 @@ def _batched_p2p_ops( - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group -+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group -+ torch.distributed.isend, tensor_send_next, next_pipeline_rank, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, - ) - ops.append(recv_next_op) - if len(ops) > 0: -diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py -index 28cff06f5..48c9c1a25 100644 ---- a/megatron/core/transformer/moe/moe_utils.py -+++ b/megatron/core/transformer/moe/moe_utils.py -@@ -587,6 +587,9 @@ def topk_routing_with_score_function( - else: - return torch.topk(scores, k=topk, dim=1) - -+ from miles.utils.routing_replay import get_routing_replay_compute_topk -+ compute_topk = get_routing_replay_compute_topk(compute_topk) -+ - if score_function == "softmax": - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) -diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py -index 16fc9d9af..3e95858a6 100644 ---- a/megatron/core/transformer/moe/router.py -+++ b/megatron/core/transformer/moe/router.py -@@ -201,6 +201,9 @@ class TopKRouter(Router): - self.global_tokens_per_expert = None - self.ga_steps = None - -+ from miles.utils.routing_replay import register_routing_replay -+ register_routing_replay(self) -+ - def _maintain_float32_expert_bias(self): - """ - Maintain the expert bias in float32. -diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py -index a8f4abfcd..f33f6f05e 100755 ---- a/megatron/core/transformer/multi_token_prediction.py -+++ b/megatron/core/transformer/multi_token_prediction.py -@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union - - import torch - from torch import Tensor -+import warnings - - from megatron.core import InferenceParams, parallel_state, tensor_parallel - from megatron.core.dist_checkpointing.mapping import ShardedStateDict -@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) -- position_ids, _ = roll_tensor( -- position_ids, -- shifts=-1, -- dims=-1, -- cp_group=self.cp_group, -- packed_seq_params=packed_seq_params, -- ) -+ if position_ids is not None: -+ position_ids, _ = roll_tensor( -+ position_ids, -+ shifts=-1, -+ dims=-1, -+ cp_group=self.cp_group, -+ packed_seq_params=packed_seq_params, -+ ) - # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) -+ decoder_input = decoder_input.detach() - -- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) -+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) - - return input_ids, position_ids, decoder_input, hidden_states - -@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): - return hidden_states - - def _checkpointed_forward(self, forward_func, *args, **kwargs): -+ """Wrap `forward_func` with activation checkpointing while only passing tensors. -+ -+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so -+ that checkpoint implementations never receive them directly, avoiding save_for_backward -+ issues with non-tensor inputs. -+ """ -+ -+ # TODO(jiajun): Is there any better implementation here? -+ positional_specs = [] -+ kw_specs = [] -+ tensor_args: List[torch.Tensor] = [] -+ -+ for arg in args: -+ if torch.is_tensor(arg): -+ positional_specs.append(('tensor', len(tensor_args))) -+ tensor_args.append(arg) -+ else: -+ positional_specs.append(('const', arg)) -+ -+ for key, value in kwargs.items(): -+ if torch.is_tensor(value): -+ kw_specs.append((key, ('tensor', len(tensor_args)))) -+ tensor_args.append(value) -+ else: -+ kw_specs.append((key, ('const', value))) -+ -+ def run(*flat_tensor_args): -+ rebuilt_args = [] -+ for spec_type, payload in positional_specs: -+ if spec_type == 'tensor': -+ rebuilt_args.append(flat_tensor_args[payload]) -+ else: -+ rebuilt_args.append(payload) -+ -+ rebuilt_kwargs = {} -+ for key, (spec_type, payload) in kw_specs: -+ if spec_type == 'tensor': -+ rebuilt_kwargs[key] = flat_tensor_args[payload] -+ else: -+ rebuilt_kwargs[key] = payload -+ -+ return forward_func(*rebuilt_args, **rebuilt_kwargs) -+ -+ tensor_args_tuple = tuple(tensor_args) -+ - def checkpoint_handler(): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: -@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), -- *args, -- **kwargs, -+ *tensor_args_tuple, - ) - else: - return tensor_parallel.checkpoint( -- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() -+ run, self.config.distribute_saved_activations, *tensor_args_tuple - ) - - if self.config.recompute_method == 'uniform': -diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index e2705bd9f..a0aa109b5 100644 ---- a/megatron/core/transformer/transformer_config.py -+++ b/megatron/core/transformer/transformer_config.py -@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): - attention_output_gate: bool = False - """Whether to apply output gate to the attention layers.""" - -+ post_self_attn_layernorm: bool = False -+ post_mlp_layernorm: bool = False -+ - test_mode: bool = False - """Whether to run real-time tests.""" - -diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 3ea405770..5a42001b9 100644 ---- a/megatron/core/transformer/transformer_layer.py -+++ b/megatron/core/transformer/transformer_layer.py -@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: - input_layernorm: Union[ModuleSpec, type] = IdentityOp - self_attention: Union[ModuleSpec, type] = IdentityOp - self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - - pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp - mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - - # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method - sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - # [Module 3: BiasDropoutFusion] - self.self_attn_bda = build_module(submodules.self_attn_bda) - -+ self.post_self_attn_layernorm = build_module( -+ submodules.post_self_attn_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon, -+ ) -+ - # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, -@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - - self.is_moe_layer = isinstance(self.mlp, MoELayer) - -+ self.post_mlp_layernorm = build_module( -+ submodules.post_mlp_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon -+ ) -+ - self.recompute_input_layernorm = False - self.recompute_pre_mlp_layernorm = False - self.recompute_mlp = False -@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - attention_output_with_bias[0] - ) - -+ attention_output, attention_output_bias = attention_output_with_bias -+ attention_output = self.post_self_attn_layernorm(attention_output) -+ attention_output_with_bias = (attention_output, attention_output_bias) -+ - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - nvtx_range_push(suffix="self_attn_bda") -@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - -+ mlp_output, mlp_output_bias = mlp_output_with_bias -+ mlp_output = self.post_mlp_layernorm(mlp_output) -+ mlp_output_with_bias = (mlp_output, mlp_output_bias) -+ - if self.recompute_pre_mlp_layernorm: - # discard the output of the pre-mlp layernorm and register the recompute - # as a gradient hook of mlp_output_with_bias[0] -diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index b267c8a81..83736acdc 100644 ---- a/megatron/training/arguments.py -+++ b/megatron/training/arguments.py -@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): - - kw_args['inference_sampling_seed'] = args.seed - -+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm -+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm -+ - # handle quantization config - # NOTE: Kitchen arguments are only added to the namespace when - # Kitchen library is available. -@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') -+ group.add_argument('--post-self-attn-layernorm', action='store_true', -+ help='If set, use post self attention layernorm.') -+ group.add_argument('--post-mlp-layernorm', action='store_true', -+ help='If set, use post MLP layernorm.') -+ group.add_argument('--use-gated-attention', action='store_true', -+ help='If set, use gated attention as in Qwen3Next') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' -diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 13b7526ca..6c590f653 100644 ---- a/megatron/training/tokenizer/tokenizer.py -+++ b/megatron/training/tokenizer/tokenizer.py -@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): - # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there - self._tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, -- trust_remote_code=trust_remote_code, -+ trust_remote_code=True, - **kwargs, - ) - self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.0rc0/sglang.patch b/docker/amd_patch/sglv0.5.0rc0/sglang.patch deleted file mode 100644 index 990c2e6289..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/sglang.patch +++ /dev/null @@ -1,203 +0,0 @@ -diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index bdb124e51..3edf30ab1 100644 ---- a/python/sglang/srt/configs/model_config.py -+++ b/python/sglang/srt/configs/model_config.py -@@ -454,14 +454,14 @@ class ModelConfig: - ).lower() - - # Detect which checkpoint is it -- for _, method in QUANTIZATION_METHODS.items(): -- quantization_override = method.override_quantization_method( -- quant_cfg, self.quantization -- ) -- if quantization_override: -- quant_method = quantization_override -- self.quantization = quantization_override -- break -+ # for _, method in QUANTIZATION_METHODS.items(): -+ # quantization_override = method.override_quantization_method( -+ # quant_cfg, self.quantization -+ # ) -+ # if quantization_override: -+ # quant_method = quantization_override -+ # self.quantization = quantization_override -+ # break - - # Verify quantization configurations. - if self.quantization is None: -diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 2dd2c75f1..f2adb18f8 100644 ---- a/python/sglang/srt/entrypoints/http_server.py -+++ b/python/sglang/srt/entrypoints/http_server.py -@@ -264,6 +264,10 @@ async def validate_json_request(raw_request: Request): - - - @app.get("/health") -+async def health(request: Request) -> Response: -+ return Response(status_code=200) -+ -+ - @app.get("/health_generate") - async def health_generate(request: Request) -> Response: - """ -diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 372717bf9..40665cc90 100644 ---- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -+++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -@@ -190,6 +190,7 @@ class DeepEPBuffer: - f"Consider using --deepep-config to change the behavior." - ) - -+ num_qps_per_rank = 20 - cls._buffer = Buffer( - group, - num_nvl_bytes, -diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py -index 956264fc9..69f729336 100644 ---- a/python/sglang/srt/layers/quantization/fp8.py -+++ b/python/sglang/srt/layers/quantization/fp8.py -@@ -351,10 +351,10 @@ class Fp8LinearMethod(LinearMethodBase): - return - else: - weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data -- layer.weight = torch.nn.Parameter(weight, requires_grad=False) -- layer.weight_scale_inv = torch.nn.Parameter( -- weight_scale, requires_grad=False -- ) -+ # layer.weight = torch.nn.Parameter(weight, requires_grad=False) -+ # layer.weight_scale_inv = torch.nn.Parameter( -+ # weight_scale, requires_grad=False -+ # ) - return - - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) -diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 95a529c89..758fbfd5f 100644 ---- a/python/sglang/srt/managers/scheduler.py -+++ b/python/sglang/srt/managers/scheduler.py -@@ -1359,7 +1359,7 @@ class Scheduler( - - if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" -- raise ValueError(msg) -+ # raise ValueError(msg) - - if self.disaggregation_mode == DisaggregationMode.DECODE: - req_total_size = ( -@@ -1374,7 +1374,7 @@ class Scheduler( - f"available_size={len(self.req_to_token_pool.free_slots)}, " - f"total_size={self.req_to_token_pool.size}\n" - ) -- raise ValueError(msg) -+ # raise ValueError(msg) - - if ( - self.enable_metrics -@@ -1830,6 +1830,7 @@ class Scheduler( - deepep_mode=DeepEPMode(self.server_args.deepep_mode), - require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), - disable_overlap_schedule=self.server_args.disable_overlap_schedule, -+ offload_tags=self.offload_tags, - ) - - def handle_dp_balance_data(self, local_batch: ScheduleBatch): -@@ -1927,6 +1928,7 @@ class Scheduler( - deepep_mode: DeepEPMode, - require_mlp_tp_gather: bool, - disable_overlap_schedule: bool, -+ offload_tags: set[str], - ): - # Check if other DP workers have running batches - if local_batch is None: -@@ -1957,7 +1959,7 @@ class Scheduler( - ) - - tbo_preparer = TboDPAttentionPreparer() -- if disable_overlap_schedule: -+ if len(offload_tags) == 0 and disable_overlap_schedule: - group = tp_group.device_group - device = tp_group.device - else: -diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 58220b1d6..3c3d081a8 100644 ---- a/python/sglang/srt/managers/tokenizer_manager.py -+++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -1044,10 +1044,15 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 -- ), "dp_size must be 1 for init parameter update group" -- result = (await self.init_weights_update_group_communicator(obj))[0] -+ results = await self.init_weights_update_group_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - return result.success, result.message - - async def update_weights_from_distributed( -@@ -1056,9 +1061,6 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention -- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" - - if obj.abort_all_requests: - self.abort_request(abort_all=True) -@@ -1066,8 +1068,15 @@ class TokenizerManager: - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: -- result = (await self.update_weights_from_distributed_communicator(obj))[0] -- return result.success, result.message -+ results = await self.update_weights_from_distributed_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - - async def update_weights_from_tensor( - self, -diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 5222bff0a..ff0bbc62a 100644 ---- a/python/sglang/srt/model_executor/model_runner.py -+++ b/python/sglang/srt/model_executor/model_runner.py -@@ -22,6 +22,7 @@ import os - import time - from dataclasses import dataclass - from typing import List, Optional, Tuple, Union -+from contextlib import nullcontext - - import torch - import torch.distributed as dist -@@ -675,7 +676,7 @@ class ModelRunner: - monkey_patch_vllm_parallel_state() - monkey_patch_isinstance_for_vllm_base_layer() - -- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): -+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS) if not self.is_draft_worker else nullcontext(): - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, -diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index e0f0b373d..a18ac10f1 100644 ---- a/python/sglang/srt/models/glm4_moe.py -+++ b/python/sglang/srt/models/glm4_moe.py -@@ -1108,5 +1108,4 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): - ) - weight_loader(param, loaded_weight) - -- - EntryClass = [Glm4MoeForCausalLM] diff --git a/docker/amd_patch/sglv0.5.10/megatron.patch b/docker/amd_patch/sglv0.5.10/megatron.patch new file mode 100644 index 0000000000..acd64149b7 --- /dev/null +++ b/docker/amd_patch/sglv0.5.10/megatron.patch @@ -0,0 +1,20 @@ +diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py +--- a/megatron/legacy/fused_kernels/__init__.py ++++ b/megatron/legacy/fused_kernels/__init__.py +@@ -3,6 +3,7 @@ + import os + import pathlib + import subprocess ++import torch + + from torch.utils import cpp_extension + +@@ -15,6 +16,8 @@ + + + def load(args): ++ if not torch.version.cuda: ++ return + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/sglv0.5.7/megatron.patch b/docker/amd_patch/sglv0.5.7/megatron.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.7/megatron.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.7/sglang.patch b/docker/amd_patch/sglv0.5.7/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/sglv0.5.7/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/docker/glm5/Dockerfile.dev-glm b/docker/glm5/Dockerfile.dev-glm new file mode 100644 index 0000000000..4ddbfbfc4b --- /dev/null +++ b/docker/glm5/Dockerfile.dev-glm @@ -0,0 +1,2 @@ +FROM radixark/miles:dev +RUN pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa diff --git a/docs/en/advanced/arch-support-beyond-megatron.md b/docs/en/advanced/arch-support-beyond-megatron.md index 0db8c8a40a..4b3e2a02ca 100644 --- a/docs/en/advanced/arch-support-beyond-megatron.md +++ b/docs/en/advanced/arch-support-beyond-megatron.md @@ -27,6 +27,35 @@ miles leverages this mechanism by **hijacking the spec generation stage to repla Through the coordination of these three components, we can successfully run a complex model architecture not natively supported by Megatron—using its HuggingFace implementation as the vehicle—on top of Megatron's parallel framework. This is achieved while fully retaining all key capabilities like model parallelism, MoE acceleration, and pipeline scheduling. +## Mixed-Precision: Preserving fp32 Parameters in bf16 Models + +Some model architectures require specific parameters to remain in fp32 even when the rest of the model runs in bf16. For example, Qwen3.5's `A_log` parameter must stay fp32 — if rounded to bf16, Megatron-side activations diverge from sglang's fp32 `A_log` on the rollout side, causing precision drift. + +Megatron's training stack has **three implicit cast points** that silently round fp32 parameters to bf16: `Float16Module` construction, `Bridge._weight_to_mcore_format`, and `Bridge.load_weights`. Both steps below are required — doing only one leaves a silent precision trap where the final dtype *looks* correct (fp32) but values were already rounded to bf16 precision. + +### Step 1: Mark the parameter in your model definition + +```python +from miles.backends.megatron_utils.fp32_param_utils import mark_param_dtype + +# In your model's __init__: +self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) +mark_param_dtype(self.A_log, torch.float32) +``` + +`enforce_marked_param_dtypes(model)` — already wired into training and checkpoint conversion entry points — restores tagged params to fp32 after `Float16Module` casts the entire model to bf16. + +### Step 2: Override the Bridge to bypass bf16 pre-cast during weight loading + +```python +class Qwen3_5Bridge(Qwen2MoEBridge): + def _weight_to_mcore_format(self, mcore_weights_name, hf_weights): + if mcore_weights_name.endswith("self_attention.linear_attn.A_log"): + assert len(hf_weights) == 1 + return hf_weights[0].to(dtype=torch.float32).contiguous() + return super()._weight_to_mcore_format(mcore_weights_name, hf_weights) +``` + ## Current Limitations * This approach does not currently support Tensor Parallelism (TP) within the replaced module itself (e.g., the Attention layer in this case). diff --git a/docs/en/get_started/qa.md b/docs/en/get_started/qa.md index c9e8dad21f..f6d5beea11 100644 --- a/docs/en/get_started/qa.md +++ b/docs/en/get_started/qa.md @@ -65,4 +65,8 @@ 13. **Gradient becomes NaN or Inf during training.** - You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps. \ No newline at end of file + You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps. + +14. **NCCL error: `Failed to bind NVLink SHARP (NVLS) Multicast memory ... CUDA error 2 'out of memory'`.** + + This issue has been observed on H100 in colocate mode with piece-wise CUDA graph enabled. Piece-wise CUDA graph is now disabled by default in colocate mode. If you encounter this after explicitly enabling it via `--sglang-enforce-piecewise-cuda-graph`, remove that flag. diff --git a/examples/experimental/swe-agent-v2/README.md b/examples/experimental/swe-agent-v2/README.md index 4190c80292..1c8aebc54b 100644 --- a/examples/experimental/swe-agent-v2/README.md +++ b/examples/experimental/swe-agent-v2/README.md @@ -50,7 +50,6 @@ Docker Network (swe-net) | `swe_agent_function.py` | Custom agent function — dispatches to Harbor server, returns env metadata | | `generate.py` | Reward function, agent metrics aggregation, `RolloutFn` | | `download_and_process_data.py` | Download from HuggingFace or local JSONL, convert to Miles format | -| `prepare_harbor_tasks.py` | Convert Miles JSONL to Harbor task directories (generic fallback) | ## Step-by-Step Setup @@ -59,7 +58,6 @@ Docker Network (swe-net) - Docker with GPU support (nvidia-container-toolkit) - Model weights downloaded (e.g. `zai-org/GLM-4.7-Flash`) - `transformers>=5` (`pip install "transformers>=5"` — GLM-4.7-Flash's `glm4_moe_lite` model type is not in transformers 4.x) -- Recommended transformer version: `pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa` - Harbor task directories prepared under a shared path ### Step 1: Create Docker network @@ -117,6 +115,20 @@ pip install harbor ### Step 4: Prepare data and Harbor task directories +Harbor task directories are prepared on the agent server side using **harbor adapters**. Each adapter converts a specific dataset into Harbor's 4-file task format. For example, to prepare SWE-bench tasks: + +```bash +# On the agent server (CPU machine), inside the harbor repo: +cd $CWD/harbor/adapters/swebench && uv sync + +# Generate Harbor task directories for all SWE-bench Verified instances +uv run run_adapter.py --task-dir $HARBOR_TASKS_DIR --all +``` + +This uses the `swebench` Python package to produce correct Docker image names and Dockerfiles for each instance. Other adapters (e.g. `adapters/swe-gym`) follow the same pattern. + +To prepare training data on the Miles side: + ```bash # Inside miles container: @@ -127,10 +139,6 @@ python download_and_process_data.py --input /data/tb.jsonl --output tb.jsonl \ # Merge into one mixed JSONL cat swe.jsonl tb.jsonl > mixed.jsonl - -# Create Harbor task dirs (for custom data without a Harbor adapter) -python prepare_harbor_tasks.py --input my.jsonl --output /root/harbor_tasks/ \ - --docker-network swe-net ``` Each Harbor task directory contains 4 files: @@ -294,7 +302,7 @@ Agent containers need to resolve the Miles container's hostname. Ensure: ### `TaskNotFound` error -The task directory for the given `instance_id` doesn't exist under `HARBOR_TASKS_DIR`. Run the appropriate Harbor adapter or `prepare_harbor_tasks.py` first. +The task directory for the given `instance_id` doesn't exist under `HARBOR_TASKS_DIR`. Run the appropriate harbor adapter first (e.g. `adapters/swebench/run_adapter.py` for SWE-bench tasks). ### SGLang engines OOM (`Not enough memory`) diff --git a/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py b/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py deleted file mode 100644 index ec452d980b..0000000000 --- a/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Convert training data to Harbor task directories (generic fallback). - -Reads a Miles JSONL (produced by ``download_and_process_data.py``) and -creates one Harbor task directory per instance. Each task directory is -self-contained — Harbor treats all tasks identically regardless of -their origin (SWE-bench, Terminal-Bench, custom, etc.). - -For standard benchmarks, prefer using Harbor's official adapters or -``harbor run -d `` to generate task directories — they -produce the exact grading harness used upstream. This script is a -generic fallback for custom datasets. - -Usage: - - python prepare_harbor_tasks.py \\ - --input /root/custom_train.jsonl \\ - --output /root/harbor_tasks/ \\ - --docker-network swe-net - -Required metadata fields per record: - - instance_id: unique task identifier (becomes directory name) - -Optional metadata fields (read if present): - - problem_statement / instruction / prompt: task text -> instruction.md - - docker_image: base Docker image (default: ubuntu:24.04) - - setup_commands: extra Dockerfile RUN commands (str or list) - - test_script: content of tests/test.sh - - timeout: verifier timeout in seconds (default: 1800) - - repo, version: included in task.toml if present - - patch: oracle solution -> solution/solve.sh -""" - -import argparse -import json -import logging -import os -import textwrap -from pathlib import Path - -logger = logging.getLogger(__name__) - - -def _get_instruction(metadata: dict) -> str: - for key in ("problem_statement", "instruction", "prompt"): - val = metadata.get(key, "") - if val: - return val - return "" - - -def _create_task_dir( - instance_id: str, - metadata: dict, - output_dir: Path, - docker_network: str | None = None, -) -> Path: - """Create a Harbor task directory from metadata.""" - task_dir = output_dir / instance_id - task_dir.mkdir(parents=True, exist_ok=True) - - (task_dir / "instruction.md").write_text(_get_instruction(metadata)) - - repo = metadata.get("repo", "") - version = metadata.get("version", "") - timeout = metadata.get("timeout", 1800) - - toml_lines = [ - "[task]", - f'name = "{instance_id}"', - ] - if repo: - toml_lines.append(f'repo = "{repo}"') - if version: - toml_lines.append(f'version = "{version}"') - toml_lines += [ - "", - "[limits]", - f"timeout = {timeout}", - ] - (task_dir / "task.toml").write_text("\n".join(toml_lines) + "\n") - - env_dir = task_dir / "environment" - env_dir.mkdir(exist_ok=True) - - docker_image = metadata.get("docker_image", "ubuntu:24.04") - setup_cmds = metadata.get("setup_commands", "") - if isinstance(setup_cmds, list): - setup_cmds = " && ".join(setup_cmds) - setup_block = f"RUN {setup_cmds}\n" if setup_cmds else "" - - (env_dir / "Dockerfile").write_text(f"FROM {docker_image}\n{setup_block}") - - if docker_network: - compose_yaml = textwrap.dedent( - f"""\ - services: - main: - networks: - - {docker_network} - networks: - {docker_network}: - external: true - """ - ) - (env_dir / "docker-compose.yaml").write_text(compose_yaml) - - tests_dir = task_dir / "tests" - tests_dir.mkdir(exist_ok=True) - - test_script = metadata.get("test_script", "") - if test_script: - test_sh = f"#!/bin/bash\n{test_script}\n" - else: - test_sh = textwrap.dedent( - """\ - #!/bin/bash - echo 0 > /logs/verifier/reward.txt - """ - ) - - (tests_dir / "test.sh").write_text(test_sh) - os.chmod(tests_dir / "test.sh", 0o755) - - patch = metadata.get("patch", "") - if patch: - sol_dir = task_dir / "solution" - sol_dir.mkdir(exist_ok=True) - (sol_dir / "fix.patch").write_text(patch) - solve_sh = textwrap.dedent( - """\ - #!/bin/bash - git apply "$(dirname "$0")/fix.patch" - """ - ) - (sol_dir / "solve.sh").write_text(solve_sh) - os.chmod(sol_dir / "solve.sh", 0o755) - - return task_dir - - -def convert( - input_path: str, - output_dir: str, - docker_network: str | None = None, -) -> int: - """Convert all instances from JSONL to Harbor task directories. - - Returns the number of tasks created. - """ - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - - records: list[dict] = [] - - with open(input_path) as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - except json.JSONDecodeError as e: - logger.warning(f"Skipping line {line_num}: {e}") - continue - - metadata = data.get("metadata", data) - instance_id = metadata.get("instance_id", "") - if not instance_id: - logger.warning(f"Skipping line {line_num}: no instance_id") - continue - - records.append(metadata) - - if not records: - logger.warning("No valid records found") - return 0 - - count = 0 - for metadata in records: - instance_id = metadata["instance_id"] - _create_task_dir( - instance_id, - metadata, - output_path, - docker_network=docker_network, - ) - count += 1 - if count % 100 == 0: - logger.info(f"Created {count} task directories...") - - logger.info(f"Created {count} task directories in {output_dir}") - return count - - -def main(): - parser = argparse.ArgumentParser( - description="Convert training JSONL to Harbor task directories", - ) - parser.add_argument( - "--input", - required=True, - help="Path to training JSONL", - ) - parser.add_argument( - "--output", - required=True, - help="Output directory for Harbor tasks", - ) - parser.add_argument( - "--docker-network", - default=None, - help="External Docker network for containers to join " "(e.g. swe-net)", - ) - args = parser.parse_args() - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s %(message)s", - ) - convert(args.input, args.output, docker_network=args.docker_network) - - -if __name__ == "__main__": - main() diff --git a/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py b/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py new file mode 100644 index 0000000000..104e2ad2c7 --- /dev/null +++ b/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py @@ -0,0 +1,354 @@ +"""GLM-4.7 Full (355B-A32B) fully-async reasoning training with GSM8K data. + +Disaggregated fully-async variant of run-glm47-reasoning.py: training and +rollout run on separate nodes concurrently. Uses train_async.py and the +fully_async_rollout module so that weight updates do not block generation. + +Default split: 4 nodes training + 12 nodes inference (configurable via +--train-num-nodes). Same model architecture as GLM-4.5-355B-A32B. +Targets 16 x 8-GPU H200 nodes. + +Usage: + python run-glm47-reasoning-async.py --num-nodes 16 + python run-glm47-reasoning-async.py --num-nodes 16 --train-num-nodes 8 + python run-glm47-reasoning-async.py --num-nodes 16 --rollout-fp8 + python run-glm47-reasoning-async.py --num-nodes 16 --pause-generation-mode retract --update-weight-transfer-mode p2p + python run-glm47-reasoning-async.py --num-nodes 16 --skip-prepare +""" + +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +SCRIPT_DIR = Path(__file__).resolve().parent +FULLY_ASYNC_DIR = (Path(__file__).resolve().parent.parent.parent / "fully_async").resolve() + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_rollout_only"] = "normal" + run_id: str = U.create_run_id() + megatron_model_type: str = "glm4.5-355B-A32B" + num_gpus_per_node: int = 8 + megatron_path: str = "/root/Megatron-LM" + + # Paths + skip_prepare: bool = False + model_name: str = "GLM-4.7" + hf_checkpoint: str = "/models/zai-org/GLM-4.7" + ref_load: str = "/models/zai-org/GLM-4.7_torch_dist" + save_dir: str = "/root/GLM-4.7-Full_reasoning_async/" + prompt_data: str = "/root/datasets/gsm8k/train.parquet" + rollout_max_response_len: int = 1024 + + # Rollout precision + rollout_fp8: bool = False + rollout_health_check_first_wait: int = 1800 + + # Disaggregated fully-async settings + train_num_nodes: int = 4 + pause_generation_mode: Literal["in_place", "retract"] = "in_place" + update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" + accumulate_allreduce_grads_in_fp32: bool = False + max_tokens_per_gpu: int = 2048 + optimizer_cpu_offload: bool = True + use_precision_aware_optimizer: bool = True + + # W&B settings + wandb_key: str = os.environ.get("WANDB_KEY", os.environ.get("WANDB_API_KEY", "")) + wandb_project: str = os.environ.get("WANDB_PROJECT", "glm47-full-reasoning-async") + wandb_team: str = os.environ.get("WANDB_TEAM", "") + wandb_run_name: str = "glm47-full-gsm8k-async" + + # Prometheus settings + use_prometheus: bool = True + prometheus_port: int = 9090 + prometheus_run_name: str = "glm47-full-gsm8k-async" + + +def cleanup(): + """Kill old Ray jobs and stale processes to free GPU resources.""" + my_pid = os.getpid() + ppid = os.getppid() + print(f"Cleanup starting (pid={my_pid}, ppid={ppid})") + targets = ["sglang", "train.py", "train_async.py", "MegatronTrain"] + exclude = f"grep -v '^{my_pid}$' | grep -v '^{ppid}$'" + for t in targets: + subprocess.run( + f"pgrep -f '{t}' | {exclude} | xargs -r kill 2>/dev/null || true", + shell=True, + ) + time.sleep(5) + print(f"Cleanup complete (pid={my_pid}) — old processes killed.") + + +def _convert_hf_to_fp8(args: ScriptArgs): + """Convert HF bf16 checkpoint to block-wise FP8 for SGLang rollout.""" + fp8_dir = f"{args.hf_checkpoint}-FP8" + if Path(fp8_dir).exists(): + print(f"FP8 checkpoint already exists at {fp8_dir}, skipping conversion.") + return + U.exec_command( + "python tools/convert_hf_to_fp8.py " + f"--model-dir {args.hf_checkpoint} " + f"--save-dir {fp8_dir} " + "--strategy block --block-size 128 128 " + "--max-workers 4" + ) + + +def prepare(args: ScriptArgs): + """Download GSM8K data and convert HF checkpoint to torch_dist format.""" + U.hf_download_dataset("zhuzilin/gsm8k") + + max_convert_nodes = 92 // args.num_gpus_per_node + convert_nodes = min(args.num_nodes, max_convert_nodes) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + multinode=True, + num_nodes=convert_nodes, + dir_dst=str(Path(args.ref_load).parent), + hf_checkpoint=args.hf_checkpoint, + megatron_path=args.megatron_path, + ) + + if args.rollout_fp8: + _convert_hf_to_fp8(args) + + +def execute(args: ScriptArgs): + if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": + raise ValueError( + "in_place + p2p is not supported: P2P transfer engine conflicts with " + "active NCCL inference. Use broadcast with in_place, or retract with p2p." + ) + + hf_checkpoint = f"{args.hf_checkpoint}-FP8" if args.rollout_fp8 else args.hf_checkpoint + ckpt_args = ( + f"--hf-checkpoint {hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_dir} " + "--save-interval 100 " + ) + + rollout_args = ( + "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " + f"--prompt-data {args.prompt_data} " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 4 " + "--rollout-temperature 0.8 " + f"--rollout-max-response-len {args.rollout_max_response_len} " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 64 " + "--balance-data " + f"--pause-generation-mode {args.pause_generation_mode} " + ) + + eval_args = ( + # "--eval-interval 20 " + # "--skip-eval-before-train " + # "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + # "--n-samples-per-eval-prompt 1 " + # "--eval-max-response-len 1024 " + # "--eval-top-k 1 " + ) + + # Disaggregated split: training on train_num_nodes, inference on the rest. + rollout_num_nodes = args.num_nodes - args.train_num_nodes + assert rollout_num_nodes > 0, ( + f"train_num_nodes ({args.train_num_nodes}) must be less than " + f"num_nodes ({args.num_nodes}) to leave room for inference" + ) + train_gpus = args.train_num_nodes * args.num_gpus_per_node + rollout_gpus = rollout_num_nodes * args.num_gpus_per_node + print( + f"Disagg split: {args.train_num_nodes} nodes ({train_gpus} GPUs) training, " + f"{rollout_num_nodes} nodes ({rollout_gpus} GPUs) inference" + ) + + # Training parallelism: TP=4, PP=2, EP chosen as largest divisor of 160 that fits. + tp, pp = 4, 2 + dp = train_gpus // (tp * pp) + assert train_gpus % (tp * pp) == 0, f"train GPUs ({train_gpus}) must be divisible by TP*PP ({tp * pp})" + num_experts = 160 + ep = max(d for d in range(1, dp + 1) if num_experts % d == 0 and dp % d == 0) + + perf_args = ( + f"--tensor-model-parallel-size {tp} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {pp} " + "--context-parallel-size 1 " + f"--expert-model-parallel-size {ep} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {args.max_tokens_per_gpu} " + ) + if args.optimizer_cpu_offload: + perf_args += "--optimizer-cpu-offload --overlap-cpu-optimizer-d2h-h2d " + if args.use_precision_aware_optimizer: + perf_args += "--use-precision-aware-optimizer " + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.01 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.0 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # SGLang: 4 nodes/engine with full EP + DP-attention on dedicated rollout nodes. + # 355B across 32 GPUs → ~22GB/GPU (bf16) or ~11GB/GPU (FP8) for weights. + # EP=32 with 160 experts → 5 experts/GPU. DP-attention keeps attention + # within a single node (attn_tp=8). + sglang_nodes_per_engine = min(4, rollout_num_nodes) + sglang_world_size = sglang_nodes_per_engine * args.num_gpus_per_node + num_engines = rollout_num_nodes // sglang_nodes_per_engine + assert rollout_num_nodes % sglang_nodes_per_engine == 0, ( + f"rollout nodes ({rollout_num_nodes}) must be divisible by " + f"sglang_nodes_per_engine ({sglang_nodes_per_engine})" + ) + print(f"Inference: {num_engines} engines x {sglang_world_size} GPUs/engine") + sglang_decode_max_bs = 256 + sglang_attn_tp_size = min(args.num_gpus_per_node, sglang_world_size) + sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size + + sglang_p2p_extra = "" + if args.update_weight_transfer_mode == "p2p": + sglang_p2p_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " + + sglang_args = ( + f"--rollout-num-gpus-per-engine {sglang_world_size} " + "--sglang-mem-fraction-static 0.80 " + f"--sglang-tp-size {sglang_world_size} " + f"--sglang-ep-size {sglang_world_size} " + "--sglang-enable-dp-attention " + f"--sglang-dp-size {sglang_attn_dp_size} " + "--sglang-moe-dense-tp-size 1 " + "--sglang-enable-dp-lm-head " + "--sglang-moe-a2a-backend deepep " + "--sglang-deepep-mode low_latency " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + f"{sglang_p2p_extra}" + ) + if args.rollout_fp8: + sglang_args += "--sglang-moe-runner-backend deep_gemm " + sglang_extra_env_vars = { + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}", + } + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + f"--update-weight-transfer-mode {args.update_weight_transfer_mode} " + f"--update-weight-buffer-size {2 * 1024 ** 3} " + f"--actor-num-nodes {args.train_num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {rollout_gpus} " + "--grad-reduce-in-bf16 " + "--use-fault-tolerance " + f"--rollout-health-check-first-wait {args.rollout_health_check_first_wait} " + ) + if args.accumulate_allreduce_grads_in_fp32: + misc_args += "--accumulate-allreduce-grads-in-fp32 " + + debug_args = "--debug-rollout-only " if args.mode == "debug_rollout_only" else "" + + wandb_args = "" + if args.wandb_key: + wandb_args = ( + "--use-wandb " + f"--wandb-project {args.wandb_project} " + f"--wandb-group {args.wandb_run_name} " + f"--wandb-key {args.wandb_key} " + ) + if args.wandb_team: + wandb_args += f"--wandb-team {args.wandb_team} " + + prometheus_args = "" + if args.use_prometheus: + prometheus_args = ( + "--use-prometheus " + f"--prometheus-port {args.prometheus_port} " + f"--prometheus-run-name {args.prometheus_run_name} " + ) + + train_args = ( + f"{ckpt_args}" + f"{rollout_args}" + f"{eval_args}" + f"{optimizer_args}" + f"{grpo_args}" + f"{wandb_args}" + f"{prometheus_args}" + f"{perf_args}" + f"{sglang_args}" + f"{misc_args}" + f"{debug_args}" + ) + + miles_root = U.repo_base_dir + + extra_env_vars = { + "PYTHONPATH": f"{args.megatron_path}:{SCRIPT_DIR}:{FULLY_ASYNC_DIR}:{miles_root}", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "NCCL_NVLS_ENABLE": "0", + "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "false", + **sglang_extra_env_vars, + } + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + train_script="train_async.py", + megatron_path=args.megatron_path, + extra_env_vars=extra_env_vars, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + cleanup() + if not args.skip_prepare: + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/experimental/swe-agent-v2/run-glm47-reasoning.py b/examples/experimental/swe-agent-v2/run-glm47-reasoning.py new file mode 100644 index 0000000000..e133f53570 --- /dev/null +++ b/examples/experimental/swe-agent-v2/run-glm47-reasoning.py @@ -0,0 +1,308 @@ +"""GLM-4.7 Full (355B-A32B) reasoning training with GSM8K data. + +Debug script: uses math (GSM8K) data instead of agentic tool use to verify +that the training pipeline produces nonzero rewards and learns successfully. + +Same model architecture and parallelism as run-glm47-full.py. +Targets 16 x 8-GPU H200 nodes (sci-h200). + +Usage: + python run-glm47-reasoning.py --num-nodes 16 + python run-glm47-reasoning.py --num-nodes 16 --rollout-fp8 + python run-glm47-reasoning.py --num-nodes 16 --skip-prepare + python run-glm47-reasoning.py --num-nodes 16 --mode debug_rollout_only +""" + +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +SCRIPT_DIR = Path(__file__).resolve().parent + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_rollout_only"] = "normal" + run_id: str = U.create_run_id() + megatron_model_type: str = "glm4.5-355B-A32B" + num_gpus_per_node: int = 8 + megatron_path: str = "/root/Megatron-LM" + + # Paths + skip_prepare: bool = False + model_name: str = "GLM-4.7" + hf_checkpoint: str = "/models/zai-org/GLM-4.7" + ref_load: str = "/models/zai-org/GLM-4.7_torch_dist" + save_dir: str = "/root/GLM-4.7-Full_reasoning/" + prompt_data: str = "/root/datasets/gsm8k/train.parquet" + rollout_max_response_len: int = 1024 + + # Rollout precision + rollout_fp8: bool = False + + # W&B settings + wandb_key: str = os.environ.get("WANDB_KEY", os.environ.get("WANDB_API_KEY", "")) + wandb_project: str = os.environ.get("WANDB_PROJECT", "glm47-full-reasoning") + wandb_team: str = os.environ.get("WANDB_TEAM", "") + wandb_run_name: str = "glm47-full-gsm8k" + + # Prometheus settings + use_prometheus: bool = True + prometheus_port: int = 9090 + prometheus_run_name: str = "glm47-full-gsm8k" + + +def cleanup(): + """Kill old Ray jobs and stale processes to free GPU resources.""" + my_pid = os.getpid() + ppid = os.getppid() + print(f"Cleanup starting (pid={my_pid}, ppid={ppid})") + targets = ["sglang", "train.py", "MegatronTrain"] + exclude = f"grep -v '^{my_pid}$' | grep -v '^{ppid}$'" + for t in targets: + subprocess.run( + f"pgrep -f '{t}' | {exclude} | xargs -r kill 2>/dev/null || true", + shell=True, + ) + time.sleep(5) + print(f"Cleanup complete (pid={my_pid}) — old processes killed.") + + +def _convert_hf_to_fp8(args: ScriptArgs): + """Convert HF bf16 checkpoint to block-wise FP8 for SGLang rollout.""" + fp8_dir = f"{args.hf_checkpoint}-FP8" + if Path(fp8_dir).exists(): + print(f"FP8 checkpoint already exists at {fp8_dir}, skipping conversion.") + return + U.exec_command( + "python tools/convert_hf_to_fp8.py " + f"--model-dir {args.hf_checkpoint} " + f"--save-dir {fp8_dir} " + "--strategy block --block-size 128 128 " + "--max-workers 4" + ) + + +def prepare(args: ScriptArgs): + """Download GSM8K data and convert HF checkpoint to torch_dist format.""" + # Download GSM8K dataset + U.hf_download_dataset("zhuzilin/gsm8k") + + # Convert checkpoint (multinode for 355B) + # The conversion tool requires world_size <= num_layers (92 for this model). + max_convert_nodes = 92 // args.num_gpus_per_node # 11 for 8 GPUs/node + convert_nodes = min(args.num_nodes, max_convert_nodes) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + multinode=True, + num_nodes=convert_nodes, + dir_dst=str(Path(args.ref_load).parent), + hf_checkpoint=args.hf_checkpoint, + megatron_path=args.megatron_path, + ) + + if args.rollout_fp8: + _convert_hf_to_fp8(args) + + +def execute(args: ScriptArgs): + hf_checkpoint = f"{args.hf_checkpoint}-FP8" if args.rollout_fp8 else args.hf_checkpoint + ckpt_args = ( + f"--hf-checkpoint {hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_dir} " + "--save-interval 100 " + ) + + rollout_args = ( + f"--prompt-data {args.prompt_data} " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 4 " + "--rollout-temperature 0.8 " + f"--rollout-max-response-len {args.rollout_max_response_len} " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 64 " + ) + + eval_args = ( + "--eval-interval 20 " + "--skip-eval-before-train " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + # Training parallelism: TP=4, PP=2, EP chosen as largest divisor of 160 that fits. + tp, pp = 4, 2 + total_gpus = args.num_nodes * args.num_gpus_per_node + dp = total_gpus // (tp * pp) + assert total_gpus % (tp * pp) == 0, f"total GPUs ({total_gpus}) must be divisible by TP*PP ({tp * pp})" + num_experts = 160 + ep = max(d for d in range(1, dp + 1) if num_experts % d == 0) + + perf_args = ( + f"--tensor-model-parallel-size {tp} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {pp} " + "--context-parallel-size 1 " + f"--expert-model-parallel-size {ep} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 2048 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.01 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.0 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # SGLang: 4 nodes/engine with full EP + DP-attention. + # 355B across 32 GPUs → ~22GB/GPU (bf16) or ~11GB/GPU (FP8) for weights, + # leaving plenty for KV cache. EP=32 with 160 experts → 5 experts/GPU. + # DP-attention keeps attention within a single node (attn_tp=8). + sglang_nodes_per_engine = min(4, args.num_nodes) + sglang_world_size = sglang_nodes_per_engine * args.num_gpus_per_node + assert ( + total_gpus % sglang_world_size == 0 + ), f"total GPUs ({total_gpus}) must be divisible by sglang_world_size ({sglang_world_size})" + sglang_decode_max_bs = 256 + sglang_attn_tp_size = min(args.num_gpus_per_node, sglang_world_size) + sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size + sglang_args = ( + f"--rollout-num-gpus-per-engine {sglang_world_size} " + "--sglang-mem-fraction-static 0.80 " + f"--sglang-tp-size {sglang_world_size} " + f"--sglang-ep-size {sglang_world_size} " + "--sglang-enable-dp-attention " + f"--sglang-dp-size {sglang_attn_dp_size} " + "--sglang-moe-dense-tp-size 1 " + "--sglang-enable-dp-lm-head " + "--sglang-moe-a2a-backend deepep " + "--sglang-deepep-mode low_latency " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + ) + if args.rollout_fp8: + sglang_args += "--sglang-moe-runner-backend deep_gemm " + sglang_extra_env_vars = { + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}", + } + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--colocate " + f"--update-weight-buffer-size {2 * 1024 ** 3} " + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {total_gpus} " + "--use-fault-tolerance " + ) + + debug_args = "--debug-rollout-only " if args.mode == "debug_rollout_only" else "" + + wandb_args = "" + if args.wandb_key: + wandb_args = ( + "--use-wandb " + f"--wandb-project {args.wandb_project} " + f"--wandb-group {args.wandb_run_name} " + f"--wandb-key {args.wandb_key} " + ) + if args.wandb_team: + wandb_args += f"--wandb-team {args.wandb_team} " + + prometheus_args = "" + if args.use_prometheus: + prometheus_args = ( + "--use-prometheus " + f"--prometheus-port {args.prometheus_port} " + f"--prometheus-run-name {args.prometheus_run_name} " + ) + + train_args = ( + f"{ckpt_args}" + f"{rollout_args}" + f"{eval_args}" + f"{optimizer_args}" + f"{grpo_args}" + f"{wandb_args}" + f"{prometheus_args}" + f"{perf_args}" + f"{sglang_args}" + f"{misc_args}" + f"{debug_args}" + ) + + miles_root = U.repo_base_dir + + extra_env_vars = { + "PYTHONPATH": f"{args.megatron_path}:{SCRIPT_DIR}:{miles_root}", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "NCCL_NVLS_ENABLE": "0", + "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "false", + **sglang_extra_env_vars, + } + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + megatron_path=args.megatron_path, + extra_env_vars=extra_env_vars, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + cleanup() + if not args.skip_prepare: + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/experimental/swe-agent-v2/run-qwen3.sh b/examples/experimental/swe-agent-v2/run-qwen3.sh deleted file mode 100755 index ac0d8bf863..0000000000 --- a/examples/experimental/swe-agent-v2/run-qwen3.sh +++ /dev/null @@ -1,166 +0,0 @@ -#!/bin/bash -# Agent V2 launcher (Qwen3-4B): Miles <-> Harbor agent orchestration. -# -# Supports any task type (SWE-bench, Terminal-Bench, custom) via Harbor. - -pkill -9 sglang 2>/dev/null || true -sleep 3 -ray stop --force 2>/dev/null || true -pkill -9 ray 2>/dev/null || true -pkill -9 python 2>/dev/null || true -sleep 3 -pkill -9 ray 2>/dev/null || true -pkill -9 python 2>/dev/null || true -sleep 3 - -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -MILES_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" - -source "$MILES_ROOT/scripts/models/qwen3-4B.sh" - -BASE_DIR=/root/shared -AGENT_SERVER_URL="${AGENT_SERVER_URL:-${SWE_AGENT_URL:-http://agent_env:11000}}" -HARBOR_TASKS_DIR="${HARBOR_TASKS_DIR:-/root/harbor_tasks}" -ROUTER_EXTERNAL_HOST="${MILES_ROUTER_EXTERNAL_HOST:-$(hostname)}" - -CKPT_ARGS=( - --hf-checkpoint $BASE_DIR/Qwen3-4B - --ref-load $BASE_DIR/Qwen3-4B_torch_dist - --save $BASE_DIR/Qwen3-4B_agent_V2/ - --save-interval 100 -) - -ROLLOUT_ARGS=( - --prompt-data /root/swe_train.jsonl - --input-key prompt - --metadata-key metadata - --rollout-shuffle - - --num-rollout 3000 - --rollout-batch-size 1 - --n-samples-per-prompt 1 - --rollout-temperature 0.8 - --rollout-max-response-len 8192 - --global-batch-size 1 - --balance-data -) - -PERF_ARGS=( - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 2048 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.01 - --kl-loss-type low_var_kl - --entropy-coef 0.0 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.8 - --sglang-tool-call-parser qwen25 - --sglang-reasoning-parser qwen3 - - --use-miles-router - --sglang-router-port 30000 -) - -AGENT_ARGS=( - --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate - --custom-agent-function-path swe_agent_function.run - --custom-rm-path generate.reward_func - --rollout-function-path generate.RolloutFn - --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_no_aborted - --tito-model qwen3 - --chat-template-path autofix - --use-session-server -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-agent-v2 - # --wandb-group agent-v2 -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash -) - -DEBUG_ARGS=( - --debug-rollout-only -) - -# ── Start Ray ──────────────────────────────────────────────────────── -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head \ - --node-ip-address "$MASTER_ADDR" \ - --num-gpus 1 \ - --disable-usage-stats \ - --dashboard-host=0.0.0.0 \ - --dashboard-port=8265 \ - --port=8899 - -RUNTIME_ENV=$(python3 -c " -import json, sys -print(json.dumps({'env_vars': { - 'PYTHONPATH': '/root/Megatron-LM/:${SCRIPT_DIR}:${MILES_ROOT}', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'MILES_EXPERIMENTAL_ROLLOUT_REFACTOR': '1', - 'AGENT_SERVER_URL': '${AGENT_SERVER_URL}', - 'AGENT_MODEL_NAME': '${AGENT_MODEL_NAME:-model}', - 'MILES_ROUTER_EXTERNAL_HOST': '${ROUTER_EXTERNAL_HOST}', - 'HARBOR_TASKS_DIR': '${HARBOR_TASKS_DIR}', - 'MILES_HOST_IP': '${MILES_HOST_IP:-$(hostname)}', - 'NCCL_NVLS_ENABLE': '0', -}})) -") - -ray job submit \ - --address="http://127.0.0.1:8265" \ - --runtime-env-json="$RUNTIME_ENV" \ - -- python3 "$MILES_ROOT/train.py" \ - --colocate \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 1 \ - --rollout-num-gpus 1 \ - "${MODEL_ARGS[@]}" \ - "${CKPT_ARGS[@]}" \ - "${ROLLOUT_ARGS[@]}" \ - "${OPTIMIZER_ARGS[@]}" \ - "${GRPO_ARGS[@]}" \ - "${PERF_ARGS[@]}" \ - "${SGLANG_ARGS[@]}" \ - "${AGENT_ARGS[@]}" \ - "${WANDB_ARGS[@]}" \ - "${MISC_ARGS[@]}" \ - "${DEBUG_ARGS[@]}" diff --git a/examples/experimental/swe-agent-v2/run.py b/examples/experimental/swe-agent-v2/run.py index b3cdbdab68..16e97f1a1d 100644 --- a/examples/experimental/swe-agent-v2/run.py +++ b/examples/experimental/swe-agent-v2/run.py @@ -38,9 +38,14 @@ class ScriptArgs(U.ExecuteTrainConfig): hf_checkpoint: str = "zai-org/GLM-4.7-Flash" ref_load: str = "/root/GLM-4.7-Flash_torch_dist" save_dir: str = "/root/GLM-4.7-Flash_agent_v2/" - max_seq_len: int = 16384 prompt_data: str = "/root/swe_train.jsonl" + # Training settings + max_seq_len: int = 16384 + rollout_batch_size: int = 2 + n_samples_per_prompt: int = 4 + global_batch_size: int = 8 + # Agent settings agent_server_url: str = os.environ.get( "AGENT_SERVER_URL", os.environ.get("SWE_AGENT_URL", "http://agent_env:11000") @@ -104,12 +109,12 @@ def execute(args: ScriptArgs): "--metadata-key metadata " "--rollout-shuffle " "--num-rollout 3000 " - "--rollout-batch-size 2 " - "--n-samples-per-prompt 4 " + f"--rollout-batch-size {args.rollout_batch_size} " + f"--n-samples-per-prompt {args.n_samples_per_prompt} " "--rollout-temperature 0.8 " "--rollout-max-response-len 8192 " f"--max-seq-len {args.max_seq_len} " - "--global-batch-size 8 " + f"--global-batch-size {args.global_batch_size} " "--balance-data " ) diff --git a/examples/experimental/swe-agent-v2/server.py b/examples/experimental/swe-agent-v2/server.py deleted file mode 100644 index 137b998975..0000000000 --- a/examples/experimental/swe-agent-v2/server.py +++ /dev/null @@ -1,298 +0,0 @@ -""" -FastAPI server wrapping Harbor for generalized agent-environment orchestration. - -Provides a single ``/run`` endpoint that handles any task type (SWE-bench, -Terminal-Bench, custom datasets, etc.) through Harbor's unified Trial API. -Harbor handles Docker orchestration, agent execution, and grading — the -server is task-type agnostic. - -Requires: - - Harbor installed: pip install harbor-framework - - Prepared task dirs under HARBOR_TASKS_DIR (via adapters or prepare_harbor_tasks.py) - -Usage: - python server.py --port 11000 --max-concurrent 8 -""" - -import argparse -import asyncio -import logging -import os -import re -import traceback -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Any - -import uvicorn -from fastapi import FastAPI -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -_semaphore: asyncio.Semaphore | None = None - - -@asynccontextmanager -async def _lifespan(app: FastAPI) -> AsyncIterator[None]: - global _semaphore - max_concurrent = int(os.getenv("AGENT_MAX_CONCURRENT", os.getenv("SWE_AGENT_MAX_CONCURRENT", "8"))) - _semaphore = asyncio.Semaphore(max_concurrent) - logger.info(f"Initialized semaphore with max_concurrent={max_concurrent}") - yield - - -app = FastAPI(title="Agent Environment Server (Harbor)", lifespan=_lifespan) - - -class RunRequest(BaseModel): - base_url: str - model: str - sampling_params: dict[str, Any] = {} - api_key: str = "dummy" - - instance_id: str = "" - agent_name: str = "mini-swe-agent" - max_seq_len: int | None = None - - model_config = {"extra": "allow"} - - -class RunResponse(BaseModel): - reward: float = 0.0 - exit_status: str = "" - agent_metrics: dict[str, Any] = {} - eval_report: dict[str, Any] = {} - - -def get_semaphore() -> asyncio.Semaphore: - assert _semaphore is not None, "Semaphore not initialized — server not started?" - return _semaphore - - -_TIMEOUT_EXCEPTIONS = {"AgentTimeoutError", "VerifierTimeoutError", "EnvironmentStartTimeoutError"} -_OUTPUT_LIMIT_EXCEPTIONS = {"MaxSeqLenExceededError"} - -_HOST_PROCESS_AGENTS = {"terminus-2", "terminus-1", "terminus"} - -_SAFE_INSTANCE_ID = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") - - -def _extract_exit_status(result) -> str: - """Derive exit status from Harbor TrialResult.""" - exc = getattr(result, "exception_info", None) - if exc is not None: - exc_type = getattr(exc, "exception_type", "") - if exc_type in _TIMEOUT_EXCEPTIONS: - return "TimeLimitExceeded" - if exc_type in _OUTPUT_LIMIT_EXCEPTIONS: - return "SequenceLengthLimitExceeded" - return "AgentError" - if getattr(result, "verifier_result", None) is not None: - return "Submitted" - return "Unknown" - - -def _timing_duration_sec(timing) -> float | None: - started = getattr(timing, "started_at", None) - finished = getattr(timing, "finished_at", None) - if started and finished: - return (finished - started).total_seconds() - return None - - -def _extract_reward(result) -> tuple[float, dict[str, Any]]: - """Extract scalar reward and full eval report from Harbor TrialResult. - - Looks for the ``"reward"`` key first, then falls back to the first value - in the rewards dict. Works with both ``reward.txt`` and ``reward.json``. - """ - vr = getattr(result, "verifier_result", None) - if vr is None: - return 0.0, {} - rewards = getattr(vr, "rewards", None) or {} - reward = float(rewards.get("reward", next(iter(rewards.values()), 0.0))) - return reward, dict(rewards) - - -def _extract_metrics(result) -> dict[str, Any]: - """Extract agent metrics from Harbor TrialResult.""" - metrics: dict[str, Any] = {} - try: - ar = getattr(result, "agent_result", None) - if ar is not None: - for field in ("n_input_tokens", "n_output_tokens", "cost_usd"): - val = getattr(ar, field, None) - if val is not None: - metrics[field] = val - agent_meta = getattr(ar, "metadata", None) - if isinstance(agent_meta, dict): - metrics.update(agent_meta) - - agent_timing = getattr(result, "agent_execution", None) - if agent_timing is not None: - dur = _timing_duration_sec(agent_timing) - if dur is not None: - metrics["agent_run_time"] = dur - - verifier_timing = getattr(result, "verifier", None) - if verifier_timing is not None: - dur = _timing_duration_sec(verifier_timing) - if dur is not None: - metrics["eval_time"] = dur - except Exception as e: - logger.warning(f"Failed to extract metrics: {e}", exc_info=True) - return metrics - - -def _error_response(exit_status: str) -> dict[str, Any]: - return {"reward": 0.0, "exit_status": exit_status, "agent_metrics": {}, "eval_report": {}} - - -async def _run_trial(request: RunRequest) -> dict[str, Any]: - """Run a Harbor trial for a single task instance. - - Task-type agnostic — all differentiation (environment, grading harness) - is encoded in the Harbor task directory's 4 files. - """ - try: - from harbor.models.trial.config import AgentConfig, EnvironmentConfig, TaskConfig, TrialConfig - from harbor.trial.trial import Trial - except ImportError: - logger.error("Harbor not installed. Install with: pip install harbor-framework") - return _error_response("ImportError") - - try: - tasks_dir = Path( - os.getenv("HARBOR_TASKS_DIR", "/root/harbor_tasks"), - ).resolve() - - if not request.instance_id: - logger.error("Empty instance_id") - return _error_response("InvalidInstanceId") - - raw_id = request.instance_id - if not _SAFE_INSTANCE_ID.match(raw_id): - logger.error(f"Invalid instance_id rejected: {raw_id!r}") - return _error_response("InvalidInstanceId") - - # Normalize and verify the path stays within tasks_dir. - # Uses the pattern recommended by CodeQL (py/path-injection): - # normpath(join(base, user_input)) + startswith(base) - tasks_dir_str = str(tasks_dir) - task_path = os.path.normpath(os.path.join(tasks_dir_str, raw_id)) - if not task_path.startswith(tasks_dir_str): - logger.error(f"Path traversal blocked: {raw_id!r}") - return _error_response("InvalidInstanceId") - - if not os.path.exists(task_path): - logger.error(f"Task directory not found: {task_path}") - return _error_response("TaskNotFound") - - task_path = Path(task_path) - agent_kwargs: dict[str, Any] = {} - agent_env: dict[str, str] = {} - - is_host_agent = request.agent_name in _HOST_PROCESS_AGENTS - - if "hosted_vllm" in request.model or "openai" in request.model: - agent_kwargs["model_info"] = { - "max_input_tokens": int(os.getenv("AGENT_MAX_INPUT_TOKENS", "32768")), - "max_output_tokens": int(os.getenv("AGENT_MAX_OUTPUT_TOKENS", "8192")), - "input_cost_per_token": 0.0, - "output_cost_per_token": 0.0, - } - - if request.max_seq_len is not None: - agent_kwargs["max_seq_len"] = request.max_seq_len - - if is_host_agent: - agent_kwargs["api_base"] = request.base_url - agent_kwargs["api_key"] = request.api_key or "dummy" - agent_kwargs["enable_summarize"] = False - agent_env = { - "OPENAI_API_KEY": request.api_key or "dummy", - "OPENAI_API_BASE": request.base_url, - } - else: - agent_env = { - "OPENAI_API_BASE": request.base_url, - "OPENAI_API_KEY": request.api_key, - "HOSTED_VLLM_API_BASE": request.base_url, - "HOSTED_VLLM_API_KEY": request.api_key, - "MSWEA_COST_TRACKING": "ignore_errors", - } - - config = TrialConfig( - task=TaskConfig(path=task_path), - agent=AgentConfig( - name=request.agent_name, - model_name=request.model, - env=agent_env, - kwargs=agent_kwargs, - ), - environment=EnvironmentConfig( - type="docker", - delete=os.getenv("HARBOR_DELETE_CONTAINERS", "false").lower() in ("true", "1", "t"), - ), - ) - - trial = Trial(config=config) - result = await trial.run() - - reward, eval_report = _extract_reward(result) - exit_status = _extract_exit_status(result) - agent_metrics = _extract_metrics(result) - - return { - "reward": reward, - "exit_status": exit_status, - "agent_metrics": agent_metrics, - "eval_report": eval_report, - } - - except Exception as e: - logger.error(f"Harbor trial failed: {e}\n{traceback.format_exc()}") - return _error_response(f"Error: {type(e).__name__}") - - -@app.post("/run") -async def run_instance(request: RunRequest) -> RunResponse: - """Run an agent on a single task instance via Harbor.""" - logger.info(f"Running instance: {request.instance_id}") - async with get_semaphore(): - result = await _run_trial(request) - logger.info( - f"Instance {request.instance_id} finished: exit_status={result['exit_status']}, reward={result['reward']}" - ) - return RunResponse(**result) - - -@app.get("/health") -async def health(): - return {"status": "ok"} - - -def main(): - parser = argparse.ArgumentParser(description="Agent Environment Server (Harbor)") - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=11000) - parser.add_argument("--max-concurrent", type=int, default=8) - args = parser.parse_args() - - os.environ["AGENT_MAX_CONCURRENT"] = str(args.max_concurrent) - - os.environ.setdefault("MSWEA_API_KEY", "dummy") - os.environ.setdefault("HOSTED_VLLM_API_KEY", "dummy") - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s %(message)s", - ) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/examples/experimental/swe-agent-v2/swe_agent_function.py b/examples/experimental/swe-agent-v2/swe_agent_function.py index fb30d8a9f5..d1460fbe06 100644 --- a/examples/experimental/swe-agent-v2/swe_agent_function.py +++ b/examples/experimental/swe-agent-v2/swe_agent_function.py @@ -14,7 +14,7 @@ import logging import os from typing import Any -from urllib.parse import urlparse, urlunparse +from urllib.parse import urlparse, urlsplit, urlunparse from miles.utils.http_utils import post @@ -49,7 +49,7 @@ async def run( netloc = f"{external_host}:{port}" if port else external_host session_url = urlunparse(parsed._replace(netloc=netloc)) - request = { + request: dict[str, Any] = { **metadata, "base_url": session_url, "model": f"openai/{model_name}", @@ -60,6 +60,17 @@ async def run( if max_seq_len is not None: request["max_seq_len"] = int(max_seq_len) + session_server_id = metadata.get("session_server_id") + if session_server_id is not None: + if external_host: + port = urlsplit(f"http://{session_server_id}").port + session_server_id = f"{external_host}:{port}" + request["session_server_id"] = session_server_id + + session_server_instance_id = metadata.get("session_server_instance_id") + if session_server_instance_id is not None: + request["session_server_instance_id"] = session_server_instance_id + try: response = await asyncio.wait_for( post(f"{agent_server_url}/run", request), diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 446c882e94..e4c23cc120 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -1,20 +1,60 @@ import asyncio import atexit +import logging import queue import threading import time -# Import core functions from sglang_rollout directly to avoid code duplication +import aiohttp + +from miles.rollout.data_source import DataSource from miles.rollout.sglang_rollout import GenerateState, generate_and_rm_group from miles.utils.async_utils import run from miles.utils.types import Sample +logger = logging.getLogger(__name__) + + +def group_oldest_weight_version(group: list[Sample]) -> int | None: + """Return the minimum weight version across all trajectories and turns in a group.""" + versions = [s.oldest_weight_version for s in group if s.oldest_weight_version is not None] + return min(versions) if versions else None + + +class _CachedWeightVersion: + """Throttled query for the current engine weight version via /model_info.""" + + def __init__(self, ttl: float = 1.0): + self._ttl = ttl + self._value: int | None = None + self._last_query: float = 0.0 + + async def get(self, args) -> int | None: + now = time.monotonic() + if self._value is not None and (now - self._last_query) < self._ttl: + return self._value + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/model_info" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=2)) as resp: + if resp.status == 200: + data = await resp.json() + self._value = int(data["weight_version"]) + self._last_query = now + except Exception as e: + logger.debug(f"Failed to query engine weight version: {e}") + return self._value + + +_cached_version = _CachedWeightVersion() + + # Global worker manager _global_worker = None _worker_lock = threading.Lock() -def get_global_worker(args, data_buffer): +def get_global_worker(args, data_buffer: DataSource): """Get or create global worker""" global _global_worker with _worker_lock: @@ -40,7 +80,7 @@ class AsyncRolloutWorker: Supports continuous running, independent of rollout function lifecycle """ - def __init__(self, args, data_buffer, concurrency=10): + def __init__(self, args, data_buffer: DataSource, concurrency=10): self.args = args self.data_buffer = data_buffer # Directly save data_buffer reference self.concurrency = concurrency @@ -146,7 +186,7 @@ def get_queue_size(self) -> int: return self.output_queue.qsize() -async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[list[Sample]]: +async def generate_rollout_async(args, rollout_id: int, data_buffer: DataSource) -> list[list[Sample]]: """ Simplified asynchronous rollout generation - using global continuous worker """ @@ -161,9 +201,15 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis data = [] completed_groups = {} do_print = True + stale_groups_recycled = 0 + staleness_values = [] + + use_staleness_filter = getattr(args, "max_weight_staleness", None) is not None print(f"Starting async rollout generation for {target_data_size} groups") print(f"Global worker queue size: {worker.get_queue_size()}") + if use_staleness_filter: + print(f"Staleness filter enabled: max_weight_staleness={args.max_weight_staleness}") # Main loop: collect results from global worker's output queue start_time = time.time() @@ -182,6 +228,11 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis if made_progress: last_progress_time = time.time() + # Query current engine version once per collection batch (cached/throttled) + current_engine_version = None + if use_staleness_filter: + current_engine_version = await _cached_version.get(args) + # Process completed groups in order (try to maintain order, but not strict requirement) processed_any = False @@ -202,7 +253,8 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis if any_aborted: try: - # add back to buffer so it can be retried or handled by buffer policy + for s in group: + s.reset_for_retry() data_buffer.add_samples([group]) print(f"Returned aborted group {group_id} to data buffer", flush=True) except Exception as e: @@ -210,6 +262,27 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis # don't count as processed for training continue + # Staleness filter: discard groups whose oldest weight version is too far behind + oldest = group_oldest_weight_version(group) + if oldest is not None and current_engine_version is not None: + staleness = current_engine_version - oldest + staleness_values.append(staleness) + if staleness > args.max_weight_staleness: + try: + for s in group: + s.reset_for_retry() + data_buffer.add_samples([group]) + except Exception as e: + logger.warning(f"Failed to recycle stale group {group_id}: {e}") + stale_groups_recycled += 1 + logger.info( + f"Recycled stale group {group_id} " + f"(oldest_version={oldest}, current={current_engine_version}, " + f"staleness={staleness} > max={args.max_weight_staleness})" + ) + # don't count as processed for training + continue + if do_print: print( f"First rollout sample: {[group[0].prompt + group[0].response]}, " @@ -238,6 +311,13 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis duration = time.time() - start_time print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}") + if stale_groups_recycled > 0 or staleness_values: + avg_staleness = sum(staleness_values) / len(staleness_values) if staleness_values else 0 + print( + f"Staleness stats: recycled={stale_groups_recycled}, " + f"avg_staleness={avg_staleness:.1f}, " + f"max_staleness={max(staleness_values) if staleness_values else 0}" + ) if data: print( @@ -250,7 +330,7 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis return data -def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False): +def generate_rollout_fully_async(args, rollout_id, data_buffer: DataSource, evaluation=False): if evaluation: raise ValueError("Evaluation mode not supported in simple async rollout") diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 026e486089..bfd12696bf 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -56,6 +56,9 @@ ROLLOUT_ARGS=( --global-batch-size 256 --balance-data + + # for staleness control + #--max-weight-staleness 2 ) PERF_ARGS=( diff --git a/examples/fully_async/run_qwen3_30b_a3b_fully_async.py b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py new file mode 100644 index 0000000000..0a2dbca924 --- /dev/null +++ b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +# in_place + broadcast +# python run_qwen3_30b_a3b_fully_async.py + +# retract + p2p +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode p2p + +# retract + broadcast +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode broadcast + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + model_name: str = "Qwen3-30B-A3B" + megatron_model_type: str = "qwen3-30B-A3B" + num_gpus_per_node: int = 8 + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + pause_generation_mode: Literal["in_place", "retract"] = "in_place" + update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" + extra_args: str = "" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{args.model_name}", + megatron_path=args.megatron_path, + ) + + +def execute(args: ScriptArgs): + if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": + raise ValueError( + "in_place + p2p is not supported: P2P transfer engine conflicts with " + "active NCCL inference. Use broadcast with in_place, or retract with p2p." + ) + + ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{args.model_name}/ " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + ) + + rollout_args = ( + "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type dapo " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + f"--pause-generation-mode {args.pause_generation_mode} " + ) + + perf_args = ( + "--tensor-model-parallel-size 8 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-tis " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_extra = "" + if args.update_weight_transfer_mode == "p2p": + sglang_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + f"--sglang-mem-fraction-static 0.7 {sglang_extra}" + "--sglang-cuda-graph-max-bs 512 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + f"--attention-backend flash --update-weight-transfer-mode {args.update_weight_transfer_mode} " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {args.num_gpus_per_node} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{sglang_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + import os + + fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__))) + U.execute_train( + train_args=train_args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + train_script="train_async.py", + megatron_path=args.megatron_path, + extra_env_vars={ + "FLASHINFER_DISABLE_VERSION_CHECK": "1", + "PYTHONPATH": f"{args.megatron_path}:{fully_async_dir}", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/lora/run-gpt-oss-20B-megatron-moe-lora.sh b/examples/lora/run-gpt-oss-20B-megatron-moe-lora.sh new file mode 100644 index 0000000000..f349f29228 --- /dev/null +++ b/examples/lora/run-gpt-oss-20B-megatron-moe-lora.sh @@ -0,0 +1,164 @@ +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} +GPUS_PER_NODE=$(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n' | wc -l) + +# Load model architecture config +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/gpt-oss-20b.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/models/gpt-oss-20b + --megatron-to-hf-mode bridge + # --save $BASE_DIR/gpt-oss-20b-BF16 + # --save-interval 50 +) + +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + --target-modules "gate_proj,up_proj,down_proj" + --sglang-lora-backend triton # !!! must for moe-lora !!!, else display "Current LoRA backend does not support LoRA on MoE layers; skipping MoE layer" +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 1 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 1.0 + + --global-batch-size 8 +) + +EVAL_ARGS=( + --eval-interval 10 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --eval-input-key messages + --n-samples-per-eval-prompt 1 + --eval-max-response-len 4096 + --eval-top-k 1 +) + +PERF_ARGS=( + # Parallelism: TP=4, EP=1, PP=1, CP=1 + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + # Recomputation: full recompute needed to fit optimizer states in 80GB + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # Batch size settings + # Note: --use-dynamic-batch-size is not supported with --qkv-format bshd + --micro-batch-size 1 + --max-tokens-per-gpu 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # TODO: need gpt oss ckpt conversion. + # --use-kl-loss + # --kl-loss-coef 0.00 + # --kl-loss-type low_var_kl + # --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + # CPU offload optimizer states (fp32 master weights + Adam moments) to free ~30GB GPU memory + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +SGLANG_ARGS=( + # TP size for sglang inference engine + --rollout-num-gpus-per-engine 4 + --sglang-dtype bfloat16 + --sglang-decode-log-interval 1000 + --sglang-mem-fraction-static 0.2 + # Note: need to use bf16 ckpt when enable triton moe backend, eg, lmsys/gpt-oss-20b-bf16 + # mxfp4 currently not supported + --sglang-moe-runner-backend triton +) + +WANDB_ARGS=( + --use-wandb + --wandb-project miles-gpt-oss + --wandb-group "gpt-oss-20b-moe-lora" +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # Sink attention (sliding window + learnable softmax) in TE only supports BSHD/SBHD, not THD. + # Must use --qkv-format bshd for the fused backend to work with this model's attention pattern. + --qkv-format bshd + --attention-backend fused + --update-weight-buffer-size 536870912 # 512MB +) + + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus $GPUS_PER_NODE --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${LORA_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/retool/generate_with_retool.py b/examples/retool/generate_with_retool.py index f5b8ad268c..6bd5d7de29 100644 --- a/examples/retool/generate_with_retool.py +++ b/examples/retool/generate_with_retool.py @@ -96,12 +96,11 @@ def format_conversation_with_tools( def postprocess_predictions(prediction: str): """Extract action and content from prediction string""" - # Check for Answer: \boxed{...} format (only format we need for math_dapo) - # Use a more robust regex that handles nested braces - answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" - answer_match = re.search(answer_pattern, prediction, re.DOTALL) - if answer_match: - content = answer_match.group(1).strip() + # Check for bare \boxed{...} (model may omit "Answer:" prefix) + boxed_pattern = r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" + boxed_match = re.search(boxed_pattern, prediction, re.DOTALL) + if boxed_match: + content = boxed_match.group(1).strip() return "answer", content # Then check for tags (new format from Jinja2 template) @@ -168,14 +167,17 @@ def postprocess_responses(resp: str) -> str: last_match = matches[-1] return resp[: last_match.end()] - # Handle Answer: \boxed{...} format (only format we need for math_dapo) - if "Answer:" in resp and "\\boxed{" in resp: - # Find the last occurrence of Answer: \boxed{...} with nested braces support - answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" - matches = list(re.finditer(answer_pattern, resp, re.DOTALL)) - if matches: - last_match = matches[-1] - return resp[: last_match.end()] + # Handle Answer: \boxed{...} or bare \boxed{...} + if "\\boxed{" in resp: + # Try "Answer: \boxed{...}" first, then bare "\boxed{...}" + for pattern in [ + r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", + r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", + ]: + matches = list(re.finditer(pattern, resp, re.DOTALL)) + if matches: + last_match = matches[-1] + return resp[: last_match.end()] return resp @@ -203,7 +205,7 @@ async def execute_predictions(prediction: str) -> str: next_obs = ( "\nMy previous action is invalid. " "If I want to execute code, I should put the code between " - " and . " + " and . " "If I want to give the final answer, I should use the format " "'Answer: \\boxed{answer}'. Let me try again.\n" ) @@ -221,7 +223,12 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: # Set up the initial prompt with system prompt and tools (outside the loop) tool_specs = tool_registry.get_tool_specs() - prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) + + if isinstance(sample.prompt, str): + # Already formatted (e.g., by --apply-chat-template), use as-is to avoid double templating + prompt = sample.prompt + else: + prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) prompt_tokens_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"] response = "" @@ -355,8 +362,7 @@ async def reward_func(args, sample, **kwargs): if not isinstance(sample, Sample): raise TypeError("Sample must be an instance of Sample class.") - # Build complete solution string - solution_str = sample.prompt + sample.response + solution_str = sample.response # Get ground truth answer - label is a string, not a dict ground_truth = sample.label if sample.label is not None else "" diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index 838ce0e2c4..99eeea3d3b 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -31,7 +31,7 @@ CKPT_ARGS=( --ref-load /root/font-info/qwen3-4b-sft_torch_dist # --load /root/Qwen3-4B_miles/ --save /root/font-info/qwen3-4b-sft/qwen3-4b-sft-multi-turn/ - --save-interval 20 + --save-interval 200 --rotary-base 5000000 ) @@ -43,12 +43,12 @@ ROLLOUT_ARGS=( --rollout-shuffle --reward-key score --num-rollout 3000 - --rollout-batch-size 32 + --rollout-batch-size 16 --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 1 - --global-batch-size 256 + --global-batch-size 128 --balance-data ) @@ -98,8 +98,8 @@ OPTIMIZER_ARGS=( WANDB_ARGS=( --use-wandb - --wandb-project miles-dapo - --wandb-group qwen3-4B-test-multi-turn + --wandb-project miles-dev-retool-v2 + --wandb-group retool-v1-qwen3-4b-sft-new --wandb-key ${WANDB_KEY} ) @@ -117,6 +117,7 @@ MISC_ARGS=( --attention-softmax-in-fp32 # need to comment this when using model with MLA --attention-backend flash + --log-passrate ) CUSTOM_ARGS=( diff --git a/examples/retool_v2/README.md b/examples/retool_v2/README.md new file mode 100644 index 0000000000..1c9752c2b9 --- /dev/null +++ b/examples/retool_v2/README.md @@ -0,0 +1,31 @@ +# Retool v2 + +This example is an upgraded version of [retool](../retool), using the updated interfaces provided by the miles framework to implement multi-turn RL training with tool calls in a cleaner way. + +## Key Differences from v1 + +**v1 (retool)** requires manually implementing the full multi-turn conversation loop in `generate_with_retool.py`, directly depending on low-level `GenerateState` and `sglang_rollout` interfaces — resulting in verbose code tightly coupled to the framework internals. + +**v2 (retool_v2)** uses the framework's standard plugin interfaces. Users only need to implement three functions and mount them via command-line arguments: + +| Argument | Description | +|----------|-------------| +| `--custom-generate-function-path` | Uses the built-in `miles.rollout.generate_hub.multi_turn.generate` — no need to implement the multi-turn loop yourself | +| `--generate-tool-specs-path` | Declare tool definitions (user-implemented) | +| `--generate-execute-tool-function-path` | Implement tool execution logic (user-implemented) | +| `--custom-rm-path` | Implement the reward function (user-implemented) | + +Users only need to focus on business logic (tool definitions, tool execution, reward calculation). Multi-turn scheduling, token concatenation, loss masking, etc. are all handled by the framework. + +## Files + +- `tool_sandbox.py`: Tool definitions (`tool_specs`), tool execution (`execute_tool`), reward function (`reward_func`), and sandboxed safe execution environment +- `run_retool_multi_turn.py`: Training launch script + +## Quick Start + +```bash +python examples/retool_v2/run_retool_multi_turn.py +``` + +For data and model preparation, refer to the [retool v1 README](../retool/README.md). diff --git a/examples/retool_v2/run_retool_multi_turn.py b/examples/retool_v2/run_retool_multi_turn.py new file mode 100644 index 0000000000..1031eba790 --- /dev/null +++ b/examples/retool_v2/run_retool_multi_turn.py @@ -0,0 +1,208 @@ +import os +from dataclasses import dataclass, field +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +WANDB_PROJECT = "miles-dev-retool-v2" +WANDB_GROUP = "sft-multi-turn-batch-32" + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = field(default_factory=U.create_run_id) + hardware: Literal["H100", "GB200", "GB300"] = "H100" + num_gpus_per_node: int | None = None + use_sft_model: bool = True + save_path: str = "/root/Qwen3-4B_miles/retool_v2_multi_turn" + prompt_data: str = "/root/dapo-math-17k/dapo-math-17k.jsonl" + generate_max_turns: int = 16 + rollout_num_gpus_per_engine: int = 2 + extra_args: str = "" + + # resolved in __post_init__, not set by user + hf_checkpoint: str = field(init=False) + ref_load: str = field(init=False) + + def __post_init__(self): + self.num_gpus_per_node = self.num_gpus_per_node or U.NUM_GPUS_OF_HARDWARE[self.hardware] + if self.use_sft_model: + self.hf_checkpoint = "/root/font-info/qwen3-4b-sft" + self.ref_load = "/root/font-info/qwen3-4b-sft_torch_dist" + else: + self.hf_checkpoint = "/root/models/Qwen3-4B" + self.ref_load = "/root/models/Qwen3-4B_torch_dist" + + +def _get_wandb_args() -> str: + WANDB_API_KEY = os.environ.get("WANDB_API_KEY") + return ( + "--use-wandb " + f"--wandb-project {WANDB_PROJECT} " + f"--wandb-group {WANDB_GROUP} " + f"--wandb-key {WANDB_API_KEY} " + ) + + +def prepare(args: ScriptArgs): + U.exec_command("mkdir -p /root/dapo-math-17k /root/aime-2024") + U.exec_command("hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k") + U.exec_command("hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/aime-2024") + + if args.use_sft_model: + U.exec_command("mkdir -p /root/font-info") + U.exec_command(f"hf download font-info/qwen3-4b-sft-SGLang-RL --local-dir {args.hf_checkpoint}") + U.convert_checkpoint( + model_name="qwen3-4b-sft", + megatron_model_type="qwen3-4B", + num_gpus_per_node=args.num_gpus_per_node, + hf_checkpoint=args.hf_checkpoint, + dir_dst="/root/font-info", + ) + else: + U.exec_command("mkdir -p /root/models") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.convert_checkpoint( + model_name="Qwen3-4B", + megatron_model_type="qwen3-4B", + num_gpus_per_node=args.num_gpus_per_node, + dir_dst="/root/models", + ) + + +def execute(args: ScriptArgs): + megatron_model_type = "qwen3-4B" + + ckpt_args = ( + f"--hf-checkpoint {args.hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_path} " + f"--save-interval {2 if args.mode == 'debug_minimal' else 1000} " + f"{'--rotary-base 5000000 ' if args.use_sft_model else ''}" + ) + + custom_args = ( + "--custom-generate-function-path miles.rollout.generate_hub.multi_turn.generate " + "--generate-tool-specs-path examples.retool_v2.tool_sandbox.tool_specs " + "--generate-execute-tool-function-path examples.retool_v2.tool_sandbox.execute_tool " + "--generate-tool-call-parser qwen25 " + f"--generate-max-turns {args.generate_max_turns} " + "--log-multi-turn " + ) + + rollout_args = ( + f"--prompt-data {args.prompt_data} " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--custom-rm-path examples.retool_v2.tool_sandbox.reward_func " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + ) + + eval_args = "" + if args.mode != "debug_minimal": + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 16 " + "--eval-max-response-len 16384 " + "--eval-top-p 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + f"--rollout-num-gpus-per-engine {args.rollout_num_gpus_per_engine} " "--sglang-mem-fraction-static 0.7 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + misc_args = ( + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + "--colocate " + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--log-passrate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{_get_wandb_args()} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{misc_args} " + f"{custom_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=megatron_model_type, + extra_env_vars={ + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "PYTHONPATH": "/root/Megatron-LM/:/root/miles", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/retool_v2/tool_sandbox.py b/examples/retool_v2/tool_sandbox.py new file mode 100644 index 0000000000..fc7a1dea45 --- /dev/null +++ b/examples/retool_v2/tool_sandbox.py @@ -0,0 +1,385 @@ +""" +copied from examples/retool/tool_sandbox.py +""" + +import asyncio +import gc +import os +import re +import subprocess +import tempfile +from contextlib import contextmanager +from typing import Any +import psutil + +from miles.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score +from miles.utils.types import Sample + +# Configuration for tool execution +TOOL_CONFIGS = { + "max_turns": 16, + "max_tool_calls": 16, + "tool_concurrency": 32, # Aggressive: 32 concurrent processes + # Python interpreter settings + "python_timeout": 120, # 2 minutes for complex calculations + "python_memory_limit": "4GB", # 4GB per Python process + "python_cpu_limit": 1, + # Memory management settings + "max_memory_usage": 12288, # 12GB total (75% of 16GB) + "cleanup_threshold": 6144, # 6GB + "aggressive_cleanup_threshold": 3072, # 3GB + "force_cleanup_threshold": 9216, # 9GB +} + +# Global semaphore for controlling concurrent tool executions +SEMAPHORE = asyncio.Semaphore(TOOL_CONFIGS["tool_concurrency"]) + + +def get_memory_usage() -> float: + """Get current memory usage in MB""" + process = psutil.Process() + return process.memory_info().rss / 1024 / 1024 + + +def cleanup_memory(): + """Force garbage collection to free memory""" + gc.collect() + + +def aggressive_cleanup_memory(): + """More aggressive memory cleanup""" + # Force multiple garbage collection cycles + for _ in range(3): + gc.collect() + + # Clear Python's internal caches + import sys + + # Note: sys.intern doesn't have a clear method, so we skip this + # Clear module cache if possible + if hasattr(sys, "modules"): + # Don't clear all modules, but clear some common ones that might cache data + modules_to_clear = ["numpy", "pandas", "matplotlib", "scipy"] + for module_name in modules_to_clear: + if module_name in sys.modules: + module = sys.modules[module_name] + if hasattr(module, "clear_cache"): + module.clear_cache() + + +def check_and_cleanup_memory(): + """Check memory usage and perform appropriate cleanup""" + current_memory = get_memory_usage() + + if current_memory > TOOL_CONFIGS["force_cleanup_threshold"]: + # Force aggressive cleanup + aggressive_cleanup_memory() + return f"Warning: High memory usage ({current_memory:.1f}MB), performed aggressive cleanup" + elif current_memory > TOOL_CONFIGS["cleanup_threshold"]: + # Normal cleanup + cleanup_memory() + return f"Info: Memory usage ({current_memory:.1f}MB), performed cleanup" + elif current_memory > TOOL_CONFIGS["aggressive_cleanup_threshold"]: + # Light cleanup + gc.collect() + return f"Info: Memory usage ({current_memory:.1f}MB), performed light cleanup" + + return None + + +class PythonSandbox: + """Python code sandbox, provides safe code execution environment""" + + def __init__(self, timeout: int = 10, memory_limit: str = "100MB"): + self.timeout = timeout + self.memory_limit = memory_limit + self.allowed_modules = { + "math", + "random", + "datetime", + "collections", + "itertools", + "functools", + "operator", + "statistics", + "decimal", + "fractions", + } + + def _check_code_safety(self, code: str) -> tuple[bool, str]: + """Check code safety by scanning for dangerous patterns""" + # Check for dangerous operations + dangerous_patterns = [ + r"import\s+os", + r"import\s+sys", + r"import\s+subprocess", + r"import\s+shutil", + r"import\s+glob", + r"import\s+pathlib", + r"__import__", + r"eval\s*\(", + r"exec\s*\(", + r"open\s*\(", + r"file\s*\(", + r"input\s*\(", + r"raw_input\s*\(", + r"compile\s*\(", + r"execfile\s*\(", + r"getattr\s*\(", + r"setattr\s*\(", + r"delattr\s*\(", + r"hasattr\s*\(", + r"globals\s*\(", + r"locals\s*\(", + r"vars\s*\(", + r"dir\s*\(", + r"type\s*\(", + r"isinstance\s*\(", + r"issubclass\s*\(", + r"super\s*\(", + r"property\s*\(", + r"staticmethod\s*\(", + r"classmethod\s*\(", + r"__\w+__", # double underscore methods + ] + + for pattern in dangerous_patterns: + if re.search(pattern, code, re.IGNORECASE): + return False, f"Code contains dangerous pattern: {pattern}" + + # Check imported modules + import_pattern = r"import\s+(\w+)" + from_pattern = r"from\s+(\w+)" + + imports = re.findall(import_pattern, code) + froms = re.findall(from_pattern, code) + + all_imports = set(imports + froms) + for imp in all_imports: + if imp not in self.allowed_modules: + return False, f"Import of '{imp}' is not allowed" + + return True, "Code is safe" + + @contextmanager + def _create_safe_environment(self): + """Create safe execution environment with temporary directory""" + # Create temporary directory + temp_dir = tempfile.mkdtemp(prefix="python_sandbox_") + + try: + # Create safe Python script + script_path = os.path.join(temp_dir, "code.py") + + # Set environment variables + env = os.environ.copy() + env["PYTHONPATH"] = temp_dir + env["PYTHONUNBUFFERED"] = "1" + + yield script_path, env, temp_dir + + finally: + # Clean up temporary directory + try: + import shutil + + shutil.rmtree(temp_dir) + except Exception: + pass + + async def execute_code(self, code: str) -> str: + """Execute Python code in sandbox with safety checks""" + # Check memory usage before execution + current_memory = get_memory_usage() + if current_memory > TOOL_CONFIGS["max_memory_usage"]: + aggressive_cleanup_memory() + return "Error: Memory usage too high, please try again" + + # Check code safety + is_safe, message = self._check_code_safety(code) + if not is_safe: + return f"Error: {message}" + + # Add necessary wrapper code with memory limits + # Properly indent the user code within the try block + # Handle indentation properly by adding 4 spaces to each line + indented_code = "\n".join(" " + line for line in code.split("\n")) + + wrapped_code = f"""import sys +import traceback +from io import StringIO +import resource + +# Set memory limit (4GB) +try: + resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, -1)) +except Exception: + pass + +# Redirect stdout and stderr +old_stdout = sys.stdout +old_stderr = sys.stderr +stdout_capture = StringIO() +stderr_capture = StringIO() +sys.stdout = stdout_capture +sys.stderr = stderr_capture + +try: + # User code +{indented_code} + + # Get output + stdout_output = stdout_capture.getvalue() + stderr_output = stderr_capture.getvalue() + + # Restore standard output + sys.stdout = old_stdout + sys.stderr = old_stderr + + # Return result + result = "" + if stdout_output: + result += f"Output:\\n{{stdout_output}}" + if stderr_output: + result += f"\\nErrors:\\n{{stderr_output}}" + + print(result) + +except Exception as e: + # Restore standard output + sys.stdout = old_stdout + sys.stderr = old_stderr + + # Return error information + error_msg = f"Error: {{str(e)}}\\nTraceback:\\n{{traceback.format_exc()}}" + print(error_msg)""" + + with self._create_safe_environment() as (script_path, env, temp_dir): + # Write code to file + with open(script_path, "w") as f: + f.write(wrapped_code) + + try: + # Use subprocess to run code + process = subprocess.Popen( + ["python3", script_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + cwd=temp_dir, + text=True, + ) + + # Set timeout + try: + stdout, stderr = process.communicate(timeout=self.timeout) + + if process.returncode == 0: + result = stdout.strip() + else: + result = f"Error: Process exited with code {process.returncode}\n{stderr}" + + except subprocess.TimeoutExpired: + process.kill() + result = f"Error: Code execution timed out after {self.timeout} seconds" + + except Exception as e: + result = f"Error: Failed to execute code: {str(e)}" + + # Check memory usage after execution and cleanup if needed + cleanup_message = check_and_cleanup_memory() + if cleanup_message: + print(f"Memory cleanup: {cleanup_message}") + + return result + + +class ToolRegistry: + """Tool registry, manages available tools and their execution""" + + def __init__(self): + self.tools = {} + self.python_sandbox = PythonSandbox( + timeout=TOOL_CONFIGS["python_timeout"], memory_limit=TOOL_CONFIGS["python_memory_limit"] + ) + self._register_default_tools() + + def _register_default_tools(self): + """Register default tools in the registry""" + # Python code interpreter + self.register_tool( + "code_interpreter", + { + "type": "function", + "function": { + "name": "code_interpreter", + "description": "A tool for executing Python code in a safe sandbox environment.", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string", "description": "The Python code to execute"}}, + "required": ["code"], + }, + }, + }, + ) + + def register_tool(self, name: str, tool_spec: dict[str, Any]): + """Register a new tool in the registry""" + self.tools[name] = tool_spec + + def get_tool_specs(self) -> list[dict[str, Any]]: + """Get all tool specifications as a list""" + return list(self.tools.values()) + + async def execute_tool(self, tool_name: str, arguments: dict[str, Any]) -> str: + """Execute a tool call with the given arguments""" + if tool_name not in self.tools: + return f"Error: Tool '{tool_name}' not found" + + async with SEMAPHORE: + if tool_name == "code_interpreter": + return await self._execute_python(arguments) + else: + return f"Error: Tool '{tool_name}' not implemented" + + async def _execute_python(self, arguments: dict[str, Any]) -> str: + """Execute Python code using the sandbox""" + code = arguments.get("code", "") + if isinstance(code, list): + code = "\n".join(str(item) for item in code) + if not code.strip(): + return "Error: No code provided" + + # Execute code in sandbox + result = await self.python_sandbox.execute_code(code) + return result + + +tool_registry = ToolRegistry() + +tool_specs = tool_registry.get_tool_specs() + + +async def execute_tool(name: str, params: dict) -> str: + return await tool_registry.execute_tool(name, params) + + +# Reward function that encourages tool usage +async def reward_func(args, sample: Sample, **kwargs): + """Tool call reward function using math_dapo, with bonus for tool usage.""" + solution_str = sample.prompt + sample.response if isinstance(sample.prompt, str) else sample.response + ground_truth = sample.label if sample.label is not None else "" + tool_call_count = sample.metadata.get("tool_call_count", 0) + + # use \boxed{...} answer + result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True) + + # encourage model to call tools + if result["score"] < 0: + tool_call_reward = tool_call_count / 2 * 0.1 + result["score"] = min(-0.6, result["score"] + tool_call_reward) + + if result["pred"] is None: + result["pred"] = "" + + return result diff --git a/miles/backends/fsdp_utils/kernels/__init__.py b/miles/backends/experimental/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/__init__.py rename to miles/backends/experimental/__init__.py diff --git a/miles/backends/fsdp_utils/__init__.py b/miles/backends/experimental/fsdp_utils/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/__init__.py rename to miles/backends/experimental/fsdp_utils/__init__.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/experimental/fsdp_utils/actor.py similarity index 98% rename from miles/backends/fsdp_utils/actor.py rename to miles/backends/experimental/fsdp_utils/actor.py index 7bdd1c17ad..47c2540f98 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/experimental/fsdp_utils/actor.py @@ -20,17 +20,17 @@ from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking -from ...utils.profile_utils import TrainProfiler -from ..training_utils.ci_utils import check_grad_norm -from ..training_utils.data import DataIterator, get_batch, get_data_iterator, get_rollout_data -from ..training_utils.log_utils import ( +from ....utils.profile_utils import TrainProfiler +from ...training_utils.ci_utils import check_grad_norm +from ...training_utils.data import DataIterator, get_batch, get_data_iterator, get_rollout_data +from ...training_utils.log_utils import ( aggregate_forward_results, aggregate_train_losses, log_rollout_data, log_train_step, ) -from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function -from ..training_utils.parallel import get_parallel_state, set_parallel_state +from ...training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function +from ...training_utils.parallel import get_parallel_state, set_parallel_state from . import checkpoint from .lr_scheduler import get_lr_scheduler from .parallel import create_fsdp_parallel_state @@ -681,7 +681,7 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None): offload_policy = CPUOffloadPolicy() if cpu_offload else None layer_cls_to_wrap = model._no_split_modules - assert len(layer_cls_to_wrap) > 0 and layer_cls_to_wrap[0] is not None + assert len(layer_cls_to_wrap) > 0 and next(iter(layer_cls_to_wrap)) is not None modules = [ module diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/experimental/fsdp_utils/arguments.py similarity index 100% rename from miles/backends/fsdp_utils/arguments.py rename to miles/backends/experimental/fsdp_utils/arguments.py diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py similarity index 100% rename from miles/backends/fsdp_utils/checkpoint.py rename to miles/backends/experimental/fsdp_utils/checkpoint.py diff --git a/miles/backends/fsdp_utils/models/__init__.py b/miles/backends/experimental/fsdp_utils/kernels/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/models/__init__.py rename to miles/backends/experimental/fsdp_utils/kernels/__init__.py diff --git a/miles/backends/fsdp_utils/kernels/fused_experts.py b/miles/backends/experimental/fsdp_utils/kernels/fused_experts.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/fused_experts.py rename to miles/backends/experimental/fsdp_utils/kernels/fused_experts.py diff --git a/miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py b/miles/backends/experimental/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py rename to miles/backends/experimental/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py diff --git a/miles/backends/fsdp_utils/lr_scheduler.py b/miles/backends/experimental/fsdp_utils/lr_scheduler.py similarity index 100% rename from miles/backends/fsdp_utils/lr_scheduler.py rename to miles/backends/experimental/fsdp_utils/lr_scheduler.py diff --git a/miles/backends/experimental/fsdp_utils/models/__init__.py b/miles/backends/experimental/fsdp_utils/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/backends/fsdp_utils/models/qwen3_moe.py b/miles/backends/experimental/fsdp_utils/models/qwen3_moe.py similarity index 98% rename from miles/backends/fsdp_utils/models/qwen3_moe.py rename to miles/backends/experimental/fsdp_utils/models/qwen3_moe.py index fe2133f3c1..7471f4aa1a 100644 --- a/miles/backends/fsdp_utils/models/qwen3_moe.py +++ b/miles/backends/experimental/fsdp_utils/models/qwen3_moe.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP -from miles.backends.fsdp_utils.kernels.fused_experts import ( +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import ( DownProjFunction, GateUpProjFunction, MoeSumReduceFunction, diff --git a/miles/backends/fsdp_utils/models/qwen3_moe_hf.py b/miles/backends/experimental/fsdp_utils/models/qwen3_moe_hf.py similarity index 100% rename from miles/backends/fsdp_utils/models/qwen3_moe_hf.py rename to miles/backends/experimental/fsdp_utils/models/qwen3_moe_hf.py diff --git a/miles/backends/fsdp_utils/parallel.py b/miles/backends/experimental/fsdp_utils/parallel.py similarity index 96% rename from miles/backends/fsdp_utils/parallel.py rename to miles/backends/experimental/fsdp_utils/parallel.py index 81e682660a..fa444975bd 100644 --- a/miles/backends/fsdp_utils/parallel.py +++ b/miles/backends/experimental/fsdp_utils/parallel.py @@ -7,7 +7,7 @@ from miles.utils.distributed_utils import get_gloo_group -from ..training_utils.parallel import GroupInfo, ParallelState +from ...training_utils.parallel import GroupInfo, ParallelState logger = logging.getLogger(__name__) diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/experimental/fsdp_utils/update_weight_utils.py similarity index 92% rename from miles/backends/fsdp_utils/update_weight_utils.py rename to miles/backends/experimental/fsdp_utils/update_weight_utils.py index 32948b2840..98000e7d34 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/experimental/fsdp_utils/update_weight_utils.py @@ -17,7 +17,7 @@ from sglang.srt.utils import MultiprocessingSerializer -from miles.utils.distributed_utils import init_process_group +from miles.utils.distributed_utils import get_gloo_group, init_process_group try: @@ -47,6 +47,13 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 + + if dist.get_rank() == 0: + futures = [engine.pause_generation.remote() for engine in self.rollout_engines] + futures.extend([engine.flush_cache.remote() for engine in self.rollout_engines]) + ray.get(futures) + dist.barrier(group=get_gloo_group()) + bucket = [] bucket_size = 0 for name, param in self.model.state_dict().items(): @@ -73,6 +80,11 @@ def update_weights(self) -> None: bucket = [] bucket_size = 0 + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] self.update_bucket_weights(bucket, weight_version=self.weight_version) @@ -172,8 +184,13 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) result = ray.get(ref) - if hasattr(result, "success") and not result.success: + if isinstance(result, dict): + success = result.get("success", True) + error_msg = result.get("error_message") or result.get("message", "unknown error") + else: + success = getattr(result, "success", True) error_msg = getattr(result, "error_message", "unknown error") + if not success: raise RuntimeError( f"Weight sync failed on rollout engine: {error_msg}. " f"Check SGLang version compatibility." ) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 80e7fd9339..2cfd373f49 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -71,6 +71,11 @@ def init( if self._is_main_rank: init_tracking(args, primary=False) + unsupported = {"train_actor", "train_log_probs"} & set(args.profile_target) + if unsupported and args.use_pytorch_profiler: + raise NotImplementedError( + f"--profile-target {' '.join(sorted(unsupported))} is not supported for Megatron backend" + ) self.prof = TrainProfiler(args) # read config and tokenizer serialized to prevent concurrent writing bug. @@ -110,6 +115,15 @@ def init( args, role ) + parallel_state = get_parallel_state() + if parallel_state.cp.size > 1: + from miles_plugins.models.cp_utils import detect_and_setup_hybrid_cp + + for model_chunk in self.model: + detect_and_setup_hybrid_cp( + model_chunk, parallel_state.cp.group, parallel_state.cp.rank, parallel_state.cp.size + ) + verify_megatron_parallel_state(self.model) if role == "critic": @@ -163,9 +177,8 @@ def init( # empty cache after initialization clear_memory() + self._switch_model("actor") if self.args.offload_train: - # recover to actor in the end. - self._switch_model("actor") self.sleep() self.rollout_engines = None diff --git a/miles/backends/megatron_utils/fp32_param_utils.py b/miles/backends/megatron_utils/fp32_param_utils.py new file mode 100644 index 0000000000..afd6bde7f0 --- /dev/null +++ b/miles/backends/megatron_utils/fp32_param_utils.py @@ -0,0 +1,52 @@ +import logging +from collections.abc import Sequence + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +# Parameter attribute used by model definitions to pin parameter dtype. +FORCED_PARAM_DTYPE_ATTR = "_miles_forced_param_dtype" + + +def mark_param_dtype(param: torch.nn.Parameter, dtype: torch.dtype) -> None: + """Mark a parameter with its required runtime dtype.""" + setattr(param, FORCED_PARAM_DTYPE_ATTR, dtype) + + +def enforce_marked_param_dtypes(model_chunks: Sequence[torch.nn.Module]) -> list[str]: + """Apply dtype overrides declared on parameters via ``mark_param_dtype``. + + This keeps the policy in model definitions and avoids model-name checks in + the training/conversion mainline. + + Motivation: Megatron's ``Float16Module`` unconditionally casts every + floating-point parameter to bf16/fp16 at wrap time, and there is no + declarative opt-out in nn.Module or Megatron. Megatron's MoE router hits the + same problem and solves it with ``_maintain_float32_expert_bias`` (see + ``megatron/core/transformer/moe/router.py``), which post-hoc casts the + expert_bias back to fp32. This function generalizes that pattern: callers + mark params with their required dtype at the model-definition site, and we + re-cast after ``get_model`` so the rest of the stack (optimizer, DDP, mbridge + load path) sees the intended dtype. + """ + updated_names: list[str] = [] + for chunk in model_chunks: + for name, param in chunk.named_parameters(): + target_dtype = getattr(param, FORCED_PARAM_DTYPE_ATTR, None) + if target_dtype is None: + continue + + if param.dtype != target_dtype: + # Keep Parameter identity to avoid breaking optimizer/DDP maps. + param.data = param.data.to(dtype=target_dtype) + updated_names.append(name) + + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + if rank == 0 and updated_names: + logger.info("Enforced marked parameter dtypes for %d tensors.", len(updated_names)) + return updated_names diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py new file mode 100644 index 0000000000..7d8b92afed --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py @@ -0,0 +1,254 @@ +import re + +import torch + +FP4_E2M1_MAX = 6.0 +FP8_E4M3_MAX = 448.0 +NVFP4_GROUP_SIZE = 16 + +GATED_PAIR_SUFFIXES = { + ".gate_proj.weight": "gate", + ".up_proj.weight": "up", + ".w1.weight": "gate", + ".w3.weight": "up", +} + + +def _get_ignore_rules(quantization_config) -> list[str]: + ignore_rules = quantization_config.get("ignore", []) or [] + if isinstance(ignore_rules, str): + ignore_rules = [ignore_rules] + exclude_rules = quantization_config.get("exclude_modules", []) or [] + if isinstance(exclude_rules, str): + exclude_rules = [exclude_rules] + return list(ignore_rules) + [rule for rule in exclude_rules if rule not in ignore_rules] + + +def _is_ignored(name: str, ignore_rules: list[str]) -> bool: + for rule in ignore_rules: + if rule.startswith("re:"): + if re.match(rule[3:], name): + return True + continue + if name == rule or name.startswith(f"{rule}."): + return True + return False + + +def quantize_params_nvfp4(args, megatron_name, converted_named_params, quantization_config): + assert quantization_config is not None + assert quantization_config.get("quant_algo") == "NVFP4" or quantization_config.get("quant_method") == "nvfp4" + group_size = _resolve_group_size(quantization_config) + ignore_rules = _get_ignore_rules(quantization_config) + + decoder_layers_pattern = r"decoder\.layers\.(\d+)\.(.+)" + match = re.search(decoder_layers_pattern, megatron_name) + + if not match: + # check mtp layers + mtp_layer_pattern = r"mtp\.layers\.(\d+)\.(.+)" + match = re.search(mtp_layer_pattern, megatron_name) + if not match: + return converted_named_params + _, rest = match.groups() + rest = rest.replace("transformer_layer.", "") + else: + _, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, _ = match.groups() + if rest in [ + "linear_fc1", + "linear_fc2", + ]: + return _quantize_moe_params(converted_named_params, group_size, ignore_rules) + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest in [ + "linear_fc1.weight", + "linear_fc2.weight", + ]: + return _quantize_moe_params(converted_named_params, group_size, ignore_rules) + + # for other parameters, we just return the original converted_named_params + return converted_named_params + + +def _resolve_group_size(quantization_config): + group_size = quantization_config.get("group_size", NVFP4_GROUP_SIZE) + if group_size != NVFP4_GROUP_SIZE: + raise ValueError(f"NVFP4 group_size must be {NVFP4_GROUP_SIZE}, got {group_size}.") + return group_size + + +def _quantize_moe_params(converted_named_params, group_size, ignore_rules): + shared_global_amax = {} + gated_candidates = {} + for converted_name, param in converted_named_params: + base, role = _split_gated_pair_name(converted_name) + if base is None or role is None: + continue + if _should_quantize_param(converted_name, param, group_size, ignore_rules): + gated_candidates.setdefault(base, {})[role] = param + + for base, roles in gated_candidates.items(): + if "gate" in roles and "up" in roles: + gate_amax = roles["gate"].abs().max().to(torch.float32) + up_amax = roles["up"].abs().max().to(torch.float32) + shared_global_amax[base] = torch.max(gate_amax, up_amax) + + quantize_named_params = [] + for converted_name, param in converted_named_params: + if not _should_quantize_param(converted_name, param, group_size, ignore_rules): + quantize_named_params.append((converted_name, param)) + continue + base, _role = _split_gated_pair_name(converted_name) + global_amax = shared_global_amax.get(base) if base else None + qweight, block_scale, weight_scale_2 = quantize_nvfp4(param, global_amax=global_amax, group_size=group_size) + quantize_named_params.append((converted_name, qweight)) + quantize_named_params.append((converted_name.replace(".weight", ".weight_scale"), block_scale)) + quantize_named_params.append((converted_name.replace(".weight", ".weight_scale_2"), weight_scale_2)) + quantize_named_params.append( + (converted_name.replace(".weight", ".input_scale"), torch.ones_like(weight_scale_2, dtype=torch.float32)) + ) + + return quantize_named_params + + +def _should_quantize_param(name, weight, group_size, ignore_rules): + if ignore_rules and _is_ignored(name, ignore_rules): + return False + if not name.endswith(".weight"): + return False + if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + return False + if weight.dim() < 2: + return False + if weight.shape[-1] % group_size != 0: + raise ValueError(f"Last dim {weight.shape[-1]} must be divisible by {group_size} for NVFP4 ({name}).") + return True + + +def _split_gated_pair_name(name: str): + for suffix, role in GATED_PAIR_SUFFIXES.items(): + if name.endswith(suffix): + return name[: -len(suffix)], role + return None, None + + +def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def _quantize_nvfp4_1d( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, + group_size: int = NVFP4_GROUP_SIZE, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + NVFP4 1D quantization (tile shape = 1x16), adapted from + TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. + + Returns: + qweight: uint8 packed fp4, shape (M, K // 2) + block_scale: float8_e4m3fn, shape (M, K // group_size) + global_scale: float32 scalar tensor + """ + weight = weight.contiguous() + m, n = weight.shape + if n % group_size != 0: + raise ValueError(f"NVFP4 requires K divisible by {group_size}, got {n}.") + + weight_f = weight.to(torch.float32) + if global_amax is None: + global_amax = torch.max(torch.abs(weight_f)) + else: + global_amax = global_amax.to(device=weight.device, dtype=torch.float32) + if global_amax.item() == 0.0: + qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) + block_scale = torch.zeros( + (m, n // group_size), + dtype=torch.float8_e4m3fn, + device=weight.device, + ) + global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + return qweight, block_scale, global_scale + + fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) + fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) + + global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) + # global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32), + ) + if global_encode_scale.item() == 0.0: + global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + weight_blocks = weight_f.view(m, n // group_size, group_size) + vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True) + decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale + decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) + + encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale) + scaled = weight_blocks * encode_scale + clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n) + + qweight = cast_to_fp4x2(clipped) + block_scale = decode_scale.squeeze(-1) + return qweight, block_scale, global_decode_scale + + +def quantize_nvfp4( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, + group_size: int = NVFP4_GROUP_SIZE, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if weight.dim() == 2: + return _quantize_nvfp4_1d(weight, global_amax=global_amax, group_size=group_size) + if weight.dim() == 3: + if global_amax is not None: + raise ValueError("global_amax override is only supported for 2D weights.") + qweights = [] + block_scales = [] + global_scales = [] + for idx in range(weight.shape[0]): + qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx], group_size=group_size) + qweights.append(qweight) + block_scales.append(block_scale) + global_scales.append(global_scale) + return ( + torch.stack(qweights, dim=0), + torch.stack(block_scales, dim=0), + torch.stack(global_scales, dim=0), + ) + raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.") diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index e95158bbfc..a9be3696d7 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -132,6 +132,7 @@ def setup_model_and_optimizer( kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) config.timers = None + optimizer = get_megatron_optimizer( config=config, model_chunks=model, diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py index f6d1c28f30..6707993822 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py @@ -139,8 +139,10 @@ def _update_expert_bucket_weights( def _pause_and_prepare_engines(self) -> None: """Pause rollout engines, flush cache, and run pre-process if needed.""" if dist.get_rank() == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + mode = self.args.pause_generation_mode + ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) + if mode not in ("in_place"): + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) # int4/fp4 pre_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: @@ -150,19 +152,18 @@ def _pause_and_prepare_engines(self) -> None: post_process_quantization=False, ) - def _finalize_and_resume_engines(self) -> None: + def _finalize_and_resume_engines(self, post_load_weights: bool = False) -> None: """Run post-process if needed and resume rollout engines.""" if dist.get_rank() == 0: - # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). - if self.quantization_config and self.quantization_config["quant_method"] in [ - "compressed-tensors", - "mxfp8", - ]: - post_process_weights( - rollout_engines=self.rollout_engines, - restore_weights_before_load=False, - post_process_quantization=True, - ) + # post_process_quantization is related to the process_weights_after_loading + # in the sglang rollout side, which should always be invoked after weight + # updating. + post_process_weights( + rollout_engines=self.rollout_engines, + restore_weights_before_load=False, + post_process_quantization=True, + post_load_weights=post_load_weights, + ) ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) @torch.no_grad() diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py index 7548fc2c9c..ba6cebde7e 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py @@ -11,6 +11,9 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import ParallelismContext, RankParallelismConfig +from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config +from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.model_loader import get_model from sglang.srt.model_loader.parameter_mapper import ParameterMapper from sglang.srt.server_args import ServerArgs @@ -18,7 +21,6 @@ from miles.utils.distributed_utils import get_gloo_group -from ..common import post_process_weights from .mixin import DistBucketedWeightUpdateMixin from .p2p_transfer_utils import ( P2PTransferManager, @@ -125,12 +127,7 @@ def _finalize_and_resume_engines(self): for engine in self.rollout_engines ] ) - post_process_weights( - rollout_engines=self.rollout_engines, - post_process_quantization=True, - post_load_weights=True, - ) - super()._finalize_and_resume_engines() + super()._finalize_and_resume_engines(post_load_weights=True) def _update_weight_implementation( self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None @@ -265,6 +262,9 @@ def create_cpu_replica( rl_quant_profile=server_args.rl_quant_profile, ) server_args_module._global_server_args = server_args + initialize_moe_config(server_args) + initialize_fp8_gemm_config(server_args) + initialize_fp4_gemm_config(server_args) with ParallelismContext(parallelism_config): model = get_model( model_config=ModelConfig(model_path), diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 10a75fd4f3..86073d5ab3 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -176,7 +176,8 @@ def update_weights(self) -> None: rank = dist.get_rank() if rank == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + mode = self.args.pause_generation_mode + ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: post_process_weights( @@ -205,17 +206,15 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) - # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). if rank == 0: - if self.quantization_config and self.quantization_config["quant_method"] in [ - "compressed-tensors", - "mxfp8", - ]: - post_process_weights( - rollout_engines=self.rollout_engines, - restore_weights_before_load=False, - post_process_quantization=True, - ) + # `post_process_quantization` is related to the `process_weights_after_loading` + # in the sglang rollout side, which should always be invoked after weight + # updating. + post_process_weights( + rollout_engines=self.rollout_engines, + restore_weights_before_load=False, + post_process_quantization=True, + ) ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) diff --git a/miles/backends/sglang_utils/arguments.py b/miles/backends/sglang_utils/arguments.py index d8ac2deacc..f4b9978a40 100644 --- a/miles/backends/sglang_utils/arguments.py +++ b/miles/backends/sglang_utils/arguments.py @@ -19,6 +19,12 @@ def add_sglang_router_arguments(parser): default=None, help="Port of the SGLang router", ) + parser.add_argument( + "--sglang-router-policy", + type=str, + default=None, + help="Routing policy for the SGLang router (e.g., 'consistent_hashing', 'round_robin')", + ) parser.add_argument( "--sglang-router-request-timeout-secs", type=int, @@ -135,5 +141,12 @@ def validate_args(args): if args.sglang_dp_size > 1: assert args.sglang_enable_dp_attention + if args.sglang_router_policy: + from miles.utils.environ import enable_experimental_rollout_refactor + + assert ( + not enable_experimental_rollout_refactor() + ), "--sglang-router-policy is not supported with MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1" + if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 8b567a744d..1c215c7b62 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -501,8 +501,11 @@ def update_weights_from_distributed( payload, ) - def pause_generation(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) + def pause_generation(self, mode: str = "retract"): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/pause_generation", + json={"mode": mode}, + ) response.raise_for_status() return response diff --git a/miles/backends/training_utils/cp_utils.py b/miles/backends/training_utils/cp_utils.py index 0fbba35b02..e79ccb4469 100644 --- a/miles/backends/training_utils/cp_utils.py +++ b/miles/backends/training_utils/cp_utils.py @@ -1,11 +1,20 @@ +import logging from collections.abc import Callable import torch import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from .parallel import get_parallel_state +try: + from fla.ops.cp import build_cp_context as _fla_build_cp_context +except ImportError: + _fla_build_cp_context = None + +logger = logging.getLogger(__name__) + def get_logits_and_tokens_offset_with_cp( total_length: int, @@ -336,3 +345,30 @@ def slice_log_prob_with_cp( return chunk_1 + chunk_2 else: return torch.cat([chunk_1, chunk_2], dim=0) + + +def build_gdn_cp_context(module: nn.Module, cu_seqlens: torch.Tensor, device: torch.device): + """Build fla CP context for a GatedDeltaNet module from packed sequence boundaries. + + Args: + module: GDN module with ``cp_group`` / ``cp_world_size`` / ``conv_kernel_size``. + cu_seqlens: Global packed sequence boundaries (e.g. ``packed_seq_params.cu_seqlens_q``). + device: Target device. + + Returns ``None`` when CP is not configured on the module (``cp_group`` not set). + Raises ``RuntimeError`` if hybrid CP is configured but ``fla.ops.cp`` is missing. + """ + cp_group = getattr(module, "cp_group", None) + if cp_group is None: + return None + if _fla_build_cp_context is None: + raise RuntimeError( + "Hybrid CP requires fla.ops.cp (flash-linear-attention >= 0.4.2) " "but it could not be imported." + ) + if cu_seqlens is None or cu_seqlens.numel() < 2: + raise ValueError(f"Hybrid CP requires valid cu_seqlens (at least 2 elements) but got {cu_seqlens}") + return _fla_build_cp_context( + cu_seqlens=cu_seqlens.to(device=device, dtype=torch.int32), + group=cp_group, + conv1d_kernel_size=module.conv_kernel_size, + ) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index c6ed8b6ffb..54aa44a6d8 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -121,6 +121,8 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "rollout_routed_experts", "max_seq_lens", "dynamic_global_batch_size", + "weight_versions", + "metadata", ]: continue # Upload per sample mean for each rollout value @@ -151,7 +153,14 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc else: val = val.mean() * cp_size else: - val = sum(val) / len(val) + # Flatten nested lists (e.g. list of lists from async rollout) + flat = val + if isinstance(val[0], (list, tuple)): + flat = [x for sublist in val for x in sublist] + # Skip non-numeric values (e.g. strings from async rollout metadata) + if flat and not isinstance(flat[0], (int, float)): + continue + val = sum(flat) / len(flat) elif isinstance(val, torch.Tensor): val = val.float().mean() else: diff --git a/miles/ray/actor_group.py b/miles/ray/actor_group.py index 54228f4228..669988b52c 100644 --- a/miles/ray/actor_group.py +++ b/miles/ray/actor_group.py @@ -81,7 +81,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): actor_impl = MegatronTrainRayActor else: - from miles.backends.fsdp_utils import FSDPTrainRayActor + from miles.backends.experimental.fsdp_utils import FSDPTrainRayActor actor_impl = FSDPTrainRayActor diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..1aa2e91fe6 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -5,6 +5,7 @@ import os import random import time +import uuid from pathlib import Path from typing import Any @@ -723,6 +724,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] + if any(sample.weight_versions for sample in samples): + train_data["weight_versions"] = [sample.weight_versions for sample in samples] + if "teacher_log_probs" in samples[0].__dict__: train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] @@ -770,6 +774,7 @@ def _split_train_data_by_dp(self, data, dp_size): "rollout_routed_experts", "prompt", "teacher_log_probs", + "weight_versions", ]: if key not in data: continue @@ -941,6 +946,9 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool router_args.log_level = "warn" router_args.request_timeout_secs = args.sglang_router_request_timeout_secs + if args.sglang_router_policy: + router_args.policy = args.sglang_router_policy + if has_pd_disaggregation: router_args.pd_disaggregation = True @@ -1114,6 +1122,8 @@ def _start_session_server(args): args.session_server_ip = args.sglang_router_ip if getattr(args, "session_server_port", None) is None: args.session_server_port = find_available_port(random.randint(5000, 6000)) + if getattr(args, "session_server_instance_id", None) is None: + args.session_server_instance_id = uuid.uuid4().hex ip, port = args.session_server_ip, args.session_server_port if not is_port_available(port): @@ -1196,6 +1206,12 @@ def compute_metrics_from_samples(args, samples): log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() + oldest_versions = [s.oldest_weight_version for s in samples if s.oldest_weight_version is not None] + if oldest_versions: + log_dict |= dict_add_prefix(compute_statistics(oldest_versions), "weight_version/") + mixed = sum(1 for s in samples if len(set(s.weight_versions)) > 1) + log_dict["weight_version/mixed_version_ratio"] = mixed / len(samples) + tito_vals = [s.metadata.get("tito_session_mismatch") for s in samples] tito_vals = [v for v in tito_vals if v is not None] if tito_vals: diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py index cfe8a232d2..feba568164 100644 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -61,8 +61,16 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: metadata = input.sample.metadata if max_seq_len is not None: metadata = {**metadata, "max_seq_len": max_seq_len} + if tracer.session_server_instance_id: + metadata = {**metadata, "session_server_instance_id": tracer.session_server_instance_id} log_prefix = f"[session={tracer.session_id}]" + + session_ip = getattr(input.args, "session_server_ip", None) + session_port = getattr(input.args, "session_server_port", None) + if session_ip and session_port: + metadata = {**metadata, "session_server_id": f"{session_ip}:{session_port}"} + agent_metadata = None t_start = time.monotonic() try: diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 5b9a445adf..2b3016fd4f 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -18,10 +18,11 @@ class OpenAIEndpointTracer: - def __init__(self, router_url: str, session_id: str): + def __init__(self, router_url: str, session_id: str, session_server_instance_id: str | None = None): self.router_url = router_url self.session_id = session_id self.base_url = f"{router_url}/sessions/{session_id}" + self.session_server_instance_id = session_server_instance_id @staticmethod async def create(args: Namespace): @@ -33,9 +34,22 @@ async def create(args: Namespace): "Pass --use-session-server to start the session server." ) session_url = f"http://{session_ip}:{session_port}" + session_server_instance_id = None + try: + health = await post(f"{session_url}/health", {}, action="get") + if isinstance(health, dict): + session_server_instance_id = health.get("session_server_instance_id") + if session_server_instance_id is not None: + args.session_server_instance_id = session_server_instance_id + except Exception as e: + logger.warning("Failed to get session server health from %s: %s", session_url, e) response = await post(f"{session_url}/sessions", {}, action="post") session_id = response["session_id"] - return OpenAIEndpointTracer(router_url=session_url, session_id=session_id) + return OpenAIEndpointTracer( + router_url=session_url, + session_id=session_id, + session_server_instance_id=session_server_instance_id, + ) async def collect_records(self) -> tuple[list[SessionRecord], dict]: try: @@ -185,6 +199,10 @@ def _compute_sample_from_openai_record( case "abort": sample.status = Sample.Status.ABORTED + sample.prefix_cache_info.add(choice.get("meta_info", {})) + if "weight_version" in choice["meta_info"]: + sample.weight_versions.append(choice["meta_info"]["weight_version"]) + return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py index 8e8f42441e..55c1cece1f 100644 --- a/miles/rollout/generate_utils/sample_utils.py +++ b/miles/rollout/generate_utils/sample_utils.py @@ -66,6 +66,7 @@ def _fill_defaults(sample: Sample): metadata=_merge_equal_value("metadata"), generate_function_path=_merge_equal_value("generate_function_path"), train_metadata=_merge_equal_value("train_metadata"), + session_id=_merge_equal_value("session_id"), non_generation_time=_merge_equal_value("non_generation_time"), spec_info=_merge_spec_info(a.spec_info, b.spec_info), prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), diff --git a/miles/rollout/session/session_server.py b/miles/rollout/session/session_server.py index 0377117bdf..bc2633350e 100644 --- a/miles/rollout/session/session_server.py +++ b/miles/rollout/session/session_server.py @@ -36,7 +36,7 @@ def __init__(self, args, backend_url: str): ) # Close the httpx connection pool when uvicorn shuts down to avoid FD leaks. - self.app.add_event_handler("shutdown", self.client.aclose) + self.app.router.on_shutdown.append(self.client.aclose) setup_session_routes(self.app, self, args) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index bf53f446f4..172e906074 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -26,6 +26,8 @@ def setup_session_routes(app, backend, args): logger.info("[session] Skipping session routes (hf_checkpoint not set).") return + session_server_instance_id = getattr(args, "session_server_instance_id", None) + tokenizer = load_tokenizer( hf_checkpoint, chat_template_path=getattr(args, "chat_template_path", None), trust_remote_code=True ) @@ -38,6 +40,13 @@ def setup_session_routes(app, backend, args): registry = SessionRegistry(args, tokenizer, tito_tokenizer=tito_tokenizer) + @app.get("/health") + async def health(): + body = {"status": "ok"} + if session_server_instance_id is not None: + body["session_server_instance_id"] = session_server_instance_id + return body + # --- DEBUG: track in-flight chat_completions --- _inflight_chat = {"count": 0} diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 6ffc67e70b..551021f53c 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -1,8 +1,7 @@ import asyncio import copy -import inspect import logging - +import uuid from argparse import Namespace from collections.abc import Callable from contextlib import contextmanager @@ -15,8 +14,9 @@ from tqdm import tqdm from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled -from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from miles.rollout.base_types import GenerateFnInput, RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.inference_rollout.compatibility import load_generate_function from miles.utils import dumper_utils from miles.utils.async_utils import run from miles.utils.data import Dataset @@ -184,7 +184,12 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if not sample.tokens: # Initialize sample.tokens for the first turn sample.tokens = prompt_ids - output = await post(url, payload) + # Use session_id for consistent hashing routing if router uses consistent_hashing policy + headers = None + if args.sglang_router_policy == "consistent_hashing" and sample.session_id: + headers = {"X-SMG-Routing-Key": sample.session_id} + + output = await post(url, payload, headers=headers) if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree @@ -255,13 +260,12 @@ async def generate_and_rm( # Check sample.generate_function_path for per-sample custom_generate_function_path (e.g., from eval dataset config) custom_func_path = getattr(sample, "generate_function_path", None) or args.custom_generate_function_path - if custom_func_path is not None: - custom_generate_func = load_function(custom_func_path) - # if signature has evaluation, pass evaluation - if "evaluation" in inspect.signature(custom_generate_func).parameters: - sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) - else: - sample = await custom_generate_func(args, sample, sampling_params) + generate_fn = load_generate_function(custom_func_path) if custom_func_path else None + if generate_fn is not None: + output = await generate_fn( + GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params, evaluation=evaluation) + ) + sample = output.samples else: sample = await generate(args, sample, sampling_params) @@ -299,6 +303,12 @@ async def generate_and_rm_group( if state.aborted: return group + # Generate a unique session_id for each sample in the group (consistent hashing only) + if args.sglang_router_policy == "consistent_hashing": + for sample in group: + if sample.session_id is None: + sample.session_id = str(uuid.uuid4()) + tasks = [] for idx, sample in enumerate(group): current_sampling_params = sampling_params.copy() diff --git a/miles/router/router.py b/miles/router/router.py index 09be44b033..51194be4cf 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -36,7 +36,7 @@ def __init__(self, args, verbose=False): self.verbose = verbose self.app = FastAPI() - self.app.add_event_handler("startup", self._start_background_health_check) + self.app.router.on_startup.append(self._start_background_health_check) # URL -> Active Request Count (load state) self.worker_request_counts: dict[str, int] = {} diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 51df300818..375cd6c2c2 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -385,6 +385,17 @@ def add_rollout_arguments(parser): "If set, only on-policy generated tokens will be used in training" ), ) + parser.add_argument( + "--max-weight-staleness", + type=int, + default=None, + help=( + "Maximum allowed gap between a group's oldest weight version and the current " + "engine weight version. Groups exceeding this threshold are recycled back to " + "the data buffer instead of being sent to training. Only effective in fully " + "async mode. None (default) disables staleness filtering." + ), + ) parser.add_argument( "--custom-generate-function-path", type=str, @@ -441,6 +452,19 @@ def add_rollout_arguments(parser): default=1, help="Interval for updating the weights", ) + parser.add_argument( + "--pause-generation-mode", + type=str, + choices=["abort", "retract", "in_place"], + default="retract", + help=( + "How SGLang pauses in-flight requests during weight updates. " + "'abort' immediately terminates all requests (previous default). " + "'retract' moves running requests back to the waiting queue and " + "recomputes KV cache after update. " + "'in_place' freezes requests and resumes with existing KV cache." + ), + ) parser.add_argument( "--keep-old-actor", action="store_true", @@ -1685,7 +1709,7 @@ def parse_args(add_custom_arguments=None): args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node args = set_default_megatron_args(args) else: - from miles.backends.fsdp_utils.arguments import load_fsdp_args + from miles.backends.experimental.fsdp_utils.arguments import load_fsdp_args args = load_fsdp_args(extra_args_provider=add_miles_arguments) args.rank = 0 # Primary process rank for wandb initialization @@ -1693,6 +1717,13 @@ def parse_args(add_custom_arguments=None): assert args.context_parallel_size == 1, "Context parallelism is not supported for FSDP backend." + if not args.ci_test: + raise ValueError( + "The FSDP backend has known issues with SGLang v0.5.10 and is not actively maintained in the current version. " + "It has been moved to miles.backends.experimental. " + "Contributions are welcome if you are interested in improving it." + ) + miles_validate_args(args) if backend == "megatron": @@ -1948,6 +1979,14 @@ def miles_validate_args(args): args.offload_train = True if args.offload_rollout is None: args.offload_rollout = True + if args.sglang_enforce_piecewise_cuda_graph: + logger.warning("Warning: colocate mode with --sglang-enforce-piecewise-cuda-graph may trigger NVLS OOM.") + if not args.sglang_disable_piecewise_cuda_graph: + args.sglang_disable_piecewise_cuda_graph = True + logger.info( + "Colocate mode: defaulting --sglang-disable-piecewise-cuda-graph to avoid NVLS OOM. " + "Use --sglang-enforce-piecewise-cuda-graph to override." + ) if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes: logger.info( f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} " @@ -2101,6 +2140,9 @@ def equal(x, y): ), ("rope_theta", "rotary_base", equal), ]: + # FIXME: Qwen3.5 transfomers has bug. + if getattr(hf_config, "model_type", "") == "qwen3_5_moe_text" and hf_config_name == "intermediate_size": + continue if hasattr(hf_config, hf_config_name): if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( diff --git a/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja b/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja index 7c06122223..07d0cdadbf 100644 --- a/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja +++ b/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja @@ -75,14 +75,11 @@ {%- endif %} {%- endif %} {%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} {%- for message in messages %} {%- set content = render_content(message.content, true)|trim %} {%- if message.role == "system" %} {%- if not loop.first %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- raise_exception('System message must be at the beginning.') }} {%- endif %} {%- elif message.role == "user" %} {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} @@ -151,4 +148,4 @@ {%- else %} {{- '\n' }} {%- endif %} -{%- endif %} \ No newline at end of file +{%- endif %} diff --git a/miles/utils/chat_template_utils/tito_tokenizer.py b/miles/utils/chat_template_utils/tito_tokenizer.py index ce02bec5f1..f279570919 100644 --- a/miles/utils/chat_template_utils/tito_tokenizer.py +++ b/miles/utils/chat_template_utils/tito_tokenizer.py @@ -1,15 +1,20 @@ """TITO tokenizer — incremental tokenization for pretokenized prefix reuse. ``TITOTokenizer`` computes incremental token IDs for non-assistant messages -(tool responses, system injections) that follow the assistant's generated -token sequence, then merges them with the pretokenized prefix — handling -model-specific boundary tokens at the junction. - -The default implementation uses a dummy-message diff: it tokenizes a -synthetic ``[dummy_user, dummy_assistant]`` base with and without the -appended messages, then takes the suffix difference as the incremental IDs. -Model-specific subclasses override ``merge_tokens`` to handle boundary -quirks at the junction. +(tool responses, user follow-ups, system injections) that follow the +assistant's generated token sequence, then merges them with the pretokenized +prefix — handling model-specific boundary tokens at the junction. + +The default implementation incrementally tokenizes appended non-assistant turns +with role-specific synthetic prefixes: + +- contiguous ``tool`` runs use ``[dummy_system, dummy_assistant]`` +- each ``user`` or ``system`` message uses ``[dummy_system]`` + +The appended suffix is processed left-to-right, then the generation prompt for +the next assistant turn is appended once at the end. Model-specific +subclasses only override ``merge_tokens`` for boundary quirks at the prefix +junction. """ from __future__ import annotations @@ -20,7 +25,8 @@ from miles.utils.chat_template_utils.template import apply_chat_template, assert_messages_append_only_with_allowed_role from miles.utils.chat_template_utils.token_seq_comparator import TokenSeqComparator -_DUMMY_USER: dict[str, Any] = {"role": "user", "content": "dummy"} +_DUMMY_SYSTEM: dict[str, Any] = {"role": "system", "content": "dummy system"} +_DUMMY_USER: dict[str, Any] = {"role": "user", "content": "dummy user"} def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: @@ -45,23 +51,12 @@ def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, An # --------------------------------------------------------------------------- -# Base / default tokenizer (dummy-prefix diff) +# Base / default tokenizer # --------------------------------------------------------------------------- class TITOTokenizer: - """Incremental tokenization and prefix merging using dummy-message diff. - - A synthetic base ``[dummy_user, dummy_assistant]`` simulates the assistant - turn boundary so that the diff captures the correct turn-transition tokens: - - 1. ``tokens_without`` = tokenize(base, add_generation_prompt=False) - 2. ``tokens_with`` = tokenize(base + appended, add_generation_prompt=True) - 3. ``incremental_ids = tokens_with[len(tokens_without):]`` - - Subclasses override ``merge_tokens`` to handle model-specific boundary - token quirks. - """ + """Incremental tokenization and prefix merging for appended non-assistant turns.""" max_trim_tokens: int = 0 trailing_token_ids: frozenset[int] = frozenset() @@ -87,20 +82,138 @@ def create_comparator(self) -> TokenSeqComparator: trim_trailing_ids=self.trailing_token_ids or None, ) + def _render_messages( + self, + messages: list[dict[str, Any]], + *, + add_generation_prompt: bool, + tools: list[dict[str, Any]] | None = None, + ) -> str: + return apply_chat_template( + messages, + tokenizer=self.tokenizer, + tokenize=False, + add_generation_prompt=add_generation_prompt, + tools=tools, + **self.chat_template_kwargs, + ) + + def _encode_text(self, text: str) -> list[int]: + return self.tokenizer.encode(text, add_special_tokens=False) + + def _split_appended_segments(self, appended_messages: list[dict[str, Any]]) -> list[list[dict[str, Any]]]: + segments: list[list[dict[str, Any]]] = [] + i = 0 + while i < len(appended_messages): + role = appended_messages[i]["role"] + # Many templates wrap a contiguous tool-response run as one logical + # block, so tool messages are diffed together instead of one-by-one. + if role == "tool": + j = i + 1 + while j < len(appended_messages) and appended_messages[j]["role"] == "tool": + j += 1 + segments.append(appended_messages[i:j]) + i = j + continue + if role in {"user", "system", "assistant"}: + # Assistant-role appends arise when the agent layer inserts a + # non-generated assistant turn between tool calls (e.g., + # terminus-2 self-reflection / planning turns). These are not + # produced by the inference engine, so they need to be + # incrementally tokenized the same way user/system single + # messages are. + segments.append([appended_messages[i]]) + i += 1 + continue + raise ValueError(f"unsupported appended role for TITO segmentation: {role}") + + return segments + + def _tokenize_rendered_suffix( + self, + base_messages: list[dict[str, Any]], + appended_messages: list[dict[str, Any]], + *, + tools: list[dict[str, Any]] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + """Render *base_messages* and *base_messages + appended_messages*, return + tokens for the suffix. + + When *add_generation_prompt* is True and *appended_messages* is empty, + this computes the generation-prompt suffix (the assistant opener tokens). + """ + text_without = self._render_messages(base_messages, add_generation_prompt=False, tools=tools) + text_with = self._render_messages( + base_messages + appended_messages, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + if not text_with.startswith(text_without): + roles = [msg["role"] for msg in appended_messages] if appended_messages else ["generation_prompt"] + raise ValueError(f"rendered suffix diff failed for {roles}") + return self._encode_text(text_with[len(text_without) :]) + + def _tokenize_tool_segment( + self, + appended_messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + # No dummy user to avoid cut think issues. + return self._tokenize_rendered_suffix( + [_DUMMY_SYSTEM, _build_dummy_assistant(appended_messages)], + appended_messages, + tools=tools, + ) + + def _tokenize_user_and_system_segment( + self, + appended_message: dict[str, Any], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + # User/system single-message appends share one synthetic context. + return self._tokenize_rendered_suffix( + [_DUMMY_SYSTEM], + [appended_message], + tools=tools, + ) + + def _tokenize_assistant_segment( + self, + appended_message: dict[str, Any], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + # Assistant-role single-message appends use [dummy_system, dummy_user] + # as the synthetic context so chat templates render the assistant-turn + # boundary tokens correctly (most templates require an assistant turn + # to follow a user or tool turn, never system directly). + return self._tokenize_rendered_suffix( + [_DUMMY_SYSTEM, _DUMMY_USER], + [appended_message], + tools=tools, + ) + def tokenize_additional_non_assistant( self, old_messages: list[dict[str, Any]], new_messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, ) -> list[int]: - """Compute incremental token IDs for non-assistant messages appended - after the pretokenized prefix. - - Only handles tool responses, system injections, etc. — never an - assistant message. Validates that *new_messages* is an append-only - extension of *old_messages* via + """Compute incremental token IDs for messages appended after the + pretokenized prefix. + + Handles tool responses, user, system, and agent-layer-inserted + assistant messages (e.g. terminus-2 self-reflection turns between + tool calls). Assistant messages produced by the inference engine are + NOT passed through this path — they arrive as pretokenized IDs and + are merged directly. Validates that *new_messages* is an + append-only extension of *old_messages* via ``assert_messages_append_only_with_allowed_role``. + The name is retained for backward compatibility; the semantics are + "tokenize appended messages that were not generated by the inference + engine." Agent-layer assistant turns fall into this category. + Args: old_messages: Previously stored messages (prefix). new_messages: Full new message list (must be a superset of @@ -114,29 +227,31 @@ def tokenize_additional_non_assistant( """ assert_messages_append_only_with_allowed_role(old_messages, new_messages, self.allowed_append_roles) appended_messages = new_messages[len(old_messages) :] - - dummy_assistant = _build_dummy_assistant(appended_messages) - base_messages = [_DUMMY_USER, dummy_assistant] - - tokens_without = apply_chat_template( - base_messages, - tokenizer=self.tokenizer, - tokenize=True, - add_generation_prompt=False, + incremental: list[int] = [] + + # Incremental appended content is assembled segment-by-segment using + # the smallest synthetic context that preserves each role's boundary + # tokens. + for segment in self._split_appended_segments(appended_messages): + role = segment[0]["role"] + if role == "tool": + incremental.extend(self._tokenize_tool_segment(segment, tools)) + elif role == "user" or role == "system": + incremental.extend(self._tokenize_user_and_system_segment(segment[0], tools)) + elif role == "assistant": + incremental.extend(self._tokenize_assistant_segment(segment[0], tools)) + else: + raise ValueError(f"unsupported appended role for TITO tokenization: {role}") + + # The next assistant opener depends on the full post-append history, so + # it is derived from the real messages once and appended only at the end. + return incremental + self._tokenize_rendered_suffix( + new_messages, + [], tools=tools, - **self.chat_template_kwargs, - ) - tokens_with = apply_chat_template( - base_messages + list(appended_messages), - tokenizer=self.tokenizer, - tokenize=True, add_generation_prompt=True, - tools=tools, - **self.chat_template_kwargs, ) - return list(tokens_with[len(tokens_without) :]) - def merge_tokens( self, old_messages: list[dict[str, Any]], diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 1e673d1c2d..0aaf792659 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -185,15 +185,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, action="post"): +async def _post(client, url, payload, max_retries=60, action="post", headers=None): retry_count = 0 while retry_count < max_retries: try: if action in ("delete", "get"): assert not payload - response = await getattr(client, action)(url) + response = await getattr(client, action)(url, headers=headers) else: - response = await getattr(client, action)(url, json=payload or {}) + response = await getattr(client, action)(url, json=payload or {}, headers=headers) response.raise_for_status() try: output = response.json() @@ -267,8 +267,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60, action="post"): - return await _post(self._client, url, payload, max_retries, action=action) + async def do_post(self, url, payload, max_retries=60, action="post", headers=None): + return await _post(self._client, url, payload, max_retries, action=action, headers=headers) # Create actors per node created = [] @@ -293,22 +293,18 @@ async def do_post(self, url, payload, max_retries=60, action="post"): # TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) -async def post(url, payload, max_retries=60, action="post"): +async def post(url, payload, max_retries=60, action="post", headers=None): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: - import ray - actor = _next_actor() if actor is not None: - # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) - return await asyncio.to_thread(ray.get, obj_ref) + return await actor.do_post.remote(url, payload, max_retries, action=action, headers=headers) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, action=action) + return await _post(_http_client, url, payload, max_retries, action=action, headers=headers) # TODO unify w/ `post` to add retries and remote-execution diff --git a/miles/utils/mask_utils.py b/miles/utils/mask_utils.py index 0ddb3a1410..cd2d72eb4c 100644 --- a/miles/utils/mask_utils.py +++ b/miles/utils/mask_utils.py @@ -38,7 +38,7 @@ def get_system_message_length(self) -> tuple[int, int]: end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2 gen_token_length = len( self.tokenizer.apply_chat_template( - test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True + test_messages, add_special_tokens=False, tokenize=True, return_dict=False, add_generation_prompt=True ) ) - len(chat_template_token_ids) @@ -53,9 +53,11 @@ def gen_multi_turn_loss_mask_qwen( for i, message in enumerate(messages): if i == 0: - message_ids = self.tokenizer.apply_chat_template([message], tokenize=True, tools=tools) + message_ids = self.tokenizer.apply_chat_template( + [message], tokenize=True, return_dict=False, tools=tools + ) else: - message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) + message_ids = self.tokenizer.apply_chat_template([message], tokenize=True, return_dict=False) if message["role"] != "system" and i > 0: message_ids = message_ids[self.system_message_length :] @@ -80,16 +82,18 @@ def gen_multi_turn_loss_mask_qwen3( all_token_ids = [] prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"} - prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True) + prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True, return_dict=False) for i, message in enumerate(messages): if i == 0: tailed_message_ids = self.tokenizer.apply_chat_template( - [message, prefix_message], tokenize=True, tools=tools + [message, prefix_message], tokenize=True, return_dict=False, tools=tools ) message_ids = tailed_message_ids[: -len(prefix_token_ids)] else: - prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) + prefixed_message_ids = self.tokenizer.apply_chat_template( + [prefix_message, message], tokenize=True, return_dict=False + ) message_ids = prefixed_message_ids[len(prefix_token_ids) :] if message["role"] != "system" and i > 0: diff --git a/miles/utils/megatron_bridge_utils.py b/miles/utils/megatron_bridge_utils.py index 9e5f065cd4..3836240f07 100644 --- a/miles/utils/megatron_bridge_utils.py +++ b/miles/utils/megatron_bridge_utils.py @@ -15,6 +15,13 @@ def patch_megatron_model(model): model_config.share_embeddings_and_output_weights = unwrapped_model.share_embeddings_and_output_weights attribute_was_added = True + # Float16Module casts buffers to bf16, but expert_bias must stay fp32. + # Restore before bridge export reads the values. + for m in model: + for module in m.modules(): + if hasattr(module, "_maintain_float32_expert_bias"): + module._maintain_float32_expert_bias() + try: yield finally: diff --git a/miles/utils/processing_utils.py b/miles/utils/processing_utils.py index 75fd2fb75e..9d12eed8b3 100644 --- a/miles/utils/processing_utils.py +++ b/miles/utils/processing_utils.py @@ -13,7 +13,26 @@ DEFAULT_PATCH_SIZE = 14 -def load_tokenizer(name_or_path: str, chat_template_path: str = None, **kwargs): +_TOKENIZER_CACHE: dict[tuple, PreTrainedTokenizerBase] = {} + + +def _make_cache_key(name_or_path: str, chat_template_path: str | None, kwargs: dict) -> tuple | None: + try: + kwargs_items = tuple(sorted(kwargs.items())) + hash(kwargs_items) + except TypeError: + return None + return (name_or_path, chat_template_path, kwargs_items) + + +def load_tokenizer(name_or_path: str, chat_template_path: str | None = None, **kwargs) -> PreTrainedTokenizerBase: + # Cache keyed by (name, chat_template_path, kwargs) — the fast suite creates + # hundreds of SessionServer / MockSGLangServer fixtures and each previously + # triggered a fresh AutoTokenizer.from_pretrained, tripping HF Hub rate limits. + cache_key = _make_cache_key(name_or_path, chat_template_path, kwargs) + if cache_key is not None and cache_key in _TOKENIZER_CACHE: + return _TOKENIZER_CACHE[cache_key] + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) if chat_template_path: assert os.path.isfile(chat_template_path), ( @@ -23,22 +42,21 @@ def load_tokenizer(name_or_path: str, chat_template_path: str = None, **kwargs): with open(chat_template_path) as f: tokenizer.chat_template = f.read() logger.info("Loaded custom chat template from %s", chat_template_path) + + if cache_key is not None: + _TOKENIZER_CACHE[cache_key] = tokenizer return tokenizer def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict: - forced = { - # force return_tensors to None for input_ids - "return_tensors": None, - } modality_forced = {"return_tensors": "pt"} result = dict(multimodal_inputs) if multimodal_inputs else {} - result.update(forced) - - # set return_tensors="pt" for modality-specific outputs + # return_tensors=None for text (input_ids), "pt" for modality-specific outputs. + # Use per-modality dicts to avoid transformers >=5.0 duplicate kwarg error. + result["text_kwargs"] = {**result.get("text_kwargs", {}), "return_tensors": None} for key in ("audio_kwargs", "images_kwargs", "videos_kwargs"): if key in result: result[key] = {**result[key], **modality_forced} diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index 4da7b68a0c..26a757a2d9 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -1,9 +1,9 @@ import json +from collections.abc import Callable from copy import deepcopy from typing import Any -from transformers import AutoTokenizer - +from miles.utils.processing_utils import load_tokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult AGENTIC_MAX_TURNS: int | None = None @@ -59,7 +59,7 @@ async def execute_tool_call(name: str, params: dict) -> str: return TOOL_EXECUTORS[name](params) -AGENTIC_RETURN_METADATA: dict[str, Any] | None = None +AGENTIC_RETURN_METADATA: dict[str, Any] | Callable | None = None async def run_agentic_tool_call( @@ -112,6 +112,8 @@ async def run_agentic_tool_call( } ) + if callable(AGENTIC_RETURN_METADATA): + return AGENTIC_RETURN_METADATA(metadata=kwargs.get("metadata")) return AGENTIC_RETURN_METADATA @@ -139,7 +141,7 @@ async def run_agentic_noop(**kwargs) -> None: ) -_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) +_TOKENIZER = load_tokenizer("Qwen/Qwen3-0.6B", trust_remote_code=True) class TwoTurnStub: diff --git a/miles/utils/types.py b/miles/utils/types.py index b36f08aecc..4d7d6ef9b2 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -47,6 +47,10 @@ class Status(Enum): # metadata used during training, e.g., what loss to use for this sample. train_metadata: dict | None = None + # Session ID for consistent hashing routing (used when router policy is consistent_hashing) + # TODO: Its definition needs to merge with the session server's session id in the new rollout function. + session_id: str | None = None + non_generation_time: float = 0.0 # time spent in non-generation steps @dataclass @@ -184,6 +188,35 @@ def strip_last_output_tokens(self, n: int, tokenizer) -> None: if self.rollout_routed_experts is not None: self.rollout_routed_experts = self.rollout_routed_experts[:-n] + def reset_for_retry(self) -> None: + """Reset generated outputs so the original prompt can be re-sampled. + + Keeps identity / prompt fields (group_index, index, prompt, label, + multimodal_inputs, metadata, generate_function_path, session_id) and + restores everything else to dataclass defaults. + """ + self.tokens = [] + self.multimodal_train_inputs = None + self.response = "" + self.response_length = 0 + self.reward = None + self.loss_mask = None + self.weight_versions = [] + self.rollout_log_probs = None + self.rollout_routed_experts = None + self.status = Sample.Status.ABORTED + self.non_generation_time = 0.0 + self.spec_info = Sample.SpecInfo() + self.prefix_cache_info = Sample.PrefixCacheInfo() + self.remove_sample = False + self.train_metadata = None + + @property + def oldest_weight_version(self) -> int | None: + """Minimum weight version across all turns (generation calls) for this trajectory.""" + numeric = [int(v) for v in self.weight_versions if str(v).isdigit()] + return min(numeric) if numeric else None + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/miles_plugins/mbridge/qwen3_5.py b/miles_plugins/mbridge/qwen3_5.py index ee629d009f..8da5b7204b 100644 --- a/miles_plugins/mbridge/qwen3_5.py +++ b/miles_plugins/mbridge/qwen3_5.py @@ -254,6 +254,12 @@ def _convert_mtp_param(self, name: str) -> list[str]: def _weight_to_mcore_format( self, mcore_weights_name: str, hf_weights: list[torch.Tensor] ) -> tuple[list[str], list[torch.Tensor]]: + if mcore_weights_name.endswith("self_attention.linear_attn.A_log"): + assert len(hf_weights) == 1 + # Keep A_log in fp32 before TP scatter; this avoids precision loss + # from Bridge's global pre-cast to self.dtype. + return hf_weights[0].to(dtype=torch.float32).contiguous() + if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: # merge qkv assert len(hf_weights) == 3 diff --git a/miles_plugins/models/cp_utils.py b/miles_plugins/models/cp_utils.py new file mode 100644 index 0000000000..87f4205f34 --- /dev/null +++ b/miles_plugins/models/cp_utils.py @@ -0,0 +1,26 @@ +import logging + +import torch.distributed as dist +import torch.nn as nn + +from miles_plugins.models.hf_attention import HuggingfaceAttention + +logger = logging.getLogger(__name__) + + +def detect_and_setup_hybrid_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> int: + """Scan for GatedDeltaNet modules and configure them for native fla CP.""" + count = 0 + for module in model.modules(): + if isinstance(module, HuggingfaceAttention): + linear_attn = getattr(module, "linear_attn", None) + if linear_attn is not None: + linear_attn.cp_group = cp_group + linear_attn.cp_rank = cp_rank + linear_attn.cp_world_size = cp_world_size + module.hybrid_cp = True + count += 1 + + if count > 0: + logger.info(f"Configured hybrid CP on {count} GDN modules (fla native state passing)") + return count diff --git a/miles_plugins/models/hf_attention.py b/miles_plugins/models/hf_attention.py index 7abe09b0ee..aeacc58ea2 100644 --- a/miles_plugins/models/hf_attention.py +++ b/miles_plugins/models/hf_attention.py @@ -38,6 +38,116 @@ def _fix_dtype(d): return ns +def _get_cp_sequence_lengths(cu_seqlens, cp_size, local_total_len=None): + global_seq_lengths = [(cu_seqlens[i + 1] - cu_seqlens[i]).item() for i in range(len(cu_seqlens) - 1)] + local_seq_lengths = [] + for global_seq_len in global_seq_lengths: + if global_seq_len % cp_size != 0: + raise ValueError(f"Expected sequence length {global_seq_len} to be divisible by cp_size={cp_size}") + local_seq_lengths.append(global_seq_len // cp_size) + + if local_total_len is not None and sum(local_seq_lengths) != local_total_len: + raise ValueError(f"Expected local total length {local_total_len}, got {sum(local_seq_lengths)}") + + return global_seq_lengths, local_seq_lengths + + +def _gather_cp_tensors(x, cp_group): + gathered = [torch.empty_like(x) for _ in range(dist.get_world_size(group=cp_group))] + dist.all_gather(gathered, x.contiguous(), group=cp_group) + return gathered + + +def _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + """Convert zigzag ring-attn layout to the contiguous packed shard expected by fla CP.""" + global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0)) + gathered_by_rank = [ + gathered.split(local_seq_lengths, dim=0) for gathered in _gather_cp_tensors(hidden_states, cp_group) + ] + + full_sequences = [] + for seq_idx, global_seq_len in enumerate(global_seq_lengths): + per_rank = [rank_seqs[seq_idx] for rank_seqs in gathered_by_rank] + if global_seq_len % (2 * cp_size) == 0: + subchunk_len = global_seq_len // (2 * cp_size) + full_seq = torch.cat( + [seq[:subchunk_len] for seq in per_rank] + [seq[subchunk_len:] for seq in per_rank][::-1], + dim=0, + ) + else: + # Final local padding is appended contiguously on each rank, not in zigzag order. + full_seq = torch.cat(per_rank, dim=0) + full_sequences.append(full_seq) + + full_stream = torch.cat(full_sequences, dim=0) if full_sequences else hidden_states[:0] + shard_len = hidden_states.size(0) + return full_stream[cp_rank * shard_len : (cp_rank + 1) * shard_len] + + +def _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + """Convert contiguous packed shard layout back to zigzag ring-attn layout.""" + global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0)) + full_stream = torch.cat(_gather_cp_tensors(hidden_states, cp_group), dim=0) + full_sequences = full_stream.split(global_seq_lengths, dim=0) + + local_sequences = [] + for full_seq, global_seq_len, local_seq_len in zip( + full_sequences, global_seq_lengths, local_seq_lengths, strict=True + ): + if global_seq_len % (2 * cp_size) == 0: + subchunk_len = global_seq_len // (2 * cp_size) + parts = full_seq.split(subchunk_len, dim=0) + local_sequences.append(torch.cat([parts[cp_rank], parts[2 * cp_size - 1 - cp_rank]], dim=0)) + else: + local_sequences.append(full_seq.split(local_seq_len, dim=0)[cp_rank]) + + return torch.cat(local_sequences, dim=0) if local_sequences else hidden_states[:0] + + +class _ZigzagToPackedShard(torch.autograd.Function): + """Convert zigzag ring-attn layout to contiguous packed shards for native fla CP.""" + + @staticmethod + def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + ctx.cp_group = cp_group + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + ctx.save_for_backward(cu_seqlens) + return _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + @staticmethod + def backward(ctx, grad_output): + (cu_seqlens,) = ctx.saved_tensors + result = _packed_shard_to_zigzag_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size) + return result, None, None, None, None + + +class _PackedShardToZigzag(torch.autograd.Function): + """Convert contiguous packed shards back to zigzag ring-attn layout.""" + + @staticmethod + def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + ctx.cp_group = cp_group + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + ctx.save_for_backward(cu_seqlens) + return _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + @staticmethod + def backward(ctx, grad_output): + (cu_seqlens,) = ctx.saved_tensors + result = _zigzag_to_packed_shard_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size) + return result, None, None, None, None + + +def _zigzag_to_packed_shard(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + return _ZigzagToPackedShard.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + +def _packed_shard_to_zigzag(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + return _PackedShardToZigzag.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + class _AllGatherForDuplicatedComputation(torch.autograd.Function): """All-gather whose backward just returns the local gradient slice (no reduce). @@ -68,6 +178,10 @@ class HuggingfaceAttention(MegatronModule, ABC): "cross attn" specializations. """ + # Subclasses set this to True when the underlying module handles CP natively + # (e.g. via fla's state-passing CP for DeltaNet), bypassing the all-gather. + hybrid_cp: bool = False + def __init__( self, args, @@ -115,7 +229,22 @@ def forward( group=mpu.get_tensor_model_parallel_group(), ) - if mpu.get_context_parallel_world_size() > 1: + if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp: + cp_size = mpu.get_context_parallel_world_size() + # Native fla CP expects each rank to own a contiguous shard of the + # packed global token stream. In allgather-CP mode the data pipeline + # already provides that layout, so no extra relayout is + # needed here. + if not self.args.allgather_cp: + hidden_states = _zigzag_to_packed_shard( + hidden_states, + cu_seqlens, + mpu.get_context_parallel_group(), + mpu.get_context_parallel_rank(), + cp_size, + ) + + elif mpu.get_context_parallel_world_size() > 1: cp_size = mpu.get_context_parallel_world_size() # Use custom all-gather whose backward returns local gradient # instead of reduce-scatter, since the computation is duplicated. @@ -150,7 +279,17 @@ def forward( output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim] - if mpu.get_context_parallel_world_size() > 1: + if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp: + if not self.args.allgather_cp: + output = _packed_shard_to_zigzag( + output, + cu_seqlens, + mpu.get_context_parallel_group(), + mpu.get_context_parallel_rank(), + cp_size, + ) + + elif mpu.get_context_parallel_world_size() > 1: cp_rank = mpu.get_context_parallel_rank() output_list = [] for i in range(len(cu_seqlens) - 1): diff --git a/miles_plugins/models/qwen3_5.py b/miles_plugins/models/qwen3_5.py index a796c8c49c..5c43d732dc 100644 --- a/miles_plugins/models/qwen3_5.py +++ b/miles_plugins/models/qwen3_5.py @@ -15,6 +15,9 @@ except ImportError: pass +from miles.backends.megatron_utils.fp32_param_utils import mark_param_dtype +from miles.backends.training_utils.cp_utils import build_gdn_cp_context + from .hf_attention import HuggingfaceAttention, _load_hf_config @@ -69,8 +72,12 @@ def __init__(self, config, layer_idx: int): self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) A = torch.empty(self.num_v_heads).uniform_(0, 16) - self.A_log = nn.Parameter(torch.log(A)) + self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) + mark_param_dtype(self.A_log, torch.float32) + # HF stores this norm in fp32, but unlike A_log its precision impact is + # negligible and sglang runs it in bf16 on the rollout side — follow + # config.dtype (bf16) to stay equivalent to rollout. self.norm = FusedRMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, @@ -88,6 +95,8 @@ def forward( ): batch_size, seq_len, _ = hidden_states.shape + cp_context = build_gdn_cp_context(self, cu_seqlens, hidden_states.device) + # Projections (flat layout: [Q_all, K_all, V_all]) mixed_qkv = self.in_proj_qkv(hidden_states) z = self.in_proj_z(hidden_states) @@ -95,10 +104,12 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - # Convolution on the flat QKV + # Convolution on the flat QKV (pass cp_context for boundary handling) + conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens mixed_qkv, _ = self.conv1d( x=mixed_qkv, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, + cp_context=cp_context, ) # Split into Q, K, V (flat split, matching HF layout) @@ -118,17 +129,29 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) + if cp_context is not None: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cp_context.cu_seqlens, + cp_context=cp_context, + ) + else: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/miles_plugins/models/qwen3_next.py b/miles_plugins/models/qwen3_next.py index 92e39ff318..1dbee8acd0 100644 --- a/miles_plugins/models/qwen3_next.py +++ b/miles_plugins/models/qwen3_next.py @@ -18,6 +18,8 @@ except ImportError: pass +from miles.backends.training_utils.cp_utils import build_gdn_cp_context + from .hf_attention import HuggingfaceAttention @@ -108,6 +110,8 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, ): + cp_context = build_gdn_cp_context(self, cu_seqlens, hidden_states.device) + projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) @@ -115,9 +119,11 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) + conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens mixed_qkv, _ = self.conv1d( x=mixed_qkv, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, + cp_context=cp_context, ) query, key, value = torch.split( @@ -140,17 +146,29 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) + if cp_context is not None: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cp_context.cu_seqlens, + cp_context=cp_context, + ) + else: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/amd/run-qwen3-4B-amd.sh similarity index 67% rename from scripts/run-qwen3-30B-A3B.sh rename to scripts/amd/run-qwen3-4B-amd.sh index 19bc70927d..bc6d4d40c0 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/amd/run-qwen3-4B-amd.sh @@ -9,30 +9,34 @@ pkill -9 python sleep 3 pkill -9 ray pkill -9 python -pkill -9 redis set -ex +# keep Ray from blanking HIP/CUDA visibility for the job entrypoint. +export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} +export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES:-"1"} + # will prevent ray from buffering stdout/stderr export PYTHONBUFFERED=16 -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 +if [[ -n "${HIP_VISIBLE_DEVICES:-}" ]]; then + export CUDA_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES}" +fi + +NUM_GPUS=${NUM_GPUS:-8} +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + IFS=',' read -r -a visible_gpu_ids <<< "${CUDA_VISIBLE_DEVICES}" + NUM_GPUS=${#visible_gpu_ids[@]} fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-4B.sh" CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-30B-A3B - #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 - --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_miles/ - --save /root/Qwen3-30B-A3B_miles/ + --hf-checkpoint /root/Qwen3-4B + --ref-load /root/Qwen3-4B_torch_dist + --load /root/Qwen3-4B_miles/ + --save /root/Qwen3-4B_miles/ --save-interval 20 ) @@ -48,7 +52,6 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 1 - --global-batch-size 256 --balance-data ) @@ -62,11 +65,11 @@ EVAL_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -75,7 +78,7 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 20480 + --max-tokens-per-gpu 9216 ) GRPO_ARGS=( @@ -95,23 +98,18 @@ OPTIMIZER_ARGS=( --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer ) WANDB_ARGS=( - #--use-wandb + # --use-wandb # --wandb-project miles-dev - # --wandb-group qwen3-30B-A3B-test + # --wandb-group qwen3-4B-test # --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 + --rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.7 - --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) MISC_ARGS=( @@ -127,14 +125,13 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } }" @@ -142,7 +139,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/scripts/run-glm4.7-flash.sh b/scripts/run-glm4.7-flash.sh index 954e2b8225..35b1687e27 100644 --- a/scripts/run-glm4.7-flash.sh +++ b/scripts/run-glm4.7-flash.sh @@ -1,8 +1,5 @@ #!/bin/bash -# Notice: run this command to upgrade transformers version which supports glm4.7-flash. -# pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa - # for rerun the task pkill -9 sglang sleep 3 diff --git a/scripts/run-llama3.2-3B-Instruct-amd.sh b/scripts/run-llama3.2-3B-Instruct-amd.sh deleted file mode 100644 index eb5d5709ce..0000000000 --- a/scripts/run-llama3.2-3B-Instruct-amd.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash - -# hf download meta-llama/Llama-3.2-3B-Instruct --local-dir /root/Llama-3.2-3B-Instruct - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -euxo pipefail - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -# if [ "$NVLINK_COUNT" -gt 0 ]; then -# HAS_NVLINK=1 -# else -# HAS_NVLINK=0 -# fi -# echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/llama3.2-3B-Instruct-amd.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Llama-3.2-3B-Instruct - --ref-load ${MODEL_DIR}/Llama-3.2-3B-Instruct_torch_dist - --load ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type math - --num-epoch 1 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 16384 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 10 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 8 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group llama3.2-3B - # --wandb-key ${WANDB_API_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.4 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/workspace/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh deleted file mode 100755 index 44257cc77f..0000000000 --- a/scripts/run-qwen3-4B-amd.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/root}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/root}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/root}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-4B - --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist - --load ${MODEL_DIR}/Qwen3-4B_miles/ - --save ${MODEL_DIR}/Qwen3-4B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# Dynamically detect Megatron-LM installation path -MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="{ - \"env_vars\": { - \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } - }" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} diff --git a/scripts/run-qwen3-8B-amd.sh b/scripts/run-qwen3-8B-amd.sh deleted file mode 100644 index 979ffa18e0..0000000000 --- a/scripts/run-qwen3-8B-amd.sh +++ /dev/null @@ -1,194 +0,0 @@ -#!/bin/bash - - -# bash scripts/run-qwen3-4B-amd.sh - - -####clear before training -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# Current Model convert script on AMD GPU has some issue, please download the converted model from here: https://huggingface.co/zyzshishui0627/models - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-8B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-8B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist - # --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist_amd_new - --load ${MODEL_DIR}/Qwen3-8B_miles/ - --save ${MODEL_DIR}/Qwen3-8B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) -#################### - - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" - } - }' \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - - - - - - - - - diff --git a/scripts/run_glm47_flash.py b/scripts/run_glm47_flash.py index cb6ce15a49..e9792676d4 100644 --- a/scripts/run_glm47_flash.py +++ b/scripts/run_glm47_flash.py @@ -24,10 +24,6 @@ class ScriptArgs(U.ExecuteTrainConfig): def prepare(args: ScriptArgs): U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") - # GLM-4.7-Flash requires a newer transformers version - U.exec_command( - "pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa" - ) U.exec_command( f"hf download {args.model_org}/{args.model_name} " f"--local-dir {args.model_dir}/{args.model_name}" ) diff --git a/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py b/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py new file mode 100644 index 0000000000..793dc65f38 --- /dev/null +++ b/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + model_name: str = "Qwen3.5-35B-A3B" + megatron_model_type: str = "qwen3.5-35B-A3B" + num_gpus_per_node: int = 8 + hardware: Literal["H200"] = "H200" + enable_eval: bool = True + extra_args: str = "" + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.hf_download_dataset("zhuzilin/aime-2024", data_dir=args.data_dir) + + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{args.model_name}", + megatron_path=args.megatron_path, + ) + + +def execute(args: ScriptArgs): + ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{args.model_name} " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + f"--save {load_save_path} " + f"--save-interval {2 if args.mode == 'debug_minimal' else 20} " + ) + + rollout_args = ( + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + f"--num-rollout {64 if args.mode == 'debug_minimal' else 3000} " + f"--rollout-batch-size {8 if args.mode == 'debug_minimal' else 32} " + f"--n-samples-per-prompt {2 if args.mode == 'debug_minimal' else 8} " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + f"--global-batch-size {16 if args.mode == 'debug_minimal' else 256} " + "--balance-data " + ) + + eval_args = "" + if (args.mode != "debug_minimal") and args.enable_eval: + eval_args += ( + "--eval-interval 20 " + f"--eval-prompt-data aime {args.data_dir}/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 16 " + "--eval-max-response-len 16384 " + "--eval-top-p 1 " + ) + + # CP=2 EP=8: validated on 8x H200 + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.7 " + "--sglang-ep-size 8 " + "--sglang-cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 " + # mtp speculative decoding + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + "--sglang-max-running-requests 512 " + "--sglang-mamba-scheduler-strategy extra_buffer " + ) + + mtp_args = "--enable-mtp-training " "--mtp-num-layers 1 " "--mtp-loss-scaling-factor 0.2 " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--moe-token-dispatcher-type flex " + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + extra_env_vars={ + "SGLANG_ENABLE_SPEC_V2": "1", + }, + megatron_path=args.megatron_path, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/tests/e2e/ckpt/test_glm47_flash_ckpt.py b/tests/e2e/ckpt/test_glm47_flash_ckpt.py index 72ee0127d0..c2b872a1f2 100644 --- a/tests/e2e/ckpt/test_glm47_flash_ckpt.py +++ b/tests/e2e/ckpt/test_glm47_flash_ckpt.py @@ -28,11 +28,6 @@ def _get_latest_checkpointed_iteration() -> int: def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - # GLM-4.7-Flash requires a newer transformers version. - U.exec_command( - "pip install git+https://github.com/huggingface/transformers.git@" - "76732b4e7120808ff989edbd16401f61fa6a0afa --break-system-packages" - ) U.exec_command(f"hf download zai-org/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.exec_command(f"rm -rf /root/models/{MODEL_NAME}_miles") U.hf_download_dataset("zhuzilin/dapo-math-17k") diff --git a/tests/e2e/megatron/test_glm47_flash_r3_mtp.py b/tests/e2e/megatron/test_glm47_flash_r3_mtp.py index 698105b04f..2be5dd0f28 100644 --- a/tests/e2e/megatron/test_glm47_flash_r3_mtp.py +++ b/tests/e2e/megatron/test_glm47_flash_r3_mtp.py @@ -13,10 +13,6 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - # GLM-4.7-Flash requires a newer transformers version - U.exec_command( - "pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa --break-system-packages" - ) U.exec_command(f"hf download zai-org/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/dapo-math-17k") U.hf_download_dataset("zhuzilin/aime-2024") @@ -77,7 +73,6 @@ def execute(): "--eps-clip 0.2 " "--eps-clip-high 0.28 " "--use-rollout-routing-replay " - "--use-miles-router " ) optimizer_args = ( diff --git a/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py b/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py index 5cb080f001..5c59cdb036 100644 --- a/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py +++ b/tests/e2e/megatron/test_moonlight_16B_A3B_r3.py @@ -69,7 +69,6 @@ def execute(): "--entropy-coef 0.00 " "--eps-clip 4e-4 " "--use-rollout-routing-replay " - "--use-miles-router " ) optimizer_args = ( diff --git a/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py index 8b54176d12..3a17241f42 100644 --- a/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py +++ b/tests/e2e/megatron/test_qwen3_30B_A3B_r3.py @@ -77,7 +77,6 @@ def execute(): "--eps-clip 4e-4 " "--use-tis " "--use-rollout-routing-replay " - "--use-miles-router " ) optimizer_args = ( diff --git a/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py b/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py new file mode 100644 index 0000000000..f951cf4f3a --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py @@ -0,0 +1,153 @@ +"""E2E test for Qwen3.5-35B-A3B with Context Parallel (CP=2 and CP=4). + +Validates that GDN layers use real fla native CP (state passing) instead of +duplicated all-gather computation. See: https://github.com/radixark/miles/issues/878 +""" + +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3.5-35B-A3B" +MODEL_TYPE = "qwen3.5-35B-A3B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def _execute_with_cp(cp_size: int): + """Run a short training loop with the given context-parallel size.""" + assert NUM_GPUS % cp_size == 0 + ep_size = NUM_GPUS // cp_size + + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + f"--context-parallel-size {cp_size} " + f"--expert-model-parallel-size {ep_size} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.7 " + f"--sglang-ep-size {NUM_GPUS} " + "--sglang-max-running-requests 512 " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + mtp_args = "--enable-mtp-training " "--mtp-num-layers 1 " "--mtp-loss-scaling-factor 0.2 " + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + "--moe-token-dispatcher-type flex " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +def execute_cp2(): + """Qwen3.5-35B-A3B with CP=2.""" + _execute_with_cp(cp_size=2) + + +def execute_cp4(): + """Qwen3.5-35B-A3B with CP=4.""" + _execute_with_cp(cp_size=4) + + +if __name__ == "__main__": + cp_size = int(os.environ.get("CP_SIZE", "2")) + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + _execute_with_cp(cp_size) diff --git a/tests/test_qwen3_5_mtp_bridge_mapping.py b/tests/e2e/megatron/test_qwen3_5_mtp_bridge_mapping.py similarity index 100% rename from tests/test_qwen3_5_mtp_bridge_mapping.py rename to tests/e2e/megatron/test_qwen3_5_mtp_bridge_mapping.py diff --git a/tests/e2e/precision/test_hf_attention_cp_relayout.py b/tests/e2e/precision/test_hf_attention_cp_relayout.py new file mode 100644 index 0000000000..39ab46f915 --- /dev/null +++ b/tests/e2e/precision/test_hf_attention_cp_relayout.py @@ -0,0 +1,101 @@ +"""Distributed correctness test for zigzag <-> packed-shard hybrid CP relayout. + +Run with: + torchrun --nproc_per_node=2 tests/e2e/precision/test_hf_attention_cp_relayout.py + torchrun --nproc_per_node=4 tests/e2e/precision/test_hf_attention_cp_relayout.py +""" + +import os +import sys + +import torch +import torch.distributed as dist + +from miles_plugins.models.hf_attention import _packed_shard_to_zigzag, _zigzag_to_packed_shard + + +def setup_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank, world_size, local_rank + + +def _make_subchunk(sample_id: int, sub_id: int, chunk_len: int, device: torch.device) -> torch.Tensor: + base = sample_id * 1000 + sub_id * 100 + values = torch.arange(base, base + chunk_len, device=device, dtype=torch.float32) + return values.view(-1, 1, 1) + + +def _build_rank_inputs(rank: int, world_size: int, device: torch.device): + chunk_lens = [3, 5] + tail_pad_local_len = 3 + zigzag_chunks = [] + full_sequences = [] + cu = [0] + + for sample_id, chunk_len in enumerate(chunk_lens): + subchunks = [_make_subchunk(sample_id, sub_id, chunk_len, device) for sub_id in range(2 * world_size)] + zigzag_chunks.extend([subchunks[rank], subchunks[2 * world_size - 1 - rank]]) + full_sequences.append(torch.cat(subchunks, dim=0)) + cu.append(cu[-1] + 2 * world_size * chunk_len) + + tail_pad = (rank * 10000 + torch.arange(tail_pad_local_len, device=device, dtype=torch.float32)).view(-1, 1, 1) + zigzag_chunks.append(tail_pad) + full_sequences.append( + torch.cat( + [ + (r * 10000 + torch.arange(tail_pad_local_len, device=device, dtype=torch.float32)).view(-1, 1, 1) + for r in range(world_size) + ], + dim=0, + ) + ) + cu.append(cu[-1] + world_size * tail_pad_local_len) + + zigzag = torch.cat(zigzag_chunks, dim=0).requires_grad_(True) + packed_full = torch.cat(full_sequences, dim=0) + local_len = zigzag.size(0) + packed_shard = packed_full[rank * local_len : (rank + 1) * local_len] + cu_seqlens = torch.tensor(cu, device=device, dtype=torch.int32) + return zigzag, packed_shard, cu_seqlens + + +def test_relayout(rank: int, world_size: int): + device = torch.device(f"cuda:{rank}") + cp_group = dist.group.WORLD + + zigzag, expected_packed_shard, cu_seqlens = _build_rank_inputs(rank, world_size, device) + + packed_shard = _zigzag_to_packed_shard(zigzag, cu_seqlens, cp_group, rank, world_size) + roundtrip = _packed_shard_to_zigzag(packed_shard, cu_seqlens, cp_group, rank, world_size) + + packed_ok = torch.equal(packed_shard, expected_packed_shard) + roundtrip_ok = torch.equal(roundtrip, zigzag) + + loss = roundtrip.sum() + loss.backward() + grad_ok = torch.equal(zigzag.grad, torch.ones_like(zigzag)) + + passed = packed_ok and roundtrip_ok and grad_ok + if rank == 0: + print(f"\n=== HF Attention Hybrid CP Relayout Test CP={world_size} ===") + print(f"zigzag->packed PASS: {packed_ok}") + print(f"roundtrip PASS: {roundtrip_ok}") + print(f"backward PASS: {grad_ok}") + if not passed: + sys.exit(1) + + +def main(): + rank, world_size, _ = setup_dist() + try: + test_relayout(rank, world_size) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/precision/test_qwen3_5_cp_correctness.py b/tests/e2e/precision/test_qwen3_5_cp_correctness.py new file mode 100644 index 0000000000..d0a2f3f32b --- /dev/null +++ b/tests/e2e/precision/test_qwen3_5_cp_correctness.py @@ -0,0 +1,147 @@ +"""Correctness test for Qwen3.5 GDN with native fla Context Parallel. + +Run with: + torchrun --nproc_per_node=2 tests/test_qwen3_5_cp_correctness.py # CP=2 + torchrun --nproc_per_node=4 tests/test_qwen3_5_cp_correctness.py # CP=4 + +Validates that GDN forward+backward with native fla CP produces results +consistent with the non-CP (single-rank full-sequence) baseline. +""" + +import os +import sys + +import torch +import torch.distributed as dist + + +def setup_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank, world_size, local_rank + + +def build_gdn_module(device, dtype=torch.bfloat16): + """Build a small Qwen3.5 GDN module for testing.""" + + class FakeConfig: + hidden_size = 256 + linear_num_value_heads = 4 + linear_num_key_heads = 2 + linear_key_head_dim = 64 + linear_value_head_dim = 64 + linear_conv_kernel_dim = 4 + hidden_act = "silu" + rms_norm_eps = 1e-6 + + FakeConfig.dtype = dtype + + from miles_plugins.models.qwen3_5 import Qwen3_5GatedDeltaNet + + return Qwen3_5GatedDeltaNet(FakeConfig, layer_idx=0).to(device=device, dtype=dtype) + + +def test_cp_forward_backward(rank, world_size): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 + + # ---- Reference: full sequence on rank 0 (no CP) ---- + torch.manual_seed(42) + model_ref = build_gdn_module(device, dtype) + + total_seq_len = 128 * world_size # must be divisible by world_size + batch = 1 + + torch.manual_seed(123) + full_hidden = torch.randn(batch, total_seq_len, 256, device=device, dtype=dtype, requires_grad=True) + full_cu = torch.tensor([0, total_seq_len], dtype=torch.int32, device=device) + + # Forward without CP + ref_out = model_ref(full_hidden, cu_seqlens=full_cu) + ref_loss = ref_out.sum() + ref_loss.backward() + ref_grad = full_hidden.grad.clone() + + # ---- Test: CP across ranks ---- + torch.manual_seed(42) + model_cp = build_gdn_module(device, dtype) + # Copy weights from ref to ensure identical params + model_cp.load_state_dict(model_ref.state_dict()) + + # Set up CP context on the module + cp_group = dist.group.WORLD + model_cp.cp_group = cp_group + model_cp.cp_rank = rank + model_cp.cp_world_size = world_size + + # Each rank gets its local chunk + local_seq_len = total_seq_len // world_size + start = rank * local_seq_len + end = start + local_seq_len + + torch.manual_seed(123) + full_hidden_cp = torch.randn(batch, total_seq_len, 256, device=device, dtype=dtype) + local_hidden = full_hidden_cp[:, start:end, :].clone().contiguous().requires_grad_(True) + + # Global cu_seqlens (build_gdn_cp_context expects global boundaries) + global_cu = torch.tensor([0, total_seq_len], dtype=torch.int32, device=device) + + # Forward with CP + cp_out = model_cp(local_hidden, cu_seqlens=global_cu) + cp_loss = cp_out.sum() + + # Reduce loss across ranks to match reference + dist.all_reduce(cp_loss, op=dist.ReduceOp.SUM) + + cp_loss.backward() + + # ---- Gather outputs for comparison ---- + gathered_out = [torch.zeros_like(cp_out) for _ in range(world_size)] + dist.all_gather(gathered_out, cp_out.contiguous()) + full_cp_out = torch.cat(gathered_out, dim=1) + + gathered_grad = [torch.zeros_like(local_hidden.grad) for _ in range(world_size)] + dist.all_gather(gathered_grad, local_hidden.grad.contiguous()) + full_cp_grad = torch.cat(gathered_grad, dim=1) + + if rank == 0: + # Compare outputs + out_diff = (ref_out.detach().float() - full_cp_out.detach().float()).abs() + out_max_diff = out_diff.max().item() + out_rel_diff = (out_diff / (ref_out.detach().float().abs() + 1e-8)).max().item() + + # Compare gradients + grad_diff = (ref_grad.float() - full_cp_grad.float()).abs() + grad_max_diff = grad_diff.max().item() + grad_rel_diff = (grad_diff / (ref_grad.float().abs() + 1e-8)).max().item() + + print(f"\n=== CP={world_size} Correctness Test ===") + print(f"Forward max abs diff: {out_max_diff:.6e} max rel diff: {out_rel_diff:.6e}") + print(f"Backward max abs diff: {grad_max_diff:.6e} max rel diff: {grad_rel_diff:.6e}") + + # bf16 tolerance: 1e-2 is generous for bf16 accumulated ops + fwd_ok = out_max_diff < 1e-2 + bwd_ok = grad_max_diff < 1e-2 + print(f"Forward PASS: {fwd_ok}") + print(f"Backward PASS: {bwd_ok}") + + if not (fwd_ok and bwd_ok): + print("FAILED!") + sys.exit(1) + else: + print(f"CP={world_size} test PASSED!") + + +def main(): + rank, world_size, _ = setup_dist() + try: + test_cp_forward_backward(rank, world_size) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/sglang/test_r3_router_equivalence.py b/tests/e2e/sglang/test_r3_router_equivalence.py new file mode 100644 index 0000000000..6d6ad08234 --- /dev/null +++ b/tests/e2e/sglang/test_r3_router_equivalence.py @@ -0,0 +1,259 @@ +"""E2E test: verify sglang router and miles router produce identical rollout +routing replay results across MoE models. + +Design +~~~~~~ +For each model in ``MODEL_REGISTRY``, run the same rollout workload twice +under ``--debug-rollout-only --sglang-enable-deterministic-inference +--use-rollout-routing-replay``: + +1. ``variant=miles``: with ``--use-miles-router`` (Python middleware + router wrapping the Rust gateway). +2. ``variant=sgl``: without ``--use-miles-router`` (direct Rust gateway, + which is what PR #1015 drops R3 tests onto). + +Each run writes a JSONL of per-sample ``(tokens, rollout_log_probs, +rollout_routed_experts)`` via the custom generate function in +``utils.router_equivalence_generate``. Once both runs finish we diff +the dumps; they must match byte-for-byte (deterministic inference + +identical prompts). + +Backend / checkpoint +~~~~~~~~~~~~~~~~~~~~ +Megatron backend (same as the sibling ``tests/e2e/megatron/*_r3.py`` +tests) — sourcing ``scripts/models/{type}.sh`` populates +``args.num_layers`` / ``args.moe_router_topk`` that the rollout-side +reshape of ``routed_experts`` depends on. We do *not* set +``--use-kl-loss`` or ``--kl-coef`` > 0, which is what gates the +``--ref-load`` existence check (``miles/utils/arguments.py``), and +``--debug-rollout-only`` makes ``_compute_megatron_num_gpus`` return +``0`` so no megatron actor is spawned and the checkpoint is never +loaded. This lets us get away with a single H200 and no +``convert_hf_to_torch_dist`` step. + +Controls +~~~~~~~~ +- ``ROUTER_EQ_MODEL_FAMILY``: ``qwen3_30b_a3b`` (default) | ``glm47_flash``. +- Single H200, bf16, rollout batch 10, num_rollout 1. Single engine + (``--rollout-num-gpus-per-engine 1``) so both variants hit the same + underlying sglang process topology. +""" + +import base64 +import json +import os +import shutil +from dataclasses import dataclass +from pathlib import Path + +import miles.utils.external_utils.command_utils as U + +MODEL_FAMILY = os.environ.get("ROUTER_EQ_MODEL_FAMILY", "qwen3_30b_a3b") +DUMP_ROOT = Path(os.environ.get("ROUTER_EQ_DUMP_ROOT", "/tmp/router-eq")) +PROMPT_DATA_PATH = "/root/datasets/dapo-math-17k/dapo-math-17k.jsonl" +NUM_PROMPTS = int(os.environ.get("ROUTER_EQ_NUM_PROMPTS", "10")) +MAX_RESPONSE_LEN = int(os.environ.get("ROUTER_EQ_MAX_RESPONSE_LEN", "256")) + +# Repo root (tests/e2e/sglang/test_*.py → parents[3]). Used to prepend the +# miles repo onto the Ray actor PYTHONPATH so the custom generate function is +# importable regardless of where the worktree lives. +_REPO_ROOT = str(Path(__file__).resolve().parents[3]) + + +@dataclass(frozen=True) +class ModelConfig: + model_name: str + hf_repo: str + local_dir: str + megatron_model_type: str + reasoning_parser: str | None = None + num_gpus: int = 1 + + +MODEL_REGISTRY: dict[str, ModelConfig] = { + "qwen3_30b_a3b": ModelConfig( + model_name="Qwen3-30B-A3B", + hf_repo="Qwen/Qwen3-30B-A3B", + local_dir="/root/models/Qwen3-30B-A3B", + megatron_model_type="qwen3-30B-A3B", + reasoning_parser=None, + num_gpus=1, + ), + "glm47_flash": ModelConfig( + model_name="GLM-4.7-Flash", + hf_repo="zai-org/GLM-4.7-Flash", + local_dir="/root/models/GLM-4.7-Flash", + megatron_model_type="glm4.7-flash", + reasoning_parser="glm45", + num_gpus=1, + ), +} + + +def _get_config() -> ModelConfig: + if MODEL_FAMILY not in MODEL_REGISTRY: + raise ValueError(f"Unknown ROUTER_EQ_MODEL_FAMILY={MODEL_FAMILY!r}; choose from {list(MODEL_REGISTRY)}") + return MODEL_REGISTRY[MODEL_FAMILY] + + +def prepare() -> None: + cfg = _get_config() + U.exec_command("mkdir -p /root/models /root/datasets") + if not Path(cfg.local_dir).exists(): + U.exec_command(f"hf download {cfg.hf_repo} --local-dir {cfg.local_dir}") + if not Path(PROMPT_DATA_PATH).exists(): + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + +def _variant_dir(variant: str) -> Path: + return DUMP_ROOT / MODEL_FAMILY / variant + + +def _variant_dump_path(variant: str) -> Path: + return _variant_dir(variant) / "dump.jsonl" + + +def _build_train_args(cfg: ModelConfig, variant: str) -> str: + ckpt_args = f"--hf-checkpoint {cfg.local_dir} " + + rollout_args = ( + f"--prompt-data {PROMPT_DATA_PATH} " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rm-type deepscaler " + "--num-rollout 1 " + f"--rollout-batch-size {NUM_PROMPTS} " + "--n-samples-per-prompt 1 " + f"--rollout-max-response-len {MAX_RESPONSE_LEN} " + "--rollout-temperature 0.0 " + f"--global-batch-size {NUM_PROMPTS} " + "--rollout-seed 42 " + ) + + generate_args = "--custom-generate-function-path " "tests.e2e.sglang.utils.router_equivalence_generate.generate " + + router_args = "--use-rollout-routing-replay " + if variant == "miles": + router_args += "--use-miles-router " + + # Minimal megatron perf args — 1 GPU, no parallelism. We don't actually + # start a megatron actor under --debug-rollout-only, so these are only + # consumed by the argument parser. + perf_args = ( + "--tensor-model-parallel-size 1 " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + ) + + sglang_args = ( + f"--rollout-num-gpus-per-engine {cfg.num_gpus} " + "--sglang-enable-deterministic-inference " + "--sglang-mem-fraction-static 0.85 " + ) + if cfg.reasoning_parser: + sglang_args += f"--sglang-reasoning-parser {cfg.reasoning_parser} " + + infra_args = ( + "--debug-rollout-only " + "--ci-test " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {cfg.num_gpus} " + "--colocate " + ) + + return ckpt_args + rollout_args + generate_args + router_args + perf_args + sglang_args + infra_args + + +def _run_variant(cfg: ModelConfig, variant: str) -> None: + dump_dir = _variant_dir(variant) + if dump_dir.exists(): + shutil.rmtree(dump_dir) + dump_dir.mkdir(parents=True, exist_ok=True) + dump_path = _variant_dump_path(variant) + + train_args = _build_train_args(cfg, variant) + U.execute_train( + train_args=train_args, + num_gpus_per_node=cfg.num_gpus, + megatron_model_type=cfg.megatron_model_type, + extra_env_vars={ + "PYTHONPATH": "/root/Megatron-LM", + "MILES_ROUTER_EQ_DUMP_PATH": str(dump_path), + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + }, + ) + + +def _load_dump(path: Path) -> list[dict]: + with open(path) as f: + records = [json.loads(line) for line in f if line.strip()] + records.sort(key=lambda r: r["index"]) + return records + + +def _assert_records_equal(left: list[dict], right: list[dict]) -> None: + assert len(left) == len(right), f"dump length differs: {len(left)} vs {len(right)}" + + for i, (a, b) in enumerate(zip(left, right, strict=True)): + assert a["index"] == b["index"], f"record {i}: index {a['index']} vs {b['index']}" + + # Tokens and status must match exactly — deterministic decoding. + for field in ("status", "response_length", "tokens"): + assert ( + a[field] == b[field] + ), f"index={a['index']} field={field} mismatch:\n miles: {a[field]}\n sgl: {b[field]}" + + # Logprobs are f32 from sglang; in deterministic mode they should be + # bit-identical, but tolerate tiny float noise as a safety margin. + la = a["rollout_log_probs"] or [] + lb = b["rollout_log_probs"] or [] + assert len(la) == len(lb), f"index={a['index']} logprob length differs" + for j, (xa, xb) in enumerate(zip(la, lb, strict=True)): + assert abs(xa - xb) <= 1e-6, f"index={a['index']} logprob[{j}] {xa} vs {xb}" + + # routed_experts: deterministic int32 → must be byte-identical. + assert ( + a["rollout_routed_experts_shape"] == b["rollout_routed_experts_shape"] + ), f"index={a['index']} routed_experts_shape mismatch" + ea = a["rollout_routed_experts_b64"] + eb = b["rollout_routed_experts_b64"] + if ea is None and eb is None: + continue + assert ea is not None and eb is not None, f"index={a['index']} one side missing routed_experts" + # Compare raw bytes, not the base64 string (equivalent, but clearer error). + ba = base64.b64decode(ea) + bb = base64.b64decode(eb) + assert ba == bb, f"index={a['index']} routed_experts bytes differ" + + +def execute() -> None: + cfg = _get_config() + for variant in ("miles", "sgl"): + _run_variant(cfg, variant) + + miles_records = _load_dump(_variant_dump_path("miles")) + sgl_records = _load_dump(_variant_dump_path("sgl")) + + assert miles_records, "miles-router run produced no dump records" + assert sgl_records, "sglang-router run produced no dump records" + + _assert_records_equal(miles_records, sgl_records) + + print(f"[router-eq] model_family={MODEL_FAMILY} variants miles/sgl " f"match across {len(miles_records)} samples") + + +def test_r3_router_equivalence(): + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/e2e/sglang/test_session_server_tool_call.py b/tests/e2e/sglang/test_session_server_tool_call.py index 43dc05be10..8987903921 100644 --- a/tests/e2e/sglang/test_session_server_tool_call.py +++ b/tests/e2e/sglang/test_session_server_tool_call.py @@ -57,11 +57,6 @@ def _get_config() -> ModelConfig: def prepare(): cfg = _get_config() U.exec_command("mkdir -p /root/models /root/datasets") - if MODEL_FAMILY == "glm47": - U.exec_command( - "pip install git+https://github.com/huggingface/transformers.git@" - "76732b4e7120808ff989edbd16401f61fa6a0afa --break-system-packages" - ) U.exec_command(f"hf download {cfg.model_name} --local-dir /root/models/{cfg.model_name.split('/')[-1]}") prompts = [ @@ -156,5 +151,3 @@ def test_session_server_tool_call(): for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): os.environ.pop(proxy_var, None) execute() - if MODEL_FAMILY == "glm47": - U.exec_command("pip install transformers==4.57.1 --break-system-packages") diff --git a/tests/e2e/sglang/test_tito_logprob_equivalence.py b/tests/e2e/sglang/test_tito_logprob_equivalence.py index 80f33bdf5a..426948227b 100644 --- a/tests/e2e/sglang/test_tito_logprob_equivalence.py +++ b/tests/e2e/sglang/test_tito_logprob_equivalence.py @@ -68,11 +68,6 @@ def _get_config() -> ModelConfig: def prepare(): cfg = _get_config() U.exec_command("mkdir -p /root/models /root/datasets") - if MODEL_FAMILY == "glm47": - U.exec_command( - "pip install git+https://github.com/huggingface/transformers.git@" - "76732b4e7120808ff989edbd16401f61fa6a0afa --break-system-packages" - ) U.exec_command(f"hf download {cfg.model_name} --local-dir /root/models/{cfg.model_name.split('/')[-1]}") prompts = [ @@ -169,5 +164,3 @@ def test_tito_logprob_equivalence(): for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): os.environ.pop(proxy_var, None) execute() - if MODEL_FAMILY == "glm47": - U.exec_command("pip install transformers==4.57.1 --break-system-packages") diff --git a/tests/e2e/sglang/utils/router_equivalence_generate.py b/tests/e2e/sglang/utils/router_equivalence_generate.py new file mode 100644 index 0000000000..7d576bb969 --- /dev/null +++ b/tests/e2e/sglang/utils/router_equivalence_generate.py @@ -0,0 +1,71 @@ +"""Custom generate function for router-equivalence e2e test. + +Wraps the stock ``single_turn.generate`` and, after each rollout, appends a +JSON record to ``$MILES_ROUTER_EQ_DUMP_PATH`` capturing the fields that +must match byte-for-byte between two runs using different routers: + +- ``tokens`` (the full input + output token ids) +- ``rollout_log_probs`` (per-output-token logprob) +- ``rollout_routed_experts`` (shape + base64-encoded int32 bytes) + +The dump is later loaded by ``test_r3_router_equivalence`` and diffed +between a ``--use-miles-router`` run and a sglang-router run. +""" + +import base64 +import json +import logging +import os +from pathlib import Path + +import numpy as np + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_hub.single_turn import generate as _base_generate +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + +_DUMP_PATH_ENV = "MILES_ROUTER_EQ_DUMP_PATH" + + +def _dump_sample(sample: Sample) -> dict: + re = sample.rollout_routed_experts + if re is not None: + arr = np.ascontiguousarray(re, dtype=np.int32) + experts_shape = list(arr.shape) + experts_b64 = base64.b64encode(arr.tobytes()).decode("ascii") + else: + experts_shape = None + experts_b64 = None + + return { + "index": sample.index, + "status": str(sample.status), + "response_length": sample.response_length, + "tokens": list(sample.tokens) if sample.tokens is not None else None, + "rollout_log_probs": list(sample.rollout_log_probs) if sample.rollout_log_probs is not None else None, + "rollout_routed_experts_shape": experts_shape, + "rollout_routed_experts_b64": experts_b64, + } + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + out = await _base_generate(input) + + dump_path = os.environ.get(_DUMP_PATH_ENV) + if not dump_path: + logger.warning("%s not set; not dumping", _DUMP_PATH_ENV) + return out + + samples = out.samples if isinstance(out.samples, list) else [out.samples] + + Path(dump_path).parent.mkdir(parents=True, exist_ok=True) + with open(dump_path, "a") as f: + for s in samples: + f.write(json.dumps(_dump_sample(s)) + "\n") + + return out + + +generate.add_arguments = getattr(_base_generate, "add_arguments", None) diff --git a/tests/e2e/short/test_dumper.py b/tests/e2e/short/test_dumper.py index c89eb526da..40471aa777 100644 --- a/tests/e2e/short/test_dumper.py +++ b/tests/e2e/short/test_dumper.py @@ -143,7 +143,7 @@ def _execute(perf_args: str, dump_subdir: str, dump_dir: str) -> None: "--attention-backend flash " f"--actor-num-nodes 1 --actor-num-gpus-per-node {NUM_GPUS} --colocate " "--moe-token-dispatcher-type alltoall " - "--use-miles-router --use-rollout-routing-replay " + "--use-rollout-routing-replay " ) train_args = " ".join( diff --git a/tests/fast/backends/megatron_utils/test_fp32_param_utils.py b/tests/fast/backends/megatron_utils/test_fp32_param_utils.py new file mode 100644 index 0000000000..c98b2c3e5d --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_fp32_param_utils.py @@ -0,0 +1,262 @@ +"""Tests for the A_log fp32 preservation chain. + +Feature: Qwen3.5's ``A_log`` must end up as fp32 in the Megatron parameter +after hf->mcore conversion, because the chunk-gated-delta-rule kernel relies +on that precision. Two complementary pieces keep this invariant: + +- Downstream — ``enforce_marked_param_dtypes`` (this module): + Megatron's ``Float16Module`` unconditionally casts every floating-point + parameter to bf16/fp16 at wrap time. There is no declarative opt-out in + nn.Module or Megatron; even Megatron's own MoE router uses the same + post-hoc ``.data = ...to(float32)`` pattern in + ``_maintain_float32_expert_bias``. We generalize that by letting model + definitions declare intent via ``mark_param_dtype`` and re-casting after + ``get_model`` returns. +- Upstream — ``Qwen3_5Bridge._weight_to_mcore_format``: + mbridge's base ``_weight_to_mcore_format`` pre-casts every HF tensor to + ``self.dtype`` (bf16) before TP scatter. For A_log that pre-cast rounds + the fp32 HF value. The override returns A_log as fp32 early, bypassing + that pre-cast entirely. + +The end-to-end test ties both halves together and checks bit-exact equality +with the HF fp32 source — this is the regression guard against the original +``patch_weight_to_mcore_format_preserve_fp32`` failure mode, where only the +upstream cast was intercepted and the downstream ``t.to(param.dtype)`` in +``Bridge.load_weights`` still demoted A_log back to bf16. +""" + +import pytest +import torch +import torch.nn as nn + +from miles.backends.megatron_utils.fp32_param_utils import ( + FORCED_PARAM_DTYPE_ATTR, + enforce_marked_param_dtypes, + mark_param_dtype, +) + + +# --------------------------------------------------------------------------- +# Downstream: mark_param_dtype + enforce_marked_param_dtypes +# --------------------------------------------------------------------------- + + +class _ToyModule(nn.Module): + """Minimal stand-in for Qwen3_5GatedDeltaNet: one marked fp32 param plus + one regular bf16-target param, so we can check the collateral damage + boundary of ``enforce_marked_param_dtypes``.""" + + def __init__(self, num_heads: int = 8): + super().__init__() + A = torch.empty(num_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) + mark_param_dtype(self.A_log, torch.float32) + self.in_proj = nn.Linear(16, num_heads, bias=False) + + +class TestMarkParamDtype: + def test_attaches_expected_attribute(self): + p = nn.Parameter(torch.zeros(4)) + mark_param_dtype(p, torch.float32) + assert getattr(p, FORCED_PARAM_DTYPE_ATTR) is torch.float32 + + def test_overwrites_previous_mark(self): + p = nn.Parameter(torch.zeros(4)) + mark_param_dtype(p, torch.float32) + mark_param_dtype(p, torch.float64) + assert getattr(p, FORCED_PARAM_DTYPE_ATTR) is torch.float64 + + +class TestEnforceMarkedParamDtypes: + def test_recasts_marked_param_back_to_fp32_after_float16_wrap(self): + """Simulates the full Megatron path: construct -> bfloat16() (what + ``Float16Module(...)`` does) -> enforce. A_log must come out fp32.""" + m = _ToyModule() + assert m.A_log.dtype == torch.float32 + + # Simulate Float16Module(config, m) — module.bfloat16() in the ctor + # demotes every floating param including the marked one. + m.bfloat16() + assert m.A_log.dtype == torch.bfloat16 + + enforce_marked_param_dtypes([m]) + assert m.A_log.dtype == torch.float32 + + def test_preserves_parameter_identity(self): + """Optimizer and DDP bucket parameters by Python identity, set up + AFTER ``enforce_marked_param_dtypes`` runs. If we re-bind via + ``self.A_log = nn.Parameter(...)`` the id changes and the optimizer + map breaks. We must only mutate ``.data``.""" + m = _ToyModule() + m.bfloat16() + before_id = id(m.A_log) + before_param_obj = m.A_log + + enforce_marked_param_dtypes([m]) + + assert id(m.A_log) == before_id + assert m.A_log is before_param_obj + + def test_leaves_unmarked_params_alone(self): + m = _ToyModule() + m.bfloat16() + assert m.in_proj.weight.dtype == torch.bfloat16 + + enforce_marked_param_dtypes([m]) + assert m.in_proj.weight.dtype == torch.bfloat16 + + def test_is_noop_when_already_target_dtype(self): + """Idempotency — second call must not re-allocate or change anything. + Guards against accidental double-work when the hook is called on + both the training and conversion entrypoints in the same process.""" + m = _ToyModule() + m.bfloat16() + enforce_marked_param_dtypes([m]) + + data_before = m.A_log.data + updated = enforce_marked_param_dtypes([m]) + assert m.A_log.dtype == torch.float32 + # ``.data`` should be the same tensor object (no unnecessary realloc). + assert m.A_log.data.data_ptr() == data_before.data_ptr() + # Name is still reported even on the no-realloc path — this is by + # design so the rank-0 log line reflects policy coverage, not churn. + assert any(n.endswith("A_log") for n in updated) + + def test_walks_multiple_model_chunks(self): + """``setup_model_and_optimizer`` passes a list of model chunks (for + virtual pipeline parallelism). The helper must iterate all of them.""" + chunks = [_ToyModule(), _ToyModule()] + for c in chunks: + c.bfloat16() + + enforce_marked_param_dtypes(chunks) + for c in chunks: + assert c.A_log.dtype == torch.float32 + + def test_returns_empty_when_no_marks(self): + m = nn.Linear(4, 4) + m.bfloat16() + assert enforce_marked_param_dtypes([m]) == [] + + +# --------------------------------------------------------------------------- +# Upstream: Qwen3_5Bridge._weight_to_mcore_format +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def bridge_stub(): + """Build a ``Qwen3_5Bridge`` without invoking ``__init__`` — ``__init__`` + needs a real HF config. The A_log branch only reads ``self.dtype``, which + we set directly, so skipping init is safe and lets this test stay + CPU-only and dep-free.""" + pytest.importorskip("mbridge") + from miles_plugins.mbridge.qwen3_5 import Qwen3_5Bridge + + bridge = Qwen3_5Bridge.__new__(Qwen3_5Bridge) + return bridge + + +class TestQwen3_5BridgeALogOverride: + A_LOG_NAME = "decoder.layers.0.self_attention.linear_attn.A_log" + + def test_returns_fp32_when_bridge_dtype_is_bf16(self, bridge_stub): + """The override must bypass mbridge's ``w.to(self.dtype)`` pre-cast + that would otherwise round HF fp32 to bf16 here.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.float32) + + out = bridge_stub._weight_to_mcore_format(self.A_LOG_NAME, [hf_tensor]) + + assert out.dtype == torch.float32 + assert torch.equal(out, hf_tensor) + assert out.is_contiguous() + + def test_upcasts_when_hf_input_is_bf16(self, bridge_stub): + """A_log arriving as bf16 (non-canonical ckpt) is still forced to + fp32 — the invariant is the output dtype, not the input's.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.bfloat16) + + out = bridge_stub._weight_to_mcore_format(self.A_LOG_NAME, [hf_tensor]) + + assert out.dtype == torch.float32 + + def test_mtp_layer_a_log_also_matches(self, bridge_stub): + """The override uses ``endswith`` so MTP-layer A_log + (``mtp.layers.{idx}...``) also matches — MTP is a real Qwen3.5 + variant and must not silently skip the override.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.float32) + + out = bridge_stub._weight_to_mcore_format("mtp.layers.0.self_attention.linear_attn.A_log", [hf_tensor]) + assert out.dtype == torch.float32 + + +# --------------------------------------------------------------------------- +# End-to-end: the two halves together, matching ``Bridge.load_weights``. +# --------------------------------------------------------------------------- + + +class TestALogLoadPathEndToEnd: + """Replays the dtype-relevant subset of ``Bridge.load_weights`` on a toy + model, as documented in ``tools/debug_a_log_old_flow.py``. No distributed + or real safetensor IO — only the two cast points we care about. + + Expected outcome: HF fp32 value lands in the Megatron A_log param + bit-exactly. Regression target: the OLD ``patch_weight_to_mcore_format_preserve_fp32`` + failed here because ``bridge.py:246`` still cast down to ``param.dtype == bf16``. + """ + + def test_lossless_roundtrip(self, bridge_stub): + a_log_name = "decoder.layers.0.self_attention.linear_attn.A_log" + hf_tensor = torch.randn(32, dtype=torch.float32) + + # 1. Build model (A_log marked fp32 at definition site). + model = _ToyModule(num_heads=32) + + # 2. Megatron wraps with Float16Module → .bfloat16(). + model.bfloat16() + + # 3. enforce_marked_param_dtypes restores A_log to fp32 BEFORE + # load_weights runs, so ``param.dtype`` at bridge.py:246 is fp32. + enforce_marked_param_dtypes([model]) + assert model.A_log.dtype == torch.float32 + + # 4. mbridge: _weight_to_mcore_format (with override → fp32). + bridge_stub.dtype = torch.bfloat16 # would demote without override + mcore_weight = bridge_stub._weight_to_mcore_format(a_log_name, [hf_tensor]) + assert mcore_weight.dtype == torch.float32 + + # 5. mbridge bridge.py:246 — ``t.to(param.device, dtype=param.dtype)``. + param = model.A_log + staged = mcore_weight.to(param.device, dtype=param.dtype).contiguous() + assert staged.dtype == torch.float32 # no-op cast + + # 6. mbridge bridge.py:258 — ``param.copy_(param_to_load)``. + param.data.copy_(staged) + + # Bit-exact round-trip: both halves were required to get here. + assert model.A_log.dtype == torch.float32 + assert torch.equal(model.A_log.data, hf_tensor) + + def test_old_patch_only_regresses_without_enforce(self, bridge_stub): + """Negative control: if we DROP ``enforce_marked_param_dtypes`` and + only keep the upstream override (the shape of the old patch), the + downstream ``t.to(param.dtype)`` still rounds to bf16. This pins the + old failure mode so it cannot be re-introduced by accident.""" + a_log_name = "decoder.layers.0.self_attention.linear_attn.A_log" + # Use a value where bf16 rounding is observable. + hf_tensor = torch.tensor([0.970378123] * 8, dtype=torch.float32) + + model = _ToyModule(num_heads=8) + model.bfloat16() # A_log is bf16; no enforce call here on purpose. + + bridge_stub.dtype = torch.bfloat16 + mcore_weight = bridge_stub._weight_to_mcore_format(a_log_name, [hf_tensor]) + assert mcore_weight.dtype == torch.float32 + + staged = mcore_weight.to(model.A_log.device, dtype=model.A_log.dtype).contiguous() + # Regression check: demoted to bf16 because param.dtype is bf16. + assert staged.dtype == torch.bfloat16 + assert not torch.equal(staged.to(torch.float32), hf_tensor) diff --git a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py index a72ca582e5..81748bca75 100644 --- a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py +++ b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py @@ -56,6 +56,7 @@ def _make_args(**overrides): update_weight_buffer_size=1 << 30, actor_num_nodes=1, actor_num_gpus_per_node=1, + pause_generation_mode="retract", ) defaults.update(overrides) return Namespace(**defaults) @@ -210,11 +211,14 @@ def test_raises_on_zero_lora_chunks(self, mock_iter_base, mock_dist, mock_ray, m with pytest.raises(RuntimeError, match="zero chunks"): updater.update_weights() + @patch("miles.backends.megatron_utils.update_weight.common.ray") @patch(f"{_UW_MODULE}.get_gloo_group", return_value=MagicMock()) @patch(f"{_UW_MODULE}.ray") @patch(f"{_UW_MODULE}.dist") @patch(f"{_UW_MODULE}.HfWeightIteratorBase") - def test_no_raise_for_base_model_zero_chunks(self, mock_iter_base, mock_dist, mock_ray, mock_gloo): + def test_no_raise_for_base_model_zero_chunks( + self, mock_iter_base, mock_dist, mock_ray, mock_gloo, mock_common_ray + ): """Base model weight sync with zero chunks is valid (e.g. empty model state).""" from miles.backends.megatron_utils.update_weight.update_weight_from_tensor import UpdateWeightFromTensor diff --git a/tests/fast/backends/training_utils/__init__.py b/tests/fast/backends/training_utils/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/fast/backends/training_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py index 91e0467e97..7e70dcd187 100644 --- a/tests/fast/fixtures/generation_fixtures.py +++ b/tests/fast/fixtures/generation_fixtures.py @@ -2,6 +2,7 @@ Fixtures to test custom-generate-function """ +import uuid from argparse import Namespace from contextlib import contextmanager from dataclasses import dataclass @@ -160,6 +161,7 @@ def make_args( "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "1", "--num-rollout", @@ -235,6 +237,7 @@ def with_session_server( chat_template_path=chat_template_path, tito_model="default", use_rollout_routing_replay=use_rollout_routing_replay, + session_server_instance_id=uuid.uuid4().hex, ) session_server = SessionServer(args, backend_url=backend_url) diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py index b54c7b9a51..90bfdd197d 100644 --- a/tests/fast/fixtures/rollout_fixtures.py +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -43,6 +43,7 @@ def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | Non "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "1", "--n-samples-per-prompt", diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py index 95fe9c3f8d..38145f93d2 100644 --- a/tests/fast/rollout/generate_hub/test_multi_turn.py +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -1,3 +1,4 @@ +import re from copy import deepcopy from dataclasses import dataclass, replace from itertools import groupby @@ -6,8 +7,8 @@ import pybase64 import pytest from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate -from transformers import AutoTokenizer +from miles.utils.processing_utils import load_tokenizer from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub from miles.utils.types import Sample @@ -24,7 +25,7 @@ def is_agentic_variant(variant: str) -> bool: MODEL_NAME = "Qwen/Qwen3-0.6B" DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} -TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +TOKENIZER = load_tokenizer(MODEL_NAME, trust_remote_code=True) @pytest.fixture( @@ -667,6 +668,29 @@ def test_agent_returns_none_metadata_unchanged(self, variant, generation_env): assert s.metadata.get("instance_id") == "test-123" assert "reward" not in s.metadata + def test_session_server_identity_forwarded_to_agent_metadata(self, variant, generation_env): + from miles.utils.test_utils import mock_tools + + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + _SESSION_KEYS = ("session_server_id", "session_server_instance_id") + + def _echo_session(metadata=None): + metadata = metadata or {} + return {k: metadata[k] for k in _SESSION_KEYS if k in metadata} + + mock_tools.AGENTIC_RETURN_METADATA = _echo_session + try: + result = _run_generate(variant, generation_env, make_sample(prompt=TwoTurnStub.PROMPT)) + finally: + mock_tools.AGENTIC_RETURN_METADATA = None + + samples = listify(result.sample) + expected_session_server_id = f"127.0.0.1:{generation_env.args.session_server_port}" + for s in samples: + assert s.metadata["session_server_id"] == expected_session_server_id + assert re.fullmatch(r"[0-9a-f]{32}", s.metadata["session_server_instance_id"]) + class TestAgentNoRecords: """When agent makes no model calls, generate should return an ABORTED sample.""" diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py index 0f2305e753..6aa1d6606b 100644 --- a/tests/fast/rollout/generate_hub/test_tool_call_utils.py +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -1,30 +1,46 @@ import pytest from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses +from miles.utils.processing_utils import load_tokenizer TOOL_CALL_TEST_MODELS = [ "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen3-0.6B", "Qwen/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-Coder-30B-A3B-Instruct", + "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3-Coder-Next", # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI "mistralai/Mistral-7B-Instruct-v0.3", - "deepseek-ai/DeepSeek-V3", - "stepfun-ai/step3", "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", "internlm/internlm3-8b-instruct", - "THUDM/glm-4-9b-chat", + "zai-org/GLM-4.7-Flash", + "stepfun-ai/Step-3.5-Flash", "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2.5", "XiaomiMiMo/MiMo-7B-RL", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", ] -SINGLE_TOOL_CALL_ONLY_MODELS = [ - # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +# Models that fail decode round-trip under transformers>=5.x due to upstream tokenizer issues. +# These are excluded from TOOL_CALL_TEST_MODELS but listed here for tracking. +# - DeepSeek-V3, step3: transformers v5 unified LlamaTokenizer overwrites their ByteLevel +# pre_tokenizer/decoder with Metaspace, causing decode(encode(text)) != text. +# See https://github.com/huggingface/transformers/issues/43066 +# - DeepSeek-V3.1: its tool-call chat template concatenates function.arguments as a string, +# but our dummy tool-call shape provides a dict, raising TypeError before the round-trip check. +# - glm-4-9b-chat: v5 removed the legacy _decode special-token segmentation, exposing a bug in +# the model's custom convert_tokens_to_string (doesn't handle str-type special tokens). +TOOL_CALL_KNOWN_FAILURES = [ + "deepseek-ai/DeepSeek-V3", + "deepseek-ai/DeepSeek-V3.1", + "stepfun-ai/step3", + "THUDM/glm-4-9b-chat", ] -# Models where tokenize->decode produces extra whitespace vs direct string diff -TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ - "THUDM/glm-4-9b-chat", +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo ] SAMPLE_TOOL_RESPONSES = [ @@ -46,9 +62,7 @@ class TestTokenizeToolResponses: @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) def test_snapshot(self, model_name): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer = load_tokenizer(model_name, trust_remote_code=True) token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) decoded = tokenizer.decode(token_ids) @@ -69,9 +83,7 @@ def test_tokenize_tool_responses(self, model_name, num_tools): if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: pytest.skip(f"{model_name} only supports single tool call") - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer = load_tokenizer(model_name, trust_remote_code=True) tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] assert len(tool_responses) == num_tools @@ -83,11 +95,6 @@ def test_tokenize_tool_responses(self, model_name, num_tools): base_messages = [_DUMMY_USER, dummy_assistant] expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) - if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: - # Some models produce whitespace differences between tokenize->decode and direct string diff - actual_str = actual_str.replace(" ", "") - expected_str = expected_str.replace(" ", "") - assert actual_str == expected_str, f"{model_name=}" @staticmethod diff --git a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py index 2128791b3d..6d90fdd38b 100644 --- a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py +++ b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py @@ -9,7 +9,10 @@ import pytest -from miles.rollout.generate_utils.openai_endpoint_utils import compute_samples_from_openai_records +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) from miles.rollout.generate_utils.sample_utils import merge_samples from miles.rollout.session.session_types import SessionRecord from miles.utils.types import Sample @@ -46,6 +49,8 @@ def _make_record( output_token_ids: list[int], output_log_probs: list[float] | None = None, finish_reason: str = "stop", + cached_tokens: int | None = None, + prompt_tokens: int | None = None, ) -> SessionRecord: """Build a minimal session record mimicking SGLang's response format. @@ -59,6 +64,14 @@ def _make_record( logprobs_content = [ {"logprob": lp, "token": f"t{tid}"} for tid, lp in zip(output_token_ids, output_log_probs, strict=True) ] + meta_info = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_ids), + } + if cached_tokens is not None: + meta_info["cached_tokens"] = cached_tokens + if prompt_tokens is not None: + meta_info["prompt_tokens"] = prompt_tokens return SessionRecord( timestamp=0.0, method="POST", @@ -72,16 +85,40 @@ def _make_record( "message": {"role": "assistant", "content": "response"}, "finish_reason": finish_reason, "logprobs": {"content": logprobs_content}, - "meta_info": { - "output_token_logprobs": output_token_logprobs, - "completion_tokens": len(output_token_ids), - }, + "meta_info": meta_info, } ] }, ) +@pytest.mark.asyncio +async def test_create_fetches_session_server_instance_id(monkeypatch): + calls: list[tuple[str, str]] = [] + + async def fake_post(url: str, payload: dict, action: str = "post"): + calls.append((action, url)) + if action == "get": + assert url == "http://127.0.0.1:12345/health" + return {"status": "ok", "session_server_instance_id": "server-instance-123"} + assert action == "post" + assert url == "http://127.0.0.1:12345/sessions" + return {"session_id": "session-123"} + + monkeypatch.setattr("miles.rollout.generate_utils.openai_endpoint_utils.post", fake_post) + + args = SimpleNamespace(session_server_ip="127.0.0.1", session_server_port=12345) + tracer = await OpenAIEndpointTracer.create(args) + + assert tracer.base_url == "http://127.0.0.1:12345/sessions/session-123" + assert tracer.session_server_instance_id == "server-instance-123" + assert args.session_server_instance_id == "server-instance-123" + assert calls == [ + ("get", "http://127.0.0.1:12345/health"), + ("post", "http://127.0.0.1:12345/sessions"), + ] + + # ── test: compute_samples_from_openai_records ──────────────────────── @@ -527,3 +564,65 @@ def test_no_thinking_tokens_prefix_chain_holds(self): merged = merge_samples(samples, tok) assert merged.tokens == [1, 2, 3, 10, 11, 20, 21, 30, 31] + + +# ── test: prefix cache info population ──────────────────────────────── + + +class TestPrefixCacheInfo: + """Validate that prefix cache statistics from meta_info are collected.""" + + def test_single_record_with_cache_stats(self): + """cached_tokens and prompt_tokens from meta_info populate prefix_cache_info.""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + cached_tokens=2, + prompt_tokens=3, + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 2 + assert samples[0].prefix_cache_info.total_prompt_tokens == 3 + + def test_multi_turn_cache_stats_accumulate_after_merge(self): + """After merge_samples, prefix_cache_info sums across turns.""" + tok = _mock_tokenizer() + records = [ + _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + output_log_probs=[-0.1, -0.2], + cached_tokens=0, + prompt_tokens=3, + ), + _make_record( + prompt_token_ids=[1, 2, 3, 10, 11, 20, 21], + output_token_ids=[30, 31], + output_log_probs=[-0.3, -0.4], + cached_tokens=5, + prompt_tokens=7, + ), + ] + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, records, tok) + merged = merge_samples(samples, tok) + + assert merged.prefix_cache_info.cached_tokens == 0 + 5 + assert merged.prefix_cache_info.total_prompt_tokens == 3 + 7 + assert merged.prefix_cache_info.prefix_cache_hit_rate == 5 / 10 + + def test_missing_cache_fields_default_to_zero(self): + """Records without cached_tokens/prompt_tokens give zero prefix_cache_info (regression).""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 0 + assert samples[0].prefix_cache_info.total_prompt_tokens == 0 diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py index ca47edeeb6..d848ef0b2a 100644 --- a/tests/fast/rollout/inference_rollout/conftest.py +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -10,6 +10,7 @@ def _build_mock_args(extra_argv: list[str] | None = None): "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "2", "--n-samples-per-prompt", diff --git a/tests/fast/router/test_session_pretokenized_e2e.py b/tests/fast/router/test_session_pretokenized_e2e.py index e5b7ed0419..68917a68f1 100644 --- a/tests/fast/router/test_session_pretokenized_e2e.py +++ b/tests/fast/router/test_session_pretokenized_e2e.py @@ -65,7 +65,10 @@ class ModelTemplateConfig: "Qwen/Qwen3-4B-Thinking-2507", try_get_fixed_chat_template("Qwen/Qwen3-4B-Thinking-2507"), ), - "qwen3.5-native": ModelTemplateConfig("Qwen/Qwen3.5-0.8B", None), + "qwen3.5-fixed": ModelTemplateConfig( + "Qwen/Qwen3.5-0.8B", + try_get_fixed_chat_template("Qwen/Qwen3.5-0.8B"), + ), "qwen3-next-instruct-native": ModelTemplateConfig("Qwen/Qwen3-Next-80B-A3B-Instruct", None), "qwen3-next-thinking-fixed": ModelTemplateConfig( "Qwen/Qwen3-Next-80B-A3B-Thinking", diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py index 23fc683647..8dd58189e1 100644 --- a/tests/fast/router/test_sessions.py +++ b/tests/fast/router/test_sessions.py @@ -1,5 +1,7 @@ """Integration tests for session HTTP routes (create / get / delete / proxy).""" +import re +import uuid from types import SimpleNamespace from unittest.mock import patch @@ -41,6 +43,7 @@ def patched_chat_response(self, payload: dict) -> dict: hf_checkpoint="Qwen/Qwen3-0.6B", chat_template_path=None, trajectory_manager="linear_trajectory", + session_server_instance_id=uuid.uuid4().hex, ) server_obj = SessionServer(args, backend_url=backend.url) @@ -57,6 +60,19 @@ def patched_chat_response(self, payload: dict) -> dict: class TestSessionRoutes: + def test_health_reports_stable_instance_id(self, router_env): + first = requests.get(f"{router_env.url}/health", timeout=5.0) + second = requests.get(f"{router_env.url}/health", timeout=5.0) + + assert first.status_code == 200 + assert second.status_code == 200 + first_body = first.json() + second_body = second.json() + assert first_body["status"] == "ok" + assert second_body["status"] == "ok" + assert re.fullmatch(r"[0-9a-f]{32}", first_body["session_server_instance_id"]) + assert second_body["session_server_instance_id"] == first_body["session_server_instance_id"] + def test_create_session(self, router_env): response = requests.post(f"{router_env.url}/sessions", timeout=5.0) assert response.status_code == 200 diff --git a/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py b/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py index 78177ee63b..d37142a6c0 100644 --- a/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py +++ b/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py @@ -13,7 +13,11 @@ from miles.utils.chat_template_utils.autofix import try_get_fixed_chat_template from miles.utils.chat_template_utils.template import load_hf_chat_template -from miles.utils.test_utils.chat_template_verify import assert_pretokenized_equals_standard, simulate_pretokenized_path +from miles.utils.test_utils.chat_template_verify import ( + assert_pretokenized_equals_standard, + simulate_pretokenized_path, + verify_append_only, +) from miles.utils.test_utils.mock_trajectories import ( MultiTurnTrajectory, MultiUserTurnThinkingTrajectory, @@ -70,21 +74,15 @@ def _load_fixed(hf_id: str) -> str: ) -def _to_pytest_params(cases, include_tools=True): +def _to_pytest_params(cases): """Convert (name, cls, n, tools) tuples to pytest.param list.""" - params = [] - for name, cls, n, tools in cases: - if include_tools: - params.append(pytest.param(cls, n, tools, id=name)) - else: - params.append(pytest.param(cls, n, tools, id=name)) - return params + return [pytest.param(cls, n, tools, id=name) for name, cls, n, tools in cases] -_STANDARD_CASES = _to_pytest_params(STANDARD_CASES) -_THINKING_CASES = _to_pytest_params(THINKING_CASES) -_INTERMEDIATE_SYSTEM_CASES = _to_pytest_params(INTERMEDIATE_SYSTEM_CASES) -_INTERMEDIATE_SYSTEM_THINKING_CASES = _to_pytest_params(INTERMEDIATE_SYSTEM_THINKING_CASES) +_STANDARD_PARAMS = _to_pytest_params(STANDARD_CASES) +_THINKING_PARAMS = _to_pytest_params(THINKING_CASES) +_INTERMEDIATE_SYSTEM_PARAMS = _to_pytest_params(INTERMEDIATE_SYSTEM_CASES) +_INTERMEDIATE_SYSTEM_THINKING_PARAMS = _to_pytest_params(INTERMEDIATE_SYSTEM_THINKING_CASES) # (chat_template, trajectory_cls, pretokenize_n) — original templates that break prefix invariant _MISMATCH_CASES = [ @@ -101,11 +99,53 @@ def _to_pytest_params(cases, include_tools=True): ), ] -# Template parametrization lists -all_template_ids = list(ALL_TEMPLATES.keys()) -all_template_values = list(ALL_TEMPLATES.values()) -thinking_template_ids = list(TEMPLATES_WITH_THINKING.keys()) -thinking_template_values = list(TEMPLATES_WITH_THINKING.values()) + +def _template_params(templates: dict[str, str]) -> list: + """Convert a {name: template_str} dict to a list of pytest.param(template_str, id=name).""" + return [pytest.param(v, id=k) for k, v in templates.items()] + + +# Intermediate-system compatibility: only qwen3.5_fixed is known to reject them. +# test_intermediate_system_probe_matrix locks this set against drift. +_INTERMEDIATE_SYSTEM_FORBIDDEN = {"qwen3.5_fixed"} +_INTERMEDIATE_SYSTEM_TEMPLATES = {k: v for k, v in ALL_TEMPLATES.items() if k not in _INTERMEDIATE_SYSTEM_FORBIDDEN} +_INTERMEDIATE_SYSTEM_THINKING_TEMPLATES = { + k: v for k, v in TEMPLATES_WITH_THINKING.items() if k not in _INTERMEDIATE_SYSTEM_FORBIDDEN +} + + +def _collect_intermediate_system_failures(template_id: str, chat_template: str) -> list[str]: + failures: list[str] = [] + for case_name, traj_cls, n, tools in INTERMEDIATE_SYSTEM_CASES: + result = verify_append_only(chat_template, deepcopy(traj_cls.MESSAGES), n, tools=tools, case_name=case_name) + if not result.passed: + failures.append(f"{case_name}: {result.error}") + + if template_id in TEMPLATES_WITH_THINKING: + for enable in (True, False): + suffix = "thinking_on" if enable else "thinking_off" + for case_name, traj_cls, n, tools in INTERMEDIATE_SYSTEM_THINKING_CASES: + full_case_name = f"{case_name}[{suffix}]" + result = verify_append_only( + chat_template, + deepcopy(traj_cls.MESSAGES), + n, + tools=tools, + case_name=full_case_name, + enable_thinking=enable, + ) + if not result.passed: + failures.append(f"{full_case_name}: {result.error}") + + return failures + + +def _format_failure_map(failure_map: dict[str, list[str]]) -> str: + lines: list[str] = [] + for template_id in sorted(failure_map): + lines.append(f"{template_id}:") + lines.extend(f" - {item}" for item in failure_map[template_id]) + return "\n".join(lines) # =========================================================================== @@ -113,8 +153,8 @@ def _to_pytest_params(cases, include_tools=True): # =========================================================================== -@pytest.mark.parametrize("chat_template", all_template_values, ids=all_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _STANDARD_CASES) +@pytest.mark.parametrize("chat_template", _template_params(ALL_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _STANDARD_PARAMS) def test_pretokenized_equals_standard(chat_template, trajectory_cls, pretokenize_n, tools): """Pretokenized incremental path produces same text as standard full render.""" assert_pretokenized_equals_standard( @@ -130,8 +170,8 @@ def test_pretokenized_equals_standard(chat_template, trajectory_cls, pretokenize # =========================================================================== -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _THINKING_CASES) +@pytest.mark.parametrize("chat_template", _template_params(TEMPLATES_WITH_THINKING)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _THINKING_PARAMS) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_pretokenized_thinking(chat_template, trajectory_cls, pretokenize_n, tools, enable_thinking): """Thinking-capable templates work with pretokenized path and enable_thinking flag.""" @@ -149,10 +189,29 @@ def test_pretokenized_thinking(chat_template, trajectory_cls, pretokenize_n, too # =========================================================================== -@pytest.mark.parametrize("chat_template", all_template_values, ids=all_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_CASES) +def test_intermediate_system_probe_matrix(): + """Probe ALL_TEMPLATES and lock the allow/forbid intermediate-system matrix.""" + failure_map: dict[str, list[str]] = {} + for template_id, chat_template in ALL_TEMPLATES.items(): + failures = _collect_intermediate_system_failures(template_id, chat_template) + if failures: + failure_map[template_id] = failures + + detected_forbidden = set(failure_map.keys()) + assert detected_forbidden == _INTERMEDIATE_SYSTEM_FORBIDDEN, ( + f"Intermediate-system forbidden set changed.\n" + f"expected={sorted(_INTERMEDIATE_SYSTEM_FORBIDDEN)}\n" + f"detected={sorted(detected_forbidden)}\n" + f"{_format_failure_map(failure_map)}" + ) + qwen35_failures = failure_map.get("qwen3.5_fixed", []) + assert any("System message must be at the beginning." in failure for failure in qwen35_failures), qwen35_failures + + +@pytest.mark.parametrize("chat_template", _template_params(_INTERMEDIATE_SYSTEM_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_PARAMS) def test_pretokenized_intermediate_system(chat_template, trajectory_cls, pretokenize_n, tools): - """All templates support intermediate system messages (converted to user role in fixed templates).""" + """Templates in the allowlist support intermediate system messages.""" assert_pretokenized_equals_standard( chat_template=chat_template, messages=deepcopy(trajectory_cls.MESSAGES), @@ -161,13 +220,13 @@ def test_pretokenized_intermediate_system(chat_template, trajectory_cls, pretoke ) -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_THINKING_CASES) +@pytest.mark.parametrize("chat_template", _template_params(_INTERMEDIATE_SYSTEM_THINKING_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_THINKING_PARAMS) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_pretokenized_intermediate_system_thinking( chat_template, trajectory_cls, pretokenize_n, tools, enable_thinking ): - """Thinking templates support intermediate system messages with thinking.""" + """Thinking templates in the allowlist support intermediate system messages.""" assert_pretokenized_equals_standard( chat_template=chat_template, messages=deepcopy(trajectory_cls.MESSAGES), @@ -204,7 +263,7 @@ def test_original_template_prefix_mismatch(chat_template, trajectory_cls, pretok _CROSS_USER_THINKING_N = last_user_index(MultiUserTurnThinkingTrajectory.MESSAGES) -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) +@pytest.mark.parametrize("chat_template", _template_params(TEMPLATES_WITH_THINKING)) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_cross_user_turn_thinking_prefix_mismatch(chat_template, enable_thinking): """Thinking templates compress reasoning_content from earlier user turns, breaking prefix invariant.""" diff --git a/tests/fast/utils/chat_template_utils/test_template.py b/tests/fast/utils/chat_template_utils/test_template.py index 225a9bf178..5ba907f2a9 100644 --- a/tests/fast/utils/chat_template_utils/test_template.py +++ b/tests/fast/utils/chat_template_utils/test_template.py @@ -60,6 +60,7 @@ def _make_serving(tokenizer) -> OpenAIServingChat: serving.use_dpsk_v32_encoding = False serving.is_gpt_oss = False serving.tool_call_parser = None + serving.reasoning_parser = None return serving @@ -118,6 +119,10 @@ def tokenizer(request) -> AutoTokenizer: # Trajectory / kwargs definitions # --------------------------------------------------------------------------- +_NO_INTERMEDIATE_SYSTEM_MODELS = { + "Qwen/Qwen3.5-4B", +} + _STANDARD_CASES = [ pytest.param(SingleToolTrajectory, {}, id="single_tool"), pytest.param(MultiTurnTrajectory, {}, id="multi_turn"), @@ -128,7 +133,7 @@ def tokenizer(request) -> AutoTokenizer: pytest.param(MultiTurnNoToolTrajectory, {}, id="multi_turn_no_tool"), ] -# Trajectories with intermediate system messages (Qwen3.5 uses fixed template). +# Trajectories with intermediate system messages. _INTERMEDIATE_SYSTEM_CASES = [ pytest.param(RetrySystemTrajectory, {}, id="retry_system"), pytest.param(IntermediateSystemTrajectory, {}, id="intermediate_system"), @@ -180,6 +185,8 @@ def test_standard(self, tokenizer, traj_cls, kwargs): @pytest.mark.parametrize("traj_cls, kwargs", _INTERMEDIATE_SYSTEM_CASES) def test_intermediate_system(self, tokenizer, traj_cls, kwargs): + if tokenizer.name_or_path in _NO_INTERMEDIATE_SYSTEM_MODELS: + pytest.skip(f"{tokenizer.name_or_path} intentionally forbids intermediate system messages") _assert_aligned(tokenizer, traj_cls, kwargs) @pytest.mark.parametrize("traj_cls, kwargs", _THINKING_CASES) @@ -188,6 +195,8 @@ def test_thinking(self, tokenizer, traj_cls, kwargs): @pytest.mark.parametrize("traj_cls, kwargs", _INTERMEDIATE_SYSTEM_THINKING_CASES) def test_intermediate_system_thinking(self, tokenizer, traj_cls, kwargs): + if tokenizer.name_or_path in _NO_INTERMEDIATE_SYSTEM_MODELS: + pytest.skip(f"{tokenizer.name_or_path} intentionally forbids intermediate system messages") _assert_aligned(tokenizer, traj_cls, kwargs) def test_json_string_arguments(self, tokenizer): diff --git a/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py b/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py index 1321260525..4172a089f6 100644 --- a/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py +++ b/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py @@ -26,16 +26,18 @@ - Default: plain concatenation (no boundary handling). TestTokenizeAdditional - Behavioral tests for tokenize_additional_non_assistant — the dummy-prefix - diff that computes incremental token IDs for appended non-assistant messages. + Behavioral tests for tokenize_additional_non_assistant — the role-segmented + synthetic-prefix diff that computes incremental token IDs for appended + non-assistant messages. ``test_produces_nonempty_incremental`` is parametrized over: _TOOL_TRAJECTORIES (trajectory classes) × _TITO_MODELS (qwen3, glm47) Split points are auto-detected by _find_tito_splits from message structure, so adding a trajectory to _TOOL_TRAJECTORIES automatically extends coverage. - Remaining tests verify append-only validation (reject prefix mutation, - fewer messages, or forbidden roles like assistant). + Remaining tests cover segmentation logic, generation-prompt timing, + reasoning-content shape, merge structure preservation, and append-only + validation (reject prefix mutation, fewer messages, or forbidden roles). TestFactory get_tito_tokenizer factory: string/enum dispatch, invalid input handling. @@ -43,16 +45,21 @@ from __future__ import annotations +from pathlib import Path + import pytest from transformers import AutoTokenizer +from miles.utils.chat_template_utils import MismatchType, apply_chat_template, try_get_fixed_chat_template from miles.utils.chat_template_utils.tito_tokenizer import ( GLM47TITOTokenizer, Qwen3TITOTokenizer, TITOTokenizer, TITOTokenizerType, + _build_dummy_assistant, get_tito_tokenizer, ) +from miles.utils.processing_utils import load_tokenizer from miles.utils.test_utils.mock_trajectories import ( IntermediateSystemTrajectory, LongChainTrajectory, @@ -68,13 +75,19 @@ # Tokenizer cache # --------------------------------------------------------------------------- -_TOK_CACHE: dict[str, AutoTokenizer] = {} +_TOK_CACHE: dict[tuple[str, str | None], AutoTokenizer] = {} def _get_tokenizer(model_id: str) -> AutoTokenizer: - if model_id not in _TOK_CACHE: - _TOK_CACHE[model_id] = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - return _TOK_CACHE[model_id] + chat_template_path = try_get_fixed_chat_template(model_id) + cache_key = (model_id, chat_template_path) + if cache_key not in _TOK_CACHE: + _TOK_CACHE[cache_key] = load_tokenizer( + model_id, + chat_template_path=chat_template_path, + trust_remote_code=True, + ) + return _TOK_CACHE[cache_key] # --------------------------------------------------------------------------- @@ -91,32 +104,28 @@ def _get_tokenizer(model_id: str) -> AutoTokenizer: } -# TODO: "user" is intentionally excluded — the dummy-prefix diff in -# tokenize_additional_non_assistant assumes appended messages don't change how -# earlier turns render, which breaks for user messages on context-sensitive -# templates (e.g. Qwen3's last_query_index). Only tool and system are safe. -_TOOL_AND_SYSTEM = ["tool", "system"] +_ALLOWED_APPEND_ROLES = ["tool", "user", "system"] @pytest.fixture(params=list(_TITO_MODELS.keys())) def tito(request) -> TITOTokenizer: model_id, cls = _TITO_MODELS[request.param] - return cls(_get_tokenizer(model_id), allowed_append_roles=_TOOL_AND_SYSTEM) + return cls(_get_tokenizer(model_id), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def qwen3_tito() -> Qwen3TITOTokenizer: - return Qwen3TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_TOOL_AND_SYSTEM) + return Qwen3TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def glm47_tito() -> GLM47TITOTokenizer: - return GLM47TITOTokenizer(_get_tokenizer("zai-org/GLM-4.7-Flash"), allowed_append_roles=_TOOL_AND_SYSTEM) + return GLM47TITOTokenizer(_get_tokenizer("zai-org/GLM-4.7-Flash"), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def default_tito() -> TITOTokenizer: - return TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_TOOL_AND_SYSTEM) + return TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_ALLOWED_APPEND_ROLES) # --------------------------------------------------------------------------- @@ -156,8 +165,8 @@ def _split_at(traj_cls, pos: int): """Split trajectory at *pos* into ``(old_msgs, new_msgs, tools)``. ``old_msgs = messages[:pos]`` — the pretokenized prefix (ends with assistant turn). - ``new_msgs`` extends through all subsequent non-assistant messages (tool/system), - stopping before the next assistant turn. + ``new_msgs`` extends through all subsequent non-assistant messages + (tool/user/system), stopping before the next assistant turn. """ msgs = traj_cls.MESSAGES end = pos @@ -281,7 +290,7 @@ def test_empty_prefix(self, qwen3_tito: Qwen3TITOTokenizer): # --------------------------------------------------------------------------- -# TestTokenizeAdditional — incremental tokenization via dummy-prefix diff +# TestTokenizeAdditional — incremental tokenization via role-segmented synthetic diff # # test_produces_nonempty_incremental is the scalable core: parametrized over # _TRAJ_CASES (trajectories × split points) × tito fixture (models). @@ -306,6 +315,108 @@ def test_produces_nonempty_incremental(self, tito: TITOTokenizer, traj_cls, pos) incremental = tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) assert len(incremental) > 0 + def test_contiguous_tool_segment_is_tokenized_together(self, qwen3_tito: Qwen3TITOTokenizer): + old_msgs, new_msgs, tools = _split_at(MultiToolSingleTurnTrajectory, 3) + appended = new_msgs[len(old_msgs) :] + + segments = qwen3_tito._split_appended_segments(appended) + assert len(segments) == 1 + assert [msg["role"] for msg in segments[0]] == ["tool", "tool"] + + incremental = qwen3_tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) + decoded = qwen3_tito.tokenizer.decode(incremental) + assert MultiToolSingleTurnTrajectory.MESSAGES[3]["content"] in decoded + assert MultiToolSingleTurnTrajectory.MESSAGES[4]["content"] in decoded + + def test_user_and_system_segments_are_singletons(self, default_tito: TITOTokenizer): + appended = [ + {"role": "system", "content": "Use JSON."}, + {"role": "user", "content": "Hello"}, + {"role": "tool", "tool_call_id": "call_1", "content": '{"ok": true}'}, + {"role": "tool", "tool_call_id": "call_2", "content": '{"ok": false}'}, + {"role": "user", "content": "Try again"}, + ] + + segments = default_tito._split_appended_segments(appended) + assert [[msg["role"] for msg in segment] for segment in segments] == [ + ["system"], + ["user"], + ["tool", "tool"], + ["user"], + ] + + def test_generation_prompt_is_appended_once_for_full_suffix(self, qwen3_tito: Qwen3TITOTokenizer): + old_msgs = list(SingleToolThinkingTrajectory.MESSAGES[:3]) + new_msgs = old_msgs + [ + SingleToolThinkingTrajectory.MESSAGES[3], + {"role": "user", "content": "Now check Shanghai too."}, + ] + tools = SingleToolThinkingTrajectory.TOOLS + + incremental = qwen3_tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) + decoded = qwen3_tito.tokenizer.decode(incremental) + assert decoded.count(qwen3_tito._assistant_start_str) == 1 + assert decoded.endswith( + qwen3_tito.tokenizer.decode( + qwen3_tito._tokenize_rendered_suffix(new_msgs, [], tools=tools, add_generation_prompt=True) + ) + ) + + def test_qwen3_tool_dummy_assistant_preserves_reasoning_shape(self): + thinking_template_path = ( + Path(__file__).resolve().parents[4] + / "miles/utils/chat_template_utils/templates/qwen3_thinking_2507_and_next_fixed.jinja" + ) + thinking_tito = Qwen3TITOTokenizer( + load_tokenizer( + "Qwen/Qwen3-4B-Instruct-2507", + chat_template_path=str(thinking_template_path), + trust_remote_code=True, + ), + allowed_append_roles=_ALLOWED_APPEND_ROLES, + ) + tool_messages = [SingleToolThinkingTrajectory.MESSAGES[3]] + dummy_assistant = _build_dummy_assistant(tool_messages) + rendered = thinking_tito._render_messages( + [{"role": "system", "content": "dummy system"}, dummy_assistant], + add_generation_prompt=False, + tools=SingleToolThinkingTrajectory.TOOLS, + ) + + assert dummy_assistant["reasoning_content"] == " " + assert rendered.endswith( + '<|im_start|>assistant\n\n{"name": "dummy_func", "arguments": {}}\n<|im_end|>\n' + ) + + @pytest.mark.parametrize( + "traj_cls, pos", + [ + pytest.param(SingleToolTrajectory, 3, id="single-tool"), + pytest.param(RetrySystemTrajectory, 3, id="tool-plus-system"), + pytest.param(IntermediateSystemTrajectory, 3, id="intermediate-system"), + ], + ) + def test_qwen3_merge_preserves_non_assistant_structure(self, qwen3_tito: Qwen3TITOTokenizer, traj_cls, pos): + """Merged tokens may differ in assistant text, but not in tool/system structure.""" + old_msgs, new_msgs, tools = _split_at(traj_cls, pos) + pretokenized = apply_chat_template( + old_msgs, + tokenizer=qwen3_tito.tokenizer, + tokenize=True, + add_generation_prompt=False, + tools=tools, + ) + merged = qwen3_tito.merge_tokens(old_msgs, new_msgs, pretokenized, tools) + expected = apply_chat_template( + new_msgs, + tokenizer=qwen3_tito.tokenizer, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + mismatches = qwen3_tito.create_comparator().compare_sequences(expected, merged) + assert all(m.type == MismatchType.ASSISTANT_TEXT for m in mismatches) + # -- Append-only validation (assert_messages_append_only_with_allowed_role is called internally) -- def test_rejects_prefix_mutation(self, qwen3_tito: Qwen3TITOTokenizer): diff --git a/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py b/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py new file mode 100644 index 0000000000..e846e8285b --- /dev/null +++ b/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass + +import pytest +from transformers import AutoTokenizer + +from miles.utils.chat_template_utils import MismatchType, apply_chat_template, try_get_fixed_chat_template +from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizer, TITOTokenizerType, get_tito_tokenizer +from miles.utils.processing_utils import load_tokenizer +from miles.utils.test_utils.mock_trajectories import ( + MultiUserTurnThinkingTrajectory, + SimpleNoToolTrajectory, + SingleToolThinkingTrajectory, + SingleToolTrajectory, +) + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "zai-org/GLM-4.7-Flash", + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2.5", + "XiaomiMiMo/MiMo-7B-RL", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", +] + +# Models excluded from TITO testing due to known template incompatibilities. +# Filtered out of parametrized test cases below. +_TITO_EXCLUDED_MODELS: dict[str, str] = { + "Qwen/Qwen3.5-0.8B": ( + "The qwen3.5 fixed template rejects non-first system messages with " + "'System message must be at the beginning'. TITO's synthetic bases " + "place system first, so this exclusion may be removable — needs testing." + ), + "deepseek-ai/DeepSeek-V3": ( + "TITO tokenizes each tool segment independently via _tokenize_tool_segment, " + "which causes DeepSeek-V3's template to emit extra " + "<|tool_outputs_begin|>/<|tool_outputs_end|> wrappers that differ from " + "full-conversation rendering." + ), +} +_TITO_TEST_MODELS = [m for m in TOOL_CALL_TEST_MODELS if m not in _TITO_EXCLUDED_MODELS] + +_ALLOWED_APPEND_ROLES = ["tool", "user", "system"] +_TOK_CACHE: dict[tuple[str, str | None], AutoTokenizer] = {} +_ASSISTANT_START_BY_MODEL: dict[str, str] = { + "Qwen/Qwen2.5-0.5B-Instruct": "<|im_start|>assistant\n", + "mistralai/Mistral-7B-Instruct-v0.3": "[/INST]", + "deepseek-ai/DeepSeek-V3": "<|Assistant|>", + "stepfun-ai/step3": "<|BOT|>assistant\n", + "MiniMaxAI/MiniMax-M2": "]~b]ai\n", + "MiniMaxAI/MiniMax-M2.5": "]~b]ai\n", + "internlm/internlm3-8b-instruct": "<|im_start|>assistant\n", + "THUDM/glm-4-9b-chat": "<|assistant|>", + "moonshotai/Kimi-K2-Instruct": "<|im_assistant|>assistant<|im_middle|>", + "moonshotai/Kimi-K2.5": "<|im_assistant|>assistant<|im_middle|>", + "XiaomiMiMo/MiMo-7B-RL": "<|im_start|>assistant\n", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "<|im_start|>assistant\n", +} +_NO_SYSTEM_APPEND_MODELS = { + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", +} +_CONTENT_WHITESPACE_AGNOSTIC_MODELS = { + "stepfun-ai/step3", +} + + +@dataclass(frozen=True) +class AppendCase: + name: str + old_messages: list[dict] + appended_messages: list[dict] + tools: list[dict] | None + required_contents: tuple[str, ...] = () + + +_APPEND_CASES = [ + AppendCase( + name="single_tool", + old_messages=deepcopy(SingleToolTrajectory.MESSAGES[:3]), + appended_messages=deepcopy([SingleToolTrajectory.MESSAGES[3]]), + tools=deepcopy(SingleToolTrajectory.TOOLS), + required_contents=(SingleToolTrajectory.MESSAGES[3]["content"],), + ), + AppendCase( + name="single_user", + old_messages=deepcopy(MultiUserTurnThinkingTrajectory.MESSAGES[:5]), + appended_messages=deepcopy([MultiUserTurnThinkingTrajectory.MESSAGES[5]]), + tools=deepcopy(MultiUserTurnThinkingTrajectory.TOOLS), + required_contents=(MultiUserTurnThinkingTrajectory.MESSAGES[5]["content"],), + ), + AppendCase( + name="single_system", + old_messages=deepcopy(SimpleNoToolTrajectory.MESSAGES), + appended_messages=[{"role": "system", "content": "Please answer in one short sentence."}], + tools=None, + required_contents=("Please answer in one short sentence.",), + ), + AppendCase( + name="alternating_user_tool", + old_messages=deepcopy(SingleToolThinkingTrajectory.MESSAGES[:3]), + appended_messages=[ + deepcopy(SingleToolThinkingTrajectory.MESSAGES[3]), + {"role": "user", "content": "Now check Shanghai too."}, + { + "role": "tool", + "tool_call_id": "call_followup_1", + "content": '{"temperature": 30, "condition": "cloudy"}', + }, + {"role": "user", "content": "And tell me the date as well."}, + ], + tools=deepcopy(SingleToolThinkingTrajectory.TOOLS), + required_contents=( + SingleToolThinkingTrajectory.MESSAGES[3]["content"], + "Now check Shanghai too.", + '{"temperature": 30, "condition": "cloudy"}', + "And tell me the date as well.", + ), + ), +] + +_ALL_PARAMS = [ + pytest.param(model_name, case, id=f"{case.name}-{model_name}") + for model_name in _TITO_TEST_MODELS + for case in _APPEND_CASES + if not (case.name == "single_system" and model_name in _NO_SYSTEM_APPEND_MODELS) +] + + +def _resolve_tito_type(model_name: str) -> TITOTokenizerType: + lowered = model_name.lower() + if "qwen3" in lowered: + return TITOTokenizerType.QWEN3 + if "glm-4.7" in lowered: + return TITOTokenizerType.GLM47 + return TITOTokenizerType.DEFAULT + + +def _get_tokenizer(model_name: str) -> AutoTokenizer: + chat_template_path = try_get_fixed_chat_template(model_name) + cache_key = (model_name, chat_template_path) + if cache_key not in _TOK_CACHE: + _TOK_CACHE[cache_key] = load_tokenizer( + model_name, + chat_template_path=chat_template_path, + trust_remote_code=True, + ) + return _TOK_CACHE[cache_key] + + +def _get_tito(model_name: str, tokenizer: AutoTokenizer) -> TITOTokenizer: + tokenizer_type = _resolve_tito_type(model_name) + kwargs = { + "tokenizer_type": tokenizer_type, + "allowed_append_roles": _ALLOWED_APPEND_ROLES, + } + if tokenizer_type == TITOTokenizerType.DEFAULT: + kwargs["assistant_start_str"] = _ASSISTANT_START_BY_MODEL[model_name] + return get_tito_tokenizer(tokenizer, **kwargs) + + +def _render_ids( + tokenizer: AutoTokenizer, messages: list[dict], tools: list[dict] | None, *, add_generation_prompt: bool +) -> list[int]: + return apply_chat_template( + messages, + tokenizer=tokenizer, + tokenize=True, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + +def _assert_only_assistant_mismatches(tito: TITOTokenizer, expected: list[int], merged: list[int]) -> None: + mismatches = tito.create_comparator().compare_sequences(expected, merged) + bad = [m for m in mismatches if m.type != MismatchType.ASSISTANT_TEXT] + assert not bad, [m.to_dict() for m in bad] + + +def _assert_contents_in_order( + incremental_text: str, required_contents: tuple[str, ...], *, model_name: str, case_name: str +) -> None: + if model_name in _CONTENT_WHITESPACE_AGNOSTIC_MODELS: + incremental_text = "".join(incremental_text.split()) + required_contents = tuple("".join(content.split()) for content in required_contents) + cursor = 0 + for content in required_contents: + found = incremental_text.find(content, cursor) + assert found >= 0, f"{model_name=} {case_name=} missing ordered content {content!r}" + cursor = found + len(content) + + +def _run_case(model_name: str, case: AppendCase) -> tuple[TITOTokenizer, list[int], list[int], str]: + tokenizer = _get_tokenizer(model_name) + tito = _get_tito(model_name, tokenizer) + old_messages = deepcopy(case.old_messages) + new_messages = old_messages + deepcopy(case.appended_messages) + try: + expected = _render_ids(tokenizer, new_messages, case.tools, add_generation_prompt=True) + pretokenized = _render_ids(tokenizer, old_messages, case.tools, add_generation_prompt=False) + except Exception as exc: + pytest.skip(f"{model_name} cannot render case {case.name}: {type(exc).__name__}: {exc}") + merged = tito.merge_tokens(old_messages, new_messages, pretokenized, case.tools) + incremental_text = tokenizer.decode(tito.tokenize_additional_non_assistant(old_messages, new_messages, case.tools)) + return tito, merged, expected, incremental_text + + +@pytest.mark.parametrize(("model_name", "case"), _ALL_PARAMS) +def test_appended_non_assistant_content_preserved(model_name: str, case: AppendCase): + tito, merged, expected, incremental_text = _run_case(model_name, case) + _assert_only_assistant_mismatches(tito, expected, merged) + _assert_contents_in_order(incremental_text, case.required_contents, model_name=model_name, case_name=case.name) diff --git a/tests/test_fused_experts_backward.py b/tests/test_fused_experts_backward.py index e2a94897b2..a89e2de51a 100644 --- a/tests/test_fused_experts_backward.py +++ b/tests/test_fused_experts_backward.py @@ -260,8 +260,8 @@ def backward(ctx, grad_output): # Import Triton Implementation # ============================================================================ -from miles.backends.fsdp_utils.kernels.fused_experts import DownProjFunction as DownProjFunctionTriton -from miles.backends.fsdp_utils.kernels.fused_experts import GateUpProjFunction as GateUpProjFunctionTriton +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import DownProjFunction as DownProjFunctionTriton +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import GateUpProjFunction as GateUpProjFunctionTriton # ============================================================================ # Test Fixtures and Utilities diff --git a/tools/convert_hf_to_nvfp4.py b/tools/convert_hf_to_nvfp4.py new file mode 100644 index 0000000000..2de2183a41 --- /dev/null +++ b/tools/convert_hf_to_nvfp4.py @@ -0,0 +1,526 @@ +""" +python tools/convert_hf_to_nvfp4.py [-h] [--model-dir MODEL_DIR] [--save-dir SAVE_DIR] + [--device DEVICE] [--keep-last-n KEEP_LAST_N] [--keep-first-n KEEP_FIRST_N] + +Convert a BF16/FP16/FP32 HF safetensors checkpoint to NVFP4 (E2M1) for MoE +expert GEMMs only. Dense linear layers are left unmodified. + +This follows the NVFP4 reference quantization in Transformer Engine and uses +1D block scaling (NVTE_NVFP4_1D_SCALING, group size = 16). +""" + +import argparse +import gc +import json +import os +import shutil + +import safetensors +import safetensors.torch +import torch +from tqdm import tqdm + +FP4_E2M1_MAX = 6.0 +FP8_E4M3_MAX = 448.0 +NVFP4_GROUP_SIZE = 16 +DEFAULT_KV_CACHE_SCHEME = {"dynamic": False, "num_bits": 8, "type": "float"} +DEFAULT_KV_CACHE_QUANT_ALGO = "FP8" + +EXPERT_WEIGHT_SUFFIXES = ( + ".w1.weight", + ".w2.weight", + ".w3.weight", + ".gate_proj.weight", + ".up_proj.weight", + ".down_proj.weight", + ".gate_up_proj.weight", +) + +EXPERT_NAME_MARKERS = ( + ".experts.", + ".shared_experts.", + "block_sparse_moe.experts.", + ".moe.experts.", +) + +FUSED_QKV_SUFFIXES = (".q_proj", ".k_proj", ".v_proj") +GATED_PAIR_SUFFIXES = { + ".gate_proj.weight": "gate", + ".up_proj.weight": "up", + ".w1.weight": "gate", + ".w3.weight": "up", +} + + +def _is_moe_expert_weight_name(name: str) -> bool: + if not name.endswith(".weight"): + return False + if not any(marker in name for marker in EXPERT_NAME_MARKERS): + return False + return any(name.endswith(suffix) for suffix in EXPERT_WEIGHT_SUFFIXES) + + +def _extract_layer_id(name: str) -> int | None: + parts = name.split(".") + for idx, part in enumerate(parts): + if part == "layers" and idx + 1 < len(parts): + layer_id = parts[idx + 1] + if layer_id.isdigit(): + return int(layer_id) + return None + + +def _get_num_hidden_layers(model_dir: str) -> int: + config_path = os.path.join(model_dir, "config.json") + if not os.path.exists(config_path): + raise ValueError("config.json is required to use --keep-first-n or --keep-last-n.") + cfg = json.load(open(config_path)) + num_layers = cfg.get("num_hidden_layers") + if num_layers is None and isinstance(cfg.get("text_config"), dict): + num_layers = cfg["text_config"].get("num_hidden_layers") + if num_layers is None: + raise ValueError("num_hidden_layers not found in config.json.") + return int(num_layers) + + +def _get_last_n_layer_ids(num_layers: int, keep_last_n: int) -> set[int]: + if keep_last_n <= 0: + return set() + start = max(0, num_layers - keep_last_n) + return set(range(start, num_layers)) + + +def _get_first_n_layer_ids(num_layers: int, keep_first_n: int) -> set[int]: + if keep_first_n <= 0: + return set() + end = min(num_layers, keep_first_n) + return set(range(0, end)) + + +def _build_keep_last_n_ignore_list(num_layers: int, keep_last_n: int) -> list[str]: + if keep_last_n <= 0: + return [] + start = max(0, num_layers - keep_last_n) + ignore_list = [] + for layer_id in range(start, num_layers): + prefix = f"model.layers.{layer_id}" + ignore_list.extend( + [ + f"{prefix}.self_attn.qkv_proj", + f"{prefix}.self_attn.o_proj", + f"{prefix}.mlp", + f"{prefix}.mlp.experts", + ] + ) + return ignore_list + + +def _build_keep_first_n_ignore_list(num_layers: int, keep_first_n: int) -> list[str]: + if keep_first_n <= 0: + return [] + end = min(num_layers, keep_first_n) + ignore_list = [] + for layer_id in range(0, end): + prefix = f"model.layers.{layer_id}" + ignore_list.extend( + [ + f"{prefix}.self_attn.qkv_proj", + f"{prefix}.self_attn.o_proj", + f"{prefix}.mlp", + f"{prefix}.mlp.experts", + ] + ) + return ignore_list + + +def should_quantize( + name: str, + weight: torch.Tensor, + skip_layers: set[int] | None = None, +) -> bool: + if skip_layers: + layer_id = _extract_layer_id(name) + if layer_id is not None and layer_id in skip_layers: + return False + if not _is_moe_expert_weight_name(name): + return False + if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + return False + if weight.dim() < 2: + return False + if weight.shape[-1] % NVFP4_GROUP_SIZE != 0: + raise ValueError( + f"Last dim {weight.shape[-1]} must be divisible by {NVFP4_GROUP_SIZE} " f"for NVFP4 quantization ({name})." + ) + return True + + +def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def _quantize_nvfp4_1d( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + NVFP4 1D quantization (tile shape = 1x16), adapted from + TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. + + Returns: + qweight: uint8 packed fp4, shape (M, K // 2) + block_scale: float8_e4m3fn, shape (M, K // 16) + global_scale: float32 scalar tensor + """ + weight = weight.contiguous() + m, n = weight.shape + if n % NVFP4_GROUP_SIZE != 0: + raise ValueError(f"NVFP4 requires K divisible by {NVFP4_GROUP_SIZE}, got {n}.") + + weight_f = weight.to(torch.float32) + if global_amax is None: + global_amax = torch.max(torch.abs(weight_f)) + else: + global_amax = global_amax.to(device=weight.device, dtype=torch.float32) + if global_amax.item() == 0.0: + qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) + block_scale = torch.zeros( + (m, n // NVFP4_GROUP_SIZE), + dtype=torch.float8_e4m3fn, + device=weight.device, + ) + global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + return qweight, block_scale, global_scale + + fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) + fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) + + global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32), + ) + if global_encode_scale.item() == 0.0: + global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + weight_blocks = weight_f.view(m, n // NVFP4_GROUP_SIZE, NVFP4_GROUP_SIZE) + vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True) + decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale + decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) + + encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale) + scaled = weight_blocks * encode_scale + clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n) + + qweight = cast_to_fp4x2(clipped) + block_scale = decode_scale.squeeze(-1) + return qweight, block_scale, global_decode_scale + + +def quantize_nvfp4( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if weight.dim() == 2: + return _quantize_nvfp4_1d(weight, global_amax=global_amax) + if weight.dim() == 3: + if global_amax is not None: + raise ValueError("global_amax override is only supported for 2D weights.") + qweights = [] + block_scales = [] + global_scales = [] + for idx in range(weight.shape[0]): + qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx]) + qweights.append(qweight) + block_scales.append(block_scale) + global_scales.append(global_scale) + return ( + torch.stack(qweights, dim=0), + torch.stack(block_scales, dim=0), + torch.stack(global_scales, dim=0), + ) + raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.") + + +class ConversionResult: + def __init__(self) -> None: + self.weight_map: dict[str, str] = {} + self.total_size: int = 0 + self.modules_to_not_convert: list[str] = [] + + def add_result(self, filename: str, q_weights: dict[str, torch.Tensor], module_names: list[str]) -> None: + for key, tensor in q_weights.items(): + self.weight_map[key] = filename + self.total_size += tensor.numel() * tensor.element_size() + self.modules_to_not_convert.extend(module_names) + + +def _update_quantization_config(cfg: dict, ignore_list: list[str]) -> None: + quant_cfg = cfg.get("quantization_config") + if not isinstance(quant_cfg, dict): + quant_cfg = {} + + quant_cfg["quant_algo"] = "NVFP4" + quant_cfg["quant_method"] = "modelopt" + quant_cfg["group_size"] = NVFP4_GROUP_SIZE + quant_cfg["ignore"] = ignore_list + quant_cfg.setdefault("kv_cache_scheme", DEFAULT_KV_CACHE_SCHEME) + + config_groups = quant_cfg.get("config_groups") + if isinstance(config_groups, dict): + for group in config_groups.values(): + if not isinstance(group, dict): + continue + group.setdefault("targets", ["Linear"]) + for key in ("input_activations", "weights"): + section = group.get(key) + if not isinstance(section, dict): + continue + section.setdefault("dynamic", False) + section.setdefault("num_bits", 4) + section.setdefault("type", "float") + section["group_size"] = NVFP4_GROUP_SIZE + + cfg["quantization_config"] = quant_cfg + + +def _write_hf_quant_config(output_path: str, ignore_list: list[str], input_path: str) -> None: + hf_quant_path = os.path.join(input_path, "hf_quant_config.json") + if os.path.exists(hf_quant_path): + with open(hf_quant_path) as f: + hf_quant_cfg = json.load(f) + else: + hf_quant_cfg = {"producer": {"name": "modelopt"}} + + quant_section = hf_quant_cfg.get("quantization") + if not isinstance(quant_section, dict): + quant_section = {} + + quant_section["quant_algo"] = "NVFP4" + quant_section["kv_cache_quant_algo"] = DEFAULT_KV_CACHE_QUANT_ALGO + quant_section["group_size"] = NVFP4_GROUP_SIZE + quant_section["exclude_modules"] = ignore_list + hf_quant_cfg["quantization"] = quant_section + + with open(os.path.join(output_path, "hf_quant_config.json"), "w") as f: + json.dump(hf_quant_cfg, f, indent=2) + + +def _augment_ignore_list(ignore_list: list[str]) -> list[str]: + ignore_set = set(ignore_list) + extra = set() + for name in ignore_list: + if name.endswith(FUSED_QKV_SUFFIXES): + for suffix in FUSED_QKV_SUFFIXES: + if name.endswith(suffix): + extra.add(name[: -len(suffix)] + ".qkv_proj") + break + ignore_set.update(extra) + return sorted(ignore_set) + + +def _split_gated_pair_name(name: str) -> tuple[str | None, str | None]: + for suffix, role in GATED_PAIR_SUFFIXES.items(): + if name.endswith(suffix): + return name[: -len(suffix)], role + return None, None + + +def _collect_shared_global_amax( + *, + input_path: str, + safetensors_files: list[str], + device: str, + skip_layers: set[int], +) -> dict[str, torch.Tensor]: + """Collect shared gate/up amax across all shards to keep w1/w3 scales equal.""" + gate_amax: dict[str, torch.Tensor] = {} + up_amax: dict[str, torch.Tensor] = {} + for filename in safetensors_files: + with safetensors.safe_open(os.path.join(input_path, filename), framework="pt", device=device) as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if not should_quantize(key, tensor, skip_layers): + continue + base, role = _split_gated_pair_name(key) + if base is None or role is None: + continue + amax = tensor.abs().max().to(torch.float32) + if role == "gate": + prev = gate_amax.get(base) + gate_amax[base] = amax if prev is None else torch.max(prev, amax) + elif role == "up": + prev = up_amax.get(base) + up_amax[base] = amax if prev is None else torch.max(prev, amax) + else: + continue + + shared_global_amax: dict[str, torch.Tensor] = {} + for base in gate_amax.keys() & up_amax.keys(): + shared_global_amax[base] = torch.max(gate_amax[base], up_amax[base]) + return shared_global_amax + + +def process_file( + input_path: str, + output_path: str, + filename: str, + result_collector: ConversionResult, + device: str, + skip_layers: set[int], + shared_global_amax: dict[str, torch.Tensor], +) -> None: + if not filename.endswith(".safetensors"): + return + + modules_to_not_convert: list[str] = [] + q_weights: dict[str, torch.Tensor] = {} + + with safetensors.safe_open(os.path.join(input_path, filename), framework="pt", device=device) as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if should_quantize(key, tensor, skip_layers): + base, _role = _split_gated_pair_name(key) + global_amax = shared_global_amax.get(base) if base else None + qweight, block_scale, weight_scale_2 = quantize_nvfp4(tensor, global_amax=global_amax) + q_weights[key] = qweight + q_weights[key.replace(".weight", ".weight_scale")] = block_scale + q_weights[key.replace(".weight", ".weight_scale_2")] = weight_scale_2 + q_weights[key.replace(".weight", ".input_scale")] = torch.ones_like( + weight_scale_2, dtype=torch.float32 + ) + else: + if key.endswith(".weight"): + modules_to_not_convert.append(key.replace(".weight", "")) + q_weights[key] = tensor + + safetensors.torch.save_file(q_weights, os.path.join(output_path, filename), metadata={"format": "pt"}) + result_collector.add_result(filename, q_weights, modules_to_not_convert) + + +def convert_nvfp4(model_dir: str, save_dir: str, device: str, keep_last_n: int, keep_first_n: int) -> None: + input_path = os.path.abspath(model_dir) + output_path = os.path.abspath(save_dir) + os.makedirs(output_path, exist_ok=True) + + for filename in os.listdir(input_path): + if not filename.endswith(".safetensors") and not os.path.isdir(os.path.join(input_path, filename)): + shutil.copyfile(os.path.join(input_path, filename), os.path.join(output_path, filename)) + + safetensors_files = [f for f in os.listdir(input_path) if f.endswith(".safetensors")] + + num_layers = _get_num_hidden_layers(input_path) if (keep_last_n > 0 or keep_first_n > 0) else 0 + skip_layers = _get_last_n_layer_ids(num_layers, keep_last_n) | _get_first_n_layer_ids(num_layers, keep_first_n) + keep_last_ignore = _build_keep_last_n_ignore_list(num_layers, keep_last_n) + keep_first_ignore = _build_keep_first_n_ignore_list(num_layers, keep_first_n) + + shared_global_amax = _collect_shared_global_amax( + input_path=input_path, + safetensors_files=safetensors_files, + device=device, + skip_layers=skip_layers, + ) + result_collector = ConversionResult() + for filename in tqdm(safetensors_files, desc="Processing files"): + process_file( + input_path, + output_path, + filename, + result_collector, + device, + skip_layers, + shared_global_amax, + ) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + ignore_list = _augment_ignore_list(result_collector.modules_to_not_convert + keep_last_ignore + keep_first_ignore) + + config_path = os.path.join(input_path, "config.json") + if os.path.exists(config_path): + cfg = json.load(open(config_path)) + _update_quantization_config(cfg, ignore_list) + json.dump(cfg, open(os.path.join(output_path, "config.json"), "w"), indent=2) + + _write_hf_quant_config(output_path, ignore_list, input_path) + + index_dict = { + "weight_map": result_collector.weight_map, + "metadata": {"total_size": result_collector.total_size}, + } + json.dump(index_dict, open(os.path.join(output_path, "model.safetensors.index.json"), "w"), indent=2) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, required=True, help="Path to HF safetensors model.") + parser.add_argument("--save-dir", type=str, required=True, help="Path to save converted model.") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Torch device to run quantization on (default: cuda).", + ) + parser.add_argument( + "--keep-last-n", + type=int, + default=0, + help="Keep the last N transformer layers unquantized (BF16/FP16).", + ) + parser.add_argument( + "--keep-first-n", + type=int, + default=0, + help="Keep the first N transformer layers unquantized (BF16/FP16).", + ) + args = parser.parse_args() + + if isinstance(args.device, str) and args.device.isdigit(): + device = torch.device(f"cuda:{args.device}") + else: + device = torch.device(args.device) + + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run NVFP4 quantization.") + if device.index is None: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + if not os.path.exists(args.save_dir): + print(f"Creating directory {args.save_dir}") + os.makedirs(args.save_dir) + elif not os.path.isdir(args.save_dir): + raise ValueError("The save_dir should be a directory.") + + convert_nvfp4(args.model_dir, args.save_dir, str(device), args.keep_last_n, args.keep_first_n) + + +if __name__ == "__main__": + main() diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 354c216a65..1ea11cf156 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -1,7 +1,6 @@ import gc import os import shutil -from functools import wraps import torch import torch.distributed as dist @@ -12,7 +11,6 @@ import miles_plugins.mbridge # noqa: F401 from mbridge import AutoBridge -from mbridge.core.bridge import Bridge from miles.backends.megatron_utils.arguments import set_default_megatron_args from miles.backends.megatron_utils.initialize import init from miles.backends.megatron_utils.model_provider import get_model_provider_func @@ -21,24 +19,6 @@ from miles_plugins.models.hf_attention import _load_hf_config -def patch_weight_to_mcore_format_preserve_fp32(): - - original_method = Bridge._weight_to_mcore_format - - @wraps(original_method) - def patched_method(self, mcore_weights_name, hf_weights): - original_dtype = getattr(self, "dtype", None) - self.dtype = None - try: - result = original_method(self, mcore_weights_name, hf_weights) - finally: - self.dtype = original_dtype - return result - - Bridge._weight_to_mcore_format = patched_method - print("[Patch] Applied patch to preserve FP32 precision in _weight_to_mcore_format") - - def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") @@ -138,9 +118,6 @@ def main(): # Fallback for configs with model_type unknown to installed transformers. bridge = AutoBridge.from_config(_load_hf_config(hf_model_path)) - # Patch to preserve FP32 precision for _keep_fp32 params - patch_weight_to_mcore_format_preserve_fp32() - bridge.load_weights(model, hf_model_path, memory_efficient=True) print(f"Model loaded: {hf_model_path}")