From d967e936e08a21bef5926519fdb4f30ff667a5d3 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 04:43:40 -0700 Subject: [PATCH 01/44] Auto-create sessions on first access for restart tolerance When the session server restarts, all in-memory sessions are lost. Previously this returned 404 to every active agent, cascading failures across all running trials. Now, get_or_create_session() auto-creates the session if it does not exist, allowing agents to transparently recover after a router restart. The GET /sessions/{session_id} and chat completions endpoints both use this new method. --- miles/rollout/session/linear_trajectory.py | 8 ++++++++ miles/rollout/session/sessions.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index 79d092349b..ab86151889 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -268,6 +268,14 @@ def get_session(self, session_id: str) -> LinearTrajectory: raise SessionNotFoundError(f"session not found: session_id={session_id}") return session + def get_or_create_session(self, session_id: str) -> LinearTrajectory: + session = self.sessions.get(session_id) + if session is None: + logger.warning("Auto-creating session %s (not found, likely router restart)", session_id) + session = LinearTrajectory() + self.sessions[session_id] = session + return session + def remove_session(self, session_id: str) -> None: if self.sessions.pop(session_id, None) is None: raise SessionNotFoundError(f"session not found: session_id={session_id}") diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index bf53f446f4..394641a0e1 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -67,7 +67,7 @@ async def create_session(): @app.get("/sessions/{session_id}") async def get_session(session_id: str): - session = registry.get_session(session_id) + session = registry.get_or_create_session(session_id) metadata = {} try: mismatch = registry.compute_session_mismatch(session) @@ -114,7 +114,7 @@ async def chat_completions(request: Request, session_id: str): """ _inflight_chat["count"] += 1 try: - session = registry.get_session(session_id) + session = registry.get_or_create_session(session_id) if session.closing: raise SessionNotFoundError(f"session not found: session_id={session_id}") From 2ae96dd9e9f4284e692520668da31c722316eeab Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 05:03:23 -0700 Subject: [PATCH 02/44] Fix: recover from SGLang rollback failures in session proxy When SGLang returns 400 "rollback failed" (prefix-cache state mismatch), retry the request once without pretokenized input_ids. This bypasses prefix continuation and lets SGLang process the request from scratch. Previously, rollback failures were passed through to the caller as fatal errors, ending the session on the first request. --- miles/rollout/session/sessions.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index bf53f446f4..b5cd436561 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -167,7 +167,24 @@ async def chat_completions(request: Request, session_id: str): # pass it through to the agent without recording — the agent can retry # or handle the error. if result["status_code"] != 200: - return backend.build_proxy_response(result) + # Rollback failures indicate corrupted prefix-cache state in SGLang. + # Retry once without pretokenized input_ids so SGLang processes the + # request from scratch instead of attempting prefix continuation. + error_body = result.get("response_body", b"") + if isinstance(error_body, bytes): + error_body = error_body.decode("utf-8", errors="replace") + if result["status_code"] == 400 and "rollback failed" in error_body: + logger.warning( + "SGLang rollback failed for session %s, retrying without prefix continuation", + session_id, + ) + request_body.pop("input_ids", None) + retry_body = json.dumps(request_body).encode() + result = await backend.do_proxy(request, "v1/chat/completions", body=retry_body) + if result["status_code"] != 200: + return backend.build_proxy_response(result) + else: + return backend.build_proxy_response(result) response = json.loads(result["response_body"]) From e0fc889fb898cc099da02fe697b790c2abba73f4 Mon Sep 17 00:00:00 2001 From: JD Date: Sun, 5 Apr 2026 13:51:25 -0700 Subject: [PATCH 03/44] Revert "[BUGFIX] [P2PRDMA] Add rollout post-processing after P2PRDMA weight updates" (#882) --- .../update_weight/update_weight_from_distributed/p2p.py | 1 - 1 file changed, 1 deletion(-) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py index 7548fc2c9c..9702b31431 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py @@ -127,7 +127,6 @@ def _finalize_and_resume_engines(self): ) post_process_weights( rollout_engines=self.rollout_engines, - post_process_quantization=True, post_load_weights=True, ) super()._finalize_and_resume_engines() From dd188aaddb887e5c1de066277e7c9f10eec297b0 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 14:15:47 -0700 Subject: [PATCH 04/44] Fix null body crash and case-insensitive rollback detection - Guard against None response_body (use `or b""` instead of default) - Use .lower() for case-insensitive "rollback failed" matching - Only retry without prefix continuation if input_ids was present --- miles/rollout/session/sessions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index b5cd436561..12dd21c75d 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -170,10 +170,14 @@ async def chat_completions(request: Request, session_id: str): # Rollback failures indicate corrupted prefix-cache state in SGLang. # Retry once without pretokenized input_ids so SGLang processes the # request from scratch instead of attempting prefix continuation. - error_body = result.get("response_body", b"") + error_body = result.get("response_body") or b"" if isinstance(error_body, bytes): error_body = error_body.decode("utf-8", errors="replace") - if result["status_code"] == 400 and "rollback failed" in error_body: + if ( + result["status_code"] == 400 + and "rollback failed" in error_body.lower() + and "input_ids" in request_body + ): logger.warning( "SGLang rollback failed for session %s, retrying without prefix continuation", session_id, From 29a0dcad45e36985c7af6f7c1dcf653e0eace4f2 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 14:15:50 -0700 Subject: [PATCH 05/44] Add session TTL eviction and make GET endpoint side-effect free - GET /sessions/{id} returns empty response for unknown sessions instead of auto-creating (keeps the endpoint idempotent) - Auto-creation in POST still works for restart tolerance - Add TTL-based eviction (2h) for auto-created sessions to prevent unbounded memory growth --- miles/rollout/session/linear_trajectory.py | 24 ++++++++++++++++++++++ miles/rollout/session/sessions.py | 4 +++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index ab86151889..3646150c2d 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -1,5 +1,6 @@ import asyncio import logging +import time import uuid from dataclasses import dataclass, field from typing import Any @@ -252,6 +253,7 @@ class SessionRegistry: def __init__(self, args, tokenizer: Any, *, tito_tokenizer: TITOTokenizer): self.sessions: dict[str, LinearTrajectory] = {} + self._session_last_access: dict[str, float] = {} self.args = args self.tokenizer = tokenizer self.tito_tokenizer = tito_tokenizer @@ -271,11 +273,33 @@ def get_session(self, session_id: str) -> LinearTrajectory: def get_or_create_session(self, session_id: str) -> LinearTrajectory: session = self.sessions.get(session_id) if session is None: + self._evict_stale_sessions() logger.warning("Auto-creating session %s (not found, likely router restart)", session_id) session = LinearTrajectory() self.sessions[session_id] = session + self._session_last_access[session_id] = time.monotonic() + else: + self._session_last_access[session_id] = time.monotonic() return session + _SESSION_TTL_SECS: int = 7200 # 2 hours + _MAX_AUTO_CREATED: int = 10000 + + def _evict_stale_sessions(self) -> None: + """Remove auto-created sessions older than _SESSION_TTL_SECS.""" + if not self._session_last_access: + return + now = time.monotonic() + stale = [ + sid for sid, ts in self._session_last_access.items() + if now - ts > self._SESSION_TTL_SECS + ] + for sid in stale: + self.sessions.pop(sid, None) + self._session_last_access.pop(sid, None) + if stale: + logger.info("Evicted %d stale auto-created sessions", len(stale)) + def remove_session(self, session_id: str) -> None: if self.sessions.pop(session_id, None) is None: raise SessionNotFoundError(f"session not found: session_id={session_id}") diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 394641a0e1..d71f577795 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -67,7 +67,9 @@ async def create_session(): @app.get("/sessions/{session_id}") async def get_session(session_id: str): - session = registry.get_or_create_session(session_id) + session = registry.sessions.get(session_id) + if session is None: + return GetSessionResponse(session_id=session_id, records=[], metadata={}) metadata = {} try: mismatch = registry.compute_session_mismatch(session) From f700bd8a6e316f58b500928914aa3eaae0bd615d Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 14:15:53 -0700 Subject: [PATCH 06/44] Fix _truncate_sample_output to truncate rollout_routed_experts rollout_routed_experts has shape (len(tokens)-1, num_layers, topk) and was not truncated alongside tokens/logprobs/loss_mask, causing Sample.validate() to fail with an assertion error when agentic sessions exceed max_seq_len with routing replay enabled. --- miles/rollout/generate_utils/openai_endpoint_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 5b9a445adf..1b449d203a 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -228,4 +228,8 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None sample.rollout_log_probs = sample.rollout_log_probs[:keep_tokens] if sample.loss_mask is not None: sample.loss_mask = sample.loss_mask[:keep_tokens] + if sample.rollout_routed_experts is not None: + # rollout_routed_experts has shape (len(tokens) - 1, ...), so slice to + # match the new total token count minus one. + sample.rollout_routed_experts = sample.rollout_routed_experts[:prompt_len - 1 + keep_tokens] sample.status = Sample.Status.TRUNCATED From 25645357a7d51daa307a9fa25081095bc3cfb2a1 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sun, 5 Apr 2026 15:25:57 -0700 Subject: [PATCH 07/44] Simplify rollout_routed_experts slice to len(tokens) - 1 Directly reflects the invariant rather than reconstructing the index from prompt_len and keep_tokens. --- miles/rollout/generate_utils/openai_endpoint_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 1b449d203a..484eb99752 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -229,7 +229,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None if sample.loss_mask is not None: sample.loss_mask = sample.loss_mask[:keep_tokens] if sample.rollout_routed_experts is not None: - # rollout_routed_experts has shape (len(tokens) - 1, ...), so slice to - # match the new total token count minus one. - sample.rollout_routed_experts = sample.rollout_routed_experts[:prompt_len - 1 + keep_tokens] + sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1] sample.status = Sample.Status.TRUNCATED From ef5dda66b41eb46a8be2eacda129ef3914cafc42 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Sun, 5 Apr 2026 15:46:23 -0700 Subject: [PATCH 08/44] [Fix] fix ci (#894) --- .github/workflows/pr-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6cd1247deb..9a47e2a0ef 100755 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1375,7 +1375,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 4, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] + info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] defaults: run: working-directory: ${{ github.workspace }} From a3db3a9ef38100047eda2947b81462af69827411 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 6 Apr 2026 07:36:03 +0800 Subject: [PATCH 09/44] Avoid threading for ray getting object (#886) --- miles/utils/http_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 1e673d1c2d..0548ec6b1d 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -297,13 +297,9 @@ async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: - import ray - actor = _next_actor() if actor is not None: - # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) - return await asyncio.to_thread(ray.get, obj_ref) + return await actor.do_post.remote(url, payload, max_retries, action=action) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local From 4dd7770ed8caf59e45f387c5af7061e5c7e2cc41 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 6 Apr 2026 07:36:32 +0800 Subject: [PATCH 10/44] Add explicit errors for unsupported Megatron profiles (#887) --- miles/backends/megatron_utils/actor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 80e7fd9339..0c265a508e 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -71,6 +71,11 @@ def init( if self._is_main_rank: init_tracking(args, primary=False) + unsupported = {"train_actor", "train_log_probs"} & set(args.profile_target) + if unsupported and args.use_pytorch_profiler: + raise NotImplementedError( + f"--profile-target {' '.join(sorted(unsupported))} is not supported for Megatron backend" + ) self.prof = TrainProfiler(args) # read config and tokenizer serialized to prevent concurrent writing bug. From 649a3538dc794db124c969da1ccafcd270d8a6a5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 5 Apr 2026 21:12:23 -0700 Subject: [PATCH 11/44] Add nvfp4 quantizer files (#907) --- .../processors/quantizer_nvfp4.py | 254 +++++++++ tools/convert_hf_to_nvfp4.py | 526 ++++++++++++++++++ 2 files changed, 780 insertions(+) create mode 100644 miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py create mode 100644 tools/convert_hf_to_nvfp4.py diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py new file mode 100644 index 0000000000..7d8b92afed --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py @@ -0,0 +1,254 @@ +import re + +import torch + +FP4_E2M1_MAX = 6.0 +FP8_E4M3_MAX = 448.0 +NVFP4_GROUP_SIZE = 16 + +GATED_PAIR_SUFFIXES = { + ".gate_proj.weight": "gate", + ".up_proj.weight": "up", + ".w1.weight": "gate", + ".w3.weight": "up", +} + + +def _get_ignore_rules(quantization_config) -> list[str]: + ignore_rules = quantization_config.get("ignore", []) or [] + if isinstance(ignore_rules, str): + ignore_rules = [ignore_rules] + exclude_rules = quantization_config.get("exclude_modules", []) or [] + if isinstance(exclude_rules, str): + exclude_rules = [exclude_rules] + return list(ignore_rules) + [rule for rule in exclude_rules if rule not in ignore_rules] + + +def _is_ignored(name: str, ignore_rules: list[str]) -> bool: + for rule in ignore_rules: + if rule.startswith("re:"): + if re.match(rule[3:], name): + return True + continue + if name == rule or name.startswith(f"{rule}."): + return True + return False + + +def quantize_params_nvfp4(args, megatron_name, converted_named_params, quantization_config): + assert quantization_config is not None + assert quantization_config.get("quant_algo") == "NVFP4" or quantization_config.get("quant_method") == "nvfp4" + group_size = _resolve_group_size(quantization_config) + ignore_rules = _get_ignore_rules(quantization_config) + + decoder_layers_pattern = r"decoder\.layers\.(\d+)\.(.+)" + match = re.search(decoder_layers_pattern, megatron_name) + + if not match: + # check mtp layers + mtp_layer_pattern = r"mtp\.layers\.(\d+)\.(.+)" + match = re.search(mtp_layer_pattern, megatron_name) + if not match: + return converted_named_params + _, rest = match.groups() + rest = rest.replace("transformer_layer.", "") + else: + _, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, _ = match.groups() + if rest in [ + "linear_fc1", + "linear_fc2", + ]: + return _quantize_moe_params(converted_named_params, group_size, ignore_rules) + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest in [ + "linear_fc1.weight", + "linear_fc2.weight", + ]: + return _quantize_moe_params(converted_named_params, group_size, ignore_rules) + + # for other parameters, we just return the original converted_named_params + return converted_named_params + + +def _resolve_group_size(quantization_config): + group_size = quantization_config.get("group_size", NVFP4_GROUP_SIZE) + if group_size != NVFP4_GROUP_SIZE: + raise ValueError(f"NVFP4 group_size must be {NVFP4_GROUP_SIZE}, got {group_size}.") + return group_size + + +def _quantize_moe_params(converted_named_params, group_size, ignore_rules): + shared_global_amax = {} + gated_candidates = {} + for converted_name, param in converted_named_params: + base, role = _split_gated_pair_name(converted_name) + if base is None or role is None: + continue + if _should_quantize_param(converted_name, param, group_size, ignore_rules): + gated_candidates.setdefault(base, {})[role] = param + + for base, roles in gated_candidates.items(): + if "gate" in roles and "up" in roles: + gate_amax = roles["gate"].abs().max().to(torch.float32) + up_amax = roles["up"].abs().max().to(torch.float32) + shared_global_amax[base] = torch.max(gate_amax, up_amax) + + quantize_named_params = [] + for converted_name, param in converted_named_params: + if not _should_quantize_param(converted_name, param, group_size, ignore_rules): + quantize_named_params.append((converted_name, param)) + continue + base, _role = _split_gated_pair_name(converted_name) + global_amax = shared_global_amax.get(base) if base else None + qweight, block_scale, weight_scale_2 = quantize_nvfp4(param, global_amax=global_amax, group_size=group_size) + quantize_named_params.append((converted_name, qweight)) + quantize_named_params.append((converted_name.replace(".weight", ".weight_scale"), block_scale)) + quantize_named_params.append((converted_name.replace(".weight", ".weight_scale_2"), weight_scale_2)) + quantize_named_params.append( + (converted_name.replace(".weight", ".input_scale"), torch.ones_like(weight_scale_2, dtype=torch.float32)) + ) + + return quantize_named_params + + +def _should_quantize_param(name, weight, group_size, ignore_rules): + if ignore_rules and _is_ignored(name, ignore_rules): + return False + if not name.endswith(".weight"): + return False + if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + return False + if weight.dim() < 2: + return False + if weight.shape[-1] % group_size != 0: + raise ValueError(f"Last dim {weight.shape[-1]} must be divisible by {group_size} for NVFP4 ({name}).") + return True + + +def _split_gated_pair_name(name: str): + for suffix, role in GATED_PAIR_SUFFIXES.items(): + if name.endswith(suffix): + return name[: -len(suffix)], role + return None, None + + +def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def _quantize_nvfp4_1d( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, + group_size: int = NVFP4_GROUP_SIZE, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + NVFP4 1D quantization (tile shape = 1x16), adapted from + TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. + + Returns: + qweight: uint8 packed fp4, shape (M, K // 2) + block_scale: float8_e4m3fn, shape (M, K // group_size) + global_scale: float32 scalar tensor + """ + weight = weight.contiguous() + m, n = weight.shape + if n % group_size != 0: + raise ValueError(f"NVFP4 requires K divisible by {group_size}, got {n}.") + + weight_f = weight.to(torch.float32) + if global_amax is None: + global_amax = torch.max(torch.abs(weight_f)) + else: + global_amax = global_amax.to(device=weight.device, dtype=torch.float32) + if global_amax.item() == 0.0: + qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) + block_scale = torch.zeros( + (m, n // group_size), + dtype=torch.float8_e4m3fn, + device=weight.device, + ) + global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + return qweight, block_scale, global_scale + + fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) + fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) + + global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) + # global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32), + ) + if global_encode_scale.item() == 0.0: + global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + weight_blocks = weight_f.view(m, n // group_size, group_size) + vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True) + decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale + decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) + + encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale) + scaled = weight_blocks * encode_scale + clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n) + + qweight = cast_to_fp4x2(clipped) + block_scale = decode_scale.squeeze(-1) + return qweight, block_scale, global_decode_scale + + +def quantize_nvfp4( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, + group_size: int = NVFP4_GROUP_SIZE, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if weight.dim() == 2: + return _quantize_nvfp4_1d(weight, global_amax=global_amax, group_size=group_size) + if weight.dim() == 3: + if global_amax is not None: + raise ValueError("global_amax override is only supported for 2D weights.") + qweights = [] + block_scales = [] + global_scales = [] + for idx in range(weight.shape[0]): + qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx], group_size=group_size) + qweights.append(qweight) + block_scales.append(block_scale) + global_scales.append(global_scale) + return ( + torch.stack(qweights, dim=0), + torch.stack(block_scales, dim=0), + torch.stack(global_scales, dim=0), + ) + raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.") diff --git a/tools/convert_hf_to_nvfp4.py b/tools/convert_hf_to_nvfp4.py new file mode 100644 index 0000000000..2de2183a41 --- /dev/null +++ b/tools/convert_hf_to_nvfp4.py @@ -0,0 +1,526 @@ +""" +python tools/convert_hf_to_nvfp4.py [-h] [--model-dir MODEL_DIR] [--save-dir SAVE_DIR] + [--device DEVICE] [--keep-last-n KEEP_LAST_N] [--keep-first-n KEEP_FIRST_N] + +Convert a BF16/FP16/FP32 HF safetensors checkpoint to NVFP4 (E2M1) for MoE +expert GEMMs only. Dense linear layers are left unmodified. + +This follows the NVFP4 reference quantization in Transformer Engine and uses +1D block scaling (NVTE_NVFP4_1D_SCALING, group size = 16). +""" + +import argparse +import gc +import json +import os +import shutil + +import safetensors +import safetensors.torch +import torch +from tqdm import tqdm + +FP4_E2M1_MAX = 6.0 +FP8_E4M3_MAX = 448.0 +NVFP4_GROUP_SIZE = 16 +DEFAULT_KV_CACHE_SCHEME = {"dynamic": False, "num_bits": 8, "type": "float"} +DEFAULT_KV_CACHE_QUANT_ALGO = "FP8" + +EXPERT_WEIGHT_SUFFIXES = ( + ".w1.weight", + ".w2.weight", + ".w3.weight", + ".gate_proj.weight", + ".up_proj.weight", + ".down_proj.weight", + ".gate_up_proj.weight", +) + +EXPERT_NAME_MARKERS = ( + ".experts.", + ".shared_experts.", + "block_sparse_moe.experts.", + ".moe.experts.", +) + +FUSED_QKV_SUFFIXES = (".q_proj", ".k_proj", ".v_proj") +GATED_PAIR_SUFFIXES = { + ".gate_proj.weight": "gate", + ".up_proj.weight": "up", + ".w1.weight": "gate", + ".w3.weight": "up", +} + + +def _is_moe_expert_weight_name(name: str) -> bool: + if not name.endswith(".weight"): + return False + if not any(marker in name for marker in EXPERT_NAME_MARKERS): + return False + return any(name.endswith(suffix) for suffix in EXPERT_WEIGHT_SUFFIXES) + + +def _extract_layer_id(name: str) -> int | None: + parts = name.split(".") + for idx, part in enumerate(parts): + if part == "layers" and idx + 1 < len(parts): + layer_id = parts[idx + 1] + if layer_id.isdigit(): + return int(layer_id) + return None + + +def _get_num_hidden_layers(model_dir: str) -> int: + config_path = os.path.join(model_dir, "config.json") + if not os.path.exists(config_path): + raise ValueError("config.json is required to use --keep-first-n or --keep-last-n.") + cfg = json.load(open(config_path)) + num_layers = cfg.get("num_hidden_layers") + if num_layers is None and isinstance(cfg.get("text_config"), dict): + num_layers = cfg["text_config"].get("num_hidden_layers") + if num_layers is None: + raise ValueError("num_hidden_layers not found in config.json.") + return int(num_layers) + + +def _get_last_n_layer_ids(num_layers: int, keep_last_n: int) -> set[int]: + if keep_last_n <= 0: + return set() + start = max(0, num_layers - keep_last_n) + return set(range(start, num_layers)) + + +def _get_first_n_layer_ids(num_layers: int, keep_first_n: int) -> set[int]: + if keep_first_n <= 0: + return set() + end = min(num_layers, keep_first_n) + return set(range(0, end)) + + +def _build_keep_last_n_ignore_list(num_layers: int, keep_last_n: int) -> list[str]: + if keep_last_n <= 0: + return [] + start = max(0, num_layers - keep_last_n) + ignore_list = [] + for layer_id in range(start, num_layers): + prefix = f"model.layers.{layer_id}" + ignore_list.extend( + [ + f"{prefix}.self_attn.qkv_proj", + f"{prefix}.self_attn.o_proj", + f"{prefix}.mlp", + f"{prefix}.mlp.experts", + ] + ) + return ignore_list + + +def _build_keep_first_n_ignore_list(num_layers: int, keep_first_n: int) -> list[str]: + if keep_first_n <= 0: + return [] + end = min(num_layers, keep_first_n) + ignore_list = [] + for layer_id in range(0, end): + prefix = f"model.layers.{layer_id}" + ignore_list.extend( + [ + f"{prefix}.self_attn.qkv_proj", + f"{prefix}.self_attn.o_proj", + f"{prefix}.mlp", + f"{prefix}.mlp.experts", + ] + ) + return ignore_list + + +def should_quantize( + name: str, + weight: torch.Tensor, + skip_layers: set[int] | None = None, +) -> bool: + if skip_layers: + layer_id = _extract_layer_id(name) + if layer_id is not None and layer_id in skip_layers: + return False + if not _is_moe_expert_weight_name(name): + return False + if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + return False + if weight.dim() < 2: + return False + if weight.shape[-1] % NVFP4_GROUP_SIZE != 0: + raise ValueError( + f"Last dim {weight.shape[-1]} must be divisible by {NVFP4_GROUP_SIZE} " f"for NVFP4 quantization ({name})." + ) + return True + + +def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor: + """Quantize a tensor to FP4 E2M1 and pack two values per byte.""" + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def _quantize_nvfp4_1d( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + NVFP4 1D quantization (tile shape = 1x16), adapted from + TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference. + + Returns: + qweight: uint8 packed fp4, shape (M, K // 2) + block_scale: float8_e4m3fn, shape (M, K // 16) + global_scale: float32 scalar tensor + """ + weight = weight.contiguous() + m, n = weight.shape + if n % NVFP4_GROUP_SIZE != 0: + raise ValueError(f"NVFP4 requires K divisible by {NVFP4_GROUP_SIZE}, got {n}.") + + weight_f = weight.to(torch.float32) + if global_amax is None: + global_amax = torch.max(torch.abs(weight_f)) + else: + global_amax = global_amax.to(device=weight.device, dtype=torch.float32) + if global_amax.item() == 0.0: + qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device) + block_scale = torch.zeros( + (m, n // NVFP4_GROUP_SIZE), + dtype=torch.float8_e4m3fn, + device=weight.device, + ) + global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + return qweight, block_scale, global_scale + + fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32) + fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32) + + global_encode_scale = torch.div(fp8_max * fp4_max, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32), + ) + if global_encode_scale.item() == 0.0: + global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + weight_blocks = weight_f.view(m, n // NVFP4_GROUP_SIZE, NVFP4_GROUP_SIZE) + vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True) + decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale + decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) + + encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale) + scaled = weight_blocks * encode_scale + clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n) + + qweight = cast_to_fp4x2(clipped) + block_scale = decode_scale.squeeze(-1) + return qweight, block_scale, global_decode_scale + + +def quantize_nvfp4( + weight: torch.Tensor, + global_amax: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if weight.dim() == 2: + return _quantize_nvfp4_1d(weight, global_amax=global_amax) + if weight.dim() == 3: + if global_amax is not None: + raise ValueError("global_amax override is only supported for 2D weights.") + qweights = [] + block_scales = [] + global_scales = [] + for idx in range(weight.shape[0]): + qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx]) + qweights.append(qweight) + block_scales.append(block_scale) + global_scales.append(global_scale) + return ( + torch.stack(qweights, dim=0), + torch.stack(block_scales, dim=0), + torch.stack(global_scales, dim=0), + ) + raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.") + + +class ConversionResult: + def __init__(self) -> None: + self.weight_map: dict[str, str] = {} + self.total_size: int = 0 + self.modules_to_not_convert: list[str] = [] + + def add_result(self, filename: str, q_weights: dict[str, torch.Tensor], module_names: list[str]) -> None: + for key, tensor in q_weights.items(): + self.weight_map[key] = filename + self.total_size += tensor.numel() * tensor.element_size() + self.modules_to_not_convert.extend(module_names) + + +def _update_quantization_config(cfg: dict, ignore_list: list[str]) -> None: + quant_cfg = cfg.get("quantization_config") + if not isinstance(quant_cfg, dict): + quant_cfg = {} + + quant_cfg["quant_algo"] = "NVFP4" + quant_cfg["quant_method"] = "modelopt" + quant_cfg["group_size"] = NVFP4_GROUP_SIZE + quant_cfg["ignore"] = ignore_list + quant_cfg.setdefault("kv_cache_scheme", DEFAULT_KV_CACHE_SCHEME) + + config_groups = quant_cfg.get("config_groups") + if isinstance(config_groups, dict): + for group in config_groups.values(): + if not isinstance(group, dict): + continue + group.setdefault("targets", ["Linear"]) + for key in ("input_activations", "weights"): + section = group.get(key) + if not isinstance(section, dict): + continue + section.setdefault("dynamic", False) + section.setdefault("num_bits", 4) + section.setdefault("type", "float") + section["group_size"] = NVFP4_GROUP_SIZE + + cfg["quantization_config"] = quant_cfg + + +def _write_hf_quant_config(output_path: str, ignore_list: list[str], input_path: str) -> None: + hf_quant_path = os.path.join(input_path, "hf_quant_config.json") + if os.path.exists(hf_quant_path): + with open(hf_quant_path) as f: + hf_quant_cfg = json.load(f) + else: + hf_quant_cfg = {"producer": {"name": "modelopt"}} + + quant_section = hf_quant_cfg.get("quantization") + if not isinstance(quant_section, dict): + quant_section = {} + + quant_section["quant_algo"] = "NVFP4" + quant_section["kv_cache_quant_algo"] = DEFAULT_KV_CACHE_QUANT_ALGO + quant_section["group_size"] = NVFP4_GROUP_SIZE + quant_section["exclude_modules"] = ignore_list + hf_quant_cfg["quantization"] = quant_section + + with open(os.path.join(output_path, "hf_quant_config.json"), "w") as f: + json.dump(hf_quant_cfg, f, indent=2) + + +def _augment_ignore_list(ignore_list: list[str]) -> list[str]: + ignore_set = set(ignore_list) + extra = set() + for name in ignore_list: + if name.endswith(FUSED_QKV_SUFFIXES): + for suffix in FUSED_QKV_SUFFIXES: + if name.endswith(suffix): + extra.add(name[: -len(suffix)] + ".qkv_proj") + break + ignore_set.update(extra) + return sorted(ignore_set) + + +def _split_gated_pair_name(name: str) -> tuple[str | None, str | None]: + for suffix, role in GATED_PAIR_SUFFIXES.items(): + if name.endswith(suffix): + return name[: -len(suffix)], role + return None, None + + +def _collect_shared_global_amax( + *, + input_path: str, + safetensors_files: list[str], + device: str, + skip_layers: set[int], +) -> dict[str, torch.Tensor]: + """Collect shared gate/up amax across all shards to keep w1/w3 scales equal.""" + gate_amax: dict[str, torch.Tensor] = {} + up_amax: dict[str, torch.Tensor] = {} + for filename in safetensors_files: + with safetensors.safe_open(os.path.join(input_path, filename), framework="pt", device=device) as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if not should_quantize(key, tensor, skip_layers): + continue + base, role = _split_gated_pair_name(key) + if base is None or role is None: + continue + amax = tensor.abs().max().to(torch.float32) + if role == "gate": + prev = gate_amax.get(base) + gate_amax[base] = amax if prev is None else torch.max(prev, amax) + elif role == "up": + prev = up_amax.get(base) + up_amax[base] = amax if prev is None else torch.max(prev, amax) + else: + continue + + shared_global_amax: dict[str, torch.Tensor] = {} + for base in gate_amax.keys() & up_amax.keys(): + shared_global_amax[base] = torch.max(gate_amax[base], up_amax[base]) + return shared_global_amax + + +def process_file( + input_path: str, + output_path: str, + filename: str, + result_collector: ConversionResult, + device: str, + skip_layers: set[int], + shared_global_amax: dict[str, torch.Tensor], +) -> None: + if not filename.endswith(".safetensors"): + return + + modules_to_not_convert: list[str] = [] + q_weights: dict[str, torch.Tensor] = {} + + with safetensors.safe_open(os.path.join(input_path, filename), framework="pt", device=device) as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if should_quantize(key, tensor, skip_layers): + base, _role = _split_gated_pair_name(key) + global_amax = shared_global_amax.get(base) if base else None + qweight, block_scale, weight_scale_2 = quantize_nvfp4(tensor, global_amax=global_amax) + q_weights[key] = qweight + q_weights[key.replace(".weight", ".weight_scale")] = block_scale + q_weights[key.replace(".weight", ".weight_scale_2")] = weight_scale_2 + q_weights[key.replace(".weight", ".input_scale")] = torch.ones_like( + weight_scale_2, dtype=torch.float32 + ) + else: + if key.endswith(".weight"): + modules_to_not_convert.append(key.replace(".weight", "")) + q_weights[key] = tensor + + safetensors.torch.save_file(q_weights, os.path.join(output_path, filename), metadata={"format": "pt"}) + result_collector.add_result(filename, q_weights, modules_to_not_convert) + + +def convert_nvfp4(model_dir: str, save_dir: str, device: str, keep_last_n: int, keep_first_n: int) -> None: + input_path = os.path.abspath(model_dir) + output_path = os.path.abspath(save_dir) + os.makedirs(output_path, exist_ok=True) + + for filename in os.listdir(input_path): + if not filename.endswith(".safetensors") and not os.path.isdir(os.path.join(input_path, filename)): + shutil.copyfile(os.path.join(input_path, filename), os.path.join(output_path, filename)) + + safetensors_files = [f for f in os.listdir(input_path) if f.endswith(".safetensors")] + + num_layers = _get_num_hidden_layers(input_path) if (keep_last_n > 0 or keep_first_n > 0) else 0 + skip_layers = _get_last_n_layer_ids(num_layers, keep_last_n) | _get_first_n_layer_ids(num_layers, keep_first_n) + keep_last_ignore = _build_keep_last_n_ignore_list(num_layers, keep_last_n) + keep_first_ignore = _build_keep_first_n_ignore_list(num_layers, keep_first_n) + + shared_global_amax = _collect_shared_global_amax( + input_path=input_path, + safetensors_files=safetensors_files, + device=device, + skip_layers=skip_layers, + ) + result_collector = ConversionResult() + for filename in tqdm(safetensors_files, desc="Processing files"): + process_file( + input_path, + output_path, + filename, + result_collector, + device, + skip_layers, + shared_global_amax, + ) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + ignore_list = _augment_ignore_list(result_collector.modules_to_not_convert + keep_last_ignore + keep_first_ignore) + + config_path = os.path.join(input_path, "config.json") + if os.path.exists(config_path): + cfg = json.load(open(config_path)) + _update_quantization_config(cfg, ignore_list) + json.dump(cfg, open(os.path.join(output_path, "config.json"), "w"), indent=2) + + _write_hf_quant_config(output_path, ignore_list, input_path) + + index_dict = { + "weight_map": result_collector.weight_map, + "metadata": {"total_size": result_collector.total_size}, + } + json.dump(index_dict, open(os.path.join(output_path, "model.safetensors.index.json"), "w"), indent=2) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, required=True, help="Path to HF safetensors model.") + parser.add_argument("--save-dir", type=str, required=True, help="Path to save converted model.") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Torch device to run quantization on (default: cuda).", + ) + parser.add_argument( + "--keep-last-n", + type=int, + default=0, + help="Keep the last N transformer layers unquantized (BF16/FP16).", + ) + parser.add_argument( + "--keep-first-n", + type=int, + default=0, + help="Keep the first N transformer layers unquantized (BF16/FP16).", + ) + args = parser.parse_args() + + if isinstance(args.device, str) and args.device.isdigit(): + device = torch.device(f"cuda:{args.device}") + else: + device = torch.device(args.device) + + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run NVFP4 quantization.") + if device.index is None: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + if not os.path.exists(args.save_dir): + print(f"Creating directory {args.save_dir}") + os.makedirs(args.save_dir) + elif not os.path.isdir(args.save_dir): + raise ValueError("The save_dir should be a directory.") + + convert_nvfp4(args.model_dir, args.save_dir, str(device), args.keep_last_n, args.keep_first_n) + + +if __name__ == "__main__": + main() From 3572922ec9057c9e5b022a091c67fa227e85b698 Mon Sep 17 00:00:00 2001 From: Zhichen Zeng Date: Mon, 6 Apr 2026 13:55:26 -0700 Subject: [PATCH 12/44] Bump flash-linear-attention version to 0.4.2 (#892) --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 69e9ba1354..e6a2dd03b8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -63,7 +63,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \ RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -RUN pip install flash-linear-attention==0.4.1 +RUN pip install flash-linear-attention==0.4.2 RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ From 8146a786563c44f73a99505ebe7d3b371e65e6bd Mon Sep 17 00:00:00 2001 From: JensenFire Date: Tue, 7 Apr 2026 08:33:27 +0800 Subject: [PATCH 13/44] [BUGFIX] Invoke "post_process_quantization" by default after weight updating (#890) Co-authored-by: Yueming Yuan --- .../update_weight_from_distributed/mixin.py | 21 +++++++++---------- .../update_weight_from_distributed/p2p.py | 7 +------ .../update_weight_from_tensor.py | 18 +++++++--------- .../test_lora_weight_sync_validation.py | 5 ++++- 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py index f6d1c28f30..79217bbb48 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py @@ -150,19 +150,18 @@ def _pause_and_prepare_engines(self) -> None: post_process_quantization=False, ) - def _finalize_and_resume_engines(self) -> None: + def _finalize_and_resume_engines(self, post_load_weights: bool = False) -> None: """Run post-process if needed and resume rollout engines.""" if dist.get_rank() == 0: - # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). - if self.quantization_config and self.quantization_config["quant_method"] in [ - "compressed-tensors", - "mxfp8", - ]: - post_process_weights( - rollout_engines=self.rollout_engines, - restore_weights_before_load=False, - post_process_quantization=True, - ) + # post_process_quantization is related to the process_weights_after_loading + # in the sglang rollout side, which should always be invoked after weight + # updating. + post_process_weights( + rollout_engines=self.rollout_engines, + restore_weights_before_load=False, + post_process_quantization=True, + post_load_weights=post_load_weights, + ) ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) @torch.no_grad() diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py index 9702b31431..fe287f72c4 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py @@ -18,7 +18,6 @@ from miles.utils.distributed_utils import get_gloo_group -from ..common import post_process_weights from .mixin import DistBucketedWeightUpdateMixin from .p2p_transfer_utils import ( P2PTransferManager, @@ -125,11 +124,7 @@ def _finalize_and_resume_engines(self): for engine in self.rollout_engines ] ) - post_process_weights( - rollout_engines=self.rollout_engines, - post_load_weights=True, - ) - super()._finalize_and_resume_engines() + super()._finalize_and_resume_engines(post_load_weights=True) def _update_weight_implementation( self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 10a75fd4f3..48c85e958b 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -205,17 +205,15 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) - # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). if rank == 0: - if self.quantization_config and self.quantization_config["quant_method"] in [ - "compressed-tensors", - "mxfp8", - ]: - post_process_weights( - rollout_engines=self.rollout_engines, - restore_weights_before_load=False, - post_process_quantization=True, - ) + # `post_process_quantization` is related to the `process_weights_after_loading` + # in the sglang rollout side, which should always be invoked after weight + # updating. + post_process_weights( + rollout_engines=self.rollout_engines, + restore_weights_before_load=False, + post_process_quantization=True, + ) ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) diff --git a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py index a72ca582e5..569039c84b 100644 --- a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py +++ b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py @@ -210,11 +210,14 @@ def test_raises_on_zero_lora_chunks(self, mock_iter_base, mock_dist, mock_ray, m with pytest.raises(RuntimeError, match="zero chunks"): updater.update_weights() + @patch("miles.backends.megatron_utils.update_weight.common.ray") @patch(f"{_UW_MODULE}.get_gloo_group", return_value=MagicMock()) @patch(f"{_UW_MODULE}.ray") @patch(f"{_UW_MODULE}.dist") @patch(f"{_UW_MODULE}.HfWeightIteratorBase") - def test_no_raise_for_base_model_zero_chunks(self, mock_iter_base, mock_dist, mock_ray, mock_gloo): + def test_no_raise_for_base_model_zero_chunks( + self, mock_iter_base, mock_dist, mock_ray, mock_gloo, mock_common_ray + ): """Base model weight sync with zero chunks is valid (e.g. empty model state).""" from miles.backends.megatron_utils.update_weight.update_weight_from_tensor import UpdateWeightFromTensor From eaa36a246456b989f92c2779214cb28793f039b1 Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Mon, 6 Apr 2026 20:02:27 -0700 Subject: [PATCH 14/44] Add heartbeat and id to session server (#866) --- .../swe-agent-v2/swe_agent_function.py | 15 +++++++-- miles/ray/rollout.py | 3 ++ .../rollout/generate_hub/agentic_tool_call.py | 8 +++++ .../generate_utils/openai_endpoint_utils.py | 18 +++++++++-- miles/rollout/session/sessions.py | 9 ++++++ miles/utils/test_utils/mock_tools.py | 5 ++- tests/fast/fixtures/generation_fixtures.py | 2 ++ .../rollout/generate_hub/test_multi_turn.py | 24 ++++++++++++++ .../test_openai_endpoint_utils.py | 32 ++++++++++++++++++- tests/fast/router/test_sessions.py | 16 ++++++++++ 10 files changed, 126 insertions(+), 6 deletions(-) diff --git a/examples/experimental/swe-agent-v2/swe_agent_function.py b/examples/experimental/swe-agent-v2/swe_agent_function.py index fb30d8a9f5..d1460fbe06 100644 --- a/examples/experimental/swe-agent-v2/swe_agent_function.py +++ b/examples/experimental/swe-agent-v2/swe_agent_function.py @@ -14,7 +14,7 @@ import logging import os from typing import Any -from urllib.parse import urlparse, urlunparse +from urllib.parse import urlparse, urlsplit, urlunparse from miles.utils.http_utils import post @@ -49,7 +49,7 @@ async def run( netloc = f"{external_host}:{port}" if port else external_host session_url = urlunparse(parsed._replace(netloc=netloc)) - request = { + request: dict[str, Any] = { **metadata, "base_url": session_url, "model": f"openai/{model_name}", @@ -60,6 +60,17 @@ async def run( if max_seq_len is not None: request["max_seq_len"] = int(max_seq_len) + session_server_id = metadata.get("session_server_id") + if session_server_id is not None: + if external_host: + port = urlsplit(f"http://{session_server_id}").port + session_server_id = f"{external_host}:{port}" + request["session_server_id"] = session_server_id + + session_server_instance_id = metadata.get("session_server_instance_id") + if session_server_instance_id is not None: + request["session_server_instance_id"] = session_server_instance_id + try: response = await asyncio.wait_for( post(f"{agent_server_url}/run", request), diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..d23e59b2d1 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -5,6 +5,7 @@ import os import random import time +import uuid from pathlib import Path from typing import Any @@ -1114,6 +1115,8 @@ def _start_session_server(args): args.session_server_ip = args.sglang_router_ip if getattr(args, "session_server_port", None) is None: args.session_server_port = find_available_port(random.randint(5000, 6000)) + if getattr(args, "session_server_instance_id", None) is None: + args.session_server_instance_id = uuid.uuid4().hex ip, port = args.session_server_ip, args.session_server_port if not is_port_available(port): diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py index cfe8a232d2..feba568164 100644 --- a/miles/rollout/generate_hub/agentic_tool_call.py +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -61,8 +61,16 @@ async def generate(input: GenerateFnInput) -> GenerateFnOutput: metadata = input.sample.metadata if max_seq_len is not None: metadata = {**metadata, "max_seq_len": max_seq_len} + if tracer.session_server_instance_id: + metadata = {**metadata, "session_server_instance_id": tracer.session_server_instance_id} log_prefix = f"[session={tracer.session_id}]" + + session_ip = getattr(input.args, "session_server_ip", None) + session_port = getattr(input.args, "session_server_port", None) + if session_ip and session_port: + metadata = {**metadata, "session_server_id": f"{session_ip}:{session_port}"} + agent_metadata = None t_start = time.monotonic() try: diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 5b9a445adf..7602e51468 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -18,10 +18,11 @@ class OpenAIEndpointTracer: - def __init__(self, router_url: str, session_id: str): + def __init__(self, router_url: str, session_id: str, session_server_instance_id: str | None = None): self.router_url = router_url self.session_id = session_id self.base_url = f"{router_url}/sessions/{session_id}" + self.session_server_instance_id = session_server_instance_id @staticmethod async def create(args: Namespace): @@ -33,9 +34,22 @@ async def create(args: Namespace): "Pass --use-session-server to start the session server." ) session_url = f"http://{session_ip}:{session_port}" + session_server_instance_id = None + try: + health = await post(f"{session_url}/health", {}, action="get") + if isinstance(health, dict): + session_server_instance_id = health.get("session_server_instance_id") + if session_server_instance_id is not None: + args.session_server_instance_id = session_server_instance_id + except Exception as e: + logger.warning("Failed to get session server health from %s: %s", session_url, e) response = await post(f"{session_url}/sessions", {}, action="post") session_id = response["session_id"] - return OpenAIEndpointTracer(router_url=session_url, session_id=session_id) + return OpenAIEndpointTracer( + router_url=session_url, + session_id=session_id, + session_server_instance_id=session_server_instance_id, + ) async def collect_records(self) -> tuple[list[SessionRecord], dict]: try: diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index bf53f446f4..172e906074 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -26,6 +26,8 @@ def setup_session_routes(app, backend, args): logger.info("[session] Skipping session routes (hf_checkpoint not set).") return + session_server_instance_id = getattr(args, "session_server_instance_id", None) + tokenizer = load_tokenizer( hf_checkpoint, chat_template_path=getattr(args, "chat_template_path", None), trust_remote_code=True ) @@ -38,6 +40,13 @@ def setup_session_routes(app, backend, args): registry = SessionRegistry(args, tokenizer, tito_tokenizer=tito_tokenizer) + @app.get("/health") + async def health(): + body = {"status": "ok"} + if session_server_instance_id is not None: + body["session_server_instance_id"] = session_server_instance_id + return body + # --- DEBUG: track in-flight chat_completions --- _inflight_chat = {"count": 0} diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py index 4da7b68a0c..1038805881 100644 --- a/miles/utils/test_utils/mock_tools.py +++ b/miles/utils/test_utils/mock_tools.py @@ -1,4 +1,5 @@ import json +from collections.abc import Callable from copy import deepcopy from typing import Any @@ -59,7 +60,7 @@ async def execute_tool_call(name: str, params: dict) -> str: return TOOL_EXECUTORS[name](params) -AGENTIC_RETURN_METADATA: dict[str, Any] | None = None +AGENTIC_RETURN_METADATA: dict[str, Any] | Callable | None = None async def run_agentic_tool_call( @@ -112,6 +113,8 @@ async def run_agentic_tool_call( } ) + if callable(AGENTIC_RETURN_METADATA): + return AGENTIC_RETURN_METADATA(metadata=kwargs.get("metadata")) return AGENTIC_RETURN_METADATA diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py index 91e0467e97..a56ec33df1 100644 --- a/tests/fast/fixtures/generation_fixtures.py +++ b/tests/fast/fixtures/generation_fixtures.py @@ -2,6 +2,7 @@ Fixtures to test custom-generate-function """ +import uuid from argparse import Namespace from contextlib import contextmanager from dataclasses import dataclass @@ -235,6 +236,7 @@ def with_session_server( chat_template_path=chat_template_path, tito_model="default", use_rollout_routing_replay=use_rollout_routing_replay, + session_server_instance_id=uuid.uuid4().hex, ) session_server = SessionServer(args, backend_url=backend_url) diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py index 95fe9c3f8d..e84767d083 100644 --- a/tests/fast/rollout/generate_hub/test_multi_turn.py +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -1,3 +1,4 @@ +import re from copy import deepcopy from dataclasses import dataclass, replace from itertools import groupby @@ -667,6 +668,29 @@ def test_agent_returns_none_metadata_unchanged(self, variant, generation_env): assert s.metadata.get("instance_id") == "test-123" assert "reward" not in s.metadata + def test_session_server_identity_forwarded_to_agent_metadata(self, variant, generation_env): + from miles.utils.test_utils import mock_tools + + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + _SESSION_KEYS = ("session_server_id", "session_server_instance_id") + + def _echo_session(metadata=None): + metadata = metadata or {} + return {k: metadata[k] for k in _SESSION_KEYS if k in metadata} + + mock_tools.AGENTIC_RETURN_METADATA = _echo_session + try: + result = _run_generate(variant, generation_env, make_sample(prompt=TwoTurnStub.PROMPT)) + finally: + mock_tools.AGENTIC_RETURN_METADATA = None + + samples = listify(result.sample) + expected_session_server_id = f"127.0.0.1:{generation_env.args.session_server_port}" + for s in samples: + assert s.metadata["session_server_id"] == expected_session_server_id + assert re.fullmatch(r"[0-9a-f]{32}", s.metadata["session_server_instance_id"]) + class TestAgentNoRecords: """When agent makes no model calls, generate should return an ABORTED sample.""" diff --git a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py index 2128791b3d..e8cb2eb340 100644 --- a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py +++ b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py @@ -9,7 +9,10 @@ import pytest -from miles.rollout.generate_utils.openai_endpoint_utils import compute_samples_from_openai_records +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) from miles.rollout.generate_utils.sample_utils import merge_samples from miles.rollout.session.session_types import SessionRecord from miles.utils.types import Sample @@ -82,6 +85,33 @@ def _make_record( ) +@pytest.mark.asyncio +async def test_create_fetches_session_server_instance_id(monkeypatch): + calls: list[tuple[str, str]] = [] + + async def fake_post(url: str, payload: dict, action: str = "post"): + calls.append((action, url)) + if action == "get": + assert url == "http://127.0.0.1:12345/health" + return {"status": "ok", "session_server_instance_id": "server-instance-123"} + assert action == "post" + assert url == "http://127.0.0.1:12345/sessions" + return {"session_id": "session-123"} + + monkeypatch.setattr("miles.rollout.generate_utils.openai_endpoint_utils.post", fake_post) + + args = SimpleNamespace(session_server_ip="127.0.0.1", session_server_port=12345) + tracer = await OpenAIEndpointTracer.create(args) + + assert tracer.base_url == "http://127.0.0.1:12345/sessions/session-123" + assert tracer.session_server_instance_id == "server-instance-123" + assert args.session_server_instance_id == "server-instance-123" + assert calls == [ + ("get", "http://127.0.0.1:12345/health"), + ("post", "http://127.0.0.1:12345/sessions"), + ] + + # ── test: compute_samples_from_openai_records ──────────────────────── diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py index 23fc683647..8dd58189e1 100644 --- a/tests/fast/router/test_sessions.py +++ b/tests/fast/router/test_sessions.py @@ -1,5 +1,7 @@ """Integration tests for session HTTP routes (create / get / delete / proxy).""" +import re +import uuid from types import SimpleNamespace from unittest.mock import patch @@ -41,6 +43,7 @@ def patched_chat_response(self, payload: dict) -> dict: hf_checkpoint="Qwen/Qwen3-0.6B", chat_template_path=None, trajectory_manager="linear_trajectory", + session_server_instance_id=uuid.uuid4().hex, ) server_obj = SessionServer(args, backend_url=backend.url) @@ -57,6 +60,19 @@ def patched_chat_response(self, payload: dict) -> dict: class TestSessionRoutes: + def test_health_reports_stable_instance_id(self, router_env): + first = requests.get(f"{router_env.url}/health", timeout=5.0) + second = requests.get(f"{router_env.url}/health", timeout=5.0) + + assert first.status_code == 200 + assert second.status_code == 200 + first_body = first.json() + second_body = second.json() + assert first_body["status"] == "ok" + assert second_body["status"] == "ok" + assert re.fullmatch(r"[0-9a-f]{32}", first_body["session_server_instance_id"]) + assert second_body["session_server_instance_id"] == first_body["session_server_instance_id"] + def test_create_session(self, router_env): response = requests.post(f"{router_env.url}/sessions", timeout=5.0) assert response.status_code == 200 From 70dc402ec1cb0317ab9c1c7cf5f2c4a608c289db Mon Sep 17 00:00:00 2001 From: Douglas Yang Date: Mon, 6 Apr 2026 20:19:26 -0700 Subject: [PATCH 15/44] fix: adding thin glm5 image to docker build + latest tag sync (#871) --- .github/workflows/docker-build.yml | 35 ++++++++++++++++++++++++++++++ docker/glm5/Dockerfile.dev-glm | 2 ++ 2 files changed, 37 insertions(+) create mode 100644 docker/glm5/Dockerfile.dev-glm diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 6acec1451f..ab0caa21c2 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -146,6 +146,11 @@ jobs: ${{ inputs.custom_tag && format('--custom-tag {0}', inputs.custom_tag) || '' }} \ --push + - name: Point latest to current dev + if: github.event_name == 'schedule' || inputs.simulate_schedule == true + run: | + docker buildx imagetools create -t radixark/miles:latest radixark/miles:dev + - name: Prune old dev tags if: github.event_name == 'schedule' run: | @@ -193,3 +198,33 @@ jobs: echo " Failed to delete ${TAG} (HTTP ${HTTP_CODE})" fi done + + build-and-push-dev-glm: + needs: [build-and-push] + # Only rebuild dev-glm when the dev image was built (schedule, push to main, or dispatch with image_tag=dev) + if: needs.build-and-push.result == 'success' && (github.event_name == 'schedule' || inputs.simulate_schedule == true) + runs-on: self-hosted + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver-opts: | + image=moby/buildkit:latest + network=host + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push dev-glm + run: | + docker buildx build \ + -f docker/glm5/Dockerfile.dev-glm \ + -t radixark/miles:dev-glm \ + --push \ + . diff --git a/docker/glm5/Dockerfile.dev-glm b/docker/glm5/Dockerfile.dev-glm new file mode 100644 index 0000000000..4ddbfbfc4b --- /dev/null +++ b/docker/glm5/Dockerfile.dev-glm @@ -0,0 +1,2 @@ +FROM radixark/miles:dev +RUN pip install git+https://github.com/huggingface/transformers.git@76732b4e7120808ff989edbd16401f61fa6a0afa From c198efa6035e7451e6f6735ba17d334bcdb6fe28 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Mon, 6 Apr 2026 22:04:04 -0700 Subject: [PATCH 16/44] Add consistent hashing routing policy for rollout (#891) Co-authored-by: Yueming Yuan --- miles/backends/sglang_utils/arguments.py | 13 +++++++++++++ miles/ray/rollout.py | 3 +++ miles/rollout/generate_utils/sample_utils.py | 1 + miles/rollout/sglang_rollout.py | 15 +++++++++++++-- miles/utils/http_utils.py | 16 ++++++++-------- miles/utils/types.py | 4 ++++ 6 files changed, 42 insertions(+), 10 deletions(-) diff --git a/miles/backends/sglang_utils/arguments.py b/miles/backends/sglang_utils/arguments.py index d8ac2deacc..f4b9978a40 100644 --- a/miles/backends/sglang_utils/arguments.py +++ b/miles/backends/sglang_utils/arguments.py @@ -19,6 +19,12 @@ def add_sglang_router_arguments(parser): default=None, help="Port of the SGLang router", ) + parser.add_argument( + "--sglang-router-policy", + type=str, + default=None, + help="Routing policy for the SGLang router (e.g., 'consistent_hashing', 'round_robin')", + ) parser.add_argument( "--sglang-router-request-timeout-secs", type=int, @@ -135,5 +141,12 @@ def validate_args(args): if args.sglang_dp_size > 1: assert args.sglang_enable_dp_attention + if args.sglang_router_policy: + from miles.utils.environ import enable_experimental_rollout_refactor + + assert ( + not enable_experimental_rollout_refactor() + ), "--sglang-router-policy is not supported with MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1" + if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index d23e59b2d1..d625ae4c5e 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -942,6 +942,9 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool router_args.log_level = "warn" router_args.request_timeout_secs = args.sglang_router_request_timeout_secs + if args.sglang_router_policy: + router_args.policy = args.sglang_router_policy + if has_pd_disaggregation: router_args.pd_disaggregation = True diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py index 8e8f42441e..55c1cece1f 100644 --- a/miles/rollout/generate_utils/sample_utils.py +++ b/miles/rollout/generate_utils/sample_utils.py @@ -66,6 +66,7 @@ def _fill_defaults(sample: Sample): metadata=_merge_equal_value("metadata"), generate_function_path=_merge_equal_value("generate_function_path"), train_metadata=_merge_equal_value("train_metadata"), + session_id=_merge_equal_value("session_id"), non_generation_time=_merge_equal_value("non_generation_time"), spec_info=_merge_spec_info(a.spec_info, b.spec_info), prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 6ffc67e70b..2d3fdadb2d 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -2,7 +2,7 @@ import copy import inspect import logging - +import uuid from argparse import Namespace from collections.abc import Callable from contextlib import contextmanager @@ -184,7 +184,12 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if not sample.tokens: # Initialize sample.tokens for the first turn sample.tokens = prompt_ids - output = await post(url, payload) + # Use session_id for consistent hashing routing if router uses consistent_hashing policy + headers = None + if args.sglang_router_policy == "consistent_hashing" and sample.session_id: + headers = {"X-SMG-Routing-Key": sample.session_id} + + output = await post(url, payload, headers=headers) if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree @@ -299,6 +304,12 @@ async def generate_and_rm_group( if state.aborted: return group + # Generate a unique session_id for each sample in the group (consistent hashing only) + if args.sglang_router_policy == "consistent_hashing": + for sample in group: + if sample.session_id is None: + sample.session_id = str(uuid.uuid4()) + tasks = [] for idx, sample in enumerate(group): current_sampling_params = sampling_params.copy() diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 0548ec6b1d..0aaf792659 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -185,15 +185,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, action="post"): +async def _post(client, url, payload, max_retries=60, action="post", headers=None): retry_count = 0 while retry_count < max_retries: try: if action in ("delete", "get"): assert not payload - response = await getattr(client, action)(url) + response = await getattr(client, action)(url, headers=headers) else: - response = await getattr(client, action)(url, json=payload or {}) + response = await getattr(client, action)(url, json=payload or {}, headers=headers) response.raise_for_status() try: output = response.json() @@ -267,8 +267,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60, action="post"): - return await _post(self._client, url, payload, max_retries, action=action) + async def do_post(self, url, payload, max_retries=60, action="post", headers=None): + return await _post(self._client, url, payload, max_retries, action=action, headers=headers) # Create actors per node created = [] @@ -293,18 +293,18 @@ async def do_post(self, url, payload, max_retries=60, action="post"): # TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) -async def post(url, payload, max_retries=60, action="post"): +async def post(url, payload, max_retries=60, action="post", headers=None): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: actor = _next_actor() if actor is not None: - return await actor.do_post.remote(url, payload, max_retries, action=action) + return await actor.do_post.remote(url, payload, max_retries, action=action, headers=headers) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, action=action) + return await _post(_http_client, url, payload, max_retries, action=action, headers=headers) # TODO unify w/ `post` to add retries and remote-execution diff --git a/miles/utils/types.py b/miles/utils/types.py index b36f08aecc..86c30a906c 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -47,6 +47,10 @@ class Status(Enum): # metadata used during training, e.g., what loss to use for this sample. train_metadata: dict | None = None + # Session ID for consistent hashing routing (used when router policy is consistent_hashing) + # TODO: Its definition needs to merge with the session server's session id in the new rollout function. + session_id: str | None = None + non_generation_time: float = 0.0 # time spent in non-generation steps @dataclass From afc5b55cefe9153517c0015f7111c5ff2564e64b Mon Sep 17 00:00:00 2001 From: Huapeng Zhou <73010314+PopSoda2002@users.noreply.github.com> Date: Tue, 7 Apr 2026 06:53:59 -0700 Subject: [PATCH 17/44] [example] add retool v2 example with multi-turn framework interfaces (#654) Co-authored-by: GuanxingLu Co-authored-by: Claude Opus 4.6 (1M context) --- examples/retool/generate_with_retool.py | 42 ++- examples/retool/retool_qwen3_4b_rl.sh | 11 +- examples/retool_v2/README.md | 31 ++ examples/retool_v2/run_retool_multi_turn.py | 208 +++++++++++ examples/retool_v2/tool_sandbox.py | 385 ++++++++++++++++++++ 5 files changed, 654 insertions(+), 23 deletions(-) create mode 100644 examples/retool_v2/README.md create mode 100644 examples/retool_v2/run_retool_multi_turn.py create mode 100644 examples/retool_v2/tool_sandbox.py diff --git a/examples/retool/generate_with_retool.py b/examples/retool/generate_with_retool.py index f5b8ad268c..6bd5d7de29 100644 --- a/examples/retool/generate_with_retool.py +++ b/examples/retool/generate_with_retool.py @@ -96,12 +96,11 @@ def format_conversation_with_tools( def postprocess_predictions(prediction: str): """Extract action and content from prediction string""" - # Check for Answer: \boxed{...} format (only format we need for math_dapo) - # Use a more robust regex that handles nested braces - answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" - answer_match = re.search(answer_pattern, prediction, re.DOTALL) - if answer_match: - content = answer_match.group(1).strip() + # Check for bare \boxed{...} (model may omit "Answer:" prefix) + boxed_pattern = r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" + boxed_match = re.search(boxed_pattern, prediction, re.DOTALL) + if boxed_match: + content = boxed_match.group(1).strip() return "answer", content # Then check for tags (new format from Jinja2 template) @@ -168,14 +167,17 @@ def postprocess_responses(resp: str) -> str: last_match = matches[-1] return resp[: last_match.end()] - # Handle Answer: \boxed{...} format (only format we need for math_dapo) - if "Answer:" in resp and "\\boxed{" in resp: - # Find the last occurrence of Answer: \boxed{...} with nested braces support - answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}" - matches = list(re.finditer(answer_pattern, resp, re.DOTALL)) - if matches: - last_match = matches[-1] - return resp[: last_match.end()] + # Handle Answer: \boxed{...} or bare \boxed{...} + if "\\boxed{" in resp: + # Try "Answer: \boxed{...}" first, then bare "\boxed{...}" + for pattern in [ + r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", + r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}", + ]: + matches = list(re.finditer(pattern, resp, re.DOTALL)) + if matches: + last_match = matches[-1] + return resp[: last_match.end()] return resp @@ -203,7 +205,7 @@ async def execute_predictions(prediction: str) -> str: next_obs = ( "\nMy previous action is invalid. " "If I want to execute code, I should put the code between " - " and . " + " and . " "If I want to give the final answer, I should use the format " "'Answer: \\boxed{answer}'. Let me try again.\n" ) @@ -221,7 +223,12 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: # Set up the initial prompt with system prompt and tools (outside the loop) tool_specs = tool_registry.get_tool_specs() - prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) + + if isinstance(sample.prompt, str): + # Already formatted (e.g., by --apply-chat-template), use as-is to avoid double templating + prompt = sample.prompt + else: + prompt = format_conversation_with_tools(prompt=sample.prompt, tools=tool_specs) prompt_tokens_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"] response = "" @@ -355,8 +362,7 @@ async def reward_func(args, sample, **kwargs): if not isinstance(sample, Sample): raise TypeError("Sample must be an instance of Sample class.") - # Build complete solution string - solution_str = sample.prompt + sample.response + solution_str = sample.response # Get ground truth answer - label is a string, not a dict ground_truth = sample.label if sample.label is not None else "" diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index 838ce0e2c4..99eeea3d3b 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -31,7 +31,7 @@ CKPT_ARGS=( --ref-load /root/font-info/qwen3-4b-sft_torch_dist # --load /root/Qwen3-4B_miles/ --save /root/font-info/qwen3-4b-sft/qwen3-4b-sft-multi-turn/ - --save-interval 20 + --save-interval 200 --rotary-base 5000000 ) @@ -43,12 +43,12 @@ ROLLOUT_ARGS=( --rollout-shuffle --reward-key score --num-rollout 3000 - --rollout-batch-size 32 + --rollout-batch-size 16 --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 1 - --global-batch-size 256 + --global-batch-size 128 --balance-data ) @@ -98,8 +98,8 @@ OPTIMIZER_ARGS=( WANDB_ARGS=( --use-wandb - --wandb-project miles-dapo - --wandb-group qwen3-4B-test-multi-turn + --wandb-project miles-dev-retool-v2 + --wandb-group retool-v1-qwen3-4b-sft-new --wandb-key ${WANDB_KEY} ) @@ -117,6 +117,7 @@ MISC_ARGS=( --attention-softmax-in-fp32 # need to comment this when using model with MLA --attention-backend flash + --log-passrate ) CUSTOM_ARGS=( diff --git a/examples/retool_v2/README.md b/examples/retool_v2/README.md new file mode 100644 index 0000000000..1c9752c2b9 --- /dev/null +++ b/examples/retool_v2/README.md @@ -0,0 +1,31 @@ +# Retool v2 + +This example is an upgraded version of [retool](../retool), using the updated interfaces provided by the miles framework to implement multi-turn RL training with tool calls in a cleaner way. + +## Key Differences from v1 + +**v1 (retool)** requires manually implementing the full multi-turn conversation loop in `generate_with_retool.py`, directly depending on low-level `GenerateState` and `sglang_rollout` interfaces — resulting in verbose code tightly coupled to the framework internals. + +**v2 (retool_v2)** uses the framework's standard plugin interfaces. Users only need to implement three functions and mount them via command-line arguments: + +| Argument | Description | +|----------|-------------| +| `--custom-generate-function-path` | Uses the built-in `miles.rollout.generate_hub.multi_turn.generate` — no need to implement the multi-turn loop yourself | +| `--generate-tool-specs-path` | Declare tool definitions (user-implemented) | +| `--generate-execute-tool-function-path` | Implement tool execution logic (user-implemented) | +| `--custom-rm-path` | Implement the reward function (user-implemented) | + +Users only need to focus on business logic (tool definitions, tool execution, reward calculation). Multi-turn scheduling, token concatenation, loss masking, etc. are all handled by the framework. + +## Files + +- `tool_sandbox.py`: Tool definitions (`tool_specs`), tool execution (`execute_tool`), reward function (`reward_func`), and sandboxed safe execution environment +- `run_retool_multi_turn.py`: Training launch script + +## Quick Start + +```bash +python examples/retool_v2/run_retool_multi_turn.py +``` + +For data and model preparation, refer to the [retool v1 README](../retool/README.md). diff --git a/examples/retool_v2/run_retool_multi_turn.py b/examples/retool_v2/run_retool_multi_turn.py new file mode 100644 index 0000000000..1031eba790 --- /dev/null +++ b/examples/retool_v2/run_retool_multi_turn.py @@ -0,0 +1,208 @@ +import os +from dataclasses import dataclass, field +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +WANDB_PROJECT = "miles-dev-retool-v2" +WANDB_GROUP = "sft-multi-turn-batch-32" + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = field(default_factory=U.create_run_id) + hardware: Literal["H100", "GB200", "GB300"] = "H100" + num_gpus_per_node: int | None = None + use_sft_model: bool = True + save_path: str = "/root/Qwen3-4B_miles/retool_v2_multi_turn" + prompt_data: str = "/root/dapo-math-17k/dapo-math-17k.jsonl" + generate_max_turns: int = 16 + rollout_num_gpus_per_engine: int = 2 + extra_args: str = "" + + # resolved in __post_init__, not set by user + hf_checkpoint: str = field(init=False) + ref_load: str = field(init=False) + + def __post_init__(self): + self.num_gpus_per_node = self.num_gpus_per_node or U.NUM_GPUS_OF_HARDWARE[self.hardware] + if self.use_sft_model: + self.hf_checkpoint = "/root/font-info/qwen3-4b-sft" + self.ref_load = "/root/font-info/qwen3-4b-sft_torch_dist" + else: + self.hf_checkpoint = "/root/models/Qwen3-4B" + self.ref_load = "/root/models/Qwen3-4B_torch_dist" + + +def _get_wandb_args() -> str: + WANDB_API_KEY = os.environ.get("WANDB_API_KEY") + return ( + "--use-wandb " + f"--wandb-project {WANDB_PROJECT} " + f"--wandb-group {WANDB_GROUP} " + f"--wandb-key {WANDB_API_KEY} " + ) + + +def prepare(args: ScriptArgs): + U.exec_command("mkdir -p /root/dapo-math-17k /root/aime-2024") + U.exec_command("hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k") + U.exec_command("hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/aime-2024") + + if args.use_sft_model: + U.exec_command("mkdir -p /root/font-info") + U.exec_command(f"hf download font-info/qwen3-4b-sft-SGLang-RL --local-dir {args.hf_checkpoint}") + U.convert_checkpoint( + model_name="qwen3-4b-sft", + megatron_model_type="qwen3-4B", + num_gpus_per_node=args.num_gpus_per_node, + hf_checkpoint=args.hf_checkpoint, + dir_dst="/root/font-info", + ) + else: + U.exec_command("mkdir -p /root/models") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.convert_checkpoint( + model_name="Qwen3-4B", + megatron_model_type="qwen3-4B", + num_gpus_per_node=args.num_gpus_per_node, + dir_dst="/root/models", + ) + + +def execute(args: ScriptArgs): + megatron_model_type = "qwen3-4B" + + ckpt_args = ( + f"--hf-checkpoint {args.hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_path} " + f"--save-interval {2 if args.mode == 'debug_minimal' else 1000} " + f"{'--rotary-base 5000000 ' if args.use_sft_model else ''}" + ) + + custom_args = ( + "--custom-generate-function-path miles.rollout.generate_hub.multi_turn.generate " + "--generate-tool-specs-path examples.retool_v2.tool_sandbox.tool_specs " + "--generate-execute-tool-function-path examples.retool_v2.tool_sandbox.execute_tool " + "--generate-tool-call-parser qwen25 " + f"--generate-max-turns {args.generate_max_turns} " + "--log-multi-turn " + ) + + rollout_args = ( + f"--prompt-data {args.prompt_data} " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--custom-rm-path examples.retool_v2.tool_sandbox.reward_func " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + ) + + eval_args = "" + if args.mode != "debug_minimal": + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 16 " + "--eval-max-response-len 16384 " + "--eval-top-p 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + f"--rollout-num-gpus-per-engine {args.rollout_num_gpus_per_engine} " "--sglang-mem-fraction-static 0.7 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + misc_args = ( + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + "--colocate " + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--log-passrate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{_get_wandb_args()} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{misc_args} " + f"{custom_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=megatron_model_type, + extra_env_vars={ + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "PYTHONPATH": "/root/Megatron-LM/:/root/miles", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/retool_v2/tool_sandbox.py b/examples/retool_v2/tool_sandbox.py new file mode 100644 index 0000000000..fc7a1dea45 --- /dev/null +++ b/examples/retool_v2/tool_sandbox.py @@ -0,0 +1,385 @@ +""" +copied from examples/retool/tool_sandbox.py +""" + +import asyncio +import gc +import os +import re +import subprocess +import tempfile +from contextlib import contextmanager +from typing import Any +import psutil + +from miles.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score +from miles.utils.types import Sample + +# Configuration for tool execution +TOOL_CONFIGS = { + "max_turns": 16, + "max_tool_calls": 16, + "tool_concurrency": 32, # Aggressive: 32 concurrent processes + # Python interpreter settings + "python_timeout": 120, # 2 minutes for complex calculations + "python_memory_limit": "4GB", # 4GB per Python process + "python_cpu_limit": 1, + # Memory management settings + "max_memory_usage": 12288, # 12GB total (75% of 16GB) + "cleanup_threshold": 6144, # 6GB + "aggressive_cleanup_threshold": 3072, # 3GB + "force_cleanup_threshold": 9216, # 9GB +} + +# Global semaphore for controlling concurrent tool executions +SEMAPHORE = asyncio.Semaphore(TOOL_CONFIGS["tool_concurrency"]) + + +def get_memory_usage() -> float: + """Get current memory usage in MB""" + process = psutil.Process() + return process.memory_info().rss / 1024 / 1024 + + +def cleanup_memory(): + """Force garbage collection to free memory""" + gc.collect() + + +def aggressive_cleanup_memory(): + """More aggressive memory cleanup""" + # Force multiple garbage collection cycles + for _ in range(3): + gc.collect() + + # Clear Python's internal caches + import sys + + # Note: sys.intern doesn't have a clear method, so we skip this + # Clear module cache if possible + if hasattr(sys, "modules"): + # Don't clear all modules, but clear some common ones that might cache data + modules_to_clear = ["numpy", "pandas", "matplotlib", "scipy"] + for module_name in modules_to_clear: + if module_name in sys.modules: + module = sys.modules[module_name] + if hasattr(module, "clear_cache"): + module.clear_cache() + + +def check_and_cleanup_memory(): + """Check memory usage and perform appropriate cleanup""" + current_memory = get_memory_usage() + + if current_memory > TOOL_CONFIGS["force_cleanup_threshold"]: + # Force aggressive cleanup + aggressive_cleanup_memory() + return f"Warning: High memory usage ({current_memory:.1f}MB), performed aggressive cleanup" + elif current_memory > TOOL_CONFIGS["cleanup_threshold"]: + # Normal cleanup + cleanup_memory() + return f"Info: Memory usage ({current_memory:.1f}MB), performed cleanup" + elif current_memory > TOOL_CONFIGS["aggressive_cleanup_threshold"]: + # Light cleanup + gc.collect() + return f"Info: Memory usage ({current_memory:.1f}MB), performed light cleanup" + + return None + + +class PythonSandbox: + """Python code sandbox, provides safe code execution environment""" + + def __init__(self, timeout: int = 10, memory_limit: str = "100MB"): + self.timeout = timeout + self.memory_limit = memory_limit + self.allowed_modules = { + "math", + "random", + "datetime", + "collections", + "itertools", + "functools", + "operator", + "statistics", + "decimal", + "fractions", + } + + def _check_code_safety(self, code: str) -> tuple[bool, str]: + """Check code safety by scanning for dangerous patterns""" + # Check for dangerous operations + dangerous_patterns = [ + r"import\s+os", + r"import\s+sys", + r"import\s+subprocess", + r"import\s+shutil", + r"import\s+glob", + r"import\s+pathlib", + r"__import__", + r"eval\s*\(", + r"exec\s*\(", + r"open\s*\(", + r"file\s*\(", + r"input\s*\(", + r"raw_input\s*\(", + r"compile\s*\(", + r"execfile\s*\(", + r"getattr\s*\(", + r"setattr\s*\(", + r"delattr\s*\(", + r"hasattr\s*\(", + r"globals\s*\(", + r"locals\s*\(", + r"vars\s*\(", + r"dir\s*\(", + r"type\s*\(", + r"isinstance\s*\(", + r"issubclass\s*\(", + r"super\s*\(", + r"property\s*\(", + r"staticmethod\s*\(", + r"classmethod\s*\(", + r"__\w+__", # double underscore methods + ] + + for pattern in dangerous_patterns: + if re.search(pattern, code, re.IGNORECASE): + return False, f"Code contains dangerous pattern: {pattern}" + + # Check imported modules + import_pattern = r"import\s+(\w+)" + from_pattern = r"from\s+(\w+)" + + imports = re.findall(import_pattern, code) + froms = re.findall(from_pattern, code) + + all_imports = set(imports + froms) + for imp in all_imports: + if imp not in self.allowed_modules: + return False, f"Import of '{imp}' is not allowed" + + return True, "Code is safe" + + @contextmanager + def _create_safe_environment(self): + """Create safe execution environment with temporary directory""" + # Create temporary directory + temp_dir = tempfile.mkdtemp(prefix="python_sandbox_") + + try: + # Create safe Python script + script_path = os.path.join(temp_dir, "code.py") + + # Set environment variables + env = os.environ.copy() + env["PYTHONPATH"] = temp_dir + env["PYTHONUNBUFFERED"] = "1" + + yield script_path, env, temp_dir + + finally: + # Clean up temporary directory + try: + import shutil + + shutil.rmtree(temp_dir) + except Exception: + pass + + async def execute_code(self, code: str) -> str: + """Execute Python code in sandbox with safety checks""" + # Check memory usage before execution + current_memory = get_memory_usage() + if current_memory > TOOL_CONFIGS["max_memory_usage"]: + aggressive_cleanup_memory() + return "Error: Memory usage too high, please try again" + + # Check code safety + is_safe, message = self._check_code_safety(code) + if not is_safe: + return f"Error: {message}" + + # Add necessary wrapper code with memory limits + # Properly indent the user code within the try block + # Handle indentation properly by adding 4 spaces to each line + indented_code = "\n".join(" " + line for line in code.split("\n")) + + wrapped_code = f"""import sys +import traceback +from io import StringIO +import resource + +# Set memory limit (4GB) +try: + resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, -1)) +except Exception: + pass + +# Redirect stdout and stderr +old_stdout = sys.stdout +old_stderr = sys.stderr +stdout_capture = StringIO() +stderr_capture = StringIO() +sys.stdout = stdout_capture +sys.stderr = stderr_capture + +try: + # User code +{indented_code} + + # Get output + stdout_output = stdout_capture.getvalue() + stderr_output = stderr_capture.getvalue() + + # Restore standard output + sys.stdout = old_stdout + sys.stderr = old_stderr + + # Return result + result = "" + if stdout_output: + result += f"Output:\\n{{stdout_output}}" + if stderr_output: + result += f"\\nErrors:\\n{{stderr_output}}" + + print(result) + +except Exception as e: + # Restore standard output + sys.stdout = old_stdout + sys.stderr = old_stderr + + # Return error information + error_msg = f"Error: {{str(e)}}\\nTraceback:\\n{{traceback.format_exc()}}" + print(error_msg)""" + + with self._create_safe_environment() as (script_path, env, temp_dir): + # Write code to file + with open(script_path, "w") as f: + f.write(wrapped_code) + + try: + # Use subprocess to run code + process = subprocess.Popen( + ["python3", script_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + cwd=temp_dir, + text=True, + ) + + # Set timeout + try: + stdout, stderr = process.communicate(timeout=self.timeout) + + if process.returncode == 0: + result = stdout.strip() + else: + result = f"Error: Process exited with code {process.returncode}\n{stderr}" + + except subprocess.TimeoutExpired: + process.kill() + result = f"Error: Code execution timed out after {self.timeout} seconds" + + except Exception as e: + result = f"Error: Failed to execute code: {str(e)}" + + # Check memory usage after execution and cleanup if needed + cleanup_message = check_and_cleanup_memory() + if cleanup_message: + print(f"Memory cleanup: {cleanup_message}") + + return result + + +class ToolRegistry: + """Tool registry, manages available tools and their execution""" + + def __init__(self): + self.tools = {} + self.python_sandbox = PythonSandbox( + timeout=TOOL_CONFIGS["python_timeout"], memory_limit=TOOL_CONFIGS["python_memory_limit"] + ) + self._register_default_tools() + + def _register_default_tools(self): + """Register default tools in the registry""" + # Python code interpreter + self.register_tool( + "code_interpreter", + { + "type": "function", + "function": { + "name": "code_interpreter", + "description": "A tool for executing Python code in a safe sandbox environment.", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string", "description": "The Python code to execute"}}, + "required": ["code"], + }, + }, + }, + ) + + def register_tool(self, name: str, tool_spec: dict[str, Any]): + """Register a new tool in the registry""" + self.tools[name] = tool_spec + + def get_tool_specs(self) -> list[dict[str, Any]]: + """Get all tool specifications as a list""" + return list(self.tools.values()) + + async def execute_tool(self, tool_name: str, arguments: dict[str, Any]) -> str: + """Execute a tool call with the given arguments""" + if tool_name not in self.tools: + return f"Error: Tool '{tool_name}' not found" + + async with SEMAPHORE: + if tool_name == "code_interpreter": + return await self._execute_python(arguments) + else: + return f"Error: Tool '{tool_name}' not implemented" + + async def _execute_python(self, arguments: dict[str, Any]) -> str: + """Execute Python code using the sandbox""" + code = arguments.get("code", "") + if isinstance(code, list): + code = "\n".join(str(item) for item in code) + if not code.strip(): + return "Error: No code provided" + + # Execute code in sandbox + result = await self.python_sandbox.execute_code(code) + return result + + +tool_registry = ToolRegistry() + +tool_specs = tool_registry.get_tool_specs() + + +async def execute_tool(name: str, params: dict) -> str: + return await tool_registry.execute_tool(name, params) + + +# Reward function that encourages tool usage +async def reward_func(args, sample: Sample, **kwargs): + """Tool call reward function using math_dapo, with bonus for tool usage.""" + solution_str = sample.prompt + sample.response if isinstance(sample.prompt, str) else sample.response + ground_truth = sample.label if sample.label is not None else "" + tool_call_count = sample.metadata.get("tool_call_count", 0) + + # use \boxed{...} answer + result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True) + + # encourage model to call tools + if result["score"] < 0: + tool_call_reward = tool_call_count / 2 * 0.1 + result["score"] = min(-0.6, result["score"] + tool_call_reward) + + if result["pred"] is None: + result["pred"] = "" + + return result From 4db9bfe9b5d871c4b337dc2a7ee6a918e1e0c877 Mon Sep 17 00:00:00 2001 From: Shi-Dong Date: Tue, 7 Apr 2026 13:53:37 -0700 Subject: [PATCH 18/44] Expose rollout-batch-size, n-samples-per-prompt, global-batch-size as CLI args in swe-agent-v2 (#954) Co-authored-by: Shi Dong --- examples/experimental/swe-agent-v2/run.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/experimental/swe-agent-v2/run.py b/examples/experimental/swe-agent-v2/run.py index b3cdbdab68..16e97f1a1d 100644 --- a/examples/experimental/swe-agent-v2/run.py +++ b/examples/experimental/swe-agent-v2/run.py @@ -38,9 +38,14 @@ class ScriptArgs(U.ExecuteTrainConfig): hf_checkpoint: str = "zai-org/GLM-4.7-Flash" ref_load: str = "/root/GLM-4.7-Flash_torch_dist" save_dir: str = "/root/GLM-4.7-Flash_agent_v2/" - max_seq_len: int = 16384 prompt_data: str = "/root/swe_train.jsonl" + # Training settings + max_seq_len: int = 16384 + rollout_batch_size: int = 2 + n_samples_per_prompt: int = 4 + global_batch_size: int = 8 + # Agent settings agent_server_url: str = os.environ.get( "AGENT_SERVER_URL", os.environ.get("SWE_AGENT_URL", "http://agent_env:11000") @@ -104,12 +109,12 @@ def execute(args: ScriptArgs): "--metadata-key metadata " "--rollout-shuffle " "--num-rollout 3000 " - "--rollout-batch-size 2 " - "--n-samples-per-prompt 4 " + f"--rollout-batch-size {args.rollout_batch_size} " + f"--n-samples-per-prompt {args.n_samples_per_prompt} " "--rollout-temperature 0.8 " "--rollout-max-response-len 8192 " f"--max-seq-len {args.max_seq_len} " - "--global-batch-size 8 " + f"--global-batch-size {args.global_batch_size} " "--balance-data " ) From 6b58ebd3ab0ad5788ad072df7adff22e85c1e386 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:01:08 -0700 Subject: [PATCH 19/44] chore: remove obsolete swe-agent server.py and run-qwen3.sh (#952) Co-authored-by: Claude Opus 4.6 (1M context) --- .../experimental/swe-agent-v2/run-qwen3.sh | 166 ---------- examples/experimental/swe-agent-v2/server.py | 298 ------------------ 2 files changed, 464 deletions(-) delete mode 100755 examples/experimental/swe-agent-v2/run-qwen3.sh delete mode 100644 examples/experimental/swe-agent-v2/server.py diff --git a/examples/experimental/swe-agent-v2/run-qwen3.sh b/examples/experimental/swe-agent-v2/run-qwen3.sh deleted file mode 100755 index ac0d8bf863..0000000000 --- a/examples/experimental/swe-agent-v2/run-qwen3.sh +++ /dev/null @@ -1,166 +0,0 @@ -#!/bin/bash -# Agent V2 launcher (Qwen3-4B): Miles <-> Harbor agent orchestration. -# -# Supports any task type (SWE-bench, Terminal-Bench, custom) via Harbor. - -pkill -9 sglang 2>/dev/null || true -sleep 3 -ray stop --force 2>/dev/null || true -pkill -9 ray 2>/dev/null || true -pkill -9 python 2>/dev/null || true -sleep 3 -pkill -9 ray 2>/dev/null || true -pkill -9 python 2>/dev/null || true -sleep 3 - -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -MILES_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" - -source "$MILES_ROOT/scripts/models/qwen3-4B.sh" - -BASE_DIR=/root/shared -AGENT_SERVER_URL="${AGENT_SERVER_URL:-${SWE_AGENT_URL:-http://agent_env:11000}}" -HARBOR_TASKS_DIR="${HARBOR_TASKS_DIR:-/root/harbor_tasks}" -ROUTER_EXTERNAL_HOST="${MILES_ROUTER_EXTERNAL_HOST:-$(hostname)}" - -CKPT_ARGS=( - --hf-checkpoint $BASE_DIR/Qwen3-4B - --ref-load $BASE_DIR/Qwen3-4B_torch_dist - --save $BASE_DIR/Qwen3-4B_agent_V2/ - --save-interval 100 -) - -ROLLOUT_ARGS=( - --prompt-data /root/swe_train.jsonl - --input-key prompt - --metadata-key metadata - --rollout-shuffle - - --num-rollout 3000 - --rollout-batch-size 1 - --n-samples-per-prompt 1 - --rollout-temperature 0.8 - --rollout-max-response-len 8192 - --global-batch-size 1 - --balance-data -) - -PERF_ARGS=( - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 2048 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.01 - --kl-loss-type low_var_kl - --entropy-coef 0.0 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.8 - --sglang-tool-call-parser qwen25 - --sglang-reasoning-parser qwen3 - - --use-miles-router - --sglang-router-port 30000 -) - -AGENT_ARGS=( - --custom-generate-function-path miles.rollout.generate_hub.agentic_tool_call.generate - --custom-agent-function-path swe_agent_function.run - --custom-rm-path generate.reward_func - --rollout-function-path generate.RolloutFn - --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_no_aborted - --tito-model qwen3 - --chat-template-path autofix - --use-session-server -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-agent-v2 - # --wandb-group agent-v2 -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash -) - -DEBUG_ARGS=( - --debug-rollout-only -) - -# ── Start Ray ──────────────────────────────────────────────────────── -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head \ - --node-ip-address "$MASTER_ADDR" \ - --num-gpus 1 \ - --disable-usage-stats \ - --dashboard-host=0.0.0.0 \ - --dashboard-port=8265 \ - --port=8899 - -RUNTIME_ENV=$(python3 -c " -import json, sys -print(json.dumps({'env_vars': { - 'PYTHONPATH': '/root/Megatron-LM/:${SCRIPT_DIR}:${MILES_ROOT}', - 'CUDA_DEVICE_MAX_CONNECTIONS': '1', - 'MILES_EXPERIMENTAL_ROLLOUT_REFACTOR': '1', - 'AGENT_SERVER_URL': '${AGENT_SERVER_URL}', - 'AGENT_MODEL_NAME': '${AGENT_MODEL_NAME:-model}', - 'MILES_ROUTER_EXTERNAL_HOST': '${ROUTER_EXTERNAL_HOST}', - 'HARBOR_TASKS_DIR': '${HARBOR_TASKS_DIR}', - 'MILES_HOST_IP': '${MILES_HOST_IP:-$(hostname)}', - 'NCCL_NVLS_ENABLE': '0', -}})) -") - -ray job submit \ - --address="http://127.0.0.1:8265" \ - --runtime-env-json="$RUNTIME_ENV" \ - -- python3 "$MILES_ROOT/train.py" \ - --colocate \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 1 \ - --rollout-num-gpus 1 \ - "${MODEL_ARGS[@]}" \ - "${CKPT_ARGS[@]}" \ - "${ROLLOUT_ARGS[@]}" \ - "${OPTIMIZER_ARGS[@]}" \ - "${GRPO_ARGS[@]}" \ - "${PERF_ARGS[@]}" \ - "${SGLANG_ARGS[@]}" \ - "${AGENT_ARGS[@]}" \ - "${WANDB_ARGS[@]}" \ - "${MISC_ARGS[@]}" \ - "${DEBUG_ARGS[@]}" diff --git a/examples/experimental/swe-agent-v2/server.py b/examples/experimental/swe-agent-v2/server.py deleted file mode 100644 index 137b998975..0000000000 --- a/examples/experimental/swe-agent-v2/server.py +++ /dev/null @@ -1,298 +0,0 @@ -""" -FastAPI server wrapping Harbor for generalized agent-environment orchestration. - -Provides a single ``/run`` endpoint that handles any task type (SWE-bench, -Terminal-Bench, custom datasets, etc.) through Harbor's unified Trial API. -Harbor handles Docker orchestration, agent execution, and grading — the -server is task-type agnostic. - -Requires: - - Harbor installed: pip install harbor-framework - - Prepared task dirs under HARBOR_TASKS_DIR (via adapters or prepare_harbor_tasks.py) - -Usage: - python server.py --port 11000 --max-concurrent 8 -""" - -import argparse -import asyncio -import logging -import os -import re -import traceback -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Any - -import uvicorn -from fastapi import FastAPI -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -_semaphore: asyncio.Semaphore | None = None - - -@asynccontextmanager -async def _lifespan(app: FastAPI) -> AsyncIterator[None]: - global _semaphore - max_concurrent = int(os.getenv("AGENT_MAX_CONCURRENT", os.getenv("SWE_AGENT_MAX_CONCURRENT", "8"))) - _semaphore = asyncio.Semaphore(max_concurrent) - logger.info(f"Initialized semaphore with max_concurrent={max_concurrent}") - yield - - -app = FastAPI(title="Agent Environment Server (Harbor)", lifespan=_lifespan) - - -class RunRequest(BaseModel): - base_url: str - model: str - sampling_params: dict[str, Any] = {} - api_key: str = "dummy" - - instance_id: str = "" - agent_name: str = "mini-swe-agent" - max_seq_len: int | None = None - - model_config = {"extra": "allow"} - - -class RunResponse(BaseModel): - reward: float = 0.0 - exit_status: str = "" - agent_metrics: dict[str, Any] = {} - eval_report: dict[str, Any] = {} - - -def get_semaphore() -> asyncio.Semaphore: - assert _semaphore is not None, "Semaphore not initialized — server not started?" - return _semaphore - - -_TIMEOUT_EXCEPTIONS = {"AgentTimeoutError", "VerifierTimeoutError", "EnvironmentStartTimeoutError"} -_OUTPUT_LIMIT_EXCEPTIONS = {"MaxSeqLenExceededError"} - -_HOST_PROCESS_AGENTS = {"terminus-2", "terminus-1", "terminus"} - -_SAFE_INSTANCE_ID = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") - - -def _extract_exit_status(result) -> str: - """Derive exit status from Harbor TrialResult.""" - exc = getattr(result, "exception_info", None) - if exc is not None: - exc_type = getattr(exc, "exception_type", "") - if exc_type in _TIMEOUT_EXCEPTIONS: - return "TimeLimitExceeded" - if exc_type in _OUTPUT_LIMIT_EXCEPTIONS: - return "SequenceLengthLimitExceeded" - return "AgentError" - if getattr(result, "verifier_result", None) is not None: - return "Submitted" - return "Unknown" - - -def _timing_duration_sec(timing) -> float | None: - started = getattr(timing, "started_at", None) - finished = getattr(timing, "finished_at", None) - if started and finished: - return (finished - started).total_seconds() - return None - - -def _extract_reward(result) -> tuple[float, dict[str, Any]]: - """Extract scalar reward and full eval report from Harbor TrialResult. - - Looks for the ``"reward"`` key first, then falls back to the first value - in the rewards dict. Works with both ``reward.txt`` and ``reward.json``. - """ - vr = getattr(result, "verifier_result", None) - if vr is None: - return 0.0, {} - rewards = getattr(vr, "rewards", None) or {} - reward = float(rewards.get("reward", next(iter(rewards.values()), 0.0))) - return reward, dict(rewards) - - -def _extract_metrics(result) -> dict[str, Any]: - """Extract agent metrics from Harbor TrialResult.""" - metrics: dict[str, Any] = {} - try: - ar = getattr(result, "agent_result", None) - if ar is not None: - for field in ("n_input_tokens", "n_output_tokens", "cost_usd"): - val = getattr(ar, field, None) - if val is not None: - metrics[field] = val - agent_meta = getattr(ar, "metadata", None) - if isinstance(agent_meta, dict): - metrics.update(agent_meta) - - agent_timing = getattr(result, "agent_execution", None) - if agent_timing is not None: - dur = _timing_duration_sec(agent_timing) - if dur is not None: - metrics["agent_run_time"] = dur - - verifier_timing = getattr(result, "verifier", None) - if verifier_timing is not None: - dur = _timing_duration_sec(verifier_timing) - if dur is not None: - metrics["eval_time"] = dur - except Exception as e: - logger.warning(f"Failed to extract metrics: {e}", exc_info=True) - return metrics - - -def _error_response(exit_status: str) -> dict[str, Any]: - return {"reward": 0.0, "exit_status": exit_status, "agent_metrics": {}, "eval_report": {}} - - -async def _run_trial(request: RunRequest) -> dict[str, Any]: - """Run a Harbor trial for a single task instance. - - Task-type agnostic — all differentiation (environment, grading harness) - is encoded in the Harbor task directory's 4 files. - """ - try: - from harbor.models.trial.config import AgentConfig, EnvironmentConfig, TaskConfig, TrialConfig - from harbor.trial.trial import Trial - except ImportError: - logger.error("Harbor not installed. Install with: pip install harbor-framework") - return _error_response("ImportError") - - try: - tasks_dir = Path( - os.getenv("HARBOR_TASKS_DIR", "/root/harbor_tasks"), - ).resolve() - - if not request.instance_id: - logger.error("Empty instance_id") - return _error_response("InvalidInstanceId") - - raw_id = request.instance_id - if not _SAFE_INSTANCE_ID.match(raw_id): - logger.error(f"Invalid instance_id rejected: {raw_id!r}") - return _error_response("InvalidInstanceId") - - # Normalize and verify the path stays within tasks_dir. - # Uses the pattern recommended by CodeQL (py/path-injection): - # normpath(join(base, user_input)) + startswith(base) - tasks_dir_str = str(tasks_dir) - task_path = os.path.normpath(os.path.join(tasks_dir_str, raw_id)) - if not task_path.startswith(tasks_dir_str): - logger.error(f"Path traversal blocked: {raw_id!r}") - return _error_response("InvalidInstanceId") - - if not os.path.exists(task_path): - logger.error(f"Task directory not found: {task_path}") - return _error_response("TaskNotFound") - - task_path = Path(task_path) - agent_kwargs: dict[str, Any] = {} - agent_env: dict[str, str] = {} - - is_host_agent = request.agent_name in _HOST_PROCESS_AGENTS - - if "hosted_vllm" in request.model or "openai" in request.model: - agent_kwargs["model_info"] = { - "max_input_tokens": int(os.getenv("AGENT_MAX_INPUT_TOKENS", "32768")), - "max_output_tokens": int(os.getenv("AGENT_MAX_OUTPUT_TOKENS", "8192")), - "input_cost_per_token": 0.0, - "output_cost_per_token": 0.0, - } - - if request.max_seq_len is not None: - agent_kwargs["max_seq_len"] = request.max_seq_len - - if is_host_agent: - agent_kwargs["api_base"] = request.base_url - agent_kwargs["api_key"] = request.api_key or "dummy" - agent_kwargs["enable_summarize"] = False - agent_env = { - "OPENAI_API_KEY": request.api_key or "dummy", - "OPENAI_API_BASE": request.base_url, - } - else: - agent_env = { - "OPENAI_API_BASE": request.base_url, - "OPENAI_API_KEY": request.api_key, - "HOSTED_VLLM_API_BASE": request.base_url, - "HOSTED_VLLM_API_KEY": request.api_key, - "MSWEA_COST_TRACKING": "ignore_errors", - } - - config = TrialConfig( - task=TaskConfig(path=task_path), - agent=AgentConfig( - name=request.agent_name, - model_name=request.model, - env=agent_env, - kwargs=agent_kwargs, - ), - environment=EnvironmentConfig( - type="docker", - delete=os.getenv("HARBOR_DELETE_CONTAINERS", "false").lower() in ("true", "1", "t"), - ), - ) - - trial = Trial(config=config) - result = await trial.run() - - reward, eval_report = _extract_reward(result) - exit_status = _extract_exit_status(result) - agent_metrics = _extract_metrics(result) - - return { - "reward": reward, - "exit_status": exit_status, - "agent_metrics": agent_metrics, - "eval_report": eval_report, - } - - except Exception as e: - logger.error(f"Harbor trial failed: {e}\n{traceback.format_exc()}") - return _error_response(f"Error: {type(e).__name__}") - - -@app.post("/run") -async def run_instance(request: RunRequest) -> RunResponse: - """Run an agent on a single task instance via Harbor.""" - logger.info(f"Running instance: {request.instance_id}") - async with get_semaphore(): - result = await _run_trial(request) - logger.info( - f"Instance {request.instance_id} finished: exit_status={result['exit_status']}, reward={result['reward']}" - ) - return RunResponse(**result) - - -@app.get("/health") -async def health(): - return {"status": "ok"} - - -def main(): - parser = argparse.ArgumentParser(description="Agent Environment Server (Harbor)") - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=11000) - parser.add_argument("--max-concurrent", type=int, default=8) - args = parser.parse_args() - - os.environ["AGENT_MAX_CONCURRENT"] = str(args.max_concurrent) - - os.environ.setdefault("MSWEA_API_KEY", "dummy") - os.environ.setdefault("HOSTED_VLLM_API_KEY", "dummy") - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s %(message)s", - ) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() From 41615af98ef921e3d48dc23a11dcabbf4c1e2ea0 Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:29:02 -0700 Subject: [PATCH 20/44] Add weight staleness control for fully async rollout (#958) --- examples/fully_async/fully_async_rollout.py | 92 +++++++++++++++++-- .../fully_async/run-qwen3-4b-fully_async.sh | 3 + miles/ray/rollout.py | 10 ++ .../generate_utils/openai_endpoint_utils.py | 3 + miles/utils/arguments.py | 11 +++ miles/utils/types.py | 30 ++++++ 6 files changed, 143 insertions(+), 6 deletions(-) diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 446c882e94..e4c23cc120 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -1,20 +1,60 @@ import asyncio import atexit +import logging import queue import threading import time -# Import core functions from sglang_rollout directly to avoid code duplication +import aiohttp + +from miles.rollout.data_source import DataSource from miles.rollout.sglang_rollout import GenerateState, generate_and_rm_group from miles.utils.async_utils import run from miles.utils.types import Sample +logger = logging.getLogger(__name__) + + +def group_oldest_weight_version(group: list[Sample]) -> int | None: + """Return the minimum weight version across all trajectories and turns in a group.""" + versions = [s.oldest_weight_version for s in group if s.oldest_weight_version is not None] + return min(versions) if versions else None + + +class _CachedWeightVersion: + """Throttled query for the current engine weight version via /model_info.""" + + def __init__(self, ttl: float = 1.0): + self._ttl = ttl + self._value: int | None = None + self._last_query: float = 0.0 + + async def get(self, args) -> int | None: + now = time.monotonic() + if self._value is not None and (now - self._last_query) < self._ttl: + return self._value + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/model_info" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=2)) as resp: + if resp.status == 200: + data = await resp.json() + self._value = int(data["weight_version"]) + self._last_query = now + except Exception as e: + logger.debug(f"Failed to query engine weight version: {e}") + return self._value + + +_cached_version = _CachedWeightVersion() + + # Global worker manager _global_worker = None _worker_lock = threading.Lock() -def get_global_worker(args, data_buffer): +def get_global_worker(args, data_buffer: DataSource): """Get or create global worker""" global _global_worker with _worker_lock: @@ -40,7 +80,7 @@ class AsyncRolloutWorker: Supports continuous running, independent of rollout function lifecycle """ - def __init__(self, args, data_buffer, concurrency=10): + def __init__(self, args, data_buffer: DataSource, concurrency=10): self.args = args self.data_buffer = data_buffer # Directly save data_buffer reference self.concurrency = concurrency @@ -146,7 +186,7 @@ def get_queue_size(self) -> int: return self.output_queue.qsize() -async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[list[Sample]]: +async def generate_rollout_async(args, rollout_id: int, data_buffer: DataSource) -> list[list[Sample]]: """ Simplified asynchronous rollout generation - using global continuous worker """ @@ -161,9 +201,15 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis data = [] completed_groups = {} do_print = True + stale_groups_recycled = 0 + staleness_values = [] + + use_staleness_filter = getattr(args, "max_weight_staleness", None) is not None print(f"Starting async rollout generation for {target_data_size} groups") print(f"Global worker queue size: {worker.get_queue_size()}") + if use_staleness_filter: + print(f"Staleness filter enabled: max_weight_staleness={args.max_weight_staleness}") # Main loop: collect results from global worker's output queue start_time = time.time() @@ -182,6 +228,11 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis if made_progress: last_progress_time = time.time() + # Query current engine version once per collection batch (cached/throttled) + current_engine_version = None + if use_staleness_filter: + current_engine_version = await _cached_version.get(args) + # Process completed groups in order (try to maintain order, but not strict requirement) processed_any = False @@ -202,7 +253,8 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis if any_aborted: try: - # add back to buffer so it can be retried or handled by buffer policy + for s in group: + s.reset_for_retry() data_buffer.add_samples([group]) print(f"Returned aborted group {group_id} to data buffer", flush=True) except Exception as e: @@ -210,6 +262,27 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis # don't count as processed for training continue + # Staleness filter: discard groups whose oldest weight version is too far behind + oldest = group_oldest_weight_version(group) + if oldest is not None and current_engine_version is not None: + staleness = current_engine_version - oldest + staleness_values.append(staleness) + if staleness > args.max_weight_staleness: + try: + for s in group: + s.reset_for_retry() + data_buffer.add_samples([group]) + except Exception as e: + logger.warning(f"Failed to recycle stale group {group_id}: {e}") + stale_groups_recycled += 1 + logger.info( + f"Recycled stale group {group_id} " + f"(oldest_version={oldest}, current={current_engine_version}, " + f"staleness={staleness} > max={args.max_weight_staleness})" + ) + # don't count as processed for training + continue + if do_print: print( f"First rollout sample: {[group[0].prompt + group[0].response]}, " @@ -238,6 +311,13 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis duration = time.time() - start_time print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}") + if stale_groups_recycled > 0 or staleness_values: + avg_staleness = sum(staleness_values) / len(staleness_values) if staleness_values else 0 + print( + f"Staleness stats: recycled={stale_groups_recycled}, " + f"avg_staleness={avg_staleness:.1f}, " + f"max_staleness={max(staleness_values) if staleness_values else 0}" + ) if data: print( @@ -250,7 +330,7 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis return data -def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False): +def generate_rollout_fully_async(args, rollout_id, data_buffer: DataSource, evaluation=False): if evaluation: raise ValueError("Evaluation mode not supported in simple async rollout") diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 026e486089..bfd12696bf 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -56,6 +56,9 @@ ROLLOUT_ARGS=( --global-batch-size 256 --balance-data + + # for staleness control + #--max-weight-staleness 2 ) PERF_ARGS=( diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index d625ae4c5e..1aa2e91fe6 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -724,6 +724,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] + if any(sample.weight_versions for sample in samples): + train_data["weight_versions"] = [sample.weight_versions for sample in samples] + if "teacher_log_probs" in samples[0].__dict__: train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] @@ -771,6 +774,7 @@ def _split_train_data_by_dp(self, data, dp_size): "rollout_routed_experts", "prompt", "teacher_log_probs", + "weight_versions", ]: if key not in data: continue @@ -1202,6 +1206,12 @@ def compute_metrics_from_samples(args, samples): log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() + oldest_versions = [s.oldest_weight_version for s in samples if s.oldest_weight_version is not None] + if oldest_versions: + log_dict |= dict_add_prefix(compute_statistics(oldest_versions), "weight_version/") + mixed = sum(1 for s in samples if len(set(s.weight_versions)) > 1) + log_dict["weight_version/mixed_version_ratio"] = mixed / len(samples) + tito_vals = [s.metadata.get("tito_session_mismatch") for s in samples] tito_vals = [v for v in tito_vals if v is not None] if tito_vals: diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 7602e51468..d054cf2c52 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -199,6 +199,9 @@ def _compute_sample_from_openai_record( case "abort": sample.status = Sample.Status.ABORTED + if "weight_version" in choice["meta_info"]: + sample.weight_versions.append(choice["meta_info"]["weight_version"]) + return sample diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 51df300818..4caa0bdf6d 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -385,6 +385,17 @@ def add_rollout_arguments(parser): "If set, only on-policy generated tokens will be used in training" ), ) + parser.add_argument( + "--max-weight-staleness", + type=int, + default=None, + help=( + "Maximum allowed gap between a group's oldest weight version and the current " + "engine weight version. Groups exceeding this threshold are recycled back to " + "the data buffer instead of being sent to training. Only effective in fully " + "async mode. None (default) disables staleness filtering." + ), + ) parser.add_argument( "--custom-generate-function-path", type=str, diff --git a/miles/utils/types.py b/miles/utils/types.py index 86c30a906c..540648e1fd 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -188,6 +188,36 @@ def strip_last_output_tokens(self, n: int, tokenizer) -> None: if self.rollout_routed_experts is not None: self.rollout_routed_experts = self.rollout_routed_experts[:-n] + def reset_for_retry(self) -> None: + """Reset generated outputs so the original prompt can be re-sampled. + + Keeps identity / prompt fields (group_index, index, prompt, label, + multimodal_inputs, metadata, generate_function_path, session_id) and + restores everything else to dataclass defaults. + """ + self.tokens = [] + self.multimodal_train_inputs = None + self.response = "" + self.response_length = 0 + self.reward = None + self.loss_mask = None + self.weight_versions = [] + self.rollout_log_probs = None + self.rollout_routed_experts = None + self.status = Sample.Status.ABORTED + self.non_generation_time = 0.0 + self.spec_info = Sample.SpecInfo() + self.prefix_cache_info = Sample.PrefixCacheInfo() + self.remove_sample = False + self.train_metadata = None + + @property + def oldest_weight_version(self) -> int | None: + """Minimum weight version across all turns (generation calls) for this trajectory.""" + if not self.weight_versions: + return None + return min(int(v) for v in self.weight_versions) + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. From 94dbb8fe22e9fff2b4b9b24d1c121dbca1bc5a7b Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Wed, 8 Apr 2026 23:08:16 -0700 Subject: [PATCH 21/44] Fix/pause generation mode (#924) Co-authored-by: Yueming Yuan --- .../update_weight_from_distributed/mixin.py | 3 ++- .../update_weight/update_weight_from_tensor.py | 3 ++- miles/backends/sglang_utils/sglang_engine.py | 7 +++++-- miles/utils/arguments.py | 13 +++++++++++++ .../test_lora_weight_sync_validation.py | 1 + 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py index 79217bbb48..006f15516d 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py @@ -139,7 +139,8 @@ def _update_expert_bucket_weights( def _pause_and_prepare_engines(self) -> None: """Pause rollout engines, flush cache, and run pre-process if needed.""" if dist.get_rank() == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + mode = self.args.pause_generation_mode + ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) # int4/fp4 pre_process diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 48c85e958b..86073d5ab3 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -176,7 +176,8 @@ def update_weights(self) -> None: rank = dist.get_rank() if rank == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + mode = self.args.pause_generation_mode + ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: post_process_weights( diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 8b567a744d..1c215c7b62 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -501,8 +501,11 @@ def update_weights_from_distributed( payload, ) - def pause_generation(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) + def pause_generation(self, mode: str = "retract"): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/pause_generation", + json={"mode": mode}, + ) response.raise_for_status() return response diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 4caa0bdf6d..0bd8b8cba3 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -452,6 +452,19 @@ def add_rollout_arguments(parser): default=1, help="Interval for updating the weights", ) + parser.add_argument( + "--pause-generation-mode", + type=str, + choices=["abort", "retract", "in_place"], + default="retract", + help=( + "How SGLang pauses in-flight requests during weight updates. " + "'abort' immediately terminates all requests (previous default). " + "'retract' moves running requests back to the waiting queue and " + "recomputes KV cache after update. " + "'in_place' freezes requests and resumes with existing KV cache." + ), + ) parser.add_argument( "--keep-old-actor", action="store_true", diff --git a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py index 569039c84b..81748bca75 100644 --- a/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py +++ b/tests/fast/backends/megatron_utils/test_lora_weight_sync_validation.py @@ -56,6 +56,7 @@ def _make_args(**overrides): update_weight_buffer_size=1 << 30, actor_num_nodes=1, actor_num_gpus_per_node=1, + pause_generation_mode="retract", ) defaults.update(overrides) return Namespace(**defaults) From 4d8b00733be065c67e59016ad000f07c107a95f5 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:43:15 -0700 Subject: [PATCH 22/44] [v0.5.10][1] Bump sglang to v0.5.10 (#898) --- docker/Dockerfile | 6 +++--- miles/rollout/session/session_server.py | 2 +- miles/router/router.py | 2 +- tests/fast/utils/chat_template_utils/test_template.py | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index e6a2dd03b8..5afad3920c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,10 +3,10 @@ # # 2. radixark/miles:dev-cu13-arm64 # build-arg:ENABLE_CUDA_13=1 \ -# build-arg:SGLANG_IMAGE_TAG=v0.5.9-cu130-arm64 \ +# build-arg:SGLANG_IMAGE_TAG=v0.5.10-cu130 \ # build-arg:WHEELS_TAG=cu130-aarch64 \ -ARG SGLANG_IMAGE_TAG=v0.5.9 +ARG SGLANG_IMAGE_TAG=v0.5.10 FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang # ======================================== Arguments ============================================= @@ -88,7 +88,7 @@ RUN pip install megatron-energon --no-deps RUN pip install multi-storage-client --no-deps COPY requirements.txt /tmp/requirements.txt -RUN pip install -r /tmp/requirements.txt +RUN rm -rf /usr/lib/python3/dist-packages/jwt /usr/lib/python3/dist-packages/PyJWT* && pip install -r /tmp/requirements.txt # https://github.com/pytorch/pytorch/issues/168167 RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ diff --git a/miles/rollout/session/session_server.py b/miles/rollout/session/session_server.py index 0377117bdf..bc2633350e 100644 --- a/miles/rollout/session/session_server.py +++ b/miles/rollout/session/session_server.py @@ -36,7 +36,7 @@ def __init__(self, args, backend_url: str): ) # Close the httpx connection pool when uvicorn shuts down to avoid FD leaks. - self.app.add_event_handler("shutdown", self.client.aclose) + self.app.router.on_shutdown.append(self.client.aclose) setup_session_routes(self.app, self, args) diff --git a/miles/router/router.py b/miles/router/router.py index 09be44b033..51194be4cf 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -36,7 +36,7 @@ def __init__(self, args, verbose=False): self.verbose = verbose self.app = FastAPI() - self.app.add_event_handler("startup", self._start_background_health_check) + self.app.router.on_startup.append(self._start_background_health_check) # URL -> Active Request Count (load state) self.worker_request_counts: dict[str, int] = {} diff --git a/tests/fast/utils/chat_template_utils/test_template.py b/tests/fast/utils/chat_template_utils/test_template.py index 225a9bf178..39ac412954 100644 --- a/tests/fast/utils/chat_template_utils/test_template.py +++ b/tests/fast/utils/chat_template_utils/test_template.py @@ -60,6 +60,7 @@ def _make_serving(tokenizer) -> OpenAIServingChat: serving.use_dpsk_v32_encoding = False serving.is_gpt_oss = False serving.tool_call_parser = None + serving.reasoning_parser = None return serving From ef228e648230868720f838028d5962b22f4aa360 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:50:01 -0700 Subject: [PATCH 23/44] [v0.5.10][2] Fix apply_chat_template behavior for transformers >=5.0 (#926) Co-authored-by: guapisolo Co-authored-by: Claude Opus 4.6 (1M context) --- miles/utils/mask_utils.py | 16 +++++---- .../generate_hub/test_tool_call_utils.py | 36 ++++++++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/miles/utils/mask_utils.py b/miles/utils/mask_utils.py index 0ddb3a1410..cd2d72eb4c 100644 --- a/miles/utils/mask_utils.py +++ b/miles/utils/mask_utils.py @@ -38,7 +38,7 @@ def get_system_message_length(self) -> tuple[int, int]: end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2 gen_token_length = len( self.tokenizer.apply_chat_template( - test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True + test_messages, add_special_tokens=False, tokenize=True, return_dict=False, add_generation_prompt=True ) ) - len(chat_template_token_ids) @@ -53,9 +53,11 @@ def gen_multi_turn_loss_mask_qwen( for i, message in enumerate(messages): if i == 0: - message_ids = self.tokenizer.apply_chat_template([message], tokenize=True, tools=tools) + message_ids = self.tokenizer.apply_chat_template( + [message], tokenize=True, return_dict=False, tools=tools + ) else: - message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) + message_ids = self.tokenizer.apply_chat_template([message], tokenize=True, return_dict=False) if message["role"] != "system" and i > 0: message_ids = message_ids[self.system_message_length :] @@ -80,16 +82,18 @@ def gen_multi_turn_loss_mask_qwen3( all_token_ids = [] prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"} - prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True) + prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True, return_dict=False) for i, message in enumerate(messages): if i == 0: tailed_message_ids = self.tokenizer.apply_chat_template( - [message, prefix_message], tokenize=True, tools=tools + [message, prefix_message], tokenize=True, return_dict=False, tools=tools ) message_ids = tailed_message_ids[: -len(prefix_token_ids)] else: - prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) + prefixed_message_ids = self.tokenizer.apply_chat_template( + [prefix_message, message], tokenize=True, return_dict=False + ) message_ids = prefixed_message_ids[len(prefix_token_ids) :] if message["role"] != "system" and i > 0: diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py index 0f2305e753..b4f684dd80 100644 --- a/tests/fast/rollout/generate_hub/test_tool_call_utils.py +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -7,24 +7,39 @@ "Qwen/Qwen3-0.6B", "Qwen/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-Coder-30B-A3B-Instruct", + "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3-Coder-Next", # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI "mistralai/Mistral-7B-Instruct-v0.3", - "deepseek-ai/DeepSeek-V3", - "stepfun-ai/step3", "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", "internlm/internlm3-8b-instruct", - "THUDM/glm-4-9b-chat", + "zai-org/GLM-4.7-Flash", + "stepfun-ai/Step-3.5-Flash", "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2.5", "XiaomiMiMo/MiMo-7B-RL", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", ] -SINGLE_TOOL_CALL_ONLY_MODELS = [ - # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +# Models that fail decode round-trip under transformers>=5.x due to upstream tokenizer issues. +# These are excluded from TOOL_CALL_TEST_MODELS but listed here for tracking. +# - DeepSeek-V3, step3: transformers v5 unified LlamaTokenizer overwrites their ByteLevel +# pre_tokenizer/decoder with Metaspace, causing decode(encode(text)) != text. +# See https://github.com/huggingface/transformers/issues/43066 +# - DeepSeek-V3.1: its tool-call chat template concatenates function.arguments as a string, +# but our dummy tool-call shape provides a dict, raising TypeError before the round-trip check. +# - glm-4-9b-chat: v5 removed the legacy _decode special-token segmentation, exposing a bug in +# the model's custom convert_tokens_to_string (doesn't handle str-type special tokens). +TOOL_CALL_KNOWN_FAILURES = [ + "deepseek-ai/DeepSeek-V3", + "deepseek-ai/DeepSeek-V3.1", + "stepfun-ai/step3", + "THUDM/glm-4-9b-chat", ] -# Models where tokenize->decode produces extra whitespace vs direct string diff -TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ - "THUDM/glm-4-9b-chat", +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo ] SAMPLE_TOOL_RESPONSES = [ @@ -83,11 +98,6 @@ def test_tokenize_tool_responses(self, model_name, num_tools): base_messages = [_DUMMY_USER, dummy_assistant] expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) - if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: - # Some models produce whitespace differences between tokenize->decode and direct string diff - actual_str = actual_str.replace(" ", "") - expected_str = expected_str.replace(" ", "") - assert actual_str == expected_str, f"{model_name=}" @staticmethod From b1a4346cf2725b13767991d3c6f0c98b4108fd56 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:50:16 -0700 Subject: [PATCH 24/44] [v0.5.10][3] Fix processor return_tensors duplicate kwarg for transformers >=5.0 (#927) Co-authored-by: guapisolo Co-authored-by: Claude Opus 4.6 (1M context) --- miles/utils/processing_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/miles/utils/processing_utils.py b/miles/utils/processing_utils.py index 75fd2fb75e..3291a7ffe4 100644 --- a/miles/utils/processing_utils.py +++ b/miles/utils/processing_utils.py @@ -28,17 +28,13 @@ def load_tokenizer(name_or_path: str, chat_template_path: str = None, **kwargs): def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict: - forced = { - # force return_tensors to None for input_ids - "return_tensors": None, - } modality_forced = {"return_tensors": "pt"} result = dict(multimodal_inputs) if multimodal_inputs else {} - result.update(forced) - - # set return_tensors="pt" for modality-specific outputs + # return_tensors=None for text (input_ids), "pt" for modality-specific outputs. + # Use per-modality dicts to avoid transformers >=5.0 duplicate kwarg error. + result["text_kwargs"] = {**result.get("text_kwargs", {}), "return_tensors": None} for key in ("audio_kwargs", "images_kwargs", "videos_kwargs"): if key in result: result[key] = {**result[key], **modality_forced} From 2a991083f1f09dc6d6b436b9d822d26c4678bdd4 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:50:36 -0700 Subject: [PATCH 25/44] [v0.5.10][4] Fix _no_split_modules set not subscriptable in transformers >=5.0 (#931) --- miles/backends/fsdp_utils/actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 7bdd1c17ad..5bb3dd0c7b 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -681,7 +681,7 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None): offload_policy = CPUOffloadPolicy() if cpu_offload else None layer_cls_to_wrap = model._no_split_modules - assert len(layer_cls_to_wrap) > 0 and layer_cls_to_wrap[0] is not None + assert len(layer_cls_to_wrap) > 0 and next(iter(layer_cls_to_wrap)) is not None modules = [ module From c74392dd56b3198f6c6baad02f5836077cdf0efe Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:50:48 -0700 Subject: [PATCH 26/44] [v0.5.10][5] Disable piecewise cuda graph to avoid NVLS oom (#935) --- docs/en/get_started/qa.md | 6 +++++- miles/utils/arguments.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/en/get_started/qa.md b/docs/en/get_started/qa.md index c9e8dad21f..f6d5beea11 100644 --- a/docs/en/get_started/qa.md +++ b/docs/en/get_started/qa.md @@ -65,4 +65,8 @@ 13. **Gradient becomes NaN or Inf during training.** - You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps. \ No newline at end of file + You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps. + +14. **NCCL error: `Failed to bind NVLink SHARP (NVLS) Multicast memory ... CUDA error 2 'out of memory'`.** + + This issue has been observed on H100 in colocate mode with piece-wise CUDA graph enabled. Piece-wise CUDA graph is now disabled by default in colocate mode. If you encounter this after explicitly enabling it via `--sglang-enforce-piecewise-cuda-graph`, remove that flag. diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 0bd8b8cba3..5ee2beea87 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1972,6 +1972,14 @@ def miles_validate_args(args): args.offload_train = True if args.offload_rollout is None: args.offload_rollout = True + if args.sglang_enforce_piecewise_cuda_graph: + logger.warning("Warning: colocate mode with --sglang-enforce-piecewise-cuda-graph may trigger NVLS OOM.") + if not args.sglang_disable_piecewise_cuda_graph: + args.sglang_disable_piecewise_cuda_graph = True + logger.info( + "Colocate mode: defaulting --sglang-disable-piecewise-cuda-graph to avoid NVLS OOM. " + "Use --sglang-enforce-piecewise-cuda-graph to override." + ) if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes: logger.info( f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} " From d6158f8ae9b412e4a29b9b8fb30955484300f38e Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 12:51:46 -0700 Subject: [PATCH 27/44] [v0.5.10][6][FSDP] fix outdated weight update logic in FSDP (#948) Co-authored-by: guapisolo Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../fsdp_utils/update_weight_utils.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 32948b2840..98000e7d34 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -17,7 +17,7 @@ from sglang.srt.utils import MultiprocessingSerializer -from miles.utils.distributed_utils import init_process_group +from miles.utils.distributed_utils import get_gloo_group, init_process_group try: @@ -47,6 +47,13 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 + + if dist.get_rank() == 0: + futures = [engine.pause_generation.remote() for engine in self.rollout_engines] + futures.extend([engine.flush_cache.remote() for engine in self.rollout_engines]) + ray.get(futures) + dist.barrier(group=get_gloo_group()) + bucket = [] bucket_size = 0 for name, param in self.model.state_dict().items(): @@ -73,6 +80,11 @@ def update_weights(self) -> None: bucket = [] bucket_size = 0 + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] self.update_bucket_weights(bucket, weight_version=self.weight_version) @@ -172,8 +184,13 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) result = ray.get(ref) - if hasattr(result, "success") and not result.success: + if isinstance(result, dict): + success = result.get("success", True) + error_msg = result.get("error_message") or result.get("message", "unknown error") + else: + success = getattr(result, "success", True) error_msg = getattr(result, "error_message", "unknown error") + if not success: raise RuntimeError( f"Weight sync failed on rollout engine: {error_msg}. " f"Check SGLang version compatibility." ) From c4e50c8d2536c621c26160a33768e575b5c0d944 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 14:29:18 -0700 Subject: [PATCH 28/44] [v0.5.10][7][FSDP] move FSDP to experimental and disable by default (#961) --- .github/workflows/pr-test.yml | 118 +----------------- .github/workflows/pr-test.yml.j2 | 16 +-- .../kernels => experimental}/__init__.py | 0 .../{ => experimental}/fsdp_utils/__init__.py | 0 .../{ => experimental}/fsdp_utils/actor.py | 12 +- .../fsdp_utils/arguments.py | 0 .../fsdp_utils/checkpoint.py | 0 .../fsdp_utils/kernels}/__init__.py | 0 .../fsdp_utils/kernels/fused_experts.py | 0 .../fused_moe_triton_backward_kernels.py | 0 .../fsdp_utils/lr_scheduler.py | 0 .../fsdp_utils/models/__init__.py | 0 .../fsdp_utils/models/qwen3_moe.py | 2 +- .../fsdp_utils/models/qwen3_moe_hf.py | 0 .../{ => experimental}/fsdp_utils/parallel.py | 2 +- .../fsdp_utils/update_weight_utils.py | 0 miles/ray/actor_group.py | 2 +- miles/utils/arguments.py | 9 +- tests/fast/fixtures/generation_fixtures.py | 1 + tests/fast/fixtures/rollout_fixtures.py | 1 + .../rollout/inference_rollout/conftest.py | 1 + tests/test_fused_experts_backward.py | 4 +- 22 files changed, 30 insertions(+), 138 deletions(-) rename miles/backends/{fsdp_utils/kernels => experimental}/__init__.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/__init__.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/actor.py (98%) rename miles/backends/{ => experimental}/fsdp_utils/arguments.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/checkpoint.py (100%) rename miles/backends/{fsdp_utils/models => experimental/fsdp_utils/kernels}/__init__.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/kernels/fused_experts.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/lr_scheduler.py (100%) create mode 100644 miles/backends/experimental/fsdp_utils/models/__init__.py rename miles/backends/{ => experimental}/fsdp_utils/models/qwen3_moe.py (98%) rename miles/backends/{ => experimental}/fsdp_utils/models/qwen3_moe_hf.py (100%) rename miles/backends/{ => experimental}/fsdp_utils/parallel.py (96%) rename miles/backends/{ => experimental}/fsdp_utils/update_weight_utils.py (100%) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 9a47e2a0ef..656c43f9ea 100755 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -166,118 +166,6 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} - unit-test: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-unit-test')) - runs-on: self-hosted - container: - image: radixark/miles:dev - options: > - --gpus all - --ipc=host - --shm-size=32g - --ulimit memlock=-1 - --ulimit stack=67108864 - --memory=0 - --memory-swap=0 - -v /mnt/nvme0n1/miles_ci:/data/miles_ci - -v /mnt/nvme0n1/miles_ci/models:/root/models - -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets - --privileged - --ulimit nofile=65535:65535 - -v /tmp:/tmp - strategy: - fail-fast: false - matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}] - defaults: - run: - working-directory: ${{ github.workspace }} - env: - GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - HF_TOKEN: ${{ secrets.HF_TOKEN }} - MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} - MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} - MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} - MILES_TEST_USE_INT4_ROLLOUT: ${{ matrix.info.use_int4_rollout || '0' }} - MILES_TEST_USE_BRIDGE: ${{ matrix.info.use_bridge || '0' }} - MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} - MILES_TEST_FEW_GPU: '0' - SESSION_TEST_MODEL_FAMILY: ${{ matrix.info.model_family || '' }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Cleanup Ray processes - shell: bash - run: | - pkill -9 -f 'ray::' 2>/dev/null || true - pkill -9 -f raylet 2>/dev/null || true - pkill -9 -f gcs_server 2>/dev/null || true - pkill -9 -f 'ray-dashboard' 2>/dev/null || true - pkill -9 sglang 2>/dev/null || true - ray stop --force 2>/dev/null || true - rm -rf /tmp/ray/* 2>/dev/null || true - sleep 3 - - - - name: Resolve dependency refs - id: resolve-refs - shell: bash - env: - PR_BODY: ${{ github.event.pull_request.body || '' }} - INPUT_MEGATRON_PR: ${{ github.event.inputs.ci_megatron_pr || '' }} - INPUT_SGLANG_PR: ${{ github.event.inputs.ci_sglang_pr || '' }} - run: | - # Priority: workflow_dispatch input > PR description > default - MEGATRON_PR="${INPUT_MEGATRON_PR}" - SGLANG_PR="${INPUT_SGLANG_PR}" - - # Parse PR description for "ci-megatron-pr:" and "ci-sglang-pr:" - if [ -n "$PR_BODY" ]; then - PR_MEGATRON_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-megatron-pr:\s)\S+' || true) - PR_SGLANG_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-sglang-pr:\s)\S+' || true) - [ -z "$MEGATRON_PR" ] && [ -n "$PR_MEGATRON_PR" ] && MEGATRON_PR="$PR_MEGATRON_PR" - [ -z "$SGLANG_PR" ] && [ -n "$PR_SGLANG_PR" ] && SGLANG_PR="$PR_SGLANG_PR" - fi - - # Defaults - [ -z "$MEGATRON_PR" ] && MEGATRON_PR="miles-main" - [ -z "$SGLANG_PR" ] && SGLANG_PR="sglang-miles" - - # Convert "#N" PR syntax to git fetch ref: "pull/N/head" - resolve_fetch_ref() { - local ref="$1" - if [[ "$ref" =~ ^#([0-9]+)$ ]]; then - echo "pull/${BASH_REMATCH[1]}/head" - else - echo "$ref" - fi - } - MEGATRON_FETCH=$(resolve_fetch_ref "$MEGATRON_PR") - SGLANG_FETCH=$(resolve_fetch_ref "$SGLANG_PR") - - echo "ci_megatron_pr=$MEGATRON_FETCH" >> $GITHUB_OUTPUT - echo "ci_sglang_pr=$SGLANG_FETCH" >> $GITHUB_OUTPUT - echo "Resolved: megatron=$MEGATRON_PR -> fetch=$MEGATRON_FETCH, sglang=$SGLANG_PR -> fetch=$SGLANG_FETCH" - - - name: Install - shell: bash - env: - MEGATRON_PR: ${{ steps.resolve-refs.outputs.ci_megatron_pr }} - SGLANG_PR: ${{ steps.resolve-refs.outputs.ci_sglang_pr }} - run: | - cd /sgl-workspace/sglang && git reset --hard HEAD && git clean -fd && git fetch origin "$SGLANG_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages - cd /root/Megatron-LM && git reset --hard HEAD && git clean -fd && git fetch origin "$MEGATRON_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e . --no-deps --break-system-packages - cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - pip install pytest-asyncio --break-system-packages - - - - name: Execute - shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - e2e-test-sglang: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-sglang')) runs-on: self-hosted @@ -412,7 +300,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -524,7 +412,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}] + info: [{"name": "[FSDP] qwen3-4B-fsdp-true-on-policy", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"name": "[FSDP] qwen3-vl-4B-fsdp", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-distributed", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"name": "[FSDP] qwen3-0.6B-megatron-fsdp-align", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU", "num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -1375,7 +1263,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] + info: [{"name": "[FSDP] qwen3-4B-fsdp-true-on-policy", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"name": "[FSDP] qwen3-vl-4B-fsdp", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-distributed", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"name": "[FSDP] qwen3-0.6B-megatron-fsdp-align", "num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"name": "[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU", "num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"name": "qwen3-30B-A3B-deepep-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-bridge", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_bridge": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload.py"}, {"num_gpus": 8, "test_file": "e2e/sglang_config/test_sglang_config_mixed_offload_ft.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}, {"name": "qwen3-30B-A3B-bf16", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0"}, {"name": "qwen3-30B-A3B-rollout-fp8", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"name": "qwen3-30B-A3B-rollout-int4", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "0", "use_fp8_rollout": "0", "use_int4_rollout": "1"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0db15cad65..cc97dfc415 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,10 +1,11 @@ <% set default_image = 'radixark/miles:dev' %> <% set fsdp_tests = [ - {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 8}, - {'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-4B-fsdp-true-on-policy', 'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-vl-4B-fsdp', 'test_file': 'e2e/fsdp/test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-fsdp-distributed', 'test_file': 'e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-megatron-fsdp-align', 'test_file': 'e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 8}, + {'name': '[FSDP] qwen3-0.6B-fsdp-colocated-2xGPU', 'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 8}, ] %> <% set megatron_tests = [ @@ -27,7 +28,6 @@ <% set short_tests = [ {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 8}, {'test_file': 'e2e/short/test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 8}, - {'test_file': 'e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config_mixed_offload.py', 'num_gpus': 8}, {'test_file': 'e2e/sglang_config/test_sglang_config_mixed_offload_ft.py', 'num_gpus': 8}, @@ -67,12 +67,6 @@ {'test_file': 'utils/test_sglang_config.py', 'num_gpus': 0}, ], }, - 'unit-test': { - 'label': 'run-unit-test', - 'tests': [ - {'test_file': 'e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 8} - ], - }, 'e2e-test-sglang': { 'label': 'run-ci-sglang', 'test_executor': 'pytest', diff --git a/miles/backends/fsdp_utils/kernels/__init__.py b/miles/backends/experimental/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/__init__.py rename to miles/backends/experimental/__init__.py diff --git a/miles/backends/fsdp_utils/__init__.py b/miles/backends/experimental/fsdp_utils/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/__init__.py rename to miles/backends/experimental/fsdp_utils/__init__.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/experimental/fsdp_utils/actor.py similarity index 98% rename from miles/backends/fsdp_utils/actor.py rename to miles/backends/experimental/fsdp_utils/actor.py index 5bb3dd0c7b..47c2540f98 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/experimental/fsdp_utils/actor.py @@ -20,17 +20,17 @@ from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking -from ...utils.profile_utils import TrainProfiler -from ..training_utils.ci_utils import check_grad_norm -from ..training_utils.data import DataIterator, get_batch, get_data_iterator, get_rollout_data -from ..training_utils.log_utils import ( +from ....utils.profile_utils import TrainProfiler +from ...training_utils.ci_utils import check_grad_norm +from ...training_utils.data import DataIterator, get_batch, get_data_iterator, get_rollout_data +from ...training_utils.log_utils import ( aggregate_forward_results, aggregate_train_losses, log_rollout_data, log_train_step, ) -from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function -from ..training_utils.parallel import get_parallel_state, set_parallel_state +from ...training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function +from ...training_utils.parallel import get_parallel_state, set_parallel_state from . import checkpoint from .lr_scheduler import get_lr_scheduler from .parallel import create_fsdp_parallel_state diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/experimental/fsdp_utils/arguments.py similarity index 100% rename from miles/backends/fsdp_utils/arguments.py rename to miles/backends/experimental/fsdp_utils/arguments.py diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py similarity index 100% rename from miles/backends/fsdp_utils/checkpoint.py rename to miles/backends/experimental/fsdp_utils/checkpoint.py diff --git a/miles/backends/fsdp_utils/models/__init__.py b/miles/backends/experimental/fsdp_utils/kernels/__init__.py similarity index 100% rename from miles/backends/fsdp_utils/models/__init__.py rename to miles/backends/experimental/fsdp_utils/kernels/__init__.py diff --git a/miles/backends/fsdp_utils/kernels/fused_experts.py b/miles/backends/experimental/fsdp_utils/kernels/fused_experts.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/fused_experts.py rename to miles/backends/experimental/fsdp_utils/kernels/fused_experts.py diff --git a/miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py b/miles/backends/experimental/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py similarity index 100% rename from miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py rename to miles/backends/experimental/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py diff --git a/miles/backends/fsdp_utils/lr_scheduler.py b/miles/backends/experimental/fsdp_utils/lr_scheduler.py similarity index 100% rename from miles/backends/fsdp_utils/lr_scheduler.py rename to miles/backends/experimental/fsdp_utils/lr_scheduler.py diff --git a/miles/backends/experimental/fsdp_utils/models/__init__.py b/miles/backends/experimental/fsdp_utils/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/miles/backends/fsdp_utils/models/qwen3_moe.py b/miles/backends/experimental/fsdp_utils/models/qwen3_moe.py similarity index 98% rename from miles/backends/fsdp_utils/models/qwen3_moe.py rename to miles/backends/experimental/fsdp_utils/models/qwen3_moe.py index fe2133f3c1..7471f4aa1a 100644 --- a/miles/backends/fsdp_utils/models/qwen3_moe.py +++ b/miles/backends/experimental/fsdp_utils/models/qwen3_moe.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP -from miles.backends.fsdp_utils.kernels.fused_experts import ( +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import ( DownProjFunction, GateUpProjFunction, MoeSumReduceFunction, diff --git a/miles/backends/fsdp_utils/models/qwen3_moe_hf.py b/miles/backends/experimental/fsdp_utils/models/qwen3_moe_hf.py similarity index 100% rename from miles/backends/fsdp_utils/models/qwen3_moe_hf.py rename to miles/backends/experimental/fsdp_utils/models/qwen3_moe_hf.py diff --git a/miles/backends/fsdp_utils/parallel.py b/miles/backends/experimental/fsdp_utils/parallel.py similarity index 96% rename from miles/backends/fsdp_utils/parallel.py rename to miles/backends/experimental/fsdp_utils/parallel.py index 81e682660a..fa444975bd 100644 --- a/miles/backends/fsdp_utils/parallel.py +++ b/miles/backends/experimental/fsdp_utils/parallel.py @@ -7,7 +7,7 @@ from miles.utils.distributed_utils import get_gloo_group -from ..training_utils.parallel import GroupInfo, ParallelState +from ...training_utils.parallel import GroupInfo, ParallelState logger = logging.getLogger(__name__) diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/experimental/fsdp_utils/update_weight_utils.py similarity index 100% rename from miles/backends/fsdp_utils/update_weight_utils.py rename to miles/backends/experimental/fsdp_utils/update_weight_utils.py diff --git a/miles/ray/actor_group.py b/miles/ray/actor_group.py index 54228f4228..669988b52c 100644 --- a/miles/ray/actor_group.py +++ b/miles/ray/actor_group.py @@ -81,7 +81,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): actor_impl = MegatronTrainRayActor else: - from miles.backends.fsdp_utils import FSDPTrainRayActor + from miles.backends.experimental.fsdp_utils import FSDPTrainRayActor actor_impl = FSDPTrainRayActor diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 5ee2beea87..f7726618cb 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1709,7 +1709,7 @@ def parse_args(add_custom_arguments=None): args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node args = set_default_megatron_args(args) else: - from miles.backends.fsdp_utils.arguments import load_fsdp_args + from miles.backends.experimental.fsdp_utils.arguments import load_fsdp_args args = load_fsdp_args(extra_args_provider=add_miles_arguments) args.rank = 0 # Primary process rank for wandb initialization @@ -1717,6 +1717,13 @@ def parse_args(add_custom_arguments=None): assert args.context_parallel_size == 1, "Context parallelism is not supported for FSDP backend." + if not args.ci_test: + raise ValueError( + "The FSDP backend has known issues with SGLang v0.5.10 and is not actively maintained in the current version. " + "It has been moved to miles.backends.experimental. " + "Contributions are welcome if you are interested in improving it." + ) + miles_validate_args(args) if backend == "megatron": diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py index a56ec33df1..7e70dcd187 100644 --- a/tests/fast/fixtures/generation_fixtures.py +++ b/tests/fast/fixtures/generation_fixtures.py @@ -161,6 +161,7 @@ def make_args( "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "1", "--num-rollout", diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py index b54c7b9a51..90bfdd197d 100644 --- a/tests/fast/fixtures/rollout_fixtures.py +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -43,6 +43,7 @@ def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | Non "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "1", "--n-samples-per-prompt", diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py index ca47edeeb6..d848ef0b2a 100644 --- a/tests/fast/rollout/inference_rollout/conftest.py +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -10,6 +10,7 @@ def _build_mock_args(extra_argv: list[str] | None = None): "pytest", "--train-backend", "fsdp", + "--ci-test", "--rollout-batch-size", "2", "--n-samples-per-prompt", diff --git a/tests/test_fused_experts_backward.py b/tests/test_fused_experts_backward.py index e2a94897b2..a89e2de51a 100644 --- a/tests/test_fused_experts_backward.py +++ b/tests/test_fused_experts_backward.py @@ -260,8 +260,8 @@ def backward(ctx, grad_output): # Import Triton Implementation # ============================================================================ -from miles.backends.fsdp_utils.kernels.fused_experts import DownProjFunction as DownProjFunctionTriton -from miles.backends.fsdp_utils.kernels.fused_experts import GateUpProjFunction as GateUpProjFunctionTriton +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import DownProjFunction as DownProjFunctionTriton +from miles.backends.experimental.fsdp_utils.kernels.fused_experts import GateUpProjFunction as GateUpProjFunctionTriton # ============================================================================ # Test Fixtures and Utilities From 8d66ac196799a0b5fd22a7e8a8d0d57c87b62ad6 Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:21:02 -0700 Subject: [PATCH 29/44] Add skiplist and more robust calculation on val (#965) --- miles/backends/training_utils/log_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index c6ed8b6ffb..54aa44a6d8 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -121,6 +121,8 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "rollout_routed_experts", "max_seq_lens", "dynamic_global_batch_size", + "weight_versions", + "metadata", ]: continue # Upload per sample mean for each rollout value @@ -151,7 +153,14 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc else: val = val.mean() * cp_size else: - val = sum(val) / len(val) + # Flatten nested lists (e.g. list of lists from async rollout) + flat = val + if isinstance(val[0], (list, tuple)): + flat = [x for sublist in val for x in sublist] + # Skip non-numeric values (e.g. strings from async rollout metadata) + if flat and not isinstance(flat[0], (int, float)): + continue + val = sum(flat) / len(flat) elif isinstance(val, torch.Tensor): val = val.float().mean() else: From 02f6e05007e059fe024f78890c07cae0e5e2380c Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 9 Apr 2026 23:19:09 -0700 Subject: [PATCH 30/44] [fix] tiny fix debug rollout only in weight version check (#967) --- miles/utils/types.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/miles/utils/types.py b/miles/utils/types.py index 540648e1fd..4d7d6ef9b2 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -214,9 +214,8 @@ def reset_for_retry(self) -> None: @property def oldest_weight_version(self) -> int | None: """Minimum weight version across all turns (generation calls) for this trajectory.""" - if not self.weight_versions: - return None - return min(int(v) for v in self.weight_versions) + numeric = [int(v) for v in self.weight_versions if str(v).isdigit()] + return min(numeric) if numeric else None def update_from_meta_info(self, args, meta_info: dict): """ From eb294e332aea137efccadc79a7e7bacf18911c73 Mon Sep 17 00:00:00 2001 From: Zhichen Zeng Date: Fri, 10 Apr 2026 20:06:54 -0700 Subject: [PATCH 31/44] feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch (#885) --- miles/backends/megatron_utils/actor.py | 9 + miles/backends/training_utils/cp_utils.py | 36 ++++ miles/utils/arguments.py | 3 + miles_plugins/models/cp_utils.py | 26 +++ miles_plugins/models/hf_attention.py | 143 +++++++++++++- miles_plugins/models/qwen3_5.py | 44 +++-- miles_plugins/models/qwen3_next.py | 42 +++-- scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py | 178 ++++++++++++++++++ tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py | 153 +++++++++++++++ .../test_qwen3_5_mtp_bridge_mapping.py | 0 .../test_hf_attention_cp_relayout.py | 101 ++++++++++ .../precision/test_qwen3_5_cp_correctness.py | 147 +++++++++++++++ .../fast/backends/training_utils/__init__.py | 1 + 13 files changed, 856 insertions(+), 27 deletions(-) create mode 100644 miles_plugins/models/cp_utils.py create mode 100644 scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py create mode 100644 tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py rename tests/{ => e2e/megatron}/test_qwen3_5_mtp_bridge_mapping.py (100%) create mode 100644 tests/e2e/precision/test_hf_attention_cp_relayout.py create mode 100644 tests/e2e/precision/test_qwen3_5_cp_correctness.py create mode 100644 tests/fast/backends/training_utils/__init__.py diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 0c265a508e..8b3ea9975c 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -115,6 +115,15 @@ def init( args, role ) + parallel_state = get_parallel_state() + if parallel_state.cp.size > 1: + from miles_plugins.models.cp_utils import detect_and_setup_hybrid_cp + + for model_chunk in self.model: + detect_and_setup_hybrid_cp( + model_chunk, parallel_state.cp.group, parallel_state.cp.rank, parallel_state.cp.size + ) + verify_megatron_parallel_state(self.model) if role == "critic": diff --git a/miles/backends/training_utils/cp_utils.py b/miles/backends/training_utils/cp_utils.py index 0fbba35b02..e79ccb4469 100644 --- a/miles/backends/training_utils/cp_utils.py +++ b/miles/backends/training_utils/cp_utils.py @@ -1,11 +1,20 @@ +import logging from collections.abc import Callable import torch import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from .parallel import get_parallel_state +try: + from fla.ops.cp import build_cp_context as _fla_build_cp_context +except ImportError: + _fla_build_cp_context = None + +logger = logging.getLogger(__name__) + def get_logits_and_tokens_offset_with_cp( total_length: int, @@ -336,3 +345,30 @@ def slice_log_prob_with_cp( return chunk_1 + chunk_2 else: return torch.cat([chunk_1, chunk_2], dim=0) + + +def build_gdn_cp_context(module: nn.Module, cu_seqlens: torch.Tensor, device: torch.device): + """Build fla CP context for a GatedDeltaNet module from packed sequence boundaries. + + Args: + module: GDN module with ``cp_group`` / ``cp_world_size`` / ``conv_kernel_size``. + cu_seqlens: Global packed sequence boundaries (e.g. ``packed_seq_params.cu_seqlens_q``). + device: Target device. + + Returns ``None`` when CP is not configured on the module (``cp_group`` not set). + Raises ``RuntimeError`` if hybrid CP is configured but ``fla.ops.cp`` is missing. + """ + cp_group = getattr(module, "cp_group", None) + if cp_group is None: + return None + if _fla_build_cp_context is None: + raise RuntimeError( + "Hybrid CP requires fla.ops.cp (flash-linear-attention >= 0.4.2) " "but it could not be imported." + ) + if cu_seqlens is None or cu_seqlens.numel() < 2: + raise ValueError(f"Hybrid CP requires valid cu_seqlens (at least 2 elements) but got {cu_seqlens}") + return _fla_build_cp_context( + cu_seqlens=cu_seqlens.to(device=device, dtype=torch.int32), + group=cp_group, + conv1d_kernel_size=module.conv_kernel_size, + ) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index f7726618cb..375cd6c2c2 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -2140,6 +2140,9 @@ def equal(x, y): ), ("rope_theta", "rotary_base", equal), ]: + # FIXME: Qwen3.5 transfomers has bug. + if getattr(hf_config, "model_type", "") == "qwen3_5_moe_text" and hf_config_name == "intermediate_size": + continue if hasattr(hf_config, hf_config_name): if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( diff --git a/miles_plugins/models/cp_utils.py b/miles_plugins/models/cp_utils.py new file mode 100644 index 0000000000..87f4205f34 --- /dev/null +++ b/miles_plugins/models/cp_utils.py @@ -0,0 +1,26 @@ +import logging + +import torch.distributed as dist +import torch.nn as nn + +from miles_plugins.models.hf_attention import HuggingfaceAttention + +logger = logging.getLogger(__name__) + + +def detect_and_setup_hybrid_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> int: + """Scan for GatedDeltaNet modules and configure them for native fla CP.""" + count = 0 + for module in model.modules(): + if isinstance(module, HuggingfaceAttention): + linear_attn = getattr(module, "linear_attn", None) + if linear_attn is not None: + linear_attn.cp_group = cp_group + linear_attn.cp_rank = cp_rank + linear_attn.cp_world_size = cp_world_size + module.hybrid_cp = True + count += 1 + + if count > 0: + logger.info(f"Configured hybrid CP on {count} GDN modules (fla native state passing)") + return count diff --git a/miles_plugins/models/hf_attention.py b/miles_plugins/models/hf_attention.py index 7abe09b0ee..aeacc58ea2 100644 --- a/miles_plugins/models/hf_attention.py +++ b/miles_plugins/models/hf_attention.py @@ -38,6 +38,116 @@ def _fix_dtype(d): return ns +def _get_cp_sequence_lengths(cu_seqlens, cp_size, local_total_len=None): + global_seq_lengths = [(cu_seqlens[i + 1] - cu_seqlens[i]).item() for i in range(len(cu_seqlens) - 1)] + local_seq_lengths = [] + for global_seq_len in global_seq_lengths: + if global_seq_len % cp_size != 0: + raise ValueError(f"Expected sequence length {global_seq_len} to be divisible by cp_size={cp_size}") + local_seq_lengths.append(global_seq_len // cp_size) + + if local_total_len is not None and sum(local_seq_lengths) != local_total_len: + raise ValueError(f"Expected local total length {local_total_len}, got {sum(local_seq_lengths)}") + + return global_seq_lengths, local_seq_lengths + + +def _gather_cp_tensors(x, cp_group): + gathered = [torch.empty_like(x) for _ in range(dist.get_world_size(group=cp_group))] + dist.all_gather(gathered, x.contiguous(), group=cp_group) + return gathered + + +def _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + """Convert zigzag ring-attn layout to the contiguous packed shard expected by fla CP.""" + global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0)) + gathered_by_rank = [ + gathered.split(local_seq_lengths, dim=0) for gathered in _gather_cp_tensors(hidden_states, cp_group) + ] + + full_sequences = [] + for seq_idx, global_seq_len in enumerate(global_seq_lengths): + per_rank = [rank_seqs[seq_idx] for rank_seqs in gathered_by_rank] + if global_seq_len % (2 * cp_size) == 0: + subchunk_len = global_seq_len // (2 * cp_size) + full_seq = torch.cat( + [seq[:subchunk_len] for seq in per_rank] + [seq[subchunk_len:] for seq in per_rank][::-1], + dim=0, + ) + else: + # Final local padding is appended contiguously on each rank, not in zigzag order. + full_seq = torch.cat(per_rank, dim=0) + full_sequences.append(full_seq) + + full_stream = torch.cat(full_sequences, dim=0) if full_sequences else hidden_states[:0] + shard_len = hidden_states.size(0) + return full_stream[cp_rank * shard_len : (cp_rank + 1) * shard_len] + + +def _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + """Convert contiguous packed shard layout back to zigzag ring-attn layout.""" + global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0)) + full_stream = torch.cat(_gather_cp_tensors(hidden_states, cp_group), dim=0) + full_sequences = full_stream.split(global_seq_lengths, dim=0) + + local_sequences = [] + for full_seq, global_seq_len, local_seq_len in zip( + full_sequences, global_seq_lengths, local_seq_lengths, strict=True + ): + if global_seq_len % (2 * cp_size) == 0: + subchunk_len = global_seq_len // (2 * cp_size) + parts = full_seq.split(subchunk_len, dim=0) + local_sequences.append(torch.cat([parts[cp_rank], parts[2 * cp_size - 1 - cp_rank]], dim=0)) + else: + local_sequences.append(full_seq.split(local_seq_len, dim=0)[cp_rank]) + + return torch.cat(local_sequences, dim=0) if local_sequences else hidden_states[:0] + + +class _ZigzagToPackedShard(torch.autograd.Function): + """Convert zigzag ring-attn layout to contiguous packed shards for native fla CP.""" + + @staticmethod + def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + ctx.cp_group = cp_group + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + ctx.save_for_backward(cu_seqlens) + return _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + @staticmethod + def backward(ctx, grad_output): + (cu_seqlens,) = ctx.saved_tensors + result = _packed_shard_to_zigzag_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size) + return result, None, None, None, None + + +class _PackedShardToZigzag(torch.autograd.Function): + """Convert contiguous packed shards back to zigzag ring-attn layout.""" + + @staticmethod + def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + ctx.cp_group = cp_group + ctx.cp_rank = cp_rank + ctx.cp_size = cp_size + ctx.save_for_backward(cu_seqlens) + return _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + @staticmethod + def backward(ctx, grad_output): + (cu_seqlens,) = ctx.saved_tensors + result = _zigzag_to_packed_shard_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size) + return result, None, None, None, None + + +def _zigzag_to_packed_shard(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + return _ZigzagToPackedShard.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + +def _packed_shard_to_zigzag(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size): + return _PackedShardToZigzag.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size) + + class _AllGatherForDuplicatedComputation(torch.autograd.Function): """All-gather whose backward just returns the local gradient slice (no reduce). @@ -68,6 +178,10 @@ class HuggingfaceAttention(MegatronModule, ABC): "cross attn" specializations. """ + # Subclasses set this to True when the underlying module handles CP natively + # (e.g. via fla's state-passing CP for DeltaNet), bypassing the all-gather. + hybrid_cp: bool = False + def __init__( self, args, @@ -115,7 +229,22 @@ def forward( group=mpu.get_tensor_model_parallel_group(), ) - if mpu.get_context_parallel_world_size() > 1: + if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp: + cp_size = mpu.get_context_parallel_world_size() + # Native fla CP expects each rank to own a contiguous shard of the + # packed global token stream. In allgather-CP mode the data pipeline + # already provides that layout, so no extra relayout is + # needed here. + if not self.args.allgather_cp: + hidden_states = _zigzag_to_packed_shard( + hidden_states, + cu_seqlens, + mpu.get_context_parallel_group(), + mpu.get_context_parallel_rank(), + cp_size, + ) + + elif mpu.get_context_parallel_world_size() > 1: cp_size = mpu.get_context_parallel_world_size() # Use custom all-gather whose backward returns local gradient # instead of reduce-scatter, since the computation is duplicated. @@ -150,7 +279,17 @@ def forward( output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim] - if mpu.get_context_parallel_world_size() > 1: + if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp: + if not self.args.allgather_cp: + output = _packed_shard_to_zigzag( + output, + cu_seqlens, + mpu.get_context_parallel_group(), + mpu.get_context_parallel_rank(), + cp_size, + ) + + elif mpu.get_context_parallel_world_size() > 1: cp_rank = mpu.get_context_parallel_rank() output_list = [] for i in range(len(cu_seqlens) - 1): diff --git a/miles_plugins/models/qwen3_5.py b/miles_plugins/models/qwen3_5.py index a796c8c49c..794cf73808 100644 --- a/miles_plugins/models/qwen3_5.py +++ b/miles_plugins/models/qwen3_5.py @@ -15,6 +15,8 @@ except ImportError: pass +from miles.backends.training_utils.cp_utils import build_gdn_cp_context + from .hf_attention import HuggingfaceAttention, _load_hf_config @@ -88,6 +90,8 @@ def forward( ): batch_size, seq_len, _ = hidden_states.shape + cp_context = build_gdn_cp_context(self, cu_seqlens, hidden_states.device) + # Projections (flat layout: [Q_all, K_all, V_all]) mixed_qkv = self.in_proj_qkv(hidden_states) z = self.in_proj_z(hidden_states) @@ -95,10 +99,12 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - # Convolution on the flat QKV + # Convolution on the flat QKV (pass cp_context for boundary handling) + conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens mixed_qkv, _ = self.conv1d( x=mixed_qkv, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, + cp_context=cp_context, ) # Split into Q, K, V (flat split, matching HF layout) @@ -118,17 +124,29 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) + if cp_context is not None: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cp_context.cu_seqlens, + cp_context=cp_context, + ) + else: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/miles_plugins/models/qwen3_next.py b/miles_plugins/models/qwen3_next.py index 92e39ff318..1dbee8acd0 100644 --- a/miles_plugins/models/qwen3_next.py +++ b/miles_plugins/models/qwen3_next.py @@ -18,6 +18,8 @@ except ImportError: pass +from miles.backends.training_utils.cp_utils import build_gdn_cp_context + from .hf_attention import HuggingfaceAttention @@ -108,6 +110,8 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, ): + cp_context = build_gdn_cp_context(self, cu_seqlens, hidden_states.device) + projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) @@ -115,9 +119,11 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) + conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens mixed_qkv, _ = self.conv1d( x=mixed_qkv, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, + cp_context=cp_context, ) query, key, value = torch.split( @@ -140,17 +146,29 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) + if cp_context is not None: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cp_context.cu_seqlens, + cp_context=cp_context, + ) + else: + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py b/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py new file mode 100644 index 0000000000..ee70d5f8b0 --- /dev/null +++ b/scripts/run_qwen3_5_35b_a3b_mtp_cp2_ep8.py @@ -0,0 +1,178 @@ +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + model_name: str = "Qwen3.5-35B-A3B" + megatron_model_type: str = "qwen3.5-35B-A3B" + num_gpus_per_node: int = 8 + hardware: Literal["H200"] = "H200" + enable_eval: bool = True + extra_args: str = "" + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command("pip install transformers==5.2.0") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.hf_download_dataset("zhuzilin/aime-2024", data_dir=args.data_dir) + + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{args.model_name}", + megatron_path=args.megatron_path, + ) + + +def execute(args: ScriptArgs): + ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{args.model_name} " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + f"--save {load_save_path} " + f"--save-interval {2 if args.mode == 'debug_minimal' else 20} " + ) + + rollout_args = ( + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + f"--num-rollout {64 if args.mode == 'debug_minimal' else 3000} " + f"--rollout-batch-size {8 if args.mode == 'debug_minimal' else 32} " + f"--n-samples-per-prompt {2 if args.mode == 'debug_minimal' else 8} " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + f"--global-batch-size {16 if args.mode == 'debug_minimal' else 256} " + "--balance-data " + ) + + eval_args = "" + if (args.mode != "debug_minimal") and args.enable_eval: + eval_args += ( + "--eval-interval 20 " + f"--eval-prompt-data aime {args.data_dir}/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 16 " + "--eval-max-response-len 16384 " + "--eval-top-p 1 " + ) + + # CP=2 EP=8: validated on 8x H200 + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.7 " + "--sglang-ep-size 8 " + "--sglang-cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 " + # mtp speculative decoding + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + "--sglang-max-running-requests 512 " + "--sglang-mamba-scheduler-strategy extra_buffer " + ) + + mtp_args = "--enable-mtp-training " "--mtp-num-layers 1 " "--mtp-loss-scaling-factor 0.2 " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--moe-token-dispatcher-type flex " + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + extra_env_vars={ + "SGLANG_ENABLE_SPEC_V2": "1", + }, + megatron_path=args.megatron_path, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py b/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py new file mode 100644 index 0000000000..f951cf4f3a --- /dev/null +++ b/tests/e2e/megatron/test_qwen3_5_35B_A3B_cp.py @@ -0,0 +1,153 @@ +"""E2E test for Qwen3.5-35B-A3B with Context Parallel (CP=2 and CP=4). + +Validates that GDN layers use real fla native CP (state passing) instead of +duplicated all-gather computation. See: https://github.com/radixark/miles/issues/878 +""" + +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3.5-35B-A3B" +MODEL_TYPE = "qwen3.5-35B-A3B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def _execute_with_cp(cp_size: int): + """Run a short training loop with the given context-parallel size.""" + assert NUM_GPUS % cp_size == 0 + ep_size = NUM_GPUS // cp_size + + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + f"--context-parallel-size {cp_size} " + f"--expert-model-parallel-size {ep_size} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.7 " + f"--sglang-ep-size {NUM_GPUS} " + "--sglang-max-running-requests 512 " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + mtp_args = "--enable-mtp-training " "--mtp-num-layers 1 " "--mtp-loss-scaling-factor 0.2 " + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + "--moe-token-dispatcher-type flex " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +def execute_cp2(): + """Qwen3.5-35B-A3B with CP=2.""" + _execute_with_cp(cp_size=2) + + +def execute_cp4(): + """Qwen3.5-35B-A3B with CP=4.""" + _execute_with_cp(cp_size=4) + + +if __name__ == "__main__": + cp_size = int(os.environ.get("CP_SIZE", "2")) + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + _execute_with_cp(cp_size) diff --git a/tests/test_qwen3_5_mtp_bridge_mapping.py b/tests/e2e/megatron/test_qwen3_5_mtp_bridge_mapping.py similarity index 100% rename from tests/test_qwen3_5_mtp_bridge_mapping.py rename to tests/e2e/megatron/test_qwen3_5_mtp_bridge_mapping.py diff --git a/tests/e2e/precision/test_hf_attention_cp_relayout.py b/tests/e2e/precision/test_hf_attention_cp_relayout.py new file mode 100644 index 0000000000..39ab46f915 --- /dev/null +++ b/tests/e2e/precision/test_hf_attention_cp_relayout.py @@ -0,0 +1,101 @@ +"""Distributed correctness test for zigzag <-> packed-shard hybrid CP relayout. + +Run with: + torchrun --nproc_per_node=2 tests/e2e/precision/test_hf_attention_cp_relayout.py + torchrun --nproc_per_node=4 tests/e2e/precision/test_hf_attention_cp_relayout.py +""" + +import os +import sys + +import torch +import torch.distributed as dist + +from miles_plugins.models.hf_attention import _packed_shard_to_zigzag, _zigzag_to_packed_shard + + +def setup_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank, world_size, local_rank + + +def _make_subchunk(sample_id: int, sub_id: int, chunk_len: int, device: torch.device) -> torch.Tensor: + base = sample_id * 1000 + sub_id * 100 + values = torch.arange(base, base + chunk_len, device=device, dtype=torch.float32) + return values.view(-1, 1, 1) + + +def _build_rank_inputs(rank: int, world_size: int, device: torch.device): + chunk_lens = [3, 5] + tail_pad_local_len = 3 + zigzag_chunks = [] + full_sequences = [] + cu = [0] + + for sample_id, chunk_len in enumerate(chunk_lens): + subchunks = [_make_subchunk(sample_id, sub_id, chunk_len, device) for sub_id in range(2 * world_size)] + zigzag_chunks.extend([subchunks[rank], subchunks[2 * world_size - 1 - rank]]) + full_sequences.append(torch.cat(subchunks, dim=0)) + cu.append(cu[-1] + 2 * world_size * chunk_len) + + tail_pad = (rank * 10000 + torch.arange(tail_pad_local_len, device=device, dtype=torch.float32)).view(-1, 1, 1) + zigzag_chunks.append(tail_pad) + full_sequences.append( + torch.cat( + [ + (r * 10000 + torch.arange(tail_pad_local_len, device=device, dtype=torch.float32)).view(-1, 1, 1) + for r in range(world_size) + ], + dim=0, + ) + ) + cu.append(cu[-1] + world_size * tail_pad_local_len) + + zigzag = torch.cat(zigzag_chunks, dim=0).requires_grad_(True) + packed_full = torch.cat(full_sequences, dim=0) + local_len = zigzag.size(0) + packed_shard = packed_full[rank * local_len : (rank + 1) * local_len] + cu_seqlens = torch.tensor(cu, device=device, dtype=torch.int32) + return zigzag, packed_shard, cu_seqlens + + +def test_relayout(rank: int, world_size: int): + device = torch.device(f"cuda:{rank}") + cp_group = dist.group.WORLD + + zigzag, expected_packed_shard, cu_seqlens = _build_rank_inputs(rank, world_size, device) + + packed_shard = _zigzag_to_packed_shard(zigzag, cu_seqlens, cp_group, rank, world_size) + roundtrip = _packed_shard_to_zigzag(packed_shard, cu_seqlens, cp_group, rank, world_size) + + packed_ok = torch.equal(packed_shard, expected_packed_shard) + roundtrip_ok = torch.equal(roundtrip, zigzag) + + loss = roundtrip.sum() + loss.backward() + grad_ok = torch.equal(zigzag.grad, torch.ones_like(zigzag)) + + passed = packed_ok and roundtrip_ok and grad_ok + if rank == 0: + print(f"\n=== HF Attention Hybrid CP Relayout Test CP={world_size} ===") + print(f"zigzag->packed PASS: {packed_ok}") + print(f"roundtrip PASS: {roundtrip_ok}") + print(f"backward PASS: {grad_ok}") + if not passed: + sys.exit(1) + + +def main(): + rank, world_size, _ = setup_dist() + try: + test_relayout(rank, world_size) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/precision/test_qwen3_5_cp_correctness.py b/tests/e2e/precision/test_qwen3_5_cp_correctness.py new file mode 100644 index 0000000000..d0a2f3f32b --- /dev/null +++ b/tests/e2e/precision/test_qwen3_5_cp_correctness.py @@ -0,0 +1,147 @@ +"""Correctness test for Qwen3.5 GDN with native fla Context Parallel. + +Run with: + torchrun --nproc_per_node=2 tests/test_qwen3_5_cp_correctness.py # CP=2 + torchrun --nproc_per_node=4 tests/test_qwen3_5_cp_correctness.py # CP=4 + +Validates that GDN forward+backward with native fla CP produces results +consistent with the non-CP (single-rank full-sequence) baseline. +""" + +import os +import sys + +import torch +import torch.distributed as dist + + +def setup_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank, world_size, local_rank + + +def build_gdn_module(device, dtype=torch.bfloat16): + """Build a small Qwen3.5 GDN module for testing.""" + + class FakeConfig: + hidden_size = 256 + linear_num_value_heads = 4 + linear_num_key_heads = 2 + linear_key_head_dim = 64 + linear_value_head_dim = 64 + linear_conv_kernel_dim = 4 + hidden_act = "silu" + rms_norm_eps = 1e-6 + + FakeConfig.dtype = dtype + + from miles_plugins.models.qwen3_5 import Qwen3_5GatedDeltaNet + + return Qwen3_5GatedDeltaNet(FakeConfig, layer_idx=0).to(device=device, dtype=dtype) + + +def test_cp_forward_backward(rank, world_size): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 + + # ---- Reference: full sequence on rank 0 (no CP) ---- + torch.manual_seed(42) + model_ref = build_gdn_module(device, dtype) + + total_seq_len = 128 * world_size # must be divisible by world_size + batch = 1 + + torch.manual_seed(123) + full_hidden = torch.randn(batch, total_seq_len, 256, device=device, dtype=dtype, requires_grad=True) + full_cu = torch.tensor([0, total_seq_len], dtype=torch.int32, device=device) + + # Forward without CP + ref_out = model_ref(full_hidden, cu_seqlens=full_cu) + ref_loss = ref_out.sum() + ref_loss.backward() + ref_grad = full_hidden.grad.clone() + + # ---- Test: CP across ranks ---- + torch.manual_seed(42) + model_cp = build_gdn_module(device, dtype) + # Copy weights from ref to ensure identical params + model_cp.load_state_dict(model_ref.state_dict()) + + # Set up CP context on the module + cp_group = dist.group.WORLD + model_cp.cp_group = cp_group + model_cp.cp_rank = rank + model_cp.cp_world_size = world_size + + # Each rank gets its local chunk + local_seq_len = total_seq_len // world_size + start = rank * local_seq_len + end = start + local_seq_len + + torch.manual_seed(123) + full_hidden_cp = torch.randn(batch, total_seq_len, 256, device=device, dtype=dtype) + local_hidden = full_hidden_cp[:, start:end, :].clone().contiguous().requires_grad_(True) + + # Global cu_seqlens (build_gdn_cp_context expects global boundaries) + global_cu = torch.tensor([0, total_seq_len], dtype=torch.int32, device=device) + + # Forward with CP + cp_out = model_cp(local_hidden, cu_seqlens=global_cu) + cp_loss = cp_out.sum() + + # Reduce loss across ranks to match reference + dist.all_reduce(cp_loss, op=dist.ReduceOp.SUM) + + cp_loss.backward() + + # ---- Gather outputs for comparison ---- + gathered_out = [torch.zeros_like(cp_out) for _ in range(world_size)] + dist.all_gather(gathered_out, cp_out.contiguous()) + full_cp_out = torch.cat(gathered_out, dim=1) + + gathered_grad = [torch.zeros_like(local_hidden.grad) for _ in range(world_size)] + dist.all_gather(gathered_grad, local_hidden.grad.contiguous()) + full_cp_grad = torch.cat(gathered_grad, dim=1) + + if rank == 0: + # Compare outputs + out_diff = (ref_out.detach().float() - full_cp_out.detach().float()).abs() + out_max_diff = out_diff.max().item() + out_rel_diff = (out_diff / (ref_out.detach().float().abs() + 1e-8)).max().item() + + # Compare gradients + grad_diff = (ref_grad.float() - full_cp_grad.float()).abs() + grad_max_diff = grad_diff.max().item() + grad_rel_diff = (grad_diff / (ref_grad.float().abs() + 1e-8)).max().item() + + print(f"\n=== CP={world_size} Correctness Test ===") + print(f"Forward max abs diff: {out_max_diff:.6e} max rel diff: {out_rel_diff:.6e}") + print(f"Backward max abs diff: {grad_max_diff:.6e} max rel diff: {grad_rel_diff:.6e}") + + # bf16 tolerance: 1e-2 is generous for bf16 accumulated ops + fwd_ok = out_max_diff < 1e-2 + bwd_ok = grad_max_diff < 1e-2 + print(f"Forward PASS: {fwd_ok}") + print(f"Backward PASS: {bwd_ok}") + + if not (fwd_ok and bwd_ok): + print("FAILED!") + sys.exit(1) + else: + print(f"CP={world_size} test PASSED!") + + +def main(): + rank, world_size, _ = setup_dist() + try: + test_cp_forward_backward(rank, world_size) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/fast/backends/training_utils/__init__.py b/tests/fast/backends/training_utils/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/fast/backends/training_utils/__init__.py @@ -0,0 +1 @@ + From 82bf1962b2c69d5cc5ae4144f6dfd59f002a201d Mon Sep 17 00:00:00 2001 From: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com> Date: Sun, 12 Apr 2026 21:31:41 -0700 Subject: [PATCH 32/44] [AMD] Upgrade to sglv0.5.10 (#973) --- docker/Dockerfile.rocm_MI350-5 | 396 ++++----- docker/amd_patch/latest/megatron.patch | 41 +- docker/amd_patch/latest/sglang.patch | 38 - .../amd_megatron_fused_kernels_init.patch | 51 -- docker/amd_patch/sglv0.5.0rc0/megatron.patch | 792 ------------------ docker/amd_patch/sglv0.5.0rc0/sglang.patch | 203 ----- docker/amd_patch/sglv0.5.10/megatron.patch | 20 + docker/amd_patch/sglv0.5.7/megatron.patch | 51 -- docker/amd_patch/sglv0.5.7/sglang.patch | 38 - .../run-qwen3-4B-amd.sh} | 57 +- scripts/run-llama3.2-3B-Instruct-amd.sh | 180 ---- scripts/run-qwen3-4B-amd.sh | 161 ---- scripts/run-qwen3-8B-amd.sh | 194 ----- 13 files changed, 192 insertions(+), 2030 deletions(-) delete mode 100644 docker/amd_patch/latest/sglang.patch delete mode 100644 docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch delete mode 100644 docker/amd_patch/sglv0.5.0rc0/megatron.patch delete mode 100644 docker/amd_patch/sglv0.5.0rc0/sglang.patch create mode 100644 docker/amd_patch/sglv0.5.10/megatron.patch delete mode 100644 docker/amd_patch/sglv0.5.7/megatron.patch delete mode 100644 docker/amd_patch/sglv0.5.7/sglang.patch rename scripts/{run-qwen3-30B-A3B.sh => amd/run-qwen3-4B-amd.sh} (67%) delete mode 100644 scripts/run-llama3.2-3B-Instruct-amd.sh delete mode 100755 scripts/run-qwen3-4B-amd.sh delete mode 100644 scripts/run-qwen3-8B-amd.sh diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index 016107b92c..bd34fd8ced 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,169 +1,156 @@ -#### Use the base image for ROCm 7 / gfx950 (MI355) - -# ===================================================================== -# Docker Image Version Information (Updated: Feb 5, 2026) -# ===================================================================== -# Base image: ROCm 7 with vllm pre-built for gfx950 -# Target GPU: MI355 (gfx950) -# -# Key Dependencies: -# - sglang: sglang-miles branch -# - sgl_kernel: built from selected sglang commit -# - Megatron-LM: radixark/Megatron-LM -# - TransformerEngine: commit 90c04bcdc3c109505b318f40a39680263af55edf -# - aiter: v0.1.10.post3 -# - Ray: 2.47.1 -# -# Patches: amd_patch/sglv0.5.7/ -# - megatron.patch -# - sglang.patch -# ===================================================================== - - -FROM rocm/sgl-dev:rocm7-vllm-20250904 +# 1. rlsys/miles:MI350-355-latest +# build-arg:SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x + +ARG SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x +FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang SHELL ["/bin/bash", "-ceuxo", "pipefail"] -ARG MAX_JOBS=128 -ARG SGLANG_REPO=sgl-project/sglang +# ======================================== Arguments ============================================= + ARG SGLANG_BRANCH=sglang-miles ARG SGLANG_COMMIT="" + ARG MEGATRON_REPO=radixark/Megatron-LM ARG MEGATRON_BRANCH=miles-main -ARG MEGATRON_COMMIT="" -ENV MAX_JOBS=${MAX_JOBS} -# Set environment variables for gfx950 -ENV GPU_ARCH=gfx950 -ENV PYTORCH_ROCM_ARCH=gfx950 -ENV GPU_ARCH_LIST=gfx950 -ENV AMDGPU_TARGET=gfx950 +ARG MILES_COMMIT=main +ARG GPU_ARCH=gfx950 +ARG MAX_JOBS=128 -########################################### -##############1. Install AITER############# -########################################### -WORKDIR /app +ARG AITER_REPO=https://github.com/ROCm/aiter.git +ARG AITER_COMMIT=v0.1.11.post1 -RUN pip uninstall -y aiter || true -RUN rm -rf aiter -RUN git clone https://github.com/ROCm/aiter.git \ - && cd aiter \ - && git checkout v0.1.10.post3 \ - && curl -fsSL https://patch-diff.githubusercontent.com/raw/ROCm/aiter/pull/2075.patch -o /tmp/aiter-pr2075.patch \ - && git apply --3way /tmp/aiter-pr2075.patch \ - && rm -f /tmp/aiter-pr2075.patch \ - && git submodule sync --recursive \ - && git submodule update --init --recursive \ - && GPU_ARCHS=gfx950 python setup.py develop -########################################### -########################################### -########################################### - - -########################################### -####2. Install TransformerEngine for gfx950 -########################################### -WORKDIR /app - -RUN rm -rf TransformerEngine -RUN git clone https://github.com/ROCm/TransformerEngine.git \ - && cd TransformerEngine \ - && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ - && git submodule update --init --recursive +ARG RCCL_TESTS_REPO=https://github.com/ROCm/rocm-systems.git +ARG RCCL_TESTS_BRANCH=develop +ARG RCCL_TESTS_PATH=projects/rccl-tests + +ARG TRANSFORMER_ENGINE_REPO=https://github.com/ROCm/TransformerEngine.git +ARG TRANSFORMER_ENGINE_BRANCH=v2.8_rocm + +# ======================================== Setup ============================================= +WORKDIR /root/ + +ENV MAX_JOBS=${MAX_JOBS} + +# Build configuration for MI350 / gfx950. +ENV GPU_ARCH=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV AMDGPU_TARGET=${GPU_ARCH} + +# Transformer Engine build knobs for the v2.8_rocm branch. ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_ROCM_ARCH=${GPU_ARCH} ENV NVTE_USE_HIPBLASLT=1 ENV NVTE_USE_ROCM=1 -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - -RUN cd TransformerEngine && pip install . -v -########################################### -########################################### -########################################### - - -######################################### -####3. Install Megatron-LM -######################################### -WORKDIR /app - -RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall - -RUN pip uninstall -y megatron-core || true -RUN rm -rf Megatron-LM -RUN git clone https://github.com/${MEGATRON_REPO}.git \ - && cd Megatron-LM \ - && git fetch origin ${MEGATRON_BRANCH} \ - && if [ -n "${MEGATRON_COMMIT}" ]; then \ - git checkout ${MEGATRON_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi \ - && pip install -e . -######################################### -######################################### -######################################### - - -######################################## -############ 4. Install mbridge######### -######################################## -RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps -######################################## -######################################## -######################################## - - -######################################## -######5. Install Ray#################### -######################################## -RUN pip uninstall ray -y || true -RUN pip install "ray[data,train,tune,serve]==2.47.1" -######################################## -######################################## -######################################## - - -######################################### -###6. Install torch_memory_saver######### -######################################### -RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@64a92e1d7fb822ea4af5579c8cebb162692c531c --no-cache-dir --force-reinstall -######################################### -######################################### - - -####################################### -####7. Install Apex for ROCm########### -####################################### -WORKDIR /app - -RUN pip uninstall -y apex || true -RUN rm -rf apex -RUN git clone https://github.com/ROCm/apex.git \ - && cd apex \ - && python setup.py install -####################################### -####################################### -####################################### - - -######################################## -###8. Install miles agent framework deps -######################################## -RUN pip install pydra_config==0.0.15 -RUN pip install together -RUN pip install google-generativeai -RUN pip install tensorboard -######################################## -######################################## -######################################## - - -######################################## -###9. Set performance environment vars## -######################################## +# Keep the core package enabled and skip the extra fused-attn kernel matrix rebuild. +ENV NVTE_FUSED_ATTN=0 +ENV CMAKE_PREFIX_PATH=/opt/rocm:/opt/rocm/hip:/usr/local:/usr + +# Patch Megatron's fused-kernel init for this toolchain. +COPY docker/amd_patch/latest/megatron.patch /tmp/amd_patch/megatron.patch +COPY requirements.txt /tmp/requirements.txt + +# ======================================== Apt dependencies ============================================= + +RUN apt update +# Install build tools and diagnostics utilities. +RUN apt install -y build-essential cmake dnsutils ethtool git nvtop rsync + +# Build rccl-tests diagnostics binaries. +RUN git clone --depth 1 --branch ${RCCL_TESTS_BRANCH} ${RCCL_TESTS_REPO} /tmp/rocm-systems && \ + make -C /tmp/rocm-systems/${RCCL_TESTS_PATH} -j$(nproc) \ + HIP_HOME=/opt/rocm \ + NCCL_HOME=/opt/rocm \ + GPU_TARGETS=${GPU_ARCH} && \ + cp /tmp/rocm-systems/${RCCL_TESTS_PATH}/build/*_perf /usr/local/bin/ && \ + rm -rf /tmp/rocm-systems + +# ====================================== Python dependencies ============================================ + +# Rebuild AITER at the version paired with SGLang. +RUN pip uninstall -y aiter || true +RUN pip install flydsl==0.0.1.dev95158637 psutil pybind11 +RUN cd /sgl-workspace/aiter && \ + git remote set-url origin ${AITER_REPO} && \ + git checkout ${AITER_COMMIT} && \ + git reset --hard ${AITER_COMMIT} && \ + git clean -fdx && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + # Temporary fixes for the current ROCm 7.2 image/toolchain combination. + sed -i '459 s/if.*:/if False:/' aiter/ops/triton/attention/pa_mqa_logits.py && \ + sed -i '/c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device)/i\ config = dict(config)' \ + aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py && \ + GPU_ARCHS=${GPU_ARCH} pip install -e . + +# Install Transformer Engine from the requested branch. +RUN pip uninstall -y transformer-engine transformer_engine transformer_engine_torch || true +RUN rm -rf /root/TransformerEngine && \ + git clone --recursive --branch ${TRANSFORMER_ENGINE_BRANCH} ${TRANSFORMER_ENGINE_REPO} /root/TransformerEngine && \ + cd /root/TransformerEngine && \ + pip install . --no-build-isolation -v + +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +RUN GPU_ARCHS=${GPU_ARCH} BUILD_TARGET=rocm MAX_JOBS=${MAX_JOBS} \ + pip -v install flash-attn==2.8.3 --no-build-isolation + +RUN pip install flash-linear-attention==0.4.2 + +RUN rm -rf /root/Megatron-LM && \ + git clone --recursive -b ${MEGATRON_BRANCH} https://github.com/${MEGATRON_REPO}.git /root/Megatron-LM && \ + cd /root/Megatron-LM && \ + git apply /tmp/amd_patch/megatron.patch && \ + pip install -e . + +RUN pip uninstall -y sgl_kernel sglang || true +RUN cd /sgl-workspace/sglang && \ + git reset --hard && \ + git clean -fdx && \ + git fetch origin ${SGLANG_BRANCH} && \ + if [ -n "${SGLANG_COMMIT}" ]; then \ + git checkout ${SGLANG_COMMIT}; \ + else \ + git checkout FETCH_HEAD; \ + fi && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + cd sgl-kernel && \ + rm -f pyproject.toml && \ + mv pyproject_rocm.toml pyproject.toml && \ + AMDGPU_TARGET=${GPU_ARCH} python setup_rocm.py install && \ + cd .. && \ + rm -rf python/pyproject.toml && \ + mv python/pyproject_other.toml python/pyproject.toml && \ + pip install -e "python[all_hip]" --no-deps + +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@d64a639 --no-cache-dir --force-reinstall +RUN pip install git+https://github.com/yushengsu-thu/Megatron-Bridge.git@merged-megatron-0.16.0rc0-miles --no-deps --no-build-isolation +RUN pip install megatron-energon --no-deps +RUN pip install multi-storage-client --no-deps + +RUN rm -rf /usr/lib/python3/dist-packages/jwt /usr/lib/python3/dist-packages/PyJWT* && \ + pip install -r /tmp/requirements.txt + +# Pin numpy 1.x for Megatron compatibility. +RUN pip install "numpy<2" + +# ====================================== Install main package ============================================ + +RUN git clone https://github.com/radixark/miles.git /root/miles && \ + cd /root/miles && \ + git checkout ${MILES_COMMIT} && \ + pip install -e . --no-deps + +# ====================================== Runtime knobs ============================================ + +# Runtime knobs consumed by the current SGLang/PyTorch stack. ENV HIP_FORCE_DEV_KERNARG=1 ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_USE_AITER=1 @@ -173,114 +160,11 @@ ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 ENV SGLANG_USE_ROCM700A=1 ENV NCCL_MIN_NCHANNELS=112 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -######################################## -######################################## -######################################## - -########################################### -##############Install SGLang############### -########################################### -WORKDIR /app - -# Install prerequisites -RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 - -# Clone SGLang -RUN pip uninstall -y sgl_kernel sglang || true -RUN rm -rf sglang -RUN git clone https://github.com/${SGLANG_REPO}.git \ - && cd sglang \ - && git fetch origin ${SGLANG_BRANCH} \ - && if [ -n "${SGLANG_COMMIT}" ]; then \ - git checkout ${SGLANG_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi - -# Build sgl-kernel for gfx950 -RUN cd sglang/sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && AMDGPU_TARGET=gfx950 python setup_rocm.py install - -# Install SGLang -RUN cd sglang \ - && rm -rf python/pyproject.toml \ - && mv python/pyproject_other.toml python/pyproject.toml \ - && pip install -e "python[all_hip]" - -# Test SGLang installation -RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" +RUN rm -rf /root/.cache/pip /root/TransformerEngine /tmp/amd_patch -RUN python -m pip cache purge -########################################### -########################################### -########################################### - - -########################################### -#### APPLY PATCHES (gfx950/MI355) ######### -########################################### - -# Copy patch from miles repo -COPY amd_patch/sglv0.5.7/megatron.patch /app/patch/megatron.patch -COPY amd_patch/sglv0.5.7/sglang.patch /app/patch/sglang.patch - -# Apply Megatron patches -RUN cd /app/Megatron-LM \ - && git apply --3way /app/patch/megatron.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "Patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi \ - && pip install -e . -v - -# Apply SGLang patch -RUN cd /app/sglang \ - && git apply --3way /app/patch/sglang.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi - -# Copy MOE configs for gfx950/MI355 -RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ - /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' 2>/dev/null | while read f; do \ - cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ - cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ -done - -########################################### -########################################### -########################################### - - -######################################## -#### Install additional packages######## -######################################## -RUN pip install sglang-router --force-reinstall -######################################## -######################################## -######################################## - - -######################################## -# Fix click/ray incompatibility with Python 3.10 -######################################## -RUN pip install click==8.2.1 -######################################## -######################################## -######################################## - - -WORKDIR /app +WORKDIR /root/ CMD ["/usr/bin/bash"] diff --git a/docker/amd_patch/latest/megatron.patch b/docker/amd_patch/latest/megatron.patch index f6efca346d..acd64149b7 100644 --- a/docker/amd_patch/latest/megatron.patch +++ b/docker/amd_patch/latest/megatron.patch @@ -1,5 +1,4 @@ diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 --- a/megatron/legacy/fused_kernels/__init__.py +++ b/megatron/legacy/fused_kernels/__init__.py @@ -3,6 +3,7 @@ @@ -10,42 +9,12 @@ index 87cceac3..ac686d74 100644 from torch.utils import cpp_extension -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" +@@ -15,6 +16,8 @@ def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') ++ if not torch.version.cuda: ++ return -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/latest/sglang.patch b/docker/amd_patch/latest/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/latest/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch b/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.0rc0/megatron.patch b/docker/amd_patch/sglv0.5.0rc0/megatron.patch deleted file mode 100644 index b129959aff..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/megatron.patch +++ /dev/null @@ -1,792 +0,0 @@ -diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py -index 41c21d93d..ef80f72d6 100644 ---- a/megatron/core/dist_checkpointing/strategies/common.py -+++ b/megatron/core/dist_checkpointing/strategies/common.py -@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): - msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu') - else: -- return torch.load(load_path, map_location='cpu') -+ return torch.load(load_path, map_location='cpu', weights_only=False) - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - if MultiStorageClientFeature.is_enabled(): -diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py -index 5a1ea308d..aa701237f 100644 ---- a/megatron/core/dist_checkpointing/strategies/torch.py -+++ b/megatron/core/dist_checkpointing/strategies/torch.py -@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - def _validate_global_shapes(self, metadata, sharded_tensors): - for sh_ten in sharded_tensors: - if sh_ten.key not in metadata.state_dict_metadata: -- raise KeyError( -- f"{sh_ten.key} from model not in state dict:" -- f" {sorted(metadata.state_dict_metadata.keys())}" -- ) -+ # raise KeyError( -+ # f"{sh_ten.key} from model not in state dict:" -+ # f" {sorted(metadata.state_dict_metadata.keys())}" -+ # ) -+ print(f"{sh_ten.key} from model not in state dict, will skip") -+ continue - loaded_shape = metadata.state_dict_metadata[sh_ten.key].size - expected_shape = self._expected_shape(sh_ten) - if loaded_shape != expected_shape: -@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - tensor_metadata = self.metadata.state_dict_metadata - metadata_with_sizes = [ - (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) -- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() -+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata - ] - try: - # Temporarily set sizes to expected shapes -@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): - planner=MCoreLoadPlanner( - shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, - allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, -+ allow_partial_load=True, - ), - ) - -diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b43..4451f2776 100644 ---- a/megatron/core/distributed/__init__.py -+++ b/megatron/core/distributed/__init__.py -@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads - from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel - from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel - from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig -+ -+# Backward compatibility patch for FSDP module reorganization -+import sys -+import importlib.util -+ -+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') -+if spec: -+ custom_fsdp = importlib.util.module_from_spec(spec) -+ spec.loader.exec_module(custom_fsdp) -+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp -+ if hasattr(custom_fsdp, 'MegatronFSDP'): -+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP -diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index acb93ef78..d239db4ab 100644 ---- a/megatron/core/extensions/transformer_engine.py -+++ b/megatron/core/extensions/transformer_engine.py -@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): - ) - - for param in self.parameters(): -+ setattr(param, "parallel_mode", parallel_mode) - if is_expert: - # Reduce the gradient on the expert_data_parallel group for expert linear layers - setattr(param, "allreduce", not self.expert_parallel) -@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): - - - if HAVE_TE and is_te_min_version("1.9.0.dev0"): -+ def ceil_div(x: int, y: int) -> int: -+ return (x + y - 1) // y -+ -+ class _FakeInt4QuantizationSTE(torch.autograd.Function): -+ @staticmethod -+ def forward(ctx, x, group_size): -+ m, n = x.shape -+ block_size_m, block_size_n = 1, group_size -+ -+ -+ m_padded = ceil_div(m, block_size_m) * block_size_m -+ n_padded = ceil_div(n, block_size_n) * block_size_n -+ -+ x_padded = torch.zeros( -+ (m_padded, n_padded), -+ dtype=x.dtype, device=x.device -+ ) -+ x_padded[:m, :n] = x -+ -+ x_view = x_padded.view( -+ m_padded // block_size_m, -+ block_size_m, -+ n_padded // block_size_n, -+ block_size_n -+ ) -+ -+ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) -+ q_max = 7 -+ x_scale = x_max / q_max -+ -+ x_scale = x_scale.clamp(min=1e-5) -+ -+ x_div = x_view / x_scale -+ x_round = torch.round(x_div) -+ -+ x_q_clamped = x_round.clamp(-q_max, q_max) -+ -+ x_dequant_view = x_q_clamped * x_scale -+ -+ x_dequant_full = x_dequant_view.view_as(x_padded) -+ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) -+ -+ return x_out -+ -+ @staticmethod -+ def backward(ctx, grad_output): -+ return grad_output, None -+ -+ def fake_int4_quantization_ste(x, group_size): -+ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) -+ -+ if hasattr(x, 'main_grad'): -+ x_out.main_grad = x.main_grad -+ -+ return x_out - - class TEGroupedLinear(te.pytorch.GroupedLinear): - """ -@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) -+ - out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - -@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - return out - return out, None - -+ def _get_weight_tensors(self): -+ """Get the weight tensors of the module.""" -+ weight_tensors = super()._get_weight_tensors() -+ -+ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": -+ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) -+ -+ weight_tensors = [ -+ fake_int4_quantization_ste(w, group_size) -+ for w in weight_tensors -+ ] -+ -+ return weight_tensors -+ - def _encode_extra_state(self, state): - # TE 2.0 changed the format of extra_state to be a byte tensor - if is_te_min_version("2.0.0"): -diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -index 1fd5dcfae..c9aeef1f0 100644 ---- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py -+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( - cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - -- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads -- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads -- mask = kv_off < head_num * stride_kv_nheads -- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] -- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] -- k = tl.load(KV_ptr + k_in_off, mask=mask) -- v = tl.load(KV_ptr + v_in_off, mask=mask) -+ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ k_off = ki_range * stride_kv_nheads + kj_range -+ if v_dim > 0: -+ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ v = tl.load(KV_ptr + v_off, mask=mask_v) -+ else: -+ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) -+ k = tl.load(KV_ptr + k_off, mask=mask_k) - -- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads -- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads -+ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads -+ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - -- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] -- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] -- tl.store(K_ptr + k_out_off, k, mask=mask) -- tl.store(V_ptr + v_out_off, v, mask=mask) -+ k_out_off = ki_range * stride_k_nheads + kj_range -+ tl.store(K_ptr + k_out_off, k, mask=mask_k) -+ if v_dim > 0: -+ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] -+ tl.store(V_ptr + v_out_off, v, mask=mask_v) - - EMB = K_POS_EMB + pid_m * stride_emb_seq - # x1 = t[..., 0::2], x2 = t[..., 1::2] -@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( - x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - -+ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ mask_x = x_range < head_num - x_left_off = ( -- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads -+ x_range * stride_k_nheads - + k_dim - + tl.arange(0, emb_dim // 2)[None, :] - ) - x_right_off = x_left_off + emb_dim // 2 -- tl.store(K_ptr + x_left_off, x_left, mask=mask) -- tl.store(K_ptr + x_right_off, x_right, mask=mask) -+ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) -+ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) - - - @triton.autotune( -@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( - else: - token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - -- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads -- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads -- mask = dkv_off < head_num * stride_dkv_nheads -- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] -- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] -- -- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads -- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads -- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] -- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -- dk = tl.load(dK_ptr + dk_in_off, mask=mask) -- dv = tl.load(dV_ptr + dv_in_off, mask=mask) -- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) -- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) -+ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ dk_out_off = ki_range * stride_dkv_nheads + kj_range -+ -+ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads -+ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads -+ dk_in_off = ki_range * stride_dk_nheads + kj_range -+ -+ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) -+ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) -+ -+ if v_dim > 0: -+ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -+ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) -+ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) - - if pid_head == 0: - x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): -- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads -- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim -+ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads -+ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads - mask = x_off < head_num * stride_dk_nheads - x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] - x_right_off = x_left_off + emb_dim // 2 -@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) - o_value = kv.new_empty(total_seqlen, nheads, v_dim) -+ k_dim_ceil = triton.next_power_of_2(k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_fwd_kv_kernel[grid]( -@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - emb_dim, - k_dim, -+ k_dim_ceil, - v_dim, - nheads, - batch_size, -@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) - d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) -+ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_bwd_kv_kernel[grid]( -@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - ctx.emb_dim, - ctx.k_dim, -+ k_dim_ceil, - ctx.v_dim, - nheads, - batch_size, -diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py -index 13d74aa52..060898a7a 100644 ---- a/megatron/core/models/common/language_module/language_module.py -+++ b/megatron/core/models/common/language_module/language_module.py -@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): - assert ( - column_parallel_linear is not None - ), "column_parallel_linear cannot be None when not using fused linear cross entropy." -- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) -+ # output -+ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} -+ output_layer_buffers = dict(column_parallel_linear.named_buffers()) -+ logits, _ = torch.func.functional_call( -+ column_parallel_linear, -+ {**output_layer_params, **output_layer_buffers}, -+ (hidden,), -+ col_linear_kwargs, -+ ) - - return self.compute_language_model_loss(labels, logits) - -diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index e21127b87..712793853 100755 ---- a/megatron/core/models/gpt/gpt_layer_specs.py -+++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_kitchen: bool = False, - use_te_activation_func: bool = False, - fallback_to_eager_attn: bool = False, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). - -@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( - mlp=mlp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - normalization=normalization, -+ post_self_attn_layernorm=post_self_attn_layernorm, -+ post_mlp_layernorm=post_mlp_layernorm, - ) - - -@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( - mlp: ModuleSpec, - sharded_state_dict_keys_map: Optional[dict] = None, - normalization: Optional[str] = None, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Helper function to get module spec for TransformerLayer""" - -@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( - input_layernorm=input_layernorm, - self_attention=attention, - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=pre_mlp_layernorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - ), - ) -diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index a1230568c..1fd52f65a 100644 ---- a/megatron/core/models/gpt/gpt_model.py -+++ b/megatron/core/models/gpt/gpt_model.py -@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, -+ mtp_kwargs: Optional[dict] = {}, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoder and finally into the post -@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, -+ mtp_kwargs=mtp_kwargs, - ) - - def _postprocess( -@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=None, - extra_block_kwargs=None, - inference_context=None, -+ mtp_kwargs={}, - ): - """Postprocesses decoder hidden states to generate logits or compute loss. - -@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() -- if mtp_in_postprocess: -+ -+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, -@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): - return hidden_states - - # Skip when mtp_num_layers is None or 0 -- if self.config.mtp_num_layers: -- mtp_labels = labels.clone() -+ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: -+ mtp_labels = mtp_kwargs['mtp_labels'].clone() -+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) -+ else: -+ # Otherwise, roll the loss_mask to keep up with the mtp_labels -+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( -@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ -- 'weight': output_weight, -+ 'weight': output_weight.detach() if output_weight else None, - 'runtime_gather_output': runtime_gather_output, - }, - ) -diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index 6e093f96f..eac21a3ea 100644 ---- a/megatron/core/optimizer/distrib_optimizer.py -+++ b/megatron/core/optimizer/distrib_optimizer.py -@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - # TE FusedAdam will not accumulate step for empty param groups, so we need to - # align the step across param groups. - param_group["step"] = int(step) -+ if "step" in param_group and param_group["step"] is None: -+ del param_group["step"] - - # Grad scaler state. - if self.grad_scaler: -@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - if key == 'padding': - tensors[key] = LocalNonpersistentObject(tensors[key]) - continue -+ if key == 'step': -+ continue - assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( - tensors[key].shape, - gbuf_local_start, -diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index a273002b9..4f821cfd5 100644 ---- a/megatron/core/parallel_state.py -+++ b/megatron/core/parallel_state.py -@@ -11,6 +11,7 @@ from typing import Callable, List, Optional - - import numpy as np - import torch -+import torch.distributed as dist - - from .utils import GlobalMemoryBuffer, is_torch_min_version - -diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index ac839c21f..f18309217 100644 ---- a/megatron/core/pipeline_parallel/p2p_communication.py -+++ b/megatron/core/pipeline_parallel/p2p_communication.py -@@ -26,22 +26,22 @@ def _batched_p2p_ops( - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group -+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group -+ torch.distributed.isend, tensor_send_next, next_pipeline_rank, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, - ) - ops.append(recv_next_op) - if len(ops) > 0: -diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py -index 28cff06f5..48c9c1a25 100644 ---- a/megatron/core/transformer/moe/moe_utils.py -+++ b/megatron/core/transformer/moe/moe_utils.py -@@ -587,6 +587,9 @@ def topk_routing_with_score_function( - else: - return torch.topk(scores, k=topk, dim=1) - -+ from miles.utils.routing_replay import get_routing_replay_compute_topk -+ compute_topk = get_routing_replay_compute_topk(compute_topk) -+ - if score_function == "softmax": - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) -diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py -index 16fc9d9af..3e95858a6 100644 ---- a/megatron/core/transformer/moe/router.py -+++ b/megatron/core/transformer/moe/router.py -@@ -201,6 +201,9 @@ class TopKRouter(Router): - self.global_tokens_per_expert = None - self.ga_steps = None - -+ from miles.utils.routing_replay import register_routing_replay -+ register_routing_replay(self) -+ - def _maintain_float32_expert_bias(self): - """ - Maintain the expert bias in float32. -diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py -index a8f4abfcd..f33f6f05e 100755 ---- a/megatron/core/transformer/multi_token_prediction.py -+++ b/megatron/core/transformer/multi_token_prediction.py -@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union - - import torch - from torch import Tensor -+import warnings - - from megatron.core import InferenceParams, parallel_state, tensor_parallel - from megatron.core.dist_checkpointing.mapping import ShardedStateDict -@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) -- position_ids, _ = roll_tensor( -- position_ids, -- shifts=-1, -- dims=-1, -- cp_group=self.cp_group, -- packed_seq_params=packed_seq_params, -- ) -+ if position_ids is not None: -+ position_ids, _ = roll_tensor( -+ position_ids, -+ shifts=-1, -+ dims=-1, -+ cp_group=self.cp_group, -+ packed_seq_params=packed_seq_params, -+ ) - # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) -+ decoder_input = decoder_input.detach() - -- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) -+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) - - return input_ids, position_ids, decoder_input, hidden_states - -@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): - return hidden_states - - def _checkpointed_forward(self, forward_func, *args, **kwargs): -+ """Wrap `forward_func` with activation checkpointing while only passing tensors. -+ -+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so -+ that checkpoint implementations never receive them directly, avoiding save_for_backward -+ issues with non-tensor inputs. -+ """ -+ -+ # TODO(jiajun): Is there any better implementation here? -+ positional_specs = [] -+ kw_specs = [] -+ tensor_args: List[torch.Tensor] = [] -+ -+ for arg in args: -+ if torch.is_tensor(arg): -+ positional_specs.append(('tensor', len(tensor_args))) -+ tensor_args.append(arg) -+ else: -+ positional_specs.append(('const', arg)) -+ -+ for key, value in kwargs.items(): -+ if torch.is_tensor(value): -+ kw_specs.append((key, ('tensor', len(tensor_args)))) -+ tensor_args.append(value) -+ else: -+ kw_specs.append((key, ('const', value))) -+ -+ def run(*flat_tensor_args): -+ rebuilt_args = [] -+ for spec_type, payload in positional_specs: -+ if spec_type == 'tensor': -+ rebuilt_args.append(flat_tensor_args[payload]) -+ else: -+ rebuilt_args.append(payload) -+ -+ rebuilt_kwargs = {} -+ for key, (spec_type, payload) in kw_specs: -+ if spec_type == 'tensor': -+ rebuilt_kwargs[key] = flat_tensor_args[payload] -+ else: -+ rebuilt_kwargs[key] = payload -+ -+ return forward_func(*rebuilt_args, **rebuilt_kwargs) -+ -+ tensor_args_tuple = tuple(tensor_args) -+ - def checkpoint_handler(): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: -@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), -- *args, -- **kwargs, -+ *tensor_args_tuple, - ) - else: - return tensor_parallel.checkpoint( -- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() -+ run, self.config.distribute_saved_activations, *tensor_args_tuple - ) - - if self.config.recompute_method == 'uniform': -diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index e2705bd9f..a0aa109b5 100644 ---- a/megatron/core/transformer/transformer_config.py -+++ b/megatron/core/transformer/transformer_config.py -@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): - attention_output_gate: bool = False - """Whether to apply output gate to the attention layers.""" - -+ post_self_attn_layernorm: bool = False -+ post_mlp_layernorm: bool = False -+ - test_mode: bool = False - """Whether to run real-time tests.""" - -diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 3ea405770..5a42001b9 100644 ---- a/megatron/core/transformer/transformer_layer.py -+++ b/megatron/core/transformer/transformer_layer.py -@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: - input_layernorm: Union[ModuleSpec, type] = IdentityOp - self_attention: Union[ModuleSpec, type] = IdentityOp - self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - - pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp - mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - - # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method - sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - # [Module 3: BiasDropoutFusion] - self.self_attn_bda = build_module(submodules.self_attn_bda) - -+ self.post_self_attn_layernorm = build_module( -+ submodules.post_self_attn_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon, -+ ) -+ - # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, -@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - - self.is_moe_layer = isinstance(self.mlp, MoELayer) - -+ self.post_mlp_layernorm = build_module( -+ submodules.post_mlp_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon -+ ) -+ - self.recompute_input_layernorm = False - self.recompute_pre_mlp_layernorm = False - self.recompute_mlp = False -@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - attention_output_with_bias[0] - ) - -+ attention_output, attention_output_bias = attention_output_with_bias -+ attention_output = self.post_self_attn_layernorm(attention_output) -+ attention_output_with_bias = (attention_output, attention_output_bias) -+ - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - nvtx_range_push(suffix="self_attn_bda") -@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - -+ mlp_output, mlp_output_bias = mlp_output_with_bias -+ mlp_output = self.post_mlp_layernorm(mlp_output) -+ mlp_output_with_bias = (mlp_output, mlp_output_bias) -+ - if self.recompute_pre_mlp_layernorm: - # discard the output of the pre-mlp layernorm and register the recompute - # as a gradient hook of mlp_output_with_bias[0] -diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index b267c8a81..83736acdc 100644 ---- a/megatron/training/arguments.py -+++ b/megatron/training/arguments.py -@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): - - kw_args['inference_sampling_seed'] = args.seed - -+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm -+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm -+ - # handle quantization config - # NOTE: Kitchen arguments are only added to the namespace when - # Kitchen library is available. -@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') -+ group.add_argument('--post-self-attn-layernorm', action='store_true', -+ help='If set, use post self attention layernorm.') -+ group.add_argument('--post-mlp-layernorm', action='store_true', -+ help='If set, use post MLP layernorm.') -+ group.add_argument('--use-gated-attention', action='store_true', -+ help='If set, use gated attention as in Qwen3Next') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' -diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 13b7526ca..6c590f653 100644 ---- a/megatron/training/tokenizer/tokenizer.py -+++ b/megatron/training/tokenizer/tokenizer.py -@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): - # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there - self._tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, -- trust_remote_code=trust_remote_code, -+ trust_remote_code=True, - **kwargs, - ) - self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.0rc0/sglang.patch b/docker/amd_patch/sglv0.5.0rc0/sglang.patch deleted file mode 100644 index 990c2e6289..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/sglang.patch +++ /dev/null @@ -1,203 +0,0 @@ -diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index bdb124e51..3edf30ab1 100644 ---- a/python/sglang/srt/configs/model_config.py -+++ b/python/sglang/srt/configs/model_config.py -@@ -454,14 +454,14 @@ class ModelConfig: - ).lower() - - # Detect which checkpoint is it -- for _, method in QUANTIZATION_METHODS.items(): -- quantization_override = method.override_quantization_method( -- quant_cfg, self.quantization -- ) -- if quantization_override: -- quant_method = quantization_override -- self.quantization = quantization_override -- break -+ # for _, method in QUANTIZATION_METHODS.items(): -+ # quantization_override = method.override_quantization_method( -+ # quant_cfg, self.quantization -+ # ) -+ # if quantization_override: -+ # quant_method = quantization_override -+ # self.quantization = quantization_override -+ # break - - # Verify quantization configurations. - if self.quantization is None: -diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 2dd2c75f1..f2adb18f8 100644 ---- a/python/sglang/srt/entrypoints/http_server.py -+++ b/python/sglang/srt/entrypoints/http_server.py -@@ -264,6 +264,10 @@ async def validate_json_request(raw_request: Request): - - - @app.get("/health") -+async def health(request: Request) -> Response: -+ return Response(status_code=200) -+ -+ - @app.get("/health_generate") - async def health_generate(request: Request) -> Response: - """ -diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 372717bf9..40665cc90 100644 ---- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -+++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -@@ -190,6 +190,7 @@ class DeepEPBuffer: - f"Consider using --deepep-config to change the behavior." - ) - -+ num_qps_per_rank = 20 - cls._buffer = Buffer( - group, - num_nvl_bytes, -diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py -index 956264fc9..69f729336 100644 ---- a/python/sglang/srt/layers/quantization/fp8.py -+++ b/python/sglang/srt/layers/quantization/fp8.py -@@ -351,10 +351,10 @@ class Fp8LinearMethod(LinearMethodBase): - return - else: - weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data -- layer.weight = torch.nn.Parameter(weight, requires_grad=False) -- layer.weight_scale_inv = torch.nn.Parameter( -- weight_scale, requires_grad=False -- ) -+ # layer.weight = torch.nn.Parameter(weight, requires_grad=False) -+ # layer.weight_scale_inv = torch.nn.Parameter( -+ # weight_scale, requires_grad=False -+ # ) - return - - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) -diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 95a529c89..758fbfd5f 100644 ---- a/python/sglang/srt/managers/scheduler.py -+++ b/python/sglang/srt/managers/scheduler.py -@@ -1359,7 +1359,7 @@ class Scheduler( - - if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" -- raise ValueError(msg) -+ # raise ValueError(msg) - - if self.disaggregation_mode == DisaggregationMode.DECODE: - req_total_size = ( -@@ -1374,7 +1374,7 @@ class Scheduler( - f"available_size={len(self.req_to_token_pool.free_slots)}, " - f"total_size={self.req_to_token_pool.size}\n" - ) -- raise ValueError(msg) -+ # raise ValueError(msg) - - if ( - self.enable_metrics -@@ -1830,6 +1830,7 @@ class Scheduler( - deepep_mode=DeepEPMode(self.server_args.deepep_mode), - require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), - disable_overlap_schedule=self.server_args.disable_overlap_schedule, -+ offload_tags=self.offload_tags, - ) - - def handle_dp_balance_data(self, local_batch: ScheduleBatch): -@@ -1927,6 +1928,7 @@ class Scheduler( - deepep_mode: DeepEPMode, - require_mlp_tp_gather: bool, - disable_overlap_schedule: bool, -+ offload_tags: set[str], - ): - # Check if other DP workers have running batches - if local_batch is None: -@@ -1957,7 +1959,7 @@ class Scheduler( - ) - - tbo_preparer = TboDPAttentionPreparer() -- if disable_overlap_schedule: -+ if len(offload_tags) == 0 and disable_overlap_schedule: - group = tp_group.device_group - device = tp_group.device - else: -diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 58220b1d6..3c3d081a8 100644 ---- a/python/sglang/srt/managers/tokenizer_manager.py -+++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -1044,10 +1044,15 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 -- ), "dp_size must be 1 for init parameter update group" -- result = (await self.init_weights_update_group_communicator(obj))[0] -+ results = await self.init_weights_update_group_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - return result.success, result.message - - async def update_weights_from_distributed( -@@ -1056,9 +1061,6 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention -- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" - - if obj.abort_all_requests: - self.abort_request(abort_all=True) -@@ -1066,8 +1068,15 @@ class TokenizerManager: - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: -- result = (await self.update_weights_from_distributed_communicator(obj))[0] -- return result.success, result.message -+ results = await self.update_weights_from_distributed_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - - async def update_weights_from_tensor( - self, -diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 5222bff0a..ff0bbc62a 100644 ---- a/python/sglang/srt/model_executor/model_runner.py -+++ b/python/sglang/srt/model_executor/model_runner.py -@@ -22,6 +22,7 @@ import os - import time - from dataclasses import dataclass - from typing import List, Optional, Tuple, Union -+from contextlib import nullcontext - - import torch - import torch.distributed as dist -@@ -675,7 +676,7 @@ class ModelRunner: - monkey_patch_vllm_parallel_state() - monkey_patch_isinstance_for_vllm_base_layer() - -- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): -+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS) if not self.is_draft_worker else nullcontext(): - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, -diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index e0f0b373d..a18ac10f1 100644 ---- a/python/sglang/srt/models/glm4_moe.py -+++ b/python/sglang/srt/models/glm4_moe.py -@@ -1108,5 +1108,4 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): - ) - weight_loader(param, loaded_weight) - -- - EntryClass = [Glm4MoeForCausalLM] diff --git a/docker/amd_patch/sglv0.5.10/megatron.patch b/docker/amd_patch/sglv0.5.10/megatron.patch new file mode 100644 index 0000000000..acd64149b7 --- /dev/null +++ b/docker/amd_patch/sglv0.5.10/megatron.patch @@ -0,0 +1,20 @@ +diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py +--- a/megatron/legacy/fused_kernels/__init__.py ++++ b/megatron/legacy/fused_kernels/__init__.py +@@ -3,6 +3,7 @@ + import os + import pathlib + import subprocess ++import torch + + from torch.utils import cpp_extension + +@@ -15,6 +16,8 @@ + + + def load(args): ++ if not torch.version.cuda: ++ return + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/sglv0.5.7/megatron.patch b/docker/amd_patch/sglv0.5.7/megatron.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.7/megatron.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.7/sglang.patch b/docker/amd_patch/sglv0.5.7/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/sglv0.5.7/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/amd/run-qwen3-4B-amd.sh similarity index 67% rename from scripts/run-qwen3-30B-A3B.sh rename to scripts/amd/run-qwen3-4B-amd.sh index 19bc70927d..bc6d4d40c0 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/amd/run-qwen3-4B-amd.sh @@ -9,30 +9,34 @@ pkill -9 python sleep 3 pkill -9 ray pkill -9 python -pkill -9 redis set -ex +# keep Ray from blanking HIP/CUDA visibility for the job entrypoint. +export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} +export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES:-"1"} + # will prevent ray from buffering stdout/stderr export PYTHONBUFFERED=16 -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 +if [[ -n "${HIP_VISIBLE_DEVICES:-}" ]]; then + export CUDA_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES}" +fi + +NUM_GPUS=${NUM_GPUS:-8} +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + IFS=',' read -r -a visible_gpu_ids <<< "${CUDA_VISIBLE_DEVICES}" + NUM_GPUS=${#visible_gpu_ids[@]} fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-4B.sh" CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-30B-A3B - #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 - --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_miles/ - --save /root/Qwen3-30B-A3B_miles/ + --hf-checkpoint /root/Qwen3-4B + --ref-load /root/Qwen3-4B_torch_dist + --load /root/Qwen3-4B_miles/ + --save /root/Qwen3-4B_miles/ --save-interval 20 ) @@ -48,7 +52,6 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 1 - --global-batch-size 256 --balance-data ) @@ -62,11 +65,11 @@ EVAL_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -75,7 +78,7 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 20480 + --max-tokens-per-gpu 9216 ) GRPO_ARGS=( @@ -95,23 +98,18 @@ OPTIMIZER_ARGS=( --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer ) WANDB_ARGS=( - #--use-wandb + # --use-wandb # --wandb-project miles-dev - # --wandb-group qwen3-30B-A3B-test + # --wandb-group qwen3-4B-test # --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 + --rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.7 - --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) MISC_ARGS=( @@ -127,14 +125,13 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } }" @@ -142,7 +139,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/scripts/run-llama3.2-3B-Instruct-amd.sh b/scripts/run-llama3.2-3B-Instruct-amd.sh deleted file mode 100644 index eb5d5709ce..0000000000 --- a/scripts/run-llama3.2-3B-Instruct-amd.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash - -# hf download meta-llama/Llama-3.2-3B-Instruct --local-dir /root/Llama-3.2-3B-Instruct - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -euxo pipefail - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -# if [ "$NVLINK_COUNT" -gt 0 ]; then -# HAS_NVLINK=1 -# else -# HAS_NVLINK=0 -# fi -# echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/llama3.2-3B-Instruct-amd.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Llama-3.2-3B-Instruct - --ref-load ${MODEL_DIR}/Llama-3.2-3B-Instruct_torch_dist - --load ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type math - --num-epoch 1 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 16384 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 10 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 8 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group llama3.2-3B - # --wandb-key ${WANDB_API_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.4 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/workspace/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh deleted file mode 100755 index 44257cc77f..0000000000 --- a/scripts/run-qwen3-4B-amd.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/root}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/root}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/root}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-4B - --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist - --load ${MODEL_DIR}/Qwen3-4B_miles/ - --save ${MODEL_DIR}/Qwen3-4B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# Dynamically detect Megatron-LM installation path -MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="{ - \"env_vars\": { - \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } - }" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} diff --git a/scripts/run-qwen3-8B-amd.sh b/scripts/run-qwen3-8B-amd.sh deleted file mode 100644 index 979ffa18e0..0000000000 --- a/scripts/run-qwen3-8B-amd.sh +++ /dev/null @@ -1,194 +0,0 @@ -#!/bin/bash - - -# bash scripts/run-qwen3-4B-amd.sh - - -####clear before training -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# Current Model convert script on AMD GPU has some issue, please download the converted model from here: https://huggingface.co/zyzshishui0627/models - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-8B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-8B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist - # --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist_amd_new - --load ${MODEL_DIR}/Qwen3-8B_miles/ - --save ${MODEL_DIR}/Qwen3-8B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) -#################### - - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" - } - }' \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - - - - - - - - - From ef7481ae3bfbcc641d031e7e6113b646bb764382 Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Mon, 13 Apr 2026 11:59:32 -0700 Subject: [PATCH 33/44] switch model to actor (#756) --- miles/backends/megatron_utils/actor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 8b3ea9975c..2cfd373f49 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -177,9 +177,8 @@ def init( # empty cache after initialization clear_memory() + self._switch_model("actor") if self.args.offload_train: - # recover to actor in the end. - self._switch_model("actor") self.sleep() self.rollout_engines = None From 85fe6519ae87482fb816b8646eec85e21d1de5da Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:08:33 -0700 Subject: [PATCH 34/44] [fix] support general logic to bypass fp32 downcast and fix qwen35 A_log dtype (#975) Co-authored-by: yueming-yuan --- .../advanced/arch-support-beyond-megatron.md | 29 ++ .../megatron_utils/fp32_param_utils.py | 52 ++++ miles/backends/megatron_utils/model.py | 5 + miles_plugins/mbridge/qwen3_5.py | 6 + miles_plugins/models/qwen3_5.py | 7 +- .../megatron_utils/test_fp32_param_utils.py | 262 ++++++++++++++++++ tools/convert_hf_to_torch_dist.py | 25 +- 7 files changed, 362 insertions(+), 24 deletions(-) create mode 100644 miles/backends/megatron_utils/fp32_param_utils.py create mode 100644 tests/fast/backends/megatron_utils/test_fp32_param_utils.py diff --git a/docs/en/advanced/arch-support-beyond-megatron.md b/docs/en/advanced/arch-support-beyond-megatron.md index 0db8c8a40a..4b3e2a02ca 100644 --- a/docs/en/advanced/arch-support-beyond-megatron.md +++ b/docs/en/advanced/arch-support-beyond-megatron.md @@ -27,6 +27,35 @@ miles leverages this mechanism by **hijacking the spec generation stage to repla Through the coordination of these three components, we can successfully run a complex model architecture not natively supported by Megatron—using its HuggingFace implementation as the vehicle—on top of Megatron's parallel framework. This is achieved while fully retaining all key capabilities like model parallelism, MoE acceleration, and pipeline scheduling. +## Mixed-Precision: Preserving fp32 Parameters in bf16 Models + +Some model architectures require specific parameters to remain in fp32 even when the rest of the model runs in bf16. For example, Qwen3.5's `A_log` parameter must stay fp32 — if rounded to bf16, Megatron-side activations diverge from sglang's fp32 `A_log` on the rollout side, causing precision drift. + +Megatron's training stack has **three implicit cast points** that silently round fp32 parameters to bf16: `Float16Module` construction, `Bridge._weight_to_mcore_format`, and `Bridge.load_weights`. Both steps below are required — doing only one leaves a silent precision trap where the final dtype *looks* correct (fp32) but values were already rounded to bf16 precision. + +### Step 1: Mark the parameter in your model definition + +```python +from miles.backends.megatron_utils.fp32_param_utils import mark_param_dtype + +# In your model's __init__: +self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) +mark_param_dtype(self.A_log, torch.float32) +``` + +`enforce_marked_param_dtypes(model)` — already wired into training and checkpoint conversion entry points — restores tagged params to fp32 after `Float16Module` casts the entire model to bf16. + +### Step 2: Override the Bridge to bypass bf16 pre-cast during weight loading + +```python +class Qwen3_5Bridge(Qwen2MoEBridge): + def _weight_to_mcore_format(self, mcore_weights_name, hf_weights): + if mcore_weights_name.endswith("self_attention.linear_attn.A_log"): + assert len(hf_weights) == 1 + return hf_weights[0].to(dtype=torch.float32).contiguous() + return super()._weight_to_mcore_format(mcore_weights_name, hf_weights) +``` + ## Current Limitations * This approach does not currently support Tensor Parallelism (TP) within the replaced module itself (e.g., the Attention layer in this case). diff --git a/miles/backends/megatron_utils/fp32_param_utils.py b/miles/backends/megatron_utils/fp32_param_utils.py new file mode 100644 index 0000000000..afd6bde7f0 --- /dev/null +++ b/miles/backends/megatron_utils/fp32_param_utils.py @@ -0,0 +1,52 @@ +import logging +from collections.abc import Sequence + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +# Parameter attribute used by model definitions to pin parameter dtype. +FORCED_PARAM_DTYPE_ATTR = "_miles_forced_param_dtype" + + +def mark_param_dtype(param: torch.nn.Parameter, dtype: torch.dtype) -> None: + """Mark a parameter with its required runtime dtype.""" + setattr(param, FORCED_PARAM_DTYPE_ATTR, dtype) + + +def enforce_marked_param_dtypes(model_chunks: Sequence[torch.nn.Module]) -> list[str]: + """Apply dtype overrides declared on parameters via ``mark_param_dtype``. + + This keeps the policy in model definitions and avoids model-name checks in + the training/conversion mainline. + + Motivation: Megatron's ``Float16Module`` unconditionally casts every + floating-point parameter to bf16/fp16 at wrap time, and there is no + declarative opt-out in nn.Module or Megatron. Megatron's MoE router hits the + same problem and solves it with ``_maintain_float32_expert_bias`` (see + ``megatron/core/transformer/moe/router.py``), which post-hoc casts the + expert_bias back to fp32. This function generalizes that pattern: callers + mark params with their required dtype at the model-definition site, and we + re-cast after ``get_model`` so the rest of the stack (optimizer, DDP, mbridge + load path) sees the intended dtype. + """ + updated_names: list[str] = [] + for chunk in model_chunks: + for name, param in chunk.named_parameters(): + target_dtype = getattr(param, FORCED_PARAM_DTYPE_ATTR, None) + if target_dtype is None: + continue + + if param.dtype != target_dtype: + # Keep Parameter identity to avoid breaking optimizer/DDP maps. + param.data = param.data.to(dtype=target_dtype) + updated_names.append(name) + + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + if rank == 0 and updated_names: + logger.info("Enforced marked parameter dtypes for %d tensors.", len(updated_names)) + return updated_names diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index e95158bbfc..c7ca14b486 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -36,6 +36,7 @@ compute_model_hashes_by_layer, save_model_hashes, ) +from .fp32_param_utils import enforce_marked_param_dtypes from .initialize import is_megatron_main_rank from .lora_utils import is_lora_enabled, is_lora_model from .model_provider import get_model_provider_func @@ -125,6 +126,9 @@ def setup_model_and_optimizer( else: model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + # Apply parameter-level dtype overrides declared in model definitions. + enforce_marked_param_dtypes(model) + # Optimizer kwargs = {} for f in dataclasses.fields(OptimizerConfig): @@ -132,6 +136,7 @@ def setup_model_and_optimizer( kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) config.timers = None + optimizer = get_megatron_optimizer( config=config, model_chunks=model, diff --git a/miles_plugins/mbridge/qwen3_5.py b/miles_plugins/mbridge/qwen3_5.py index ee629d009f..8da5b7204b 100644 --- a/miles_plugins/mbridge/qwen3_5.py +++ b/miles_plugins/mbridge/qwen3_5.py @@ -254,6 +254,12 @@ def _convert_mtp_param(self, name: str) -> list[str]: def _weight_to_mcore_format( self, mcore_weights_name: str, hf_weights: list[torch.Tensor] ) -> tuple[list[str], list[torch.Tensor]]: + if mcore_weights_name.endswith("self_attention.linear_attn.A_log"): + assert len(hf_weights) == 1 + # Keep A_log in fp32 before TP scatter; this avoids precision loss + # from Bridge's global pre-cast to self.dtype. + return hf_weights[0].to(dtype=torch.float32).contiguous() + if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: # merge qkv assert len(hf_weights) == 3 diff --git a/miles_plugins/models/qwen3_5.py b/miles_plugins/models/qwen3_5.py index 794cf73808..5c43d732dc 100644 --- a/miles_plugins/models/qwen3_5.py +++ b/miles_plugins/models/qwen3_5.py @@ -15,6 +15,7 @@ except ImportError: pass +from miles.backends.megatron_utils.fp32_param_utils import mark_param_dtype from miles.backends.training_utils.cp_utils import build_gdn_cp_context from .hf_attention import HuggingfaceAttention, _load_hf_config @@ -71,8 +72,12 @@ def __init__(self, config, layer_idx: int): self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) A = torch.empty(self.num_v_heads).uniform_(0, 16) - self.A_log = nn.Parameter(torch.log(A)) + self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) + mark_param_dtype(self.A_log, torch.float32) + # HF stores this norm in fp32, but unlike A_log its precision impact is + # negligible and sglang runs it in bf16 on the rollout side — follow + # config.dtype (bf16) to stay equivalent to rollout. self.norm = FusedRMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, diff --git a/tests/fast/backends/megatron_utils/test_fp32_param_utils.py b/tests/fast/backends/megatron_utils/test_fp32_param_utils.py new file mode 100644 index 0000000000..c98b2c3e5d --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_fp32_param_utils.py @@ -0,0 +1,262 @@ +"""Tests for the A_log fp32 preservation chain. + +Feature: Qwen3.5's ``A_log`` must end up as fp32 in the Megatron parameter +after hf->mcore conversion, because the chunk-gated-delta-rule kernel relies +on that precision. Two complementary pieces keep this invariant: + +- Downstream — ``enforce_marked_param_dtypes`` (this module): + Megatron's ``Float16Module`` unconditionally casts every floating-point + parameter to bf16/fp16 at wrap time. There is no declarative opt-out in + nn.Module or Megatron; even Megatron's own MoE router uses the same + post-hoc ``.data = ...to(float32)`` pattern in + ``_maintain_float32_expert_bias``. We generalize that by letting model + definitions declare intent via ``mark_param_dtype`` and re-casting after + ``get_model`` returns. +- Upstream — ``Qwen3_5Bridge._weight_to_mcore_format``: + mbridge's base ``_weight_to_mcore_format`` pre-casts every HF tensor to + ``self.dtype`` (bf16) before TP scatter. For A_log that pre-cast rounds + the fp32 HF value. The override returns A_log as fp32 early, bypassing + that pre-cast entirely. + +The end-to-end test ties both halves together and checks bit-exact equality +with the HF fp32 source — this is the regression guard against the original +``patch_weight_to_mcore_format_preserve_fp32`` failure mode, where only the +upstream cast was intercepted and the downstream ``t.to(param.dtype)`` in +``Bridge.load_weights`` still demoted A_log back to bf16. +""" + +import pytest +import torch +import torch.nn as nn + +from miles.backends.megatron_utils.fp32_param_utils import ( + FORCED_PARAM_DTYPE_ATTR, + enforce_marked_param_dtypes, + mark_param_dtype, +) + + +# --------------------------------------------------------------------------- +# Downstream: mark_param_dtype + enforce_marked_param_dtypes +# --------------------------------------------------------------------------- + + +class _ToyModule(nn.Module): + """Minimal stand-in for Qwen3_5GatedDeltaNet: one marked fp32 param plus + one regular bf16-target param, so we can check the collateral damage + boundary of ``enforce_marked_param_dtypes``.""" + + def __init__(self, num_heads: int = 8): + super().__init__() + A = torch.empty(num_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A).to(torch.float32)) + mark_param_dtype(self.A_log, torch.float32) + self.in_proj = nn.Linear(16, num_heads, bias=False) + + +class TestMarkParamDtype: + def test_attaches_expected_attribute(self): + p = nn.Parameter(torch.zeros(4)) + mark_param_dtype(p, torch.float32) + assert getattr(p, FORCED_PARAM_DTYPE_ATTR) is torch.float32 + + def test_overwrites_previous_mark(self): + p = nn.Parameter(torch.zeros(4)) + mark_param_dtype(p, torch.float32) + mark_param_dtype(p, torch.float64) + assert getattr(p, FORCED_PARAM_DTYPE_ATTR) is torch.float64 + + +class TestEnforceMarkedParamDtypes: + def test_recasts_marked_param_back_to_fp32_after_float16_wrap(self): + """Simulates the full Megatron path: construct -> bfloat16() (what + ``Float16Module(...)`` does) -> enforce. A_log must come out fp32.""" + m = _ToyModule() + assert m.A_log.dtype == torch.float32 + + # Simulate Float16Module(config, m) — module.bfloat16() in the ctor + # demotes every floating param including the marked one. + m.bfloat16() + assert m.A_log.dtype == torch.bfloat16 + + enforce_marked_param_dtypes([m]) + assert m.A_log.dtype == torch.float32 + + def test_preserves_parameter_identity(self): + """Optimizer and DDP bucket parameters by Python identity, set up + AFTER ``enforce_marked_param_dtypes`` runs. If we re-bind via + ``self.A_log = nn.Parameter(...)`` the id changes and the optimizer + map breaks. We must only mutate ``.data``.""" + m = _ToyModule() + m.bfloat16() + before_id = id(m.A_log) + before_param_obj = m.A_log + + enforce_marked_param_dtypes([m]) + + assert id(m.A_log) == before_id + assert m.A_log is before_param_obj + + def test_leaves_unmarked_params_alone(self): + m = _ToyModule() + m.bfloat16() + assert m.in_proj.weight.dtype == torch.bfloat16 + + enforce_marked_param_dtypes([m]) + assert m.in_proj.weight.dtype == torch.bfloat16 + + def test_is_noop_when_already_target_dtype(self): + """Idempotency — second call must not re-allocate or change anything. + Guards against accidental double-work when the hook is called on + both the training and conversion entrypoints in the same process.""" + m = _ToyModule() + m.bfloat16() + enforce_marked_param_dtypes([m]) + + data_before = m.A_log.data + updated = enforce_marked_param_dtypes([m]) + assert m.A_log.dtype == torch.float32 + # ``.data`` should be the same tensor object (no unnecessary realloc). + assert m.A_log.data.data_ptr() == data_before.data_ptr() + # Name is still reported even on the no-realloc path — this is by + # design so the rank-0 log line reflects policy coverage, not churn. + assert any(n.endswith("A_log") for n in updated) + + def test_walks_multiple_model_chunks(self): + """``setup_model_and_optimizer`` passes a list of model chunks (for + virtual pipeline parallelism). The helper must iterate all of them.""" + chunks = [_ToyModule(), _ToyModule()] + for c in chunks: + c.bfloat16() + + enforce_marked_param_dtypes(chunks) + for c in chunks: + assert c.A_log.dtype == torch.float32 + + def test_returns_empty_when_no_marks(self): + m = nn.Linear(4, 4) + m.bfloat16() + assert enforce_marked_param_dtypes([m]) == [] + + +# --------------------------------------------------------------------------- +# Upstream: Qwen3_5Bridge._weight_to_mcore_format +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def bridge_stub(): + """Build a ``Qwen3_5Bridge`` without invoking ``__init__`` — ``__init__`` + needs a real HF config. The A_log branch only reads ``self.dtype``, which + we set directly, so skipping init is safe and lets this test stay + CPU-only and dep-free.""" + pytest.importorskip("mbridge") + from miles_plugins.mbridge.qwen3_5 import Qwen3_5Bridge + + bridge = Qwen3_5Bridge.__new__(Qwen3_5Bridge) + return bridge + + +class TestQwen3_5BridgeALogOverride: + A_LOG_NAME = "decoder.layers.0.self_attention.linear_attn.A_log" + + def test_returns_fp32_when_bridge_dtype_is_bf16(self, bridge_stub): + """The override must bypass mbridge's ``w.to(self.dtype)`` pre-cast + that would otherwise round HF fp32 to bf16 here.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.float32) + + out = bridge_stub._weight_to_mcore_format(self.A_LOG_NAME, [hf_tensor]) + + assert out.dtype == torch.float32 + assert torch.equal(out, hf_tensor) + assert out.is_contiguous() + + def test_upcasts_when_hf_input_is_bf16(self, bridge_stub): + """A_log arriving as bf16 (non-canonical ckpt) is still forced to + fp32 — the invariant is the output dtype, not the input's.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.bfloat16) + + out = bridge_stub._weight_to_mcore_format(self.A_LOG_NAME, [hf_tensor]) + + assert out.dtype == torch.float32 + + def test_mtp_layer_a_log_also_matches(self, bridge_stub): + """The override uses ``endswith`` so MTP-layer A_log + (``mtp.layers.{idx}...``) also matches — MTP is a real Qwen3.5 + variant and must not silently skip the override.""" + bridge_stub.dtype = torch.bfloat16 + hf_tensor = torch.randn(32, dtype=torch.float32) + + out = bridge_stub._weight_to_mcore_format("mtp.layers.0.self_attention.linear_attn.A_log", [hf_tensor]) + assert out.dtype == torch.float32 + + +# --------------------------------------------------------------------------- +# End-to-end: the two halves together, matching ``Bridge.load_weights``. +# --------------------------------------------------------------------------- + + +class TestALogLoadPathEndToEnd: + """Replays the dtype-relevant subset of ``Bridge.load_weights`` on a toy + model, as documented in ``tools/debug_a_log_old_flow.py``. No distributed + or real safetensor IO — only the two cast points we care about. + + Expected outcome: HF fp32 value lands in the Megatron A_log param + bit-exactly. Regression target: the OLD ``patch_weight_to_mcore_format_preserve_fp32`` + failed here because ``bridge.py:246`` still cast down to ``param.dtype == bf16``. + """ + + def test_lossless_roundtrip(self, bridge_stub): + a_log_name = "decoder.layers.0.self_attention.linear_attn.A_log" + hf_tensor = torch.randn(32, dtype=torch.float32) + + # 1. Build model (A_log marked fp32 at definition site). + model = _ToyModule(num_heads=32) + + # 2. Megatron wraps with Float16Module → .bfloat16(). + model.bfloat16() + + # 3. enforce_marked_param_dtypes restores A_log to fp32 BEFORE + # load_weights runs, so ``param.dtype`` at bridge.py:246 is fp32. + enforce_marked_param_dtypes([model]) + assert model.A_log.dtype == torch.float32 + + # 4. mbridge: _weight_to_mcore_format (with override → fp32). + bridge_stub.dtype = torch.bfloat16 # would demote without override + mcore_weight = bridge_stub._weight_to_mcore_format(a_log_name, [hf_tensor]) + assert mcore_weight.dtype == torch.float32 + + # 5. mbridge bridge.py:246 — ``t.to(param.device, dtype=param.dtype)``. + param = model.A_log + staged = mcore_weight.to(param.device, dtype=param.dtype).contiguous() + assert staged.dtype == torch.float32 # no-op cast + + # 6. mbridge bridge.py:258 — ``param.copy_(param_to_load)``. + param.data.copy_(staged) + + # Bit-exact round-trip: both halves were required to get here. + assert model.A_log.dtype == torch.float32 + assert torch.equal(model.A_log.data, hf_tensor) + + def test_old_patch_only_regresses_without_enforce(self, bridge_stub): + """Negative control: if we DROP ``enforce_marked_param_dtypes`` and + only keep the upstream override (the shape of the old patch), the + downstream ``t.to(param.dtype)`` still rounds to bf16. This pins the + old failure mode so it cannot be re-introduced by accident.""" + a_log_name = "decoder.layers.0.self_attention.linear_attn.A_log" + # Use a value where bf16 rounding is observable. + hf_tensor = torch.tensor([0.970378123] * 8, dtype=torch.float32) + + model = _ToyModule(num_heads=8) + model.bfloat16() # A_log is bf16; no enforce call here on purpose. + + bridge_stub.dtype = torch.bfloat16 + mcore_weight = bridge_stub._weight_to_mcore_format(a_log_name, [hf_tensor]) + assert mcore_weight.dtype == torch.float32 + + staged = mcore_weight.to(model.A_log.device, dtype=model.A_log.dtype).contiguous() + # Regression check: demoted to bf16 because param.dtype is bf16. + assert staged.dtype == torch.bfloat16 + assert not torch.equal(staged.to(torch.float32), hf_tensor) diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 354c216a65..a0682cf665 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -1,7 +1,6 @@ import gc import os import shutil -from functools import wraps import torch import torch.distributed as dist @@ -12,8 +11,8 @@ import miles_plugins.mbridge # noqa: F401 from mbridge import AutoBridge -from mbridge.core.bridge import Bridge from miles.backends.megatron_utils.arguments import set_default_megatron_args +from miles.backends.megatron_utils.fp32_param_utils import enforce_marked_param_dtypes from miles.backends.megatron_utils.initialize import init from miles.backends.megatron_utils.model_provider import get_model_provider_func from miles.utils.logging_utils import configure_logger @@ -21,24 +20,6 @@ from miles_plugins.models.hf_attention import _load_hf_config -def patch_weight_to_mcore_format_preserve_fp32(): - - original_method = Bridge._weight_to_mcore_format - - @wraps(original_method) - def patched_method(self, mcore_weights_name, hf_weights): - original_dtype = getattr(self, "dtype", None) - self.dtype = None - try: - result = original_method(self, mcore_weights_name, hf_weights) - finally: - self.dtype = original_dtype - return result - - Bridge._weight_to_mcore_format = patched_method - print("[Patch] Applied patch to preserve FP32 precision in _weight_to_mcore_format") - - def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") @@ -129,6 +110,7 @@ def main(): args = get_args() init(args) model = get_model(get_model_provider_func(args), ModelType.encoder_or_decoder, wrap_with_ddp=False) + enforce_marked_param_dtypes(model) # Load model hf_model_path = args.hf_checkpoint @@ -138,9 +120,6 @@ def main(): # Fallback for configs with model_type unknown to installed transformers. bridge = AutoBridge.from_config(_load_hf_config(hf_model_path)) - # Patch to preserve FP32 precision for _keep_fp32 params - patch_weight_to_mcore_format_preserve_fp32() - bridge.load_weights(model, hf_model_path, memory_efficient=True) print(f"Model loaded: {hf_model_path}") From 6cc3feb812801fd68abc4fff778148ae75f47697 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:17:05 -0700 Subject: [PATCH 35/44] fix: populate prefix_cache_info in OpenAI/session rollout path (#960) --- .../generate_utils/openai_endpoint_utils.py | 1 + .../test_openai_endpoint_utils.py | 77 ++++++++++++++++++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index d054cf2c52..2b3016fd4f 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -199,6 +199,7 @@ def _compute_sample_from_openai_record( case "abort": sample.status = Sample.Status.ABORTED + sample.prefix_cache_info.add(choice.get("meta_info", {})) if "weight_version" in choice["meta_info"]: sample.weight_versions.append(choice["meta_info"]["weight_version"]) diff --git a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py index e8cb2eb340..6d90fdd38b 100644 --- a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py +++ b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py @@ -49,6 +49,8 @@ def _make_record( output_token_ids: list[int], output_log_probs: list[float] | None = None, finish_reason: str = "stop", + cached_tokens: int | None = None, + prompt_tokens: int | None = None, ) -> SessionRecord: """Build a minimal session record mimicking SGLang's response format. @@ -62,6 +64,14 @@ def _make_record( logprobs_content = [ {"logprob": lp, "token": f"t{tid}"} for tid, lp in zip(output_token_ids, output_log_probs, strict=True) ] + meta_info = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_ids), + } + if cached_tokens is not None: + meta_info["cached_tokens"] = cached_tokens + if prompt_tokens is not None: + meta_info["prompt_tokens"] = prompt_tokens return SessionRecord( timestamp=0.0, method="POST", @@ -75,10 +85,7 @@ def _make_record( "message": {"role": "assistant", "content": "response"}, "finish_reason": finish_reason, "logprobs": {"content": logprobs_content}, - "meta_info": { - "output_token_logprobs": output_token_logprobs, - "completion_tokens": len(output_token_ids), - }, + "meta_info": meta_info, } ] }, @@ -557,3 +564,65 @@ def test_no_thinking_tokens_prefix_chain_holds(self): merged = merge_samples(samples, tok) assert merged.tokens == [1, 2, 3, 10, 11, 20, 21, 30, 31] + + +# ── test: prefix cache info population ──────────────────────────────── + + +class TestPrefixCacheInfo: + """Validate that prefix cache statistics from meta_info are collected.""" + + def test_single_record_with_cache_stats(self): + """cached_tokens and prompt_tokens from meta_info populate prefix_cache_info.""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + cached_tokens=2, + prompt_tokens=3, + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 2 + assert samples[0].prefix_cache_info.total_prompt_tokens == 3 + + def test_multi_turn_cache_stats_accumulate_after_merge(self): + """After merge_samples, prefix_cache_info sums across turns.""" + tok = _mock_tokenizer() + records = [ + _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + output_log_probs=[-0.1, -0.2], + cached_tokens=0, + prompt_tokens=3, + ), + _make_record( + prompt_token_ids=[1, 2, 3, 10, 11, 20, 21], + output_token_ids=[30, 31], + output_log_probs=[-0.3, -0.4], + cached_tokens=5, + prompt_tokens=7, + ), + ] + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, records, tok) + merged = merge_samples(samples, tok) + + assert merged.prefix_cache_info.cached_tokens == 0 + 5 + assert merged.prefix_cache_info.total_prompt_tokens == 3 + 7 + assert merged.prefix_cache_info.prefix_cache_hit_rate == 5 / 10 + + def test_missing_cache_fields_default_to_zero(self): + """Records without cached_tokens/prompt_tokens give zero prefix_cache_info (regression).""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 0 + assert samples[0].prefix_cache_info.total_prompt_tokens == 0 From 6706c7346d3f9b932dc2e5f000845d4e0bc055b8 Mon Sep 17 00:00:00 2001 From: Shi-Dong Date: Tue, 14 Apr 2026 15:35:00 -0700 Subject: [PATCH 36/44] Remove prepare_harbor_tasks.py; use harbor-private adapters (#982) --- examples/experimental/swe-agent-v2/README.md | 21 +- .../swe-agent-v2/prepare_harbor_tasks.py | 225 ------------------ 2 files changed, 15 insertions(+), 231 deletions(-) delete mode 100644 examples/experimental/swe-agent-v2/prepare_harbor_tasks.py diff --git a/examples/experimental/swe-agent-v2/README.md b/examples/experimental/swe-agent-v2/README.md index 4190c80292..c589cba3de 100644 --- a/examples/experimental/swe-agent-v2/README.md +++ b/examples/experimental/swe-agent-v2/README.md @@ -50,7 +50,6 @@ Docker Network (swe-net) | `swe_agent_function.py` | Custom agent function — dispatches to Harbor server, returns env metadata | | `generate.py` | Reward function, agent metrics aggregation, `RolloutFn` | | `download_and_process_data.py` | Download from HuggingFace or local JSONL, convert to Miles format | -| `prepare_harbor_tasks.py` | Convert Miles JSONL to Harbor task directories (generic fallback) | ## Step-by-Step Setup @@ -117,6 +116,20 @@ pip install harbor ### Step 4: Prepare data and Harbor task directories +Harbor task directories are prepared on the agent server side using **harbor adapters**. Each adapter converts a specific dataset into Harbor's 4-file task format. For example, to prepare SWE-bench tasks: + +```bash +# On the agent server (CPU machine), inside the harbor repo: +cd $CWD/harbor/adapters/swebench && uv sync + +# Generate Harbor task directories for all SWE-bench Verified instances +uv run run_adapter.py --task-dir $HARBOR_TASKS_DIR --all +``` + +This uses the `swebench` Python package to produce correct Docker image names and Dockerfiles for each instance. Other adapters (e.g. `adapters/swe-gym`) follow the same pattern. + +To prepare training data on the Miles side: + ```bash # Inside miles container: @@ -127,10 +140,6 @@ python download_and_process_data.py --input /data/tb.jsonl --output tb.jsonl \ # Merge into one mixed JSONL cat swe.jsonl tb.jsonl > mixed.jsonl - -# Create Harbor task dirs (for custom data without a Harbor adapter) -python prepare_harbor_tasks.py --input my.jsonl --output /root/harbor_tasks/ \ - --docker-network swe-net ``` Each Harbor task directory contains 4 files: @@ -294,7 +303,7 @@ Agent containers need to resolve the Miles container's hostname. Ensure: ### `TaskNotFound` error -The task directory for the given `instance_id` doesn't exist under `HARBOR_TASKS_DIR`. Run the appropriate Harbor adapter or `prepare_harbor_tasks.py` first. +The task directory for the given `instance_id` doesn't exist under `HARBOR_TASKS_DIR`. Run the appropriate harbor adapter first (e.g. `adapters/swebench/run_adapter.py` for SWE-bench tasks). ### SGLang engines OOM (`Not enough memory`) diff --git a/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py b/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py deleted file mode 100644 index ec452d980b..0000000000 --- a/examples/experimental/swe-agent-v2/prepare_harbor_tasks.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Convert training data to Harbor task directories (generic fallback). - -Reads a Miles JSONL (produced by ``download_and_process_data.py``) and -creates one Harbor task directory per instance. Each task directory is -self-contained — Harbor treats all tasks identically regardless of -their origin (SWE-bench, Terminal-Bench, custom, etc.). - -For standard benchmarks, prefer using Harbor's official adapters or -``harbor run -d `` to generate task directories — they -produce the exact grading harness used upstream. This script is a -generic fallback for custom datasets. - -Usage: - - python prepare_harbor_tasks.py \\ - --input /root/custom_train.jsonl \\ - --output /root/harbor_tasks/ \\ - --docker-network swe-net - -Required metadata fields per record: - - instance_id: unique task identifier (becomes directory name) - -Optional metadata fields (read if present): - - problem_statement / instruction / prompt: task text -> instruction.md - - docker_image: base Docker image (default: ubuntu:24.04) - - setup_commands: extra Dockerfile RUN commands (str or list) - - test_script: content of tests/test.sh - - timeout: verifier timeout in seconds (default: 1800) - - repo, version: included in task.toml if present - - patch: oracle solution -> solution/solve.sh -""" - -import argparse -import json -import logging -import os -import textwrap -from pathlib import Path - -logger = logging.getLogger(__name__) - - -def _get_instruction(metadata: dict) -> str: - for key in ("problem_statement", "instruction", "prompt"): - val = metadata.get(key, "") - if val: - return val - return "" - - -def _create_task_dir( - instance_id: str, - metadata: dict, - output_dir: Path, - docker_network: str | None = None, -) -> Path: - """Create a Harbor task directory from metadata.""" - task_dir = output_dir / instance_id - task_dir.mkdir(parents=True, exist_ok=True) - - (task_dir / "instruction.md").write_text(_get_instruction(metadata)) - - repo = metadata.get("repo", "") - version = metadata.get("version", "") - timeout = metadata.get("timeout", 1800) - - toml_lines = [ - "[task]", - f'name = "{instance_id}"', - ] - if repo: - toml_lines.append(f'repo = "{repo}"') - if version: - toml_lines.append(f'version = "{version}"') - toml_lines += [ - "", - "[limits]", - f"timeout = {timeout}", - ] - (task_dir / "task.toml").write_text("\n".join(toml_lines) + "\n") - - env_dir = task_dir / "environment" - env_dir.mkdir(exist_ok=True) - - docker_image = metadata.get("docker_image", "ubuntu:24.04") - setup_cmds = metadata.get("setup_commands", "") - if isinstance(setup_cmds, list): - setup_cmds = " && ".join(setup_cmds) - setup_block = f"RUN {setup_cmds}\n" if setup_cmds else "" - - (env_dir / "Dockerfile").write_text(f"FROM {docker_image}\n{setup_block}") - - if docker_network: - compose_yaml = textwrap.dedent( - f"""\ - services: - main: - networks: - - {docker_network} - networks: - {docker_network}: - external: true - """ - ) - (env_dir / "docker-compose.yaml").write_text(compose_yaml) - - tests_dir = task_dir / "tests" - tests_dir.mkdir(exist_ok=True) - - test_script = metadata.get("test_script", "") - if test_script: - test_sh = f"#!/bin/bash\n{test_script}\n" - else: - test_sh = textwrap.dedent( - """\ - #!/bin/bash - echo 0 > /logs/verifier/reward.txt - """ - ) - - (tests_dir / "test.sh").write_text(test_sh) - os.chmod(tests_dir / "test.sh", 0o755) - - patch = metadata.get("patch", "") - if patch: - sol_dir = task_dir / "solution" - sol_dir.mkdir(exist_ok=True) - (sol_dir / "fix.patch").write_text(patch) - solve_sh = textwrap.dedent( - """\ - #!/bin/bash - git apply "$(dirname "$0")/fix.patch" - """ - ) - (sol_dir / "solve.sh").write_text(solve_sh) - os.chmod(sol_dir / "solve.sh", 0o755) - - return task_dir - - -def convert( - input_path: str, - output_dir: str, - docker_network: str | None = None, -) -> int: - """Convert all instances from JSONL to Harbor task directories. - - Returns the number of tasks created. - """ - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - - records: list[dict] = [] - - with open(input_path) as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - except json.JSONDecodeError as e: - logger.warning(f"Skipping line {line_num}: {e}") - continue - - metadata = data.get("metadata", data) - instance_id = metadata.get("instance_id", "") - if not instance_id: - logger.warning(f"Skipping line {line_num}: no instance_id") - continue - - records.append(metadata) - - if not records: - logger.warning("No valid records found") - return 0 - - count = 0 - for metadata in records: - instance_id = metadata["instance_id"] - _create_task_dir( - instance_id, - metadata, - output_path, - docker_network=docker_network, - ) - count += 1 - if count % 100 == 0: - logger.info(f"Created {count} task directories...") - - logger.info(f"Created {count} task directories in {output_dir}") - return count - - -def main(): - parser = argparse.ArgumentParser( - description="Convert training JSONL to Harbor task directories", - ) - parser.add_argument( - "--input", - required=True, - help="Path to training JSONL", - ) - parser.add_argument( - "--output", - required=True, - help="Output directory for Harbor tasks", - ) - parser.add_argument( - "--docker-network", - default=None, - help="External Docker network for containers to join " "(e.g. swe-net)", - ) - args = parser.parse_args() - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s %(message)s", - ) - convert(args.input, args.output, docker_network=args.docker_network) - - -if __name__ == "__main__": - main() From f1449617168d18ef1c25ebc3190e30ab6d89b13c Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:51:24 -0700 Subject: [PATCH 37/44] [fix] Skip flush_cache in in_place mode and add fully async example (#974) Co-authored-by: Claude Opus 4.6 (1M context) --- .../run_qwen3_30b_a3b_fully_async.py | 176 ++++++++++++++++++ .../update_weight_from_distributed/mixin.py | 3 +- 2 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 examples/fully_async/run_qwen3_30b_a3b_fully_async.py diff --git a/examples/fully_async/run_qwen3_30b_a3b_fully_async.py b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py new file mode 100644 index 0000000000..0a2dbca924 --- /dev/null +++ b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +# in_place + broadcast +# python run_qwen3_30b_a3b_fully_async.py + +# retract + p2p +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode p2p + +# retract + broadcast +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode broadcast + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + model_name: str = "Qwen3-30B-A3B" + megatron_model_type: str = "qwen3-30B-A3B" + num_gpus_per_node: int = 8 + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + pause_generation_mode: Literal["in_place", "retract"] = "in_place" + update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" + extra_args: str = "" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{args.model_name}", + megatron_path=args.megatron_path, + ) + + +def execute(args: ScriptArgs): + if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": + raise ValueError( + "in_place + p2p is not supported: P2P transfer engine conflicts with " + "active NCCL inference. Use broadcast with in_place, or retract with p2p." + ) + + ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{args.model_name}/ " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + ) + + rollout_args = ( + "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type dapo " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + f"--pause-generation-mode {args.pause_generation_mode} " + ) + + perf_args = ( + "--tensor-model-parallel-size 8 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-tis " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_extra = "" + if args.update_weight_transfer_mode == "p2p": + sglang_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + f"--sglang-mem-fraction-static 0.7 {sglang_extra}" + "--sglang-cuda-graph-max-bs 512 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + f"--attention-backend flash --update-weight-transfer-mode {args.update_weight_transfer_mode} " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {args.num_gpus_per_node} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{sglang_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + import os + + fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__))) + U.execute_train( + train_args=train_args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + train_script="train_async.py", + megatron_path=args.megatron_path, + extra_env_vars={ + "FLASHINFER_DISABLE_VERSION_CHECK": "1", + "PYTHONPATH": f"{args.megatron_path}:{fully_async_dir}", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py index 006f15516d..6707993822 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py @@ -141,7 +141,8 @@ def _pause_and_prepare_engines(self) -> None: if dist.get_rank() == 0: mode = self.args.pause_generation_mode ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + if mode not in ("in_place"): + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) # int4/fp4 pre_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: From c271e14f7916e47299f74d00280aaa7cfec0bdde Mon Sep 17 00:00:00 2001 From: maocheng23 <35615230+maocheng23@users.noreply.github.com> Date: Wed, 15 Apr 2026 19:00:03 -0700 Subject: [PATCH 38/44] GLM47 full cmd for async and sync reasoning (#986) Co-authored-by: Claude Opus 4.6 (1M context) --- .../swe-agent-v2/run-glm47-reasoning-async.py | 354 ++++++++++++++++++ .../swe-agent-v2/run-glm47-reasoning.py | 308 +++++++++++++++ .../update_weight_from_distributed/p2p.py | 6 + 3 files changed, 668 insertions(+) create mode 100644 examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py create mode 100644 examples/experimental/swe-agent-v2/run-glm47-reasoning.py diff --git a/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py b/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py new file mode 100644 index 0000000000..104e2ad2c7 --- /dev/null +++ b/examples/experimental/swe-agent-v2/run-glm47-reasoning-async.py @@ -0,0 +1,354 @@ +"""GLM-4.7 Full (355B-A32B) fully-async reasoning training with GSM8K data. + +Disaggregated fully-async variant of run-glm47-reasoning.py: training and +rollout run on separate nodes concurrently. Uses train_async.py and the +fully_async_rollout module so that weight updates do not block generation. + +Default split: 4 nodes training + 12 nodes inference (configurable via +--train-num-nodes). Same model architecture as GLM-4.5-355B-A32B. +Targets 16 x 8-GPU H200 nodes. + +Usage: + python run-glm47-reasoning-async.py --num-nodes 16 + python run-glm47-reasoning-async.py --num-nodes 16 --train-num-nodes 8 + python run-glm47-reasoning-async.py --num-nodes 16 --rollout-fp8 + python run-glm47-reasoning-async.py --num-nodes 16 --pause-generation-mode retract --update-weight-transfer-mode p2p + python run-glm47-reasoning-async.py --num-nodes 16 --skip-prepare +""" + +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +SCRIPT_DIR = Path(__file__).resolve().parent +FULLY_ASYNC_DIR = (Path(__file__).resolve().parent.parent.parent / "fully_async").resolve() + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_rollout_only"] = "normal" + run_id: str = U.create_run_id() + megatron_model_type: str = "glm4.5-355B-A32B" + num_gpus_per_node: int = 8 + megatron_path: str = "/root/Megatron-LM" + + # Paths + skip_prepare: bool = False + model_name: str = "GLM-4.7" + hf_checkpoint: str = "/models/zai-org/GLM-4.7" + ref_load: str = "/models/zai-org/GLM-4.7_torch_dist" + save_dir: str = "/root/GLM-4.7-Full_reasoning_async/" + prompt_data: str = "/root/datasets/gsm8k/train.parquet" + rollout_max_response_len: int = 1024 + + # Rollout precision + rollout_fp8: bool = False + rollout_health_check_first_wait: int = 1800 + + # Disaggregated fully-async settings + train_num_nodes: int = 4 + pause_generation_mode: Literal["in_place", "retract"] = "in_place" + update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" + accumulate_allreduce_grads_in_fp32: bool = False + max_tokens_per_gpu: int = 2048 + optimizer_cpu_offload: bool = True + use_precision_aware_optimizer: bool = True + + # W&B settings + wandb_key: str = os.environ.get("WANDB_KEY", os.environ.get("WANDB_API_KEY", "")) + wandb_project: str = os.environ.get("WANDB_PROJECT", "glm47-full-reasoning-async") + wandb_team: str = os.environ.get("WANDB_TEAM", "") + wandb_run_name: str = "glm47-full-gsm8k-async" + + # Prometheus settings + use_prometheus: bool = True + prometheus_port: int = 9090 + prometheus_run_name: str = "glm47-full-gsm8k-async" + + +def cleanup(): + """Kill old Ray jobs and stale processes to free GPU resources.""" + my_pid = os.getpid() + ppid = os.getppid() + print(f"Cleanup starting (pid={my_pid}, ppid={ppid})") + targets = ["sglang", "train.py", "train_async.py", "MegatronTrain"] + exclude = f"grep -v '^{my_pid}$' | grep -v '^{ppid}$'" + for t in targets: + subprocess.run( + f"pgrep -f '{t}' | {exclude} | xargs -r kill 2>/dev/null || true", + shell=True, + ) + time.sleep(5) + print(f"Cleanup complete (pid={my_pid}) — old processes killed.") + + +def _convert_hf_to_fp8(args: ScriptArgs): + """Convert HF bf16 checkpoint to block-wise FP8 for SGLang rollout.""" + fp8_dir = f"{args.hf_checkpoint}-FP8" + if Path(fp8_dir).exists(): + print(f"FP8 checkpoint already exists at {fp8_dir}, skipping conversion.") + return + U.exec_command( + "python tools/convert_hf_to_fp8.py " + f"--model-dir {args.hf_checkpoint} " + f"--save-dir {fp8_dir} " + "--strategy block --block-size 128 128 " + "--max-workers 4" + ) + + +def prepare(args: ScriptArgs): + """Download GSM8K data and convert HF checkpoint to torch_dist format.""" + U.hf_download_dataset("zhuzilin/gsm8k") + + max_convert_nodes = 92 // args.num_gpus_per_node + convert_nodes = min(args.num_nodes, max_convert_nodes) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + multinode=True, + num_nodes=convert_nodes, + dir_dst=str(Path(args.ref_load).parent), + hf_checkpoint=args.hf_checkpoint, + megatron_path=args.megatron_path, + ) + + if args.rollout_fp8: + _convert_hf_to_fp8(args) + + +def execute(args: ScriptArgs): + if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": + raise ValueError( + "in_place + p2p is not supported: P2P transfer engine conflicts with " + "active NCCL inference. Use broadcast with in_place, or retract with p2p." + ) + + hf_checkpoint = f"{args.hf_checkpoint}-FP8" if args.rollout_fp8 else args.hf_checkpoint + ckpt_args = ( + f"--hf-checkpoint {hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_dir} " + "--save-interval 100 " + ) + + rollout_args = ( + "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " + f"--prompt-data {args.prompt_data} " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 4 " + "--rollout-temperature 0.8 " + f"--rollout-max-response-len {args.rollout_max_response_len} " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 64 " + "--balance-data " + f"--pause-generation-mode {args.pause_generation_mode} " + ) + + eval_args = ( + # "--eval-interval 20 " + # "--skip-eval-before-train " + # "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + # "--n-samples-per-eval-prompt 1 " + # "--eval-max-response-len 1024 " + # "--eval-top-k 1 " + ) + + # Disaggregated split: training on train_num_nodes, inference on the rest. + rollout_num_nodes = args.num_nodes - args.train_num_nodes + assert rollout_num_nodes > 0, ( + f"train_num_nodes ({args.train_num_nodes}) must be less than " + f"num_nodes ({args.num_nodes}) to leave room for inference" + ) + train_gpus = args.train_num_nodes * args.num_gpus_per_node + rollout_gpus = rollout_num_nodes * args.num_gpus_per_node + print( + f"Disagg split: {args.train_num_nodes} nodes ({train_gpus} GPUs) training, " + f"{rollout_num_nodes} nodes ({rollout_gpus} GPUs) inference" + ) + + # Training parallelism: TP=4, PP=2, EP chosen as largest divisor of 160 that fits. + tp, pp = 4, 2 + dp = train_gpus // (tp * pp) + assert train_gpus % (tp * pp) == 0, f"train GPUs ({train_gpus}) must be divisible by TP*PP ({tp * pp})" + num_experts = 160 + ep = max(d for d in range(1, dp + 1) if num_experts % d == 0 and dp % d == 0) + + perf_args = ( + f"--tensor-model-parallel-size {tp} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {pp} " + "--context-parallel-size 1 " + f"--expert-model-parallel-size {ep} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {args.max_tokens_per_gpu} " + ) + if args.optimizer_cpu_offload: + perf_args += "--optimizer-cpu-offload --overlap-cpu-optimizer-d2h-h2d " + if args.use_precision_aware_optimizer: + perf_args += "--use-precision-aware-optimizer " + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.01 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.0 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # SGLang: 4 nodes/engine with full EP + DP-attention on dedicated rollout nodes. + # 355B across 32 GPUs → ~22GB/GPU (bf16) or ~11GB/GPU (FP8) for weights. + # EP=32 with 160 experts → 5 experts/GPU. DP-attention keeps attention + # within a single node (attn_tp=8). + sglang_nodes_per_engine = min(4, rollout_num_nodes) + sglang_world_size = sglang_nodes_per_engine * args.num_gpus_per_node + num_engines = rollout_num_nodes // sglang_nodes_per_engine + assert rollout_num_nodes % sglang_nodes_per_engine == 0, ( + f"rollout nodes ({rollout_num_nodes}) must be divisible by " + f"sglang_nodes_per_engine ({sglang_nodes_per_engine})" + ) + print(f"Inference: {num_engines} engines x {sglang_world_size} GPUs/engine") + sglang_decode_max_bs = 256 + sglang_attn_tp_size = min(args.num_gpus_per_node, sglang_world_size) + sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size + + sglang_p2p_extra = "" + if args.update_weight_transfer_mode == "p2p": + sglang_p2p_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " + + sglang_args = ( + f"--rollout-num-gpus-per-engine {sglang_world_size} " + "--sglang-mem-fraction-static 0.80 " + f"--sglang-tp-size {sglang_world_size} " + f"--sglang-ep-size {sglang_world_size} " + "--sglang-enable-dp-attention " + f"--sglang-dp-size {sglang_attn_dp_size} " + "--sglang-moe-dense-tp-size 1 " + "--sglang-enable-dp-lm-head " + "--sglang-moe-a2a-backend deepep " + "--sglang-deepep-mode low_latency " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + f"{sglang_p2p_extra}" + ) + if args.rollout_fp8: + sglang_args += "--sglang-moe-runner-backend deep_gemm " + sglang_extra_env_vars = { + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}", + } + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + f"--update-weight-transfer-mode {args.update_weight_transfer_mode} " + f"--update-weight-buffer-size {2 * 1024 ** 3} " + f"--actor-num-nodes {args.train_num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {rollout_gpus} " + "--grad-reduce-in-bf16 " + "--use-fault-tolerance " + f"--rollout-health-check-first-wait {args.rollout_health_check_first_wait} " + ) + if args.accumulate_allreduce_grads_in_fp32: + misc_args += "--accumulate-allreduce-grads-in-fp32 " + + debug_args = "--debug-rollout-only " if args.mode == "debug_rollout_only" else "" + + wandb_args = "" + if args.wandb_key: + wandb_args = ( + "--use-wandb " + f"--wandb-project {args.wandb_project} " + f"--wandb-group {args.wandb_run_name} " + f"--wandb-key {args.wandb_key} " + ) + if args.wandb_team: + wandb_args += f"--wandb-team {args.wandb_team} " + + prometheus_args = "" + if args.use_prometheus: + prometheus_args = ( + "--use-prometheus " + f"--prometheus-port {args.prometheus_port} " + f"--prometheus-run-name {args.prometheus_run_name} " + ) + + train_args = ( + f"{ckpt_args}" + f"{rollout_args}" + f"{eval_args}" + f"{optimizer_args}" + f"{grpo_args}" + f"{wandb_args}" + f"{prometheus_args}" + f"{perf_args}" + f"{sglang_args}" + f"{misc_args}" + f"{debug_args}" + ) + + miles_root = U.repo_base_dir + + extra_env_vars = { + "PYTHONPATH": f"{args.megatron_path}:{SCRIPT_DIR}:{FULLY_ASYNC_DIR}:{miles_root}", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "NCCL_NVLS_ENABLE": "0", + "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "false", + **sglang_extra_env_vars, + } + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + train_script="train_async.py", + megatron_path=args.megatron_path, + extra_env_vars=extra_env_vars, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + cleanup() + if not args.skip_prepare: + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/experimental/swe-agent-v2/run-glm47-reasoning.py b/examples/experimental/swe-agent-v2/run-glm47-reasoning.py new file mode 100644 index 0000000000..e133f53570 --- /dev/null +++ b/examples/experimental/swe-agent-v2/run-glm47-reasoning.py @@ -0,0 +1,308 @@ +"""GLM-4.7 Full (355B-A32B) reasoning training with GSM8K data. + +Debug script: uses math (GSM8K) data instead of agentic tool use to verify +that the training pipeline produces nonzero rewards and learns successfully. + +Same model architecture and parallelism as run-glm47-full.py. +Targets 16 x 8-GPU H200 nodes (sci-h200). + +Usage: + python run-glm47-reasoning.py --num-nodes 16 + python run-glm47-reasoning.py --num-nodes 16 --rollout-fp8 + python run-glm47-reasoning.py --num-nodes 16 --skip-prepare + python run-glm47-reasoning.py --num-nodes 16 --mode debug_rollout_only +""" + +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +SCRIPT_DIR = Path(__file__).resolve().parent + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_rollout_only"] = "normal" + run_id: str = U.create_run_id() + megatron_model_type: str = "glm4.5-355B-A32B" + num_gpus_per_node: int = 8 + megatron_path: str = "/root/Megatron-LM" + + # Paths + skip_prepare: bool = False + model_name: str = "GLM-4.7" + hf_checkpoint: str = "/models/zai-org/GLM-4.7" + ref_load: str = "/models/zai-org/GLM-4.7_torch_dist" + save_dir: str = "/root/GLM-4.7-Full_reasoning/" + prompt_data: str = "/root/datasets/gsm8k/train.parquet" + rollout_max_response_len: int = 1024 + + # Rollout precision + rollout_fp8: bool = False + + # W&B settings + wandb_key: str = os.environ.get("WANDB_KEY", os.environ.get("WANDB_API_KEY", "")) + wandb_project: str = os.environ.get("WANDB_PROJECT", "glm47-full-reasoning") + wandb_team: str = os.environ.get("WANDB_TEAM", "") + wandb_run_name: str = "glm47-full-gsm8k" + + # Prometheus settings + use_prometheus: bool = True + prometheus_port: int = 9090 + prometheus_run_name: str = "glm47-full-gsm8k" + + +def cleanup(): + """Kill old Ray jobs and stale processes to free GPU resources.""" + my_pid = os.getpid() + ppid = os.getppid() + print(f"Cleanup starting (pid={my_pid}, ppid={ppid})") + targets = ["sglang", "train.py", "MegatronTrain"] + exclude = f"grep -v '^{my_pid}$' | grep -v '^{ppid}$'" + for t in targets: + subprocess.run( + f"pgrep -f '{t}' | {exclude} | xargs -r kill 2>/dev/null || true", + shell=True, + ) + time.sleep(5) + print(f"Cleanup complete (pid={my_pid}) — old processes killed.") + + +def _convert_hf_to_fp8(args: ScriptArgs): + """Convert HF bf16 checkpoint to block-wise FP8 for SGLang rollout.""" + fp8_dir = f"{args.hf_checkpoint}-FP8" + if Path(fp8_dir).exists(): + print(f"FP8 checkpoint already exists at {fp8_dir}, skipping conversion.") + return + U.exec_command( + "python tools/convert_hf_to_fp8.py " + f"--model-dir {args.hf_checkpoint} " + f"--save-dir {fp8_dir} " + "--strategy block --block-size 128 128 " + "--max-workers 4" + ) + + +def prepare(args: ScriptArgs): + """Download GSM8K data and convert HF checkpoint to torch_dist format.""" + # Download GSM8K dataset + U.hf_download_dataset("zhuzilin/gsm8k") + + # Convert checkpoint (multinode for 355B) + # The conversion tool requires world_size <= num_layers (92 for this model). + max_convert_nodes = 92 // args.num_gpus_per_node # 11 for 8 GPUs/node + convert_nodes = min(args.num_nodes, max_convert_nodes) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + multinode=True, + num_nodes=convert_nodes, + dir_dst=str(Path(args.ref_load).parent), + hf_checkpoint=args.hf_checkpoint, + megatron_path=args.megatron_path, + ) + + if args.rollout_fp8: + _convert_hf_to_fp8(args) + + +def execute(args: ScriptArgs): + hf_checkpoint = f"{args.hf_checkpoint}-FP8" if args.rollout_fp8 else args.hf_checkpoint + ckpt_args = ( + f"--hf-checkpoint {hf_checkpoint} " + f"--ref-load {args.ref_load} " + f"--save {args.save_dir} " + "--save-interval 100 " + ) + + rollout_args = ( + f"--prompt-data {args.prompt_data} " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 4 " + "--rollout-temperature 0.8 " + f"--rollout-max-response-len {args.rollout_max_response_len} " + "--over-sampling-batch-size 64 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 64 " + ) + + eval_args = ( + "--eval-interval 20 " + "--skip-eval-before-train " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + # Training parallelism: TP=4, PP=2, EP chosen as largest divisor of 160 that fits. + tp, pp = 4, 2 + total_gpus = args.num_nodes * args.num_gpus_per_node + dp = total_gpus // (tp * pp) + assert total_gpus % (tp * pp) == 0, f"total GPUs ({total_gpus}) must be divisible by TP*PP ({tp * pp})" + num_experts = 160 + ep = max(d for d in range(1, dp + 1) if num_experts % d == 0) + + perf_args = ( + f"--tensor-model-parallel-size {tp} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {pp} " + "--context-parallel-size 1 " + f"--expert-model-parallel-size {ep} " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 2048 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.01 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.0 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # SGLang: 4 nodes/engine with full EP + DP-attention. + # 355B across 32 GPUs → ~22GB/GPU (bf16) or ~11GB/GPU (FP8) for weights, + # leaving plenty for KV cache. EP=32 with 160 experts → 5 experts/GPU. + # DP-attention keeps attention within a single node (attn_tp=8). + sglang_nodes_per_engine = min(4, args.num_nodes) + sglang_world_size = sglang_nodes_per_engine * args.num_gpus_per_node + assert ( + total_gpus % sglang_world_size == 0 + ), f"total GPUs ({total_gpus}) must be divisible by sglang_world_size ({sglang_world_size})" + sglang_decode_max_bs = 256 + sglang_attn_tp_size = min(args.num_gpus_per_node, sglang_world_size) + sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size + sglang_args = ( + f"--rollout-num-gpus-per-engine {sglang_world_size} " + "--sglang-mem-fraction-static 0.80 " + f"--sglang-tp-size {sglang_world_size} " + f"--sglang-ep-size {sglang_world_size} " + "--sglang-enable-dp-attention " + f"--sglang-dp-size {sglang_attn_dp_size} " + "--sglang-moe-dense-tp-size 1 " + "--sglang-enable-dp-lm-head " + "--sglang-moe-a2a-backend deepep " + "--sglang-deepep-mode low_latency " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + ) + if args.rollout_fp8: + sglang_args += "--sglang-moe-runner-backend deep_gemm " + sglang_extra_env_vars = { + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}", + } + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--colocate " + f"--update-weight-buffer-size {2 * 1024 ** 3} " + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {total_gpus} " + "--use-fault-tolerance " + ) + + debug_args = "--debug-rollout-only " if args.mode == "debug_rollout_only" else "" + + wandb_args = "" + if args.wandb_key: + wandb_args = ( + "--use-wandb " + f"--wandb-project {args.wandb_project} " + f"--wandb-group {args.wandb_run_name} " + f"--wandb-key {args.wandb_key} " + ) + if args.wandb_team: + wandb_args += f"--wandb-team {args.wandb_team} " + + prometheus_args = "" + if args.use_prometheus: + prometheus_args = ( + "--use-prometheus " + f"--prometheus-port {args.prometheus_port} " + f"--prometheus-run-name {args.prometheus_run_name} " + ) + + train_args = ( + f"{ckpt_args}" + f"{rollout_args}" + f"{eval_args}" + f"{optimizer_args}" + f"{grpo_args}" + f"{wandb_args}" + f"{prometheus_args}" + f"{perf_args}" + f"{sglang_args}" + f"{misc_args}" + f"{debug_args}" + ) + + miles_root = U.repo_base_dir + + extra_env_vars = { + "PYTHONPATH": f"{args.megatron_path}:{SCRIPT_DIR}:{miles_root}", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "NCCL_NVLS_ENABLE": "0", + "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "false", + **sglang_extra_env_vars, + } + + U.execute_train( + train_args=train_args, + config=args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + megatron_path=args.megatron_path, + extra_env_vars=extra_env_vars, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + cleanup() + if not args.skip_prepare: + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py index fe287f72c4..ba6cebde7e 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/p2p.py @@ -11,6 +11,9 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import ParallelismContext, RankParallelismConfig +from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config +from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.model_loader import get_model from sglang.srt.model_loader.parameter_mapper import ParameterMapper from sglang.srt.server_args import ServerArgs @@ -259,6 +262,9 @@ def create_cpu_replica( rl_quant_profile=server_args.rl_quant_profile, ) server_args_module._global_server_args = server_args + initialize_moe_config(server_args) + initialize_fp8_gemm_config(server_args) + initialize_fp4_gemm_config(server_args) with ParallelismContext(parallelism_config): model = get_model( model_config=ModelConfig(model_path), From 7b7efa922b47f3cb2b3d1e1b6c549ae6d7aa640e Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Thu, 16 Apr 2026 07:13:14 -0700 Subject: [PATCH 39/44] fix(rollout): guard round(None) in zero-std metric aggregation _compute_zero_std_metrics crashes with TypeError when any zero-std group's leading sample has a None reward (typical for Status.ABORTED trials): File "miles/ray/rollout.py", line 1266, in _compute_zero_std_metrics interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) ...] TypeError: type NoneType doesn't define __round__ method This crash fires on RolloutManager.generate() inside _log_rollout_data, after the rollout collection + dynamic sampling filter have already accepted the batch. With agentic tasks where some trials routinely abort (Daytona sandbox timeout, tool-invocation loops, etc.), the trainer never receives the batch and optimizer.step() never fires, so async RL training silently stalls. Fix: extract a _reward_label helper that buckets None-reward samples under a dedicated 'none' label instead of passing None to round(). This keeps the metric informative (zero_std/count_none shows the aborted-group count) and preserves the existing behavior for numeric rewards. Observed on LLM360/RL360 #76 FAST_ITER smoke runs (job 1559799) with GLM-4.7-Flash on agentic terminal-bench tasks. --- miles/ray/rollout.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..af11accfc3 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1260,10 +1260,18 @@ def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] return len(rewards) == 0 or all(rewards[0] == r for r in rewards) + def _reward_label(sample: Sample) -> str: + # Aborted / None-reward samples have no numeric reward to round; bucket + # them under a dedicated label so downstream round() never sees None. + reward = sample.get_reward_value(args) + if reward is None: + return "none" + return str(round(reward, 1)) + all_sample_groups = group_by(all_samples, lambda s: s.group_index) interesting_sample_groups = [g for g in all_sample_groups.values() if _is_zero_std(g)] - interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) for g in interesting_sample_groups] + interesting_rewards = [_reward_label(g[0]) for g in interesting_sample_groups] return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()} From f0c9d3cc1f9f9a9ed98723e9462f8e1a3465a428 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Thu, 16 Apr 2026 07:15:09 -0700 Subject: [PATCH 40/44] feat(sglang_engine): allow PD worker_type on /add_worker registration path The old sglang_router (<=0.2.1) and the miles-router both use the single-arg /add_worker?url=... endpoint for engine registration. Previously, the Miles engine asserted worker_type=='regular' before hitting that endpoint, so any attempt to stand up prefill/decode workers via the miles-router path (including the sgl-model-gateway that mirrors it) fail-fasts at engine init: AssertionError: pd disaggregation is not supported in old router or miles router. This blocks PD disagg throughput scaling in any deployment that uses the miles-router path, even when the receiving router (e.g. sgl-model-gateway with a PD-aware shim) can handle worker_type on /add_worker. Relax the assertion: forward worker_type (and bootstrap_port for prefill) as extra query params. Routers that honor them get PD registration; routers that only accept the single-arg form ignore the extras and register as regular, with a warning logged so the fallback is visible. The companion server-side change is on the receiving router: - sgl-model-gateway must accept ?worker_type=&bootstrap_port= on /add_worker - Or deployments can use the newer /workers endpoint (non-miles path). Context: LLM360/RL360 #76. Track G (job 1559336) showed full PD KV transfer via mooncake works with SGLang's own mini_lb; this unblocks the same flow through Miles-driven rollouts. --- miles/backends/sglang_utils/sglang_engine.py | 32 +++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 8b567a744d..6650930a20 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -216,12 +216,34 @@ def _init_normal(self, server_args_dict): if self.node_rank == 0 and self.router_ip and self.router_port: if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: - assert ( - self.worker_type == "regular" - ), "pd disaggregation is not supported in old router or miles router." - response = requests.post( - f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}" + # Old sglang_router (<=0.2.1) and miles-router use the single-arg + # /add_worker?url=... endpoint. For PD disaggregation, forward + # worker_type (and bootstrap_port for prefill) as extra query + # params so a router that supports PD via /add_worker can act on + # them. Routers that only understand the regular form will see + # the extra params, ignore them, and register the worker as + # regular -- so PD routing through such a router still needs a + # server-side update. This at least removes the unconditional + # assert that would fail-fast before the request is ever sent. + add_worker_url = ( + f"http://{self.router_ip}:{self.router_port}/add_worker" + f"?url=http://{self.server_host}:{self.server_port}" ) + if self.worker_type != "regular": + add_worker_url += f"&worker_type={self.worker_type}" + if self.worker_type == "prefill": + bootstrap_port = server_args_dict.get("disaggregation_bootstrap_port") + if bootstrap_port is not None: + add_worker_url += f"&bootstrap_port={bootstrap_port}" + logger.warning( + "Registering a '%s' worker via /add_worker on the " + "old-style router path. PD disaggregation requires the " + "router to honor worker_type on this endpoint; if it " + "only accepts the single-arg form, workers will be " + "treated as regular and PD routing will not function.", + self.worker_type, + ) + response = requests.post(add_worker_url) else: payload = { "url": f"http://{self.server_host}:{self.server_port}", From 779839cb2569f411a373dac9a229b469aa0aa991 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Thu, 16 Apr 2026 09:53:57 -0700 Subject: [PATCH 41/44] fix(rollout): propagate PYTHONPATH to Ray remote actors SGLangEngine Ray actors created via RolloutRayActor.options(runtime_env= {'env_vars': env_vars}) only inherit the env vars explicitly listed in env_vars. PYTHONPATH is not included, so the actor process uses only the container's default PYTHONPATH. In deployments that install miles as a pip package in the container image (e.g. the radixark/miles:dev overlay), 'import miles' in the actor imports from /root/miles (site-packages pointer), silently bypassing any MILES_OVERRIDE prepended to PYTHONPATH on the driver. This means local-patched miles code on the driver's PYTHONPATH is not executed in actors, so per-cluster patches (e.g. kept in a shared clone under SHARED_DIR/miles) never reach SGLangEngine. Observed on LLM360/RL360#76: the driver had the patched miles from LLM360/miles:deploy visible via PYTHONPATH, but SGLangEngine actors still hit the assert from an earlier revision because they imported from the container's /root/miles. Fix: copy os.environ['PYTHONPATH'] into env_vars when it is set. No-op when the driver doesn't have PYTHONPATH exported (container defaults apply; same behavior as before). --- miles/ray/rollout.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..1c78111d5d 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -139,6 +139,13 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis }.items() } env_vars.update(dumper_utils.get_sglang_env(self.args)) + # Propagate PYTHONPATH so Ray remote actors import the same miles + # / sglang modules as the driver. Without this, SGLangEngine + # actors inherit only the container's default PYTHONPATH and + # `import miles` falls back to a pip-installed /root/miles, + # silently bypassing any MILES_OVERRIDE on the driver's PYTHONPATH. + if "PYTHONPATH" in os.environ: + env_vars["PYTHONPATH"] = os.environ["PYTHONPATH"] rollout_engine = RolloutRayActor.options( num_cpus=num_cpus, From 9a0ef97613294d854b60e98e2202f58a7734936f Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Fri, 17 Apr 2026 12:07:39 -0700 Subject: [PATCH 42/44] fix(session-server): strip stale hop-by-hop headers when re-emitting proxy response build_proxy_response forwards upstream gateway headers into either JSONResponse or Response. Both re-serialize / re-frame the body: * JSONResponse runs json.dumps over the parsed content, so whitespace and unicode escape behavior may produce a different byte count than the upstream did. * Response may be re-framed by Starlette with chunked transfer encoding. Forwarding the upstream content-length, transfer-encoding, or content-encoding in these cases causes a mismatch between declared framing and the bytes Starlette actually writes. Clients (e.g. Miles's own http_utils.post) then error with h11 LocalProtocolError 'Too much data for declared Content-Length' or 'peer closed connection without sending complete message body' and retry. Observed: on a mock-agent FAST_ITER run with PD disaggregation through a gateway that serializes merged prefill+decode logprobs, ~200 of 332 chat completions hit this error before mock retries salvaged training progress. Strip the three hop-by-hop headers before building the outgoing Response; Starlette / hyper then recompute the correct framing from the actual body. --- miles/rollout/session/session_server.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/miles/rollout/session/session_server.py b/miles/rollout/session/session_server.py index bc2633350e..fd0d5698e0 100644 --- a/miles/rollout/session/session_server.py +++ b/miles/rollout/session/session_server.py @@ -81,7 +81,20 @@ async def do_proxy( def build_proxy_response(self, result: dict) -> Response: content = result["response_body"] status_code = result["status_code"] - headers = result["headers"] + # Strip hop-by-hop headers that become stale when we re-emit the body. + # JSONResponse re-serializes via json.dumps (different whitespace/unicode + # escape behavior than the upstream may have used) and Response may be + # re-framed with chunked encoding. Forwarding the upstream content-length, + # transfer-encoding, or content-encoding produces a mismatch between the + # declared framing and the bytes Starlette actually writes, surfacing to + # clients as h11 LocalProtocolError ("Too much data for declared + # Content-Length") or "peer closed connection without sending complete + # message body". Starlette/hyper will recompute the correct values from + # the actual body when we omit these headers. + headers = { + k: v for k, v in result["headers"].items() + if k.lower() not in ("content-length", "transfer-encoding", "content-encoding") + } content_type = headers.get("content-type", "") try: data = json.loads(content) From af3de7036ea9f730e0a01d77467e8681be4d6b6e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 17 Apr 2026 23:13:33 +0000 Subject: [PATCH 43/44] state:c271e14f7916e47299f74d00280aaa7cfec0bdde|fix/allow-pd-worker-type-on-miles-router=f0c9d3cc1f9f9a9ed98723e9462f8e1a3465a428,fix/guard-round-none-in-zero-std-metrics=7b7efa922b47f3cb2b3d1e1b6c549ae6d7aa640e,fix/propagate-pythonpath-to-ray-remote-actors=779839cb2569f411a373dac9a229b469aa0aa991,fix/rollback-error-recovery=dd188aaddb887e5c1de066277e7c9f10eec297b0,fix/session-auto-create=29a0dcad45e36985c7af6f7c1dcf653e0eace4f2,fix/session-server-strip-stale-content-length-clean=9a0ef97613294d854b60e98e2202f58a7734936f,fix/truncate-routed-experts=25645357a7d51daa307a9fa25081095bc3cfb2a1 From c15c70487c93a69248f327550399f14890ecf13c Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Sat, 18 Apr 2026 12:38:23 -0700 Subject: [PATCH 44/44] arguments: allow 'assistant' in --tito-allowed-append-roles choices Multi-turn agent harnesses such as Harbor's terminus-2 append their own planning or self-reflection assistant messages to the conversation before the next tool/user turn. TITO's session-server validates the appended role against this allowlist; without 'assistant' in the choices, the agent's 400 surfaces as: litellm.BadRequestError: OpenAIException - Error code: 400 - {'error': "appended message at index N has role='assistant', allowed=[...]; to allow more roles use --tito-allowed-append-roles"} and the user cannot fix it via CLI because argparse rejects 'assistant' with 'invalid choice'. This widens the allowlist; the default remains ['tool'], so no change for existing callers. --- miles/utils/arguments.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 375cd6c2c2..d9650996f2 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1617,9 +1617,12 @@ def add_session_arguments(parser): "--tito-allowed-append-roles", nargs="+", default=["tool"], - choices=["tool", "user", "system"], + choices=["tool", "user", "system", "assistant"], help="Message roles allowed to be appended after the pretokenized " - "assistant prefix in TITO sessions (default: tool).", + "assistant prefix in TITO sessions (default: tool). Include " + "'assistant' for multi-turn agents (e.g. terminus-2) that " + "append their own planning/self-reflection turns before the " + "next tool or user message.", ) return parser