diff --git a/.github/workflows/maintain-deploy.yml b/.github/workflows/maintain-deploy.yml new file mode 100644 index 000000000000..6bf55d92322b --- /dev/null +++ b/.github/workflows/maintain-deploy.yml @@ -0,0 +1,131 @@ +name: Maintain deploy branch + +on: + schedule: + - cron: '*/15 * * * *' + workflow_dispatch: + inputs: + force: + description: 'Rebuild even if no changes detected' + type: boolean + default: false + +env: + BASE_BRANCH: llm360-main + +jobs: + check: + runs-on: ubuntu-latest + outputs: + skip: ${{ steps.fingerprint.outputs.skip }} + state: ${{ steps.fingerprint.outputs.state }} + branches: ${{ steps.fingerprint.outputs.branches }} + steps: + - name: Compute desired state and compare + id: fingerprint + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + FORCE: ${{ inputs.force }} + REPO: ${{ github.repository }} + run: | + BASE_SHA=$(gh api "repos/${REPO}/commits/${BASE_BRANCH}" --jq '.sha' 2>/dev/null || echo "unknown") + + PR_STATE=$(gh api "repos/${REPO}/pulls?state=open&base=${BASE_BRANCH}&per_page=100" \ + --jq "sort_by(.head.ref) | map(.head.ref + \"=\" + .head.sha) | join(\",\")") + + BRANCHES=$(gh api "repos/${REPO}/pulls?state=open&base=${BASE_BRANCH}&per_page=100" \ + --jq '.[].head.ref' | tr '\n' ' ') + + STATE="${BASE_SHA}|${PR_STATE}" + echo "state=$STATE" >> "$GITHUB_OUTPUT" + echo "branches=$BRANCHES" >> "$GITHUB_OUTPUT" + echo "Desired state: $STATE" + + CURRENT=$(gh api "repos/${REPO}/commits/deploy" \ + --jq '.commit.message' 2>/dev/null \ + | grep '^state:' | head -1 | cut -d: -f2- || echo "") + + echo "Current state: $CURRENT" + + if [ "$CURRENT" = "$STATE" ] && [ "$FORCE" != "true" ]; then + echo "skip=true" >> "$GITHUB_OUTPUT" + echo "::notice::No changes detected, skipping rebuild" + else + echo "skip=false" >> "$GITHUB_OUTPUT" + fi + + rebuild: + needs: check + if: needs.check.outputs.skip != 'true' + runs-on: ubuntu-latest + permissions: + contents: write + issues: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Build deploy branch + env: + PR_BRANCHES: ${{ needs.check.outputs.branches }} + run: | + git checkout -B deploy "origin/${BASE_BRANCH}" + + MERGED="" + FAILED="" + for branch in $PR_BRANCHES; do + echo "Merging $branch..." + if git merge "origin/$branch" --no-edit -m "Deploy: merge $branch"; then + MERGED="$MERGED $branch" + else + echo "::error::Merge conflict on $branch, skipping" + git merge --abort + FAILED="$FAILED $branch" + fi + done + + echo "Successfully merged:${MERGED:-}" + if [ -n "$FAILED" ]; then + echo "::warning::Failed to merge (conflicts):$FAILED" + fi + echo "FAILED_BRANCHES=$FAILED" >> "$GITHUB_ENV" + + - name: Stamp fingerprint and push + env: + STATE: ${{ needs.check.outputs.state }} + run: | + git commit --allow-empty -m "state:${STATE}" + git push origin deploy --force + + - name: Report merge conflicts + if: always() && needs.check.outputs.skip != 'true' + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + ISSUE_TITLE="Deploy: merge conflict" + EXISTING=$(gh issue list --label deploy-conflict --state open --json number --jq '.[0].number' 2>/dev/null || echo "") + + if [ -z "$FAILED_BRANCHES" ]; then + if [ -n "$EXISTING" ]; then + gh issue close "$EXISTING" --comment "Resolved: all branches merged cleanly." + fi + exit 0 + fi + + BODY=$(printf "The following branches failed to merge into deploy:\n\n") + for b in $FAILED_BRANCHES; do + BODY=$(printf "%s\n- \`%s\`" "$BODY" "$b") + done + BODY=$(printf "%s\n\nThis issue auto-closes when the next build merges all branches cleanly." "$BODY") + + if [ -n "$EXISTING" ]; then + gh issue edit "$EXISTING" --body "$BODY" + else + gh issue create --title "$ISSUE_TITLE" --body "$BODY" --label deploy-conflict + fi diff --git a/.github/workflows/nightly-test-nvidia.yml b/.github/workflows/nightly-test-nvidia.yml index 0373ea2bce57..f3b33b8cbc9f 100644 --- a/.github/workflows/nightly-test-nvidia.yml +++ b/.github/workflows/nightly-test-nvidia.yml @@ -683,6 +683,7 @@ jobs: if: always() env: GH_PAT_FOR_NIGHTLY_CI_DATA: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} + GH_TOKEN: ${{ github.token }} run: | python3 scripts/ci/utils/diffusion/generate_diffusion_dashboard.py \ --results comparison-results.json \ diff --git a/python/sglang/cli/utils.py b/python/sglang/cli/utils.py index be8981b132e1..612f01d8c3f2 100644 --- a/python/sglang/cli/utils.py +++ b/python/sglang/cli/utils.py @@ -78,6 +78,9 @@ def get_is_diffusion_model(model_path: str) -> bool: if is_known_non_diffusers_diffusion_model(model_path): return True + if _is_registered_diffusion_model(model_path): + return True + try: if envs.SGLANG_USE_MODELSCOPE.get(): from modelscope import model_file_download diff --git a/python/sglang/multimodal_gen/runtime/launch_server.py b/python/sglang/multimodal_gen/runtime/launch_server.py index 318ed604215b..5b60c844ae5e 100644 --- a/python/sglang/multimodal_gen/runtime/launch_server.py +++ b/python/sglang/multimodal_gen/runtime/launch_server.py @@ -88,7 +88,7 @@ def launch_server(server_args: ServerArgs, launch_http_server: bool = True): result_pipes_from_slaves_w.append(w) # Launch all worker processes - master_port = server_args.master_port or (server_args.master_port + 100) + master_port = server_args.master_port scheduler_pipe_readers = [] scheduler_pipe_writers = [] diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index f8578a495713..0d0c82cfb303 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -188,8 +188,7 @@ class ServerArgs: ) # Master port for distributed inference - # TODO: do not hard code - master_port: int | None = None + master_port: int = 30005 # http server endpoint config host: str | None = "127.0.0.1" @@ -386,36 +385,27 @@ def _adjust_warmup(self): "Warmup enabled, the launch time is expected to be longer than usual" ) + @staticmethod + def _require_port(port: int, name: str) -> None: + """Raise if *port* is occupied (used under ``--strict-ports``).""" + if not is_port_available(port): + raise RuntimeError( + f"{name} port {port} is unavailable and --strict-ports is enabled. " + f"Either use a different port or disable --strict-ports." + ) + def _adjust_network_ports(self): if self.strict_ports: - # Strict mode: fail if port is unavailable - if not is_port_available(self.port): - raise RuntimeError( - f"Port {self.port} is unavailable and --strict-ports is enabled. " - f"Either use a different port or remove --strict-ports to allow auto-selection." - ) - if not is_port_available(self.scheduler_port): - raise RuntimeError( - f"Scheduler port {self.scheduler_port} is unavailable and --strict-ports is enabled. " - f"Either use a different port or remove --strict-ports to allow auto-selection." - ) - if self.master_port is not None and not is_port_available(self.master_port): - raise RuntimeError( - f"Master port {self.master_port} is unavailable and --strict-ports is enabled. " - f"Either use a different port or remove --strict-ports to allow auto-selection." - ) + self._require_port(self.port, "HTTP") + self._require_port(self.scheduler_port, "Scheduler") + self._require_port(self.master_port, "Master") else: self.port = self.settle_port(self.port) initial_scheduler_port = self.scheduler_port + ( random.randint(0, 100) if self.scheduler_port == 5555 else 0 ) self.scheduler_port = self.settle_port(initial_scheduler_port) - initial_master_port = ( - self.master_port - if self.master_port is not None - else (30005 + random.randint(0, 100)) - ) - self.master_port = self.settle_port(initial_master_port, 37) + self.master_port = self.settle_port(self.master_port, 37) def _adjust_parallelism(self): if self.tp_size is None: diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index 4178a779754e..700d4d6b8b18 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -28,17 +28,21 @@ "Qwen-Image", ) + +def _discover_unit_tests() -> list[str]: + """Auto-discover all test_*.py files in the unit/ directory.""" + unit_dir = Path(__file__).resolve().parent / "unit" + if not unit_dir.is_dir(): + return [] + return sorted( + f"../unit/{f.name}" for f in unit_dir.glob("test_*.py") if f.is_file() + ) + + SUITES = { # no GPU required; safe to run on any CPU-only runner - "unit": [ - "../unit/test_sampling_params.py", - "../unit/test_storage.py", - "../unit/test_lora_format_adapter.py", - "../unit/test_server_args.py", - "../unit/test_input_validation.py", - "../unit/test_resolve_prompts.py", - # add new unit tests here - ], + # Auto-discovered from test/unit/test_*.py + "unit": _discover_unit_tests(), "1-gpu": [ "test_server_a.py", "test_server_b.py", diff --git a/python/sglang/multimodal_gen/test/server/test_server_common.py b/python/sglang/multimodal_gen/test/server/test_server_common.py index dbc21b0d02cd..61bf4f5730c2 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_common.py +++ b/python/sglang/multimodal_gen/test/server/test_server_common.py @@ -102,6 +102,10 @@ def diffusion_server(case: DiffusionTestCase) -> ServerContext: if server_args.enable_warmup: extra_args += " --warmup" + # Strict ports: fail immediately if port is occupied instead of silently + # picking another one (which causes the test client to connect to the wrong server). + extra_args += " --strict-ports" + for arg in server_args.extras: extra_args += f" {arg}" diff --git a/python/sglang/multimodal_gen/test/server/test_server_utils.py b/python/sglang/multimodal_gen/test/server/test_server_utils.py index ec8340327ed7..7525a6a38acf 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_utils.py +++ b/python/sglang/multimodal_gen/test/server/test_server_utils.py @@ -375,8 +375,10 @@ def start(self) -> ServerContext: # Apply custom environment variables env.update(self.env_vars) - # TODO: unify with run_command - logger.info(f"Running command: {shlex.join(command)}") + cmd_str = shlex.join(command) + # Use print (not logger) so the command always appears in CI output + # regardless of log-level configuration. + print(f"[server-test] Running command: {cmd_str}", flush=True) process = subprocess.Popen( command, @@ -412,11 +414,10 @@ def _log_pipe(pipe: Any, file: Any) -> None: log_thread.daemon = True log_thread.start() - logger.info( - "[server-test] Starting server pid=%s, model=%s, log=%s", - process.pid, - self.model, - stdout_path, + print( + f"[server-test] Starting server pid={process.pid}, " + f"model={self.model}, log={stdout_path}", + flush=True, ) self._wait_for_ready(process, stdout_path) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index f54c882cc2ec..d94f845bea50 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -398,6 +398,16 @@ def _init_kv_manager(self) -> CommonKVManager: ) return kv_manager + def release_memory_occupation(self): + self.queue.clear() + self.retracted_queue.clear() + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): + self.kv_manager.deregister_buffer_to_engine() + + def resume_memory_occupation(self): + if hasattr(self.kv_manager, "register_buffer_to_engine"): + self.kv_manager.register_buffer_to_engine() + def add(self, req: Req, is_retracted: bool = False) -> None: """Add a request to the pending queue.""" if self._check_if_req_exceed_kv_capacity(req): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 64d97f5c6966..88df241064e1 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -567,6 +567,19 @@ def send_kvcache_staged( ) return ret + def deregister_buffer_to_engine(self): + # Batch deregister KV data buffers + if self.kv_args.kv_data_ptrs: + self.engine.batch_deregister(self.kv_args.kv_data_ptrs) + + # Batch deregister auxiliary data buffers + if self.kv_args.aux_data_ptrs: + self.engine.batch_deregister(self.kv_args.aux_data_ptrs) + + # Batch deregister state/extra pool data buffers + if self.kv_args.state_data_ptrs: + self.engine.batch_deregister(self.kv_args.state_data_ptrs) + def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 8eadf8195421..0ea94f9ae9c8 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -346,6 +346,15 @@ def pop_bootstrapped( else: return bootstrapped_reqs, failed_reqs + def release_memory_occupation(self): + self.queue.clear() + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): + self.kv_manager.deregister_buffer_to_engine() + + def resume_memory_occupation(self): + if hasattr(self.kv_manager, "register_buffer_to_engine"): + self.kv_manager.register_buffer_to_engine() + class SchedulerDisaggregationPrefillMixin: """ diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index f0cba2189c0e..51434a9ca3e4 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -30,11 +30,11 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import torch import torch.distributed @@ -2111,7 +2111,10 @@ def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" - return get_tp_group().rank_in_group + try: + return get_tp_group().rank_in_group + except Exception: + return 0 # ATTN_TP @@ -2348,3 +2351,179 @@ def monkey_patch_vllm_parallel_state(reverse: bool = False): setattr(vllm_parrlel_state, "get_pp_group", get_pp_group) setattr(vllm_parrlel_state, "get_tp_group", get_tp_group) setattr(vllm_parrlel_state, "get_world_group", get_world_group) + + +@dataclass +class RankParallelismConfig: + """ + Complete parallelism configuration for a single inference rank. + + This configuration captures all the parallelism settings needed to recreate + a model shard outside of sglang. It supports: + - TP/PP/EP for model parallelism + - MoE-TP/Attn-TP/Attn-DP for MoE and DP attention. + """ + + tp_size: int = 1 + tp_rank: int = 0 + pp_size: int = 1 + pp_rank: int = 0 + ep_size: int = 1 + ep_rank: int = 0 + moe_tp_size: int = 1 + moe_tp_rank: int = 0 + attn_tp_size: int = 1 + attn_tp_rank: int = 0 + attn_dp_size: int = 1 + attn_dp_rank: int = 0 + attn_cp_size: int = 1 + attn_cp_rank: int = 0 + moe_dp_size: int = 1 + moe_dp_rank: int = 0 + + world_size: int = 1 + global_rank: int = 0 + local_rank: int = 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RankParallelismConfig": + """Create from dictionary, filtering unknown fields.""" + import dataclasses + + valid_fields = {f.name for f in dataclasses.fields(cls)} + filtered_data = {k: v for k, v in data.items() if k in valid_fields} + return cls(**filtered_data) + + @classmethod + def from_parallel_state(cls, local_rank: int = 0) -> "RankParallelismConfig": + """Extract current parallelism settings from the global parallel state.""" + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # Import dp_attention lazily to avoid circular imports + from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, + get_attention_dp_rank, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + ) + + return cls( + tp_size=tp_size, + tp_rank=tp_rank, + pp_size=get_pipeline_model_parallel_world_size(), + pp_rank=get_pipeline_model_parallel_rank(), + ep_size=get_moe_expert_parallel_world_size(), + ep_rank=get_moe_expert_parallel_rank(), + moe_tp_size=get_moe_tensor_parallel_world_size(), + moe_tp_rank=get_moe_tensor_parallel_rank(), + attn_tp_size=get_attention_tp_size(), + attn_tp_rank=get_attention_tp_rank(), + attn_dp_size=get_attention_dp_size(), + attn_dp_rank=get_attention_dp_rank(), + attn_cp_size=get_attention_cp_size(), + attn_cp_rank=get_attention_cp_rank(), + moe_dp_size=get_moe_data_parallel_world_size(), + moe_dp_rank=get_moe_data_parallel_rank(), + world_size=( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ), + global_rank=( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ), + local_rank=local_rank, + ) + + +# Globals on parallel_state module to save/restore +_PS_GLOBALS = ("_TP", "_PP", "_MOE_EP", "_MOE_TP", "_ATTN_TP", "_ATTN_CP", "_MOE_DP") +# Globals on dp_attention module to save/restore +_DA_GLOBALS = ("_ATTN_DP_RANK", "_ATTN_DP_SIZE", "_ENABLE_DP_ATTENTION_FLAG") + + +class ParallelismContext: + """ + Context manager for creating model replicas with specific parallelism settings. + + Temporarily sets global variables to allow creating model shards outside of a + real distributed environment. + Usage: + with ParallelismContext(RankParallelismConfig.from_dict(parallelism_info)): + model = get_model(...) + """ + + def __init__(self, parallelism_config: RankParallelismConfig): + self.config = parallelism_config + self._original_globals: Dict[str, Any] = {} + + def _create_mock_group(self, world_size: int, rank_in_group: int): + """Create a mock group coordinator with all necessary properties.""" + mock_group = MagicMock() + mock_group.world_size = world_size + mock_group.rank_in_group = rank_in_group + mock_group.rank = rank_in_group + mock_group.local_rank = rank_in_group + mock_group.ranks = list(range(world_size)) + mock_group.first_rank = 0 + mock_group.last_rank = world_size - 1 + mock_group.is_first_rank = rank_in_group == 0 + mock_group.is_last_rank = rank_in_group == world_size - 1 + mock_group.next_rank = mock_group.ranks[(rank_in_group + 1) % world_size] + mock_group.prev_rank = mock_group.ranks[(rank_in_group - 1) % world_size] + return mock_group + + def __enter__(self): + conf = self.config + + from sglang.srt.distributed import parallel_state + from sglang.srt.layers import dp_attention + + # Save original globals + for name in _PS_GLOBALS: + self._original_globals[name] = getattr(parallel_state, name, None) + for name in _DA_GLOBALS: + self._original_globals[name] = getattr(dp_attention, name, None) + + # Build and set mock group objects on parallel_state + _ps_new_values = { + "_TP": self._create_mock_group(conf.tp_size, conf.tp_rank), + "_PP": self._create_mock_group(conf.pp_size, conf.pp_rank), + "_MOE_EP": self._create_mock_group(conf.ep_size, conf.ep_rank), + "_MOE_TP": self._create_mock_group(conf.moe_tp_size, conf.moe_tp_rank), + "_ATTN_TP": self._create_mock_group(conf.attn_tp_size, conf.attn_tp_rank), + "_ATTN_CP": self._create_mock_group(conf.attn_cp_size, conf.attn_cp_rank), + "_MOE_DP": self._create_mock_group(conf.moe_dp_size, conf.moe_dp_rank), + } + for name, value in _ps_new_values.items(): + setattr(parallel_state, name, value) + + # Set dp_attention scalar globals + dp_attention._ATTN_DP_RANK = conf.attn_dp_rank + dp_attention._ATTN_DP_SIZE = conf.attn_dp_size + dp_attention._ENABLE_DP_ATTENTION_FLAG = conf.attn_dp_size > 1 + + logger.info(f"[ParallelismContext] Activated: {conf}") + return self + + def __exit__(self, *args): + from sglang.srt.distributed import parallel_state + from sglang.srt.layers import dp_attention + + # Restore original globals + for name in _PS_GLOBALS: + setattr(parallel_state, name, self._original_globals.get(name)) + for name in _DA_GLOBALS: + setattr(dp_attention, name, self._original_globals.get(name)) + + logger.info("[ParallelismContext] Deactivated") + return False diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index aca1adc26f93..971762f4369c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -69,6 +69,7 @@ LoadLoRAAdapterReqInput, MultimodalDataInputFormat, OpenSessionReqInput, + PostProcessWeightsReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, @@ -957,6 +958,33 @@ def update_weights_from_ipc( self.tokenizer_manager.update_weights_from_ipc(obj, None) ) + def post_process_weights( + self, + restore_weights_before_load: bool = False, + post_process_quantization: bool = False, + post_load_weights: bool = False, + ): + """ + Optional post-processing for updated weights (e.g., Marlin conversion). + Should be called after weight update is finished. + + Args: + restore_weights_before_load: Restore weights to pre-quantization state. + post_process_quantization: Re-apply quantization post-processing. + post_load_weights: Call model.post_load_weights() for models that + need post-load decomposition (e.g., DeepSeek MLA kv_b_proj + decomposition into w_kc/w_vc tensors after RDMA weight transfer). + """ + obj = PostProcessWeightsReqInput( + restore_weights_before_load=restore_weights_before_load, + post_process_quantization=post_process_quantization, + post_load_weights=post_load_weights, + ) + + return self.loop.run_until_complete( + self.tokenizer_manager.post_process_weights(obj, None) + ) + def get_weights_by_name(self, name: str, truncate_size: int = 100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py index 77de7fc7d030..88e48075a644 100644 --- a/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py +++ b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py @@ -31,7 +31,8 @@ class EngineInfoBootstrapServer: accesses the collected info directly in-process; external consumers can query via HTTP GET. - Currently supports transfer engine memory registration info. + Currently supports transfer engine memory registration info and + per-rank parallelism configuration. """ def __init__(self, host: str, port: int): @@ -40,6 +41,8 @@ def __init__(self, host: str, port: int): # Storage: {tp_rank: (session_id, weights_info_dict)} self.transfer_engine_info: Dict[int, Tuple] = {} + # Storage: {tp_rank: parallelism_config_dict} + self.parallelism_config: Dict[int, dict] = {} self.lock = threading.Lock() app = FastAPI() @@ -89,6 +92,38 @@ def get_transfer_engine_info(rank: int): config = uvicorn.Config(app, host=host, port=port, log_level="warning") self._server = uvicorn.Server(config) + + @app.put("/register_parallelism_config") + def register_parallelism_config(data: dict): + try: + tp_rank = data["tp_rank"] + config = data["parallelism_config"] + + with self.lock: + self.parallelism_config[tp_rank] = config + + logger.info(f"Registered parallelism config for tp_rank={tp_rank}") + return PlainTextResponse("OK") + except Exception as e: + logger.error(f"Failed to register parallelism config: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + @app.get("/get_parallelism_config") + def get_parallelism_config(rank: int): + if rank < 0: + raise HTTPException(status_code=400, detail="Invalid rank parameter") + + with self.lock: + config = self.parallelism_config.get(rank) + + if config is None: + raise HTTPException( + status_code=404, + detail=f"No parallelism config for rank {rank}", + ) + + return config + self._thread = threading.Thread( target=self._server.run, daemon=True, @@ -103,3 +138,7 @@ def close(self): def get_transfer_engine_info(self, rank: int) -> Optional[Tuple]: """Direct in-process access for co-located HTTP server (no HTTP round-trip).""" return self.transfer_engine_info.get(rank) + + def get_parallelism_config_info(self, rank: int) -> Optional[dict]: + """Direct in-process access for parallelism config (no HTTP round-trip).""" + return self.parallelism_config.get(rank) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 6978e0c062e8..288f2a24c305 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -127,6 +127,7 @@ OpenSessionReqInput, ParseFunctionCallReq, PauseGenerationReqInput, + PostProcessWeightsReqInput, ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -1032,6 +1033,32 @@ async def remote_instance_transfer_engine_info(rank: int = None): ) +@app.get("/parallelism_config") +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def parallelism_config(rank: int = None): + """Get per-rank parallelism config from the bootstrap server.""" + if rank is None or rank < 0: + return ORJSONResponse( + {"error": {"message": "Missing or invalid rank parameter"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) + + server_args = _global_state.tokenizer_manager.server_args + try: + + resp = requests.get( + f"{server_args.engine_info_bootstrap_url}/get_parallelism_config", + params={"rank": rank}, + timeout=5, + ) + if resp.status_code == 200: + return resp.json() + except Exception: + pass + + return Response(status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/init_weights_update_group") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def init_weights_update_group( @@ -1121,6 +1148,22 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/post_process_weights") +async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): + """ + Optional post-processing for updated weights (e.g., Marlin conversion). + This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. + """ + success, message = await _global_state.tokenizer_manager.post_process_weights( + req, request + ) + + content = {"success": success, "message": message} + return ORJSONResponse( + content, status_code=200 if success else HTTPStatus.BAD_REQUEST + ) + + @app.post("/update_weight_version") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index c40bd37d9f58..3fd5df68353a 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -589,6 +589,9 @@ class ChatCompletionRequest(BaseModel): return_hidden_states: bool = False return_routed_experts: bool = False return_cached_tokens_details: bool = False + return_prompt_token_ids: bool = False + return_completion_token_ids: bool = False + return_meta_info: bool = False reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = Field( default=None, description="Constrains effort on reasoning for reasoning models. " @@ -625,6 +628,11 @@ class ChatCompletionRequest(BaseModel): custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None custom_params: Optional[Dict] = None + # Pre-computed prompt token IDs: when provided, bypasses chat template + # tokenization entirely. Messages are still used to derive stop tokens + # and tool_call_constraint. + input_ids: Optional[List[int]] = None + # For request id rid: Optional[Union[List[str], str]] = None # Extra key for classifying the request (e.g. cache_salt) @@ -845,12 +853,21 @@ class ChatCompletionResponseChoice(BaseModel): ] = None matched_stop: Union[None, int, str] = None hidden_states: Optional[object] = None + prompt_token_ids: Optional[List[int]] = None + completion_token_ids: Optional[List[int]] = None + meta_info: Optional[Dict[str, Any]] = None @model_serializer(mode="wrap") def _serialize(self, handler): data = handler(self) if self.hidden_states is None: data.pop("hidden_states", None) + if self.prompt_token_ids is None: + data.pop("prompt_token_ids", None) + if self.completion_token_ids is None: + data.pop("completion_token_ids", None) + if self.meta_info is None: + data.pop("meta_info", None) return data diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d0d05e3a483e..df682c9dfcd5 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -258,11 +258,22 @@ def _convert_to_internal_request( request.reasoning_effort = reasoning_effort """Convert OpenAI chat completion request to internal format""" + if request.return_prompt_token_ids and request.stream: + raise ValueError( + "return_prompt_token_ids is not supported with streaming. " + "Please set stream=false when using return_prompt_token_ids=true." + ) + + if request.return_completion_token_ids and request.stream: + raise ValueError( + "return_completion_token_ids is not supported with streaming. " + "Please set stream=false when using return_completion_token_ids=true." + ) + is_multimodal = self.tokenizer_manager.model_config.is_multimodal # Process messages and apply chat template processed_messages = self._process_messages(request, is_multimodal) - # Build sampling parameters sampling_params = request.to_sampling_params( stop=processed_messages.stop, @@ -322,6 +333,7 @@ def _convert_to_internal_request( image_max_dynamic_patch=img_max_dynamic_patch, video_max_dynamic_patch=vid_max_dynamic_patch, max_dynamic_patch=getattr(request, "max_dynamic_patch", None), + return_prompt_token_ids=request.return_prompt_token_ids, ) return adapted_request, request @@ -367,8 +379,19 @@ def _process_messages( ) tool_call_constraint = ("json_schema", json_schema) - # Use chat template - if self.template_manager.chat_template_name is None: + # When input_ids are provided, skip template tokenization entirely; + # only stop tokens and tool_call_constraint are needed. + if request.input_ids is not None: + result = MessageProcessingResult( + prompt=self.tokenizer_manager.tokenizer.decode(request.input_ids), + prompt_ids=request.input_ids, + image_data=None, + audio_data=None, + video_data=None, + modalities=[], + stop=request.stop or [], + ) + elif self.template_manager.chat_template_name is None: result = self._apply_jinja_template(request, tools, is_multimodal) else: result = self._apply_conversation_template(request, is_multimodal) @@ -1013,11 +1036,29 @@ def _build_chat_response( history_tool_calls_cnt, ) + # Extract prompt_token_ids if requested + choice_prompt_token_ids = ( + ret_item.get("prompt_token_ids") + if request.return_prompt_token_ids + else None + ) + + # Extract completion_token_ids if requested + choice_completion_token_ids = ( + ret_item.get("output_ids") + if request.return_completion_token_ids + else None + ) + + choice_meta_info = ( + ret_item["meta_info"] if request.return_meta_info else None + ) + # NOTE: content should not be None but empty string to make sure retokenize consistency. choice_data = ChatCompletionResponseChoice( index=idx, message=ChatMessage( role="assistant", - content=text if text else None, + content=text if text else "", tool_calls=tool_calls, reasoning_content=reasoning_text if reasoning_text else None, ), @@ -1029,6 +1070,9 @@ def _build_chat_response( else None ), hidden_states=hidden_states, + prompt_token_ids=choice_prompt_token_ids, + completion_token_ids=choice_completion_token_ids, + meta_info=choice_meta_info, ) choices.append(choice_data) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 205b5fadf8ce..722c598b6926 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -871,7 +871,7 @@ def forward_extend( max_kv_len=self.max_context_len, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, - batch_size=forward_batch.batch_size, + batch_size=self.forward_metadata.cu_seqlens_q.shape[0] - 1, cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, window_left=layer.sliding_window_size, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e4960bdb42d6..d59ba6ebaf26 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -117,7 +117,7 @@ def _forward_with_allreduce_fusion( if world_size > 1: if post_residual_addition is not None: - residual = residual + post_residual_addition + x = x + post_residual_addition # Prefer AITER fused AR+RMSNorm when enabled on AMD. if _use_aiter: @@ -152,20 +152,17 @@ def __init__( eps: float = 1e-6, var_hidden_size: Optional[int] = None, cast_x_before_out_mul: bool = False, - fp32_residual: bool = False, + fp32_residual: bool = True, has_weight: bool = True, - weight_dtype: Optional = None, - override_orig_dtype: Optional = None, ) -> None: super().__init__() self.has_weight = has_weight self.cast_x_before_out_mul = cast_x_before_out_mul self.fp32_residual = fp32_residual - self.override_orig_dtype = override_orig_dtype if self.has_weight: - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) + self.weight = nn.Parameter(torch.ones(hidden_size)) else: - self.weight = torch.ones(hidden_size, dtype=weight_dtype) + self.weight = torch.ones(hidden_size) self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( @@ -229,10 +226,10 @@ def forward_aiter( post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: - residual_out = torch.empty_like(x) - output = torch.empty_like(x) if post_residual_addition is not None: residual = residual + post_residual_addition + residual_out = torch.empty_like(x) + output = torch.empty_like(x) fused_add_rms_norm( output, x, @@ -258,10 +255,10 @@ def forward_hip( # NOTE: Remove this if aiter kernel supports discontinuous input x = x.contiguous() if residual is not None: - out = torch.empty_like(x) - residual_out = torch.empty_like(x) if post_residual_addition is not None: residual = residual + post_residual_addition + out = torch.empty_like(x) + residual_out = torch.empty_like(x) fused_add_rms_norm( out, x, residual_out, residual, self.weight.data, self.variance_epsilon ) @@ -278,16 +275,19 @@ def forward_native( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() - orig_dtype = self.override_orig_dtype or x.dtype + orig_dtype = x.dtype + + if residual is not None and not self.fp32_residual: + x = x + residual + if post_residual_addition is not None: + x = x + post_residual_addition + residual = x.clone() x = x.to(torch.float32) - if residual is not None: + if residual is not None and self.fp32_residual: x = x + residual.to(torch.float32) if post_residual_addition is not None: x = x + post_residual_addition.to(torch.float32) - if self.fp32_residual: - residual = x.clone() - else: - residual = x.to(orig_dtype) + residual = x.to(orig_dtype) hidden_size = x.shape[-1] if hidden_size != self.hidden_size: @@ -477,7 +477,7 @@ def forward_native( orig_dtype = x.dtype if residual is not None: if post_residual_addition is not None: - residual = residual + post_residual_addition + x = x + post_residual_addition x = x + residual residual = x @@ -566,7 +566,7 @@ def forward_npu( return self.forward_native(x, residual) if residual is not None: if post_residual_addition is not None: - residual = residual + post_residual_addition + x = x + post_residual_addition x = x + residual residual = x diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 662cc25191cb..a323cf7c81c7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -885,11 +885,6 @@ def _compute_lm_head( None, # bias True, # is_vnni ) - elif get_global_server_args().rl_on_policy_target is not None: - # Due to tie-weight, we may not be able to change lm_head's weight dtype - logits = torch.matmul( - hidden_states.bfloat16(), lm_head.weight.T.bfloat16() - ) else: logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index f2621473cb49..a8494d06961b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -14,6 +14,7 @@ import triton.language as tl from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -644,7 +645,10 @@ def fused_experts_impl( ).squeeze(dim=1) else: # According to micro benchmark results, torch.compile can get better performance for small token. - if tokens_in_chunk <= 32: + if ( + not get_global_server_args().enable_deterministic_inference + and tokens_in_chunk <= 32 + ): moe_sum_reduce_torch_compile( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 72483f4ea6f5..a9e26d9a6e7c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -702,6 +702,7 @@ def _weight_loader_impl( "CompressedTensorsWNA16TritonMoE", ] ) + and "zero" not in weight_name else loaded_weight ) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 4494f195f4dd..e568579bb7e8 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -518,8 +518,15 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # Move kernel call outside context manager to avoid graph breaks # during torch.compile for piecewise cuda graph. # Use custom op wrapper for torch.compile compatibility. + + # The DeepSeekV3 routing method requires float32 router logits. + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + else: + router_logits = router_logits.to(torch.bfloat16) + output = trtllm_fp8_per_tensor_scale_moe_wrapper( - routing_logits=router_logits.to(torch.bfloat16), + routing_logits=router_logits, routing_bias=routing_bias_cast, hidden_states=a_q, gemm1_weights=quant_info.w13_weight, diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755587..12d5577af2ba 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,15 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, get_attention_dp_rank, + get_attention_tp_size, get_dp_local_info, is_dp_attention_enabled, ) +from sglang.srt.layers.moe import ( + get_moe_a2a_backend, +) from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args @@ -181,13 +186,26 @@ def __init__( device=device, ) + if get_moe_a2a_backend().is_deepep(): + attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 + self.gather_buffer = torch.empty( + ( + self.device_cache.buffer.shape[0] * attn_tp_size, + self.device_cache.buffer.shape[2], + ), + dtype=torch.int32, + device=device, + ) + def _sync_fwd_experts_buffer_DtoH( self, forward_batch: ForwardBatch, can_run_graph: bool, cuda_graph_batch: int, ): - if is_dp_attention_enabled(): + # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer + # contains data from all DP ranks. We should not slice by DP rank in this case. + if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) # handle with cuda graph padding if can_run_graph: @@ -206,6 +224,12 @@ def _sync_fwd_experts_buffer_DtoH( ].cpu() def capture(self, layer_id: int, topk_ids: torch.Tensor): + if get_moe_a2a_backend().is_deepep(): + local_topk_ids = topk_ids + topk_ids = self.gather_buffer[ + : local_topk_ids.size(0) * get_attention_tp_size() + ] + attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) def get_routed_experts( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index a13c53af4d2d..c68623f7bcf3 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -500,7 +500,7 @@ def _is_wNa16_group_channel( ) is_static = not weight_quant.dynamic - return is_channel_group and input_quant_none and is_symmetric and is_static + return is_channel_group and input_quant_none and is_static def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None @@ -969,6 +969,10 @@ def __init__(self, quantization_config: CompressedTensorsConfig): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) + def restore_weights_before_loading(self, layer: torch.nn.Module) -> None: + if hasattr(layer.scheme, "restore_weights_before_loading"): + layer.scheme.restore_weights_before_loading(layer) + def create_weights( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py index 7a8fb6542189..8e52f3d832b5 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -17,7 +17,10 @@ CompressedTensorsMoEScheme, ) from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack -from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_moe_permute_scales, + moe_awq_to_marlin_zero_points, +) from sglang.srt.layers.quantization.utils import replace_parameter from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs @@ -64,7 +67,7 @@ def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder - assert config.symmetric, "Only symmetric quantization is supported for MoE" + self.sym = config.symmetric if not ( self.quant_config.quant_format == CompressionFormat.pack_quantized.value @@ -124,7 +127,7 @@ def create_weights( # In the case where we have actorder/g_idx, # we do not partition the w2 scales - load_full_w2 = self.actorder and self.group_size != -1 + load_full_w2 = (self.actorder != "static") and self.group_size != -1 if load_full_w2: w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size @@ -172,6 +175,32 @@ def create_weights( layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) + # add zero param + if not self.sym: + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_zero_point", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_zero_point", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -225,14 +254,16 @@ def create_weights( # Force record: these are the target GPTQ shapes for rollback. layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) - layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) + if not self.sym: + layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape - # Also record the shapes of the scales. + layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) - layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) + if not self.sym: + layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Skip if the layer is already converted to Marlin format to prevent double-packing. if getattr(layer, "is_marlin_converted", False): return @@ -334,11 +365,28 @@ def replace_tensor(name, new_t): ) replace_tensor("w2_weight_scale", marlin_w2_scales) + # Repack zero + if not self.sym: + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_weight_zero_point, + size_k=layer.w13_weight_zero_point.shape[1], + size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor, + num_bits=self.num_bits, + ) + replace_tensor("w13_weight_zero_point", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_weight_zero_point, + size_k=layer.w2_weight_zero_point.shape[1], + size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor, + num_bits=self.num_bits, + ) + replace_tensor("w2_weight_zero_point", marlin_w2_zp) + layer.is_marlin_converted = True def restore_weights_before_loading(self, layer: torch.nn.Module): """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" - if not hasattr(layer, "_original_shapes"): return @@ -399,6 +447,8 @@ def apply_weights( g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, + w1_zeros=layer.w13_weight_zero_point if not self.sym else None, + w2_zeros=layer.w2_weight_zero_point if not self.sym else None, num_bits=self.num_bits, is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 99a3f11ca05f..9263249a8531 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -101,9 +101,6 @@ def __init__( if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native - self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( - apply_rotary_emb - ) self.position_cos, self.position_sin = None, None def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e947a48cbde8..d1bfaee6d027 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -120,17 +120,13 @@ def forward( if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: original_logprobs = torch.log_softmax(logits, dim=-1) + # Post process logits + logits.div_(sampling_info.temperatures) + # In RL on-policy mode, we use log_softmax to compute logprobs to match the trainer. logprobs_via_logsoftmax_kernel = None if self.rl_on_policy_target is not None: - # TODO: use more inplace ops to save memory - logits_div_temperature = ( - logits.bfloat16().div(sampling_info.temperatures).bfloat16() - ) - logprobs_via_logsoftmax_kernel = torch.log_softmax( - logits_div_temperature, dim=-1 - ) - del logits_div_temperature + logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) if self.use_ascend_backend: # Ascend backend: sample from logits directly. @@ -152,8 +148,6 @@ def forward( logprobs = logprobs_via_logsoftmax_kernel else: # Standard path: do softmax and sample from probs. - logits.div_(sampling_info.temperatures) - # In-place op to save memory logits[:] = torch.softmax(logits, dim=-1) probs = logits diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index bd979653458a..7958556d659d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -231,6 +231,9 @@ class GenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False + # Whether to return prompt token IDs without computing logprobs + return_prompt_token_ids: bool = False + # Propagates trace context via Engine.generate/async_generate external_trace_header: Optional[Dict] = None received_time: Optional[float] = None @@ -1344,6 +1347,23 @@ class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): message: str +@dataclass +class PostProcessWeightsReqInput(BaseReq): + # Whether to restore weights before loading new weights + restore_weights_before_load: bool = False + # Whether to enable quantization post-processing + post_process_quantization: bool = False + # Whether to call model.post_load_weights() after weight update + # (e.g., DeepSeek MLA kv_b_proj decomposition into w_kc/w_vc tensors) + post_load_weights: bool = False + + +@dataclass +class PostProcessWeightsReqOutput(BaseReq): + success: bool + message: str + + @dataclass class SendWeightsToRemoteInstanceReqInput(BaseReq): # The master address diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0b26be6c6d18..09254f314b0c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1972,8 +1972,10 @@ def retract_decode( while first_iter or ( not self.check_decode_mem(selected_indices=sorted_indices) ): - if len(sorted_indices) == 1: - # Always keep at least one request + # We should allow all requests to be retracted in decode disaggregation mode + # because there can be prealloc prefill requests. + num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1 + if len(sorted_indices) == num_minimum_reqs: break first_iter = False diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 67af2d0de943..122ddb387486 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -120,6 +120,7 @@ LoadLoRAAdapterReqOutput, OpenSessionReqInput, PauseGenerationReqInput, + PostProcessWeightsReqInput, ProfileReq, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -1232,6 +1233,7 @@ def init_request_dispatcher(self): ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), + (PostProcessWeightsReqInput, self.post_process_weights), (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 496cd96656e5..64b57ccb2d0c 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1155,6 +1155,8 @@ def stream_output_generation( # Send to detokenizer if reqs or is_idle_batch: + if getattr(self.model_config, "is_multimodal_gen", False): + return self.send_to_detokenizer.send_output( BatchTokenIDOutput( rids=rids, diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index abcda6794674..0976b7ddab8d 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,6 +12,7 @@ GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, ) +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.managers.io_struct import ( CheckWeightsReqInput, CheckWeightsReqOutput, @@ -21,6 +22,8 @@ GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, + PostProcessWeightsReqInput, + PostProcessWeightsReqOutput, ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, @@ -117,6 +120,11 @@ def update_weights_from_ipc( torch.distributed.barrier(group=self.tp_cpu_group) return UpdateWeightsFromIPCReqOutput(success, message) + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + """Optional post-processing for updated weights (e.g., Marlin conversion).""" + success, message = self.tp_worker.post_process_weights(recv_req) + return PostProcessWeightsReqOutput(success, message) + def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) @@ -140,6 +148,13 @@ def release_memory_occupation( self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) self.flush_cache() + if self.disaggregation_mode == DisaggregationMode.DECODE: + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.release_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + if hasattr(self, "disagg_prefill_bootstrap_queue"): + self.disagg_prefill_bootstrap_queue.release_memory_occupation() + if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( self.tp_worker.model_runner.model @@ -180,6 +195,13 @@ def resume_memory_occupation( if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + if self.disaggregation_mode == DisaggregationMode.DECODE: + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.resume_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + if hasattr(self, "disagg_prefill_bootstrap_queue"): + self.disagg_prefill_bootstrap_queue.resume_memory_occupation() + return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 544c6094014c..50a1c4045adb 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,6 +59,8 @@ LoadLoRAAdapterReqOutput, LoRAUpdateOutput, OpenSessionReqInput, + PostProcessWeightsReqInput, + PostProcessWeightsReqOutput, ProfileReq, ProfileReqOutput, ProfileReqType, @@ -187,6 +189,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.update_weights_from_ipc_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.post_process_weights_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -272,6 +277,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): UpdateWeightsFromIPCReqOutput, self.update_weights_from_ipc_communicator.handle_recv, ), + ( + PostProcessWeightsReqOutput, + self.post_process_weights_communicator.handle_recv, + ), ( GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, @@ -616,6 +625,17 @@ async def update_weights_from_ipc( return success, message + async def post_process_weights( + self: TokenizerManager, + obj: PostProcessWeightsReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" + self.auto_create_handle_loop() + async with self.model_update_lock.writer_lock: + results = await self.post_process_weights_communicator(obj) + return _Communicator.merge_results(results) + async def _unload_lora_adapter_locked( self: TokenizerManager, obj: UnloadLoRAAdapterReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 81424329a0c1..9a3499b556ac 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -166,6 +166,9 @@ class ReqState: input_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) output_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + # For return_prompt_token_ids: stores prompt token IDs captured after tokenization + prompt_token_ids: Optional[List[int]] = None + class InputFormat(Enum): """Input format types for tokenization handling.""" @@ -518,6 +521,8 @@ async def generate_request( tokenized_obj = await self._tokenize_one_request(obj) state = self.rid_to_state[obj.rid] self._send_one_request(tokenized_obj) + if getattr(obj, "return_prompt_token_ids", False): + state.prompt_token_ids = list(tokenized_obj.input_ids) async for response in self._wait_one_response(obj, state, request): yield response else: @@ -1280,6 +1285,8 @@ async def _handle_batch_request( tmp_obj = obj[i] state = self.rid_to_state[tmp_obj.rid] state.obj = tmp_obj + if getattr(tmp_obj, "return_prompt_token_ids", False): + state.prompt_token_ids = list(tokenized_objs[i].input_ids) generators.append(self._wait_one_response(tmp_obj, state, request)) rids.append(tmp_obj.rid) else: @@ -1295,6 +1302,8 @@ async def _handle_batch_request( state = self.rid_to_state[tmp_obj.rid] state.obj = tmp_obj self._send_one_request(tokenized_obj) + if getattr(tmp_obj, "return_prompt_token_ids", False): + state.prompt_token_ids = list(tokenized_obj.input_ids) generators.append( self._wait_one_response(tmp_obj, state, request) ) @@ -1338,6 +1347,8 @@ async def _handle_batch_request( state = self.rid_to_state[tmp_obj.rid] tokenized_obj.time_stats = state.time_stats self._send_one_request(tokenized_obj) + if getattr(tmp_obj, "return_prompt_token_ids", False): + state.prompt_token_ids = list(tokenized_objs[i].input_ids) generators.append(self._wait_one_response(tmp_obj, state, request)) rids.append(tmp_obj.rid) @@ -1622,6 +1633,8 @@ def _handle_batch_output( "output_ids": output_token_ids, "meta_info": meta_info, } + if state.prompt_token_ids is not None: + out_dict["prompt_token_ids"] = state.prompt_token_ids elif isinstance(recv_obj, BatchTokenIDOutput): is_stream = getattr(state.obj, "stream", False) @@ -1637,6 +1650,8 @@ def _handle_batch_output( "output_ids": output_token_ids, "meta_info": meta_info, } + if state.prompt_token_ids is not None: + out_dict["prompt_token_ids"] = state.prompt_token_ids else: assert isinstance(recv_obj, BatchEmbeddingOutput) out_dict = { diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7f63610da8ee..fb56de158374 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ InitWeightsUpdateGroupReqInput, LoadLoRAAdapterFromTensorsReqInput, LoadLoRAAdapterReqInput, + PostProcessWeightsReqInput, SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, @@ -170,6 +171,11 @@ def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): success, message = self.model_runner.update_weights_from_ipc(recv_req) return success, message + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" + success, message = self.model_runner.post_process_weights(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a59742b94354..22f0f3e8a91c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -72,7 +72,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) -from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.distributed.parallel_state import ( + RankParallelismConfig, + monkey_patch_vllm_parallel_state, +) from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.elastic_ep.expert_backup_client import ExpertBackupClient from sglang.srt.environ import envs @@ -351,6 +354,7 @@ def __init__( self.remote_instance_transfer_engine = None self.remote_instance_transfer_engine_session_id = "" self.remote_instance_transfer_engine_weight_info = None + self.parallelism_config = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: @@ -467,6 +471,9 @@ def initialize(self, pre_model_load_memory: float): if self.server_args.remote_instance_weight_loader_use_transfer_engine(): self.remote_instance_init_transfer_engine() + self.parallelism_config = RankParallelismConfig.from_parallel_state( + self.tp_rank + ) if not self.is_draft_worker: set_global_expert_location_metadata( @@ -527,6 +534,13 @@ def initialize(self, pre_model_load_memory: float): ) self._register_to_engine_info_bootstrap() + # Register parallelism config with the bootstrap server + if ( + self.server_args.remote_instance_weight_loader_use_transfer_engine() + and self.parallelism_config is not None + ): + self._register_parallelism_config_to_bootstrap() + # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # determine the number of layers. @@ -646,7 +660,8 @@ def initialize(self, pre_model_load_memory: float): ) # Init routed experts capturer - self.init_routed_experts_capturer() + if not self.is_draft_worker: + self.init_routed_experts_capturer() if self.device == "cuda" or self.device == "musa": self.init_cublas() @@ -701,6 +716,9 @@ def remote_instance_init_transfer_engine(self): "Please install mooncake for using remote instance transfer engine: pip install mooncake" ) return + + from sglang.srt.utils import get_local_ip_auto + self.remote_instance_transfer_engine = TransferEngine() local_ip = get_local_ip_auto() self.remote_instance_transfer_engine.initialize( @@ -824,6 +842,92 @@ def _publish_modelexpress_metadata(self): finally: mx_client.close() + def _register_to_engine_info_bootstrap(self): + """Register transfer engine info with the EngineInfoBootstrapServer via HTTP PUT. + + The bootstrap server runs on node_rank==0. For multi-node setups, the + host is derived from dist_init_addr. For single-node, use 127.0.0.1. + """ + import requests as http_requests + + bootstrap_url = self._get_bootstrap_url() + url = f"{bootstrap_url}/register_transfer_engine_info" + + payload = { + "tp_rank": self.tp_rank, + "transfer_engine_info": { + "session_id": self.remote_instance_transfer_engine_session_id, + "weights_info_dict": self.remote_instance_transfer_engine_weight_info, + }, + } + + try: + resp = http_requests.put(url, json=payload, timeout=5) + if resp.status_code == 200: + logger.info( + f"Registered transfer engine info for tp_rank={self.tp_rank} " + f"with bootstrap server at {bootstrap_url}" + ) + else: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: " + f"{resp.status_code}, {resp.text}" + ) + except Exception as e: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: {e}" + ) + + def _register_parallelism_config_to_bootstrap(self): + """Register parallelism config with the EngineInfoBootstrapServer via HTTP PUT.""" + import requests as http_requests + + bootstrap_url = self._get_bootstrap_url() + url = f"{bootstrap_url}/register_parallelism_config" + + payload = { + "tp_rank": self.tp_rank, + "parallelism_config": self.parallelism_config.to_dict(), + } + + try: + resp = http_requests.put(url, json=payload, timeout=5) + if resp.status_code == 200: + logger.info( + f"Registered parallelism config for tp_rank={self.tp_rank} " + f"with bootstrap server at {bootstrap_url}" + ) + else: + logger.error( + f"Failed to register parallelism config for tp_rank={self.tp_rank}: " + f"{resp.status_code}, {resp.text}" + ) + except Exception as e: + logger.error( + f"Failed to register parallelism config for tp_rank={self.tp_rank}: {e}" + ) + + def _get_bootstrap_url(self): + """Get the base URL for the EngineInfoBootstrapServer.""" + if self.server_args.dist_init_addr: + # Multi-node: bootstrap server is on the head node (node_rank==0). + # Derive host from dist_init_addr (shared across all nodes). + import socket + + host_part = self.server_args.dist_init_addr.rsplit(":", 1)[0] + try: + # Resolve hostname to IP if needed + bootstrap_host = socket.getaddrinfo( + host_part, None, socket.AF_UNSPEC, 0, 0, socket.AI_ADDRCONFIG + )[0][4][0] + except socket.gaierror: + bootstrap_host = host_part + else: + bootstrap_host = "127.0.0.1" + + bootstrap_port = self.server_args.engine_info_bootstrap_port + return f"http://{bootstrap_host}:{bootstrap_port}" + def model_specific_adjustment(self): server_args = self.server_args @@ -2767,11 +2871,19 @@ def forward( output.expert_distribution_metrics = recorder_outputs.get("metrics") # Copy cached routing experts' buffers back to CPU cache - get_global_experts_capturer().on_forward_end( - forward_batch=forward_batch, - can_run_graph=output.can_run_graph, - cuda_graph_batch=getattr(self.graph_runner, "bs", None), - ) + if not self.is_draft_worker: + # In speculative decoding, num_tokens_per_bs > 1, so we need to pass + # the actual number of tokens per dp rank in cuda graph, not batch size. + cuda_graph_num_tokens = None + if getattr(self.graph_runner, "bs", None): + cuda_graph_num_tokens = ( + self.graph_runner.bs * self.graph_runner.num_tokens_per_bs + ) + get_global_experts_capturer().on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=cuda_graph_num_tokens, + ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -3021,6 +3133,50 @@ def prealloc_symmetric_memory_pool(self): device=self.device, ) + def post_process_weights(self, recv_req): + """ + Execute post-processing logic for model weights, such as Marlin quantization format conversion + and model-specific post_load_weights hooks (e.g., DeepSeek MLA kv_b_proj decomposition). + """ + from sglang.srt.model_loader.loader import device_loading_context + + target_device = torch.device("cuda", torch.cuda.current_device()) + + if recv_req.post_load_weights: + # Call model.post_load_weights() if available (e.g., for DeepSeek MLA + # models that need to decompose kv_b_proj.weight into w_kc/w_vc tensors + # after RDMA weight transfer) + if hasattr(self.model, "post_load_weights"): + self.model.post_load_weights() + + if recv_req.restore_weights_before_load: + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + + # Check if the module supports restoring weights + if quant_method is not None and hasattr( + quant_method, "restore_weights_before_loading" + ): + + with device_loading_context(module, target_device): + quant_method.restore_weights_before_loading(module) + + if recv_req.post_process_quantization: + # Iterate through all modules to apply specific post-loading processing + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + + # Check if the module supports quantization post-processing + if quant_method is not None and hasattr( + quant_method, "process_weights_after_loading" + ): + + # Apply the post-processing (e.g., repacking weights for Marlin kernel) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + return True, "Success" + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/model_loader/parameter_mapper.py b/python/sglang/srt/model_loader/parameter_mapper.py new file mode 100644 index 000000000000..56056d570843 --- /dev/null +++ b/python/sglang/srt/model_loader/parameter_mapper.py @@ -0,0 +1,261 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Parameter mapping from HuggingFace checkpoint names to SGLang model parameters. + +This module provides utilities for translating weight names between HuggingFace +checkpoint format and SGLang's internal parameter naming, handling: + +1. Stacked Parameter Fusion + - gate_proj + up_proj → gate_up_proj (num_shards=2) + - q_proj + k_proj + v_proj → qkv_proj (num_shards=3) + - q_a_proj + kv_a_proj_with_mqa → fused_qkv_a_proj_with_mqa (DeepSeek MLA) + +2. Expert Parameter Sharding (MoE models) + - experts.{id}.gate_proj + experts.{id}.up_proj → experts.w13_weight (num_shards=2) + - experts.{id}.down_proj → experts.w2_weight (num_shards=1) + - Handles expert parallelism: num_local_experts = n_routed // ep_size + shared + +3. Scale Remapping (Quantized models) + - k_proj.k_scale → attn.k_scale + - v_proj.v_scale → attn.v_scale + - Quark-specific: output_scale → per-component scales + +Supported Models: + Dense: Llama, Qwen2, Qwen3, GLM4 + MoE: DeepSeekV2/V3/R1, Qwen3-MoE, GLM4-MoE, GLM4-MoE-Lite (GLM-4.7) + +Example: + >>> mapper = ParameterMapper.from_model(model) + >>> result = mapper.map("model.layers.0.mlp.gate_proj.weight") + >>> result.sglang_name # "model.layers.0.mlp.gate_up_proj.weight" + >>> result.shard_id # 0 + >>> result.num_shards # 2 +""" + +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +StackedParamsEntry = Tuple[str, str, Union[int, str]] +ExpertParamsEntry = Tuple[str, str, int, Union[int, str]] + + +@dataclass +class MappingResult: + """Result of mapping a HuggingFace checkpoint weight name to SGLang parameter.""" + + sglang_name: str + shard_id: Optional[Union[int, str]] + num_shards: int + expert_id: Optional[int] + num_local_experts: Optional[int] + + +# Standard FP8 scale remapping patterns +_SCALE_REMAP_PATTERNS: List[Tuple[str, str, str]] = [ + (".k_scale", ".self_attn.k_proj.k_scale", ".self_attn.attn.k_scale"), + (".v_scale", ".self_attn.v_proj.v_scale", ".self_attn.attn.v_scale"), + (".k_scale", ".k_scale", ".attn.k_scale"), + (".v_scale", ".v_scale", ".attn.v_scale"), +] + +# Quark quantization scale remapping +_QUARK_SCALE_REMAP: Dict[str, str] = { + ".q_proj.output_scale": ".attn.q_scale", + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + "self_attn.prob_output_scale": ".attn.prob_scale", +} + + +class ParameterMapper: + """Maps HuggingFace checkpoint weight names to SGLang model parameters. + + This class pre-computes lookup tables at initialization for efficient + repeated mapping. It handles: + - Stacked/fused parameter mapping (gate_up_proj, qkv_proj, etc.) + - Expert parameter mapping with shard information + - Scale remapping for quantized models + - Model-specific weight name mutations + """ + + def __init__( + self, + stacked_params_mapping: List[StackedParamsEntry], + expert_params_mapping: List[ExpertParamsEntry], + num_local_experts: int = 0, + mutate_weight_preload: Optional[Callable[[str], str]] = None, + custom_scale_remap: Optional[Callable[[str], str]] = None, + ): + """Initialize the parameter mapper with model-specific configuration. + + Args: + stacked_params_mapping: List of (sglang_name, hf_name, shard_id) tuples. + Example: [("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1)] + expert_params_mapping: List of (sglang_name, hf_name, expert_id, shard_id) tuples. + Example: [("w13_weight", "experts.0.gate_proj.weight", 0, 0), ...] + num_local_experts: Number of experts in the current model rank. + For EP=1: num_local_experts = n_routed_experts + num_fused_shared_experts + For EP>1: num_local_experts = n_routed_experts // ep_size + num_fused_shared_experts + mutate_weight_preload: Optional function to transform weight names before mapping. + Used for shared expert fusion in DeepSeek (shared_experts → experts.{n_routed}). + custom_scale_remap: Optional function for model-specific scale remapping. + Used for DeepSeek k_proj/v_proj → attn_mqa scale mapping. + """ + self.num_local_experts = num_local_experts + self._mutate_weight_preload = mutate_weight_preload + self._custom_scale_remap = custom_scale_remap + + self._stacked_lookup, self._stacked_num_shards = self._build_stacked_lookup( + stacked_params_mapping + ) + self._expert_lookup, self._expert_num_shards = self._build_expert_lookup( + expert_params_mapping + ) + + @staticmethod + def _build_stacked_lookup( + mapping: List[StackedParamsEntry], + ) -> Tuple[Dict[str, Tuple[str, Union[int, str]]], Dict[str, int]]: + """Build lookup table and num_shards from stacked params mapping.""" + lookup: Dict[str, Tuple[str, Union[int, str]]] = {} + shard_counts: Dict[str, int] = {} + + for sglang_name, hf_name, shard_id in mapping: + lookup[hf_name] = (sglang_name, shard_id) + shard_counts[sglang_name] = shard_counts.get(sglang_name, 0) + 1 + + return lookup, shard_counts + + @staticmethod + def _build_expert_lookup( + mapping: List[ExpertParamsEntry], + ) -> Tuple[Dict[str, Tuple[str, int, Union[int, str]]], Dict[str, int]]: + """Build lookup table and num_shards from expert params mapping.""" + lookup: Dict[str, Tuple[str, int, Union[int, str]]] = {} + shard_counts: Dict[str, int] = {} + + for sglang_name, hf_name, expert_id, shard_id in mapping: + lookup[hf_name] = (sglang_name, expert_id, shard_id) + + for sglang_name, _, _, shard_id in mapping: + key = sglang_name + if key not in shard_counts: + unique_shards = set( + s_id for s_name, _, _, s_id in mapping if s_name == sglang_name + ) + shard_counts[key] = len(unique_shards) + + return lookup, shard_counts + + def _apply_scale_remap(self, name: str) -> str: + """Apply standard and Quark scale remapping patterns.""" + for suffix, pattern, replacement in _SCALE_REMAP_PATTERNS: + if name.endswith(suffix) and pattern in name: + return name.replace(pattern, replacement) + + for quark_suffix, replacement in _QUARK_SCALE_REMAP.items(): + if name.endswith(quark_suffix): + return name.replace(quark_suffix, replacement) + + return name + + def map(self, hf_weight_name: str) -> MappingResult: + """Map a HuggingFace checkpoint weight name to SGLang parameter info. + + Args: + hf_weight_name: The weight name from HuggingFace checkpoint. + + Returns: + MappingResult with mapped name and sharding information. + """ + name = hf_weight_name + + if self._mutate_weight_preload is not None: + name = self._mutate_weight_preload(name) + + if "scale" in name: + if self._custom_scale_remap is not None: + remapped = self._custom_scale_remap(name) + if remapped != name: + name = remapped + else: + name = self._apply_scale_remap(name) + else: + name = self._apply_scale_remap(name) + + for hf_pattern, ( + sglang_name, + expert_id, + shard_id, + ) in self._expert_lookup.items(): + if hf_pattern in name: + mapped_name = name.replace(hf_pattern, sglang_name) + return MappingResult( + sglang_name=mapped_name, + shard_id=shard_id, + num_shards=self._expert_num_shards.get(sglang_name, 1), + expert_id=expert_id, + num_local_experts=self.num_local_experts, + ) + + for hf_pattern, (sglang_name, shard_id) in self._stacked_lookup.items(): + if hf_pattern in name: + mapped_name = name.replace(hf_pattern, sglang_name) + return MappingResult( + sglang_name=mapped_name, + shard_id=shard_id, + num_shards=self._stacked_num_shards.get(sglang_name, 1), + expert_id=None, + num_local_experts=None, + ) + + return MappingResult( + sglang_name=name, + shard_id=None, + num_shards=1, + expert_id=None, + num_local_experts=None, + ) + + @classmethod + def from_model(cls, model) -> "ParameterMapper": + """Create a ParameterMapper from a model instance; currently supports + DeepseekV2ForCausalLM, Glm4ForCausalLM, Glm4MoeForCausalLM, + Glm4MoeLiteForCausalLM, LlamaForCausalLM, Qwen2ForCausalLM, + Qwen3ForCausalLM, Qwen3MoeForCausalLM.""" + stacked_mapping = list(getattr(model, "stacked_params_mapping", []) or []) + expert_mapping = list(getattr(model, "expert_params_mapping", []) or []) + + num_local_experts = 0 + if hasattr(model, "num_local_experts"): + num_local_experts = model.num_local_experts + elif expert_mapping: + expert_ids = set(entry[2] for entry in expert_mapping) + num_local_experts = len(expert_ids) + + mutate_fn = None + if hasattr(model, "mutate_weight_preload"): + mutate_fn = model.mutate_weight_preload + + scale_fn = None + if hasattr(model, "custom_scale_remap"): + scale_fn = model.custom_scale_remap + + return cls( + stacked_params_mapping=stacked_mapping, + expert_params_mapping=expert_mapping, + num_local_experts=num_local_experts, + mutate_weight_preload=mutate_fn, + custom_scale_remap=scale_fn, + ) diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py index cb83a13e9af0..fc3fd6c19ce7 100644 --- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py +++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py @@ -91,6 +91,7 @@ def forward_absorb_prepare( forward_batch: ForwardBatch, zero_allocator: BumpAllocator, llama_4_scaling: Optional[torch.Tensor] = None, + prev_topk_indices: Optional[torch.Tensor] = None, ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -182,18 +183,7 @@ def forward_absorb_prepare( q = self.q_b_proj(q)[0].view( -1, self.num_local_heads, self.qk_head_dim ) - topk_indices = self.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=self.layer_id, - ) - current_stream.wait_stream(self.alt_stream) - else: - k_nope = k_nope.unsqueeze(1) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) - if q_lora is not None: + if not self.skip_topk or prev_topk_indices is None: topk_indices = self.indexer( x=hidden_states, q_lora=q_lora, @@ -201,6 +191,23 @@ def forward_absorb_prepare( forward_batch=forward_batch, layer_id=self.layer_id, ) + else: + topk_indices = prev_topk_indices + current_stream.wait_stream(self.alt_stream) + else: + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + if q_lora is not None: + if not self.skip_topk or prev_topk_indices is None: + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_id, + ) + else: + topk_indices = prev_topk_indices else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim @@ -557,7 +564,14 @@ def forward_absorb_core( ) output, _ = self.o_proj(attn_bmm_output) - return output + if self.next_skip_topk is None: + return output + + # Return topk_indices for the next layer when enabling index cache + if not self.next_skip_topk: + return output, None + else: + return output, topk_indices def _fuse_rope_for_trtllm_mla( self: DeepseekV2AttentionMLA, forward_batch: ForwardBatch diff --git a/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py index b72e8290d773..6cee718600d5 100644 --- a/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py +++ b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py @@ -25,7 +25,6 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.environ import envs from sglang.srt.layers import deep_gemm_wrapper -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, @@ -91,6 +90,18 @@ class DeepseekV2WeightLoaderMixin: quant_config: Optional[QuantizationConfig] pp_group: GroupCoordinator num_fused_shared_experts: int + # Weight mapping relationships determined at model initialization time. + fuse_qkv_a_proj: bool + stacked_params_mapping: List[Tuple[str, str, int]] + expert_params_mapping: List[Tuple[str, str, int, int]] + + def mutate_weight_preload(self, name: str) -> str: + """Override in subclass for model-specific weight name mutations.""" + return name + + def custom_scale_remap(self, name: str) -> str: + """Override in subclass for model-specific scale remapping.""" + return name def do_load_weights( self, @@ -109,33 +120,7 @@ def do_load_weights( weights, NVFP4_CKPT_FP8_ATTN_QUANT_MODULES, nextn_conf ) - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, - ) - # Params for special naming rules in mixed-precision models, for example: - # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, - # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. - if self.quant_config and self.quant_config.get_name() == "w4afp8": - expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( - num_experts=self.config.n_routed_experts - ) - - # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None - fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( - self.config.q_lora_rank is not None - ) - cached_a_proj = {} if fuse_qkv_a_proj else None + cached_a_proj = {} if self.fuse_qkv_a_proj else None if self.num_fused_shared_experts > 0: assert self.num_fused_shared_experts == 1 @@ -157,11 +142,7 @@ def do_load_weights( ) ): continue - if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: - name = name.replace( - "mlp.shared_experts", - f"mlp.experts.{self.config.n_routed_experts}", - ) + name = self.mutate_weight_preload(name) weight_names.append(name) @@ -197,7 +178,7 @@ def do_load_weights( if "rotary_emb.inv_freq" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -226,7 +207,7 @@ def do_load_weights( ) break else: - for mapping in expert_params_mapping: + for mapping in self.expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue @@ -263,7 +244,7 @@ def do_load_weights( # Skip loading norm if not last rank in pipeline parallelism if ".norm." in name and not self.pp_group.is_last_rank: continue - if fuse_qkv_a_proj and ( + if self.fuse_qkv_a_proj and ( "q_a_proj" in name or "kv_a_proj_with_mqa" in name ): cached_a_proj[name] = loaded_weight @@ -331,13 +312,7 @@ def do_load_weights( if ( "k_scale" in name or "v_scale" in name ) and name not in params_dict: - # modelopt attn kv scale is named differently - for scale in ["k_scale", "v_scale"]: - if scale in name: - name = name.replace( - f"{scale[0]}_proj", "attn_mqa" - ) - break + name = self.custom_scale_remap(name) if name not in params_dict: # modelopt ckpt contains not needed weights for MTP module: # model.decoder.self_attn.attn_mqa.v_scale and diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 28029a0c75e9..f83db2d477a4 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -185,7 +185,7 @@ def forward( positions = cp_split_and_rebuild_position(forward_batch, positions) residual = None with get_global_expert_distribution_recorder().disable_this_region(): - hidden_states, residual = self.decoder( + hidden_states, residual, topk_indices = self.decoder( positions, hidden_states, forward_batch, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a7fbabdc05ff..8b111f5215a1 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -152,6 +152,7 @@ make_layers, use_intel_amx_backend, ) +from sglang.srt.utils.custom_op import register_custom_op if _use_aiter: from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm @@ -167,8 +168,6 @@ if _is_cuda: from flashinfer.gemm import mm_M1_16_K7168_N256 as _raw_dsv3_router_gemm from sgl_kernel import dsv3_fused_a_gemm, dsv3_router_gemm - - from sglang.srt.utils.custom_op import register_custom_op elif _is_npu: from sglang.srt.hardware_backend.npu.modules.deepseek_v2_attention_mla_npu import ( forward_dsa_core_npu, @@ -327,7 +326,7 @@ def forward( and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384) and _device_sm >= 90 ): - if _device_sm >= 100 and self.weight.shape[0] == 256: + if _device_sm == 100 and self.weight.shape[0] == 256: # router gemm output float32 logits = torch.empty( hidden_states.shape[0], @@ -1106,6 +1105,7 @@ def __init__( prefix: str = "", alt_stream: Optional[torch.cuda.Stream] = None, skip_rope: bool = False, + is_nextn: bool = False, ) -> None: super().__init__() self.layer_id = layer_id @@ -1175,6 +1175,8 @@ def __init__( prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) + self.skip_topk = None + self.next_skip_topk = None if self.use_nsa: is_neox_style = not getattr(config, "indexer_rope_interleave", False) self.indexer = Indexer( @@ -1195,6 +1197,26 @@ def __init__( layer_id=layer_id, alt_stream=alt_stream, ) + # Refer: https://arxiv.org/abs/2603.12201 for more details. + # skip_topk: when True, this layer will skip computation and reuse previous layer's topk indices. + # next_skip_topk: when True, the next layer will skip computation and reuse this layer's topk indices. + if is_nextn: + self.skip_topk = False + self.next_skip_topk = False + else: + self.index_topk_freq = getattr(config, "index_topk_freq", 1) + self.index_topk_pattern = getattr(config, "index_topk_pattern", None) + if self.index_topk_pattern is None: + self.skip_topk = max(layer_id - 1, 0) % self.index_topk_freq != 0 + self.next_skip_topk = layer_id % self.index_topk_freq != 0 + else: + self.skip_topk = self.index_topk_pattern[layer_id] == "S" + if layer_id < len(self.index_topk_pattern) - 1: + self.next_skip_topk = ( + self.index_topk_pattern[layer_id + 1] == "S" + ) + else: + self.next_skip_topk = False self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, @@ -1326,9 +1348,14 @@ def op_prepare(self, state): ) def op_core(self, state): - state.hidden_states_after_attn = self.forward_core( - state.pop("attn_intermediate_state") - ) + result = self.forward_core(state.pop("attn_intermediate_state")) + # forward_core may return (hidden_states, topk_indices) for NSA models + # with index cache enabled. In the TBO path, topk_indices is not + # propagated between layers, so we discard it here. + if isinstance(result, tuple): + state.hidden_states_after_attn = result[0] + else: + state.hidden_states_after_attn = result def forward( self, @@ -1338,6 +1365,7 @@ def forward( zero_allocator: BumpAllocator, layer_scatter_modes: LayerScatterModes = None, llama_4_scaling: Optional[torch.Tensor] = None, + prev_topk_indices: Optional[torch.Tensor] = None, ): s = self.forward_prepare( positions=positions, @@ -1346,6 +1374,7 @@ def forward( zero_allocator=zero_allocator, layer_scatter_modes=layer_scatter_modes, llama_4_scaling=llama_4_scaling, + prev_topk_indices=prev_topk_indices, ) return self.forward_core(s) @@ -1357,6 +1386,7 @@ def forward_prepare( zero_allocator: BumpAllocator, layer_scatter_modes: LayerScatterModes = None, llama_4_scaling: Optional[torch.Tensor] = None, + prev_topk_indices: Optional[torch.Tensor] = None, ): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj @@ -1396,7 +1426,12 @@ def forward_prepare( ) elif attn_forward_method == AttnForwardMethod.MLA: inner_state = self.forward_absorb_prepare( - positions, hidden_states, forward_batch, zero_allocator, llama_4_scaling + positions, + hidden_states, + forward_batch, + zero_allocator, + llama_4_scaling, + prev_topk_indices, ) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_ROCM: inner_state = self.forward_absorb_fused_mla_rope_prepare( @@ -1557,6 +1592,7 @@ def __init__( reduce_results=False, prefix=add_prefix("self_attn", prefix), alt_stream=alt_stream, + is_nextn=is_nextn, ) if not hasattr(config, "q_lora_rank") and envs.SGLANG_USE_AG_AFTER_QLORA.get(): raise ValueError( @@ -1643,6 +1679,7 @@ def forward( zero_allocator: BumpAllocator, gemm_output_zero_allocator: BumpAllocator = None, llama_4_scaling: Optional[torch.Tensor] = None, + prev_topk_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: quant_format = ( "mxfp4" @@ -1685,7 +1722,12 @@ def forward( zero_allocator=zero_allocator, llama_4_scaling=llama_4_scaling, layer_scatter_modes=self.layer_scatter_modes, + prev_topk_indices=prev_topk_indices, ) + if isinstance(hidden_states, tuple): + hidden_states, topk_indices = hidden_states + else: + topk_indices = None hidden_states, residual = self.layer_communicator.prepare_mlp( hidden_states, residual, forward_batch @@ -1721,7 +1763,7 @@ def forward( hidden_states, residual, forward_batch ) - return hidden_states, residual + return hidden_states, residual, topk_indices def op_comm_prepare_attn( self, @@ -1998,6 +2040,7 @@ def forward( elif self.first_k_dense_replace < normal_start_layer: normal_end_layer = normal_start_layer = 0 aux_hidden_states = [] + topk_indices = None for i in range(normal_start_layer, normal_end_layer): # NOTE: torch dynamo does not support graph break in context manager ctx = ( @@ -2015,7 +2058,7 @@ def forward( else: aux_hidden_states.append(hidden_states + residual) layer = self.layers[i] - hidden_states, residual = layer( + hidden_states, residual, topk_indices = layer( positions, hidden_states, forward_batch, @@ -2023,6 +2066,7 @@ def forward( zero_allocator, gemm_output_zero_allocator, llama_4_scaling, + prev_topk_indices=topk_indices, ) if normal_end_layer != self.end_layer: @@ -2115,6 +2159,37 @@ def __init__( self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config) + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Add A-proj fusion mapping when q_lora_rank is enabled + # q_a_proj + kv_a_proj_with_mqa -> fused_qkv_a_proj_with_mqa + if self.fuse_qkv_a_proj: + self.stacked_params_mapping.extend( + [ + ("fused_qkv_a_proj_with_mqa", "q_a_proj", 0), + ("fused_qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", 1), + ] + ) + self.expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + # Params for special naming rules in mixed-precision models, for example: + # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, + # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. + if self.quant_config and self.quant_config.get_name() == "w4afp8": + self.expert_params_mapping += ( + FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + ) + self._routed_experts_weights_of_layer = LazyValue( lambda: { layer_id: layer.mlp.get_moe_weights() @@ -2134,6 +2209,22 @@ def __init__( q_lora_rank = config.q_lora_rank if hasattr(config, "q_lora_rank") else None get_attn_tp_context().init_context(q_lora_rank, is_deepseek_nsa(config)) + def mutate_weight_preload(self, name: str) -> str: + """DeepSeek V2: shared expert fusion.""" + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + return name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + return name + + def custom_scale_remap(self, name: str) -> str: + """DeepSeek V2: k_proj -> attn_mqa when k_scale in name, v_proj -> attn_mqa when v_scale in name.""" + for scale in ["k_scale", "v_scale"]: + if scale in name: + return name.replace(f"{scale[0]}_proj", "attn_mqa") + return name + @property def routed_experts_weights_of_layer(self): return self._routed_experts_weights_of_layer.value diff --git a/python/sglang/srt/models/glm4.py b/python/sglang/srt/models/glm4.py index 016941b4b6c6..c9dd59f9cec1 100644 --- a/python/sglang/srt/models/glm4.py +++ b/python/sglang/srt/models/glm4.py @@ -463,6 +463,17 @@ def __init__( self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + # For EAGLE3 support self.capture_aux_hidden_states = False @@ -557,14 +568,7 @@ def end_layer(self): return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".gate_proj", 0), - ] + stacked_params_mapping = self.stacked_params_mapping params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 7673c49c957e..a80d6301929d 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -15,7 +15,6 @@ """Inference-only GLM-4.5, GLM-4.6 and GLM-4.7 model compatible with HuggingFace weights""" import logging -import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -1130,9 +1129,37 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + self.expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + # For EAGLE3 support self.capture_aux_hidden_states = False + def mutate_weight_preload(self, name: str) -> str: + """GLM4-MoE: shared expert fusion.""" + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + return name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + return name + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def determine_num_fused_shared_experts(self): if get_global_server_args().disable_shared_experts_fusion: return @@ -1220,43 +1247,8 @@ def load_weights( else: raise ValueError("num_nextn_predict_layers is not in the config") - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - if self.num_fused_shared_experts > 0: - assert self.num_fused_shared_experts == 1 - - def iter_weights_with_fused_shared_experts( - weights: Iterable[Tuple[str, torch.Tensor]], - ) -> Iterable[Tuple[str, torch.Tensor]]: - - pattern = re.compile( - r"^model\.layers\.(\d+)\.mlp\.shared_experts\.(.+)$" - ) - for name, weight in weights: - match = pattern.match(name) - if match: - layer_id = int(match.group(1)) - suffix = match.group(2) - name = f"model.layers.{layer_id}.mlp.experts.{self.config.n_routed_experts}.{suffix}" - yield name, weight - - weights = iter_weights_with_fused_shared_experts(weights) - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, - ) + stacked_params_mapping = self.stacked_params_mapping + expert_params_mapping = self.expert_params_mapping if is_nextn: nextn_layer_prefix = f"model.layers.{nextn_layer_id}" @@ -1277,6 +1269,8 @@ def iter_weights_with_fused_shared_experts( for name, loaded_weight in weights: weight_names.append(name) + name = self.mutate_weight_preload(name) + if not is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers diff --git a/python/sglang/srt/models/glm4_moe_lite.py b/python/sglang/srt/models/glm4_moe_lite.py index 03ff3d201d45..9753314fef2b 100644 --- a/python/sglang/srt/models/glm4_moe_lite.py +++ b/python/sglang/srt/models/glm4_moe_lite.py @@ -499,6 +499,33 @@ def __init__( ) self.capture_aux_hidden_states = False + # Weight loading mappings for ParameterMapper compatibility + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Add A-proj fusion mapping when q_lora_rank is enabled (MLA) + self.fuse_qkv_a_proj = hasattr(config, "q_lora_rank") and ( + config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.stacked_params_mapping.extend( + [ + ("fused_qkv_a_proj_with_mqa", "q_a_proj", 0), + ("fused_qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", 1), + ] + ) + self.expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=config.n_routed_experts + self.num_fused_shared_experts, + ) + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() if self.nsa_enable_prefill_cp: self.cp_rank = get_attention_tp_rank() @@ -535,6 +562,22 @@ def determine_num_fused_shared_experts( self.num_fused_shared_experts = self.config.n_shared_experts + def mutate_weight_preload(self, name: str) -> str: + """GLM4-MoE-Lite: shared expert fusion.""" + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + return name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + return name + + def custom_scale_remap(self, name: str) -> str: + """GLM4-MoE-Lite: k_proj/v_proj -> attn_mqa for MLA kv scale.""" + for s in ["k_scale", "v_scale"]: + if s in name: + return name.replace(f"{s[0]}_proj", "attn_mqa") + return name + def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 6d2ef89724da..946af33fee00 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -28,6 +28,7 @@ from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -148,6 +149,22 @@ def __init__( 0 if get_global_server_args().disable_shared_experts_fusion else 1 ) + # Weight loading mappings (must match parent for load_weights to work) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + self.expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f955ac750d34..e23484dcf195 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -493,9 +493,21 @@ def __init__( (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + # Llama-specific scale remapping patterns (suffix, pattern, replacement) + self._llama_scale_remap_patterns = [ + (".activation_scale", ".activation_scale", ".input_scale"), + (".weight_scale_inv", ".weight_scale_inv", ".weight_scale"), + ] self.capture_aux_hidden_states = False + def custom_scale_remap(self, name: str) -> str: + """Llama: activation_scale->input_scale, weight_scale_inv->weight_scale.""" + for suffix, pattern, replacement in self._llama_scale_remap_patterns: + if name.endswith(suffix) and pattern in name: + return name.replace(pattern, replacement) + return name + def _init_model( self, config: LlamaConfig, @@ -606,23 +618,11 @@ def get_num_params(self): return len(params_dict) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if name.endswith(".activation_scale"): - name = name.replace(".activation_scale", ".input_scale") - if name.endswith(".weight_scale_inv"): - name = name.replace(".weight_scale_inv", ".weight_scale") - + name = self.custom_scale_remap(name) layer_id = get_layer_id(name) if ( layer_id is not None @@ -649,7 +649,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name is None: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 84cfa46f1c77..555e5e799510 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -92,9 +92,6 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - if get_global_server_args().rl_on_policy_target is not None: - x = x.bfloat16() - gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -280,11 +277,6 @@ def __init__( quant_config=quant_config, use_attn_tp_group=is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), - params_dtype=( - torch.float32 - if get_global_server_args().rl_on_policy_target is not None - else None - ), ) else: self.embed_tokens = PPMissingLayer() @@ -307,10 +299,8 @@ def __init__( if self.pp_group.is_last_rank: norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, - override_orig_dtype=torch.float32, - fp32_residual=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} @@ -460,6 +450,17 @@ def __init__( self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # For EAGLE3 support self.capture_aux_hidden_states = False @@ -554,14 +555,7 @@ def end_layer(self): return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] + stacked_params_mapping = self.stacked_params_mapping params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 8f3475a24323..723c6c3ba6b7 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -618,7 +618,17 @@ def __init__( prefix=add_prefix("layers", prefix), ) if self.pp_group.is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) else: self.norm = PPMissingLayer(return_tuple=True) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 7e557a8d5b31..8fcf148dae9b 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -103,8 +103,8 @@ def __init__( norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} @@ -345,10 +345,8 @@ def __init__( norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, - override_orig_dtype=torch.float32, - fp32_residual=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} @@ -471,6 +469,16 @@ def __init__( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # handle the lm head on different pp ranks if self.pp_group.is_last_rank: if self.pp_group.world_size == 1 and config.tie_word_embeddings: @@ -582,15 +590,7 @@ def end_layer(self): return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - + stacked_params_mapping = self.stacked_params_mapping params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if not name.startswith("model.") and ( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 912891b6a7eb..b37594d8cf7b 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar import torch +import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig @@ -54,7 +55,7 @@ ) from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK from sglang.srt.layers.moe.utils import ( RoutingMethodType, filter_moe_weight_param_global_expert, @@ -322,7 +323,20 @@ def forward_normal( # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) + if get_global_server_args().rl_on_policy_target is not None: + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + topk_output = StandardTopKOutput( + topk_weights=routing_weights, + topk_ids=selected_experts, + router_logits=router_logits, + ) + else: + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) if self.ep_size > 1 and not should_allreduce_fusion: @@ -509,7 +523,7 @@ def __init__( ) self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True - ) + ) and (get_global_server_args().rl_on_policy_target is None) self.compatible_with_fused_qk_norm_rope = not isinstance( self.rotary_emb, MRotaryEmbedding ) and self.head_dim in (64, 128, 256) @@ -524,6 +538,7 @@ def __init__( torch.bfloat16, _yarn_factor != 1.0, ) + and (get_global_server_args().rl_on_policy_target is None) ) self._used_fused_qk_norm_rope_last_call = False @@ -536,8 +551,16 @@ def __init__( prefix=add_prefix("attn", prefix), ) - self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) self.alt_stream = alt_stream def op_prepare(self, state): @@ -781,9 +804,19 @@ def __init__( quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ) self.layer_communicator = LayerCommunicator( @@ -957,6 +990,23 @@ def __init__( use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) + + # Stacked params mapping for unified weight loading API + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + self.expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + self.capture_aux_hidden_states = False self.attn_cp_size = get_attn_context_model_parallel_world_size() @@ -1080,21 +1130,8 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - ) + stacked_params_mapping = self.stacked_params_mapping + expert_params_mapping = self.expert_params_mapping # Pre-define `params_dict` to avoid repeated expensive traversal of model parameters. params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 7746b2445999..78e1c5fa250e 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -431,75 +431,73 @@ def rot_pos_emb( return cos_combined, sin_combined - def _get_interpolation_indices(self, dim_size: int) -> torch.Tensor: - """ - Compute continuous interpolation indices for a single dimension. + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + num_grid_per_side = int(self.num_position_embeddings**0.5) + device = self.pos_embed.weight.device - Returns continuous indices. - """ - if self.align_corners: - indices = np.linspace( - 0, self.num_grid_per_side - 1, dim_size, dtype=np.float32 - ) - else: - indices = (np.arange(dim_size, dtype=np.float32) + 0.5) * ( - self.num_grid_per_side / dim_size - ) - 0.5 - indices = np.clip(indices, 0, self.num_grid_per_side - 1) - return indices + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] - def _calculate_indices_and_weights(self, h_idxs, w_idxs): - """ - Compute bilinear interpolation indices and weights. + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w) - Returns tuple of (indices, weights), each as 4 numpy arrays for the 4 corner points. - """ - h_f = np.floor(h_idxs).astype(np.int64) - h_c = np.clip(h_f + 1, 0, self.num_grid_per_side - 1) - dh = h_idxs - h_f + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) - w_f = np.floor(w_idxs).astype(np.int64) - w_c = np.clip(w_f + 1, 0, self.num_grid_per_side - 1) - dw = w_idxs - w_f + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor - side = self.num_grid_per_side + base_h = h_idxs_floor * num_grid_per_side + base_h_ceil = h_idxs_ceil * num_grid_per_side - indices = [ - (h_f[:, None] * side + w_f).flatten(), - (h_f[:, None] * side + w_c).flatten(), - (h_c[:, None] * side + w_f).flatten(), - (h_c[:, None] * side + w_c).flatten(), - ] - weights = [ - ((1 - dh)[:, None] * (1 - dw)).flatten(), - ((1 - dh)[:, None] * dw).flatten(), - (dh[:, None] * (1 - dw)).flatten(), - (dh[:, None] * dw).flatten(), - ] - return indices, weights + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] - def _get_position_embedding(self, patch_pos_embeds, grid_ts, grid_hs, grid_ws): - """ - Tile and reorganize position embeddings to align with the token sequence. - """ - result_parts = [] - merge_size = self.spatial_merge_size + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) - h_merge = h // merge_size - w_merge = w // merge_size + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) pos_embed = ( - pos_embed.view(t, h_merge, merge_size, w_merge, merge_size, -1) + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) .permute(0, 1, 3, 2, 4, 5) .flatten(0, 4) ) - result_parts.append(pos_embed) + patch_pos_embeds_permute.append(pos_embed) - return torch.cat(result_parts, dim=0) + return torch.cat(patch_pos_embeds_permute) def _torch_interp_indices( self, dim_size: int, device: torch.device @@ -607,61 +605,6 @@ def bucket_flashinfer_max_seqlen(self, real_max_seqlen: int) -> int: round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]), ) - def fast_pos_embed_interpolate(self, grid_thw): - """Interpolate position embeddings for (batch, 3) size input dimensions. - - Performs bilinear interpolation on spatial dimensions (height, width) and replicates - along temporal dimension. The result is reorganized according to spatial_merge_size. - - Args: - grid_thw: Tensor of shape [batch_size, 3] with (temporal, height, width) dimensions - in patches for each sample. - - Returns: - Interpolated position embeddings tensor. - """ - grid_thw_cpu = grid_thw.cpu().numpy() - - # transfer data to CPU before loop - temporal_dims = grid_thw_cpu[:, 0].tolist() - height_dims = grid_thw_cpu[:, 1].tolist() - width_dims = grid_thw_cpu[:, 2].tolist() - - device = self.pos_embed.weight.device - dtype = self.pos_embed.weight.dtype - - patches_size = [h * w for h, w in zip(height_dims, width_dims)] - total_patches = sum(patches_size) - all_indices_np = np.zeros((4, total_patches), dtype=np.int64) - all_weights_np = np.zeros((4, total_patches), dtype=np.float32) - - current_idx = 0 - - # calculate indices and weights on CPU - for t, h, w in zip(temporal_dims, height_dims, width_dims): - h_idxs = self._get_interpolation_indices(h) - w_idxs = self._get_interpolation_indices(w) - - indices, weights = self._calculate_indices_and_weights(h_idxs, w_idxs) - - end_idx = current_idx + h * w - for i in range(4): - all_indices_np[i, current_idx:end_idx] = indices[i] - all_weights_np[i, current_idx:end_idx] = weights[i] - current_idx = end_idx - - idx_tensor = torch.from_numpy(all_indices_np).to(device) - weight_tensor = torch.from_numpy(all_weights_np).to(dtype=dtype, device=device) - - # calculate interpolation - pos_embeds = self.pos_embed(idx_tensor.view(-1)) - pos_embeds = pos_embeds.view(4, total_patches, -1) - patch_pos_embeds = (pos_embeds * weight_tensor.unsqueeze(-1)).sum(dim=0) - patch_pos_embeds = patch_pos_embeds.split(patches_size) - return self._get_position_embedding( - patch_pos_embeds, temporal_dims, height_dims, width_dims - ) - def compute_flashinfer_batch_offsets_packed( self, token_cu_seqlens: np.ndarray, @@ -1005,14 +948,19 @@ def forward( hidden_states + residual if residual is not None else hidden_states ) + deepstack_embeds = None + if input_deepstack_embeds is not None: + prev_layer_idx = layer_idx - 1 + if prev_layer_idx in self.deepstack_embed_to_decoder_layer: + sep = self.hidden_size * prev_layer_idx + deepstack_embeds = input_deepstack_embeds[ + :, sep : sep + self.hidden_size + ] + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack # The order matters because addition with different tensors is not associative in practice. - # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. - deepstack_embeds = self.get_deepstack_embeds( - layer_idx - 1, input_deepstack_embeds - ) hidden_states, residual = layer( positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/lfm2_vl.py b/python/sglang/srt/multimodal/processors/lfm2_vl.py index fc8700e7ff72..0d57dd9ddb47 100644 --- a/python/sglang/srt/multimodal/processors/lfm2_vl.py +++ b/python/sglang/srt/multimodal/processors/lfm2_vl.py @@ -12,9 +12,9 @@ # limitations under the License. """Multimodal processor for LFM2-VL models with SigLip2 NaFlex support.""" -from typing import Any, Dict, List, Optional, Union +from typing import List, Union -from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.managers.schedule_batch import Modality, MultimodalProcessorOutput from sglang.srt.models.lfm2_vl import Lfm2VlForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -56,7 +56,7 @@ async def process_mm_data_async( input_text: str, request_obj, **kwargs, - ) -> Optional[Dict[str, Any]]: + ): if not image_data: input_ids = self._tokenizer( input_text, return_tensors="pt", add_special_tokens=False @@ -77,8 +77,8 @@ async def process_mm_data_async( base_output, self.mm_tokens ) - return { - "input_ids": input_ids.tolist(), - "mm_items": mm_items, - "im_token_id": self.IMAGE_TOKEN_ID, - } + return MultimodalProcessorOutput( + input_ids=input_ids.tolist(), + mm_items=mm_items, + im_token_id=self.IMAGE_TOKEN_ID, + ) diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 3f102567d01a..6fb3899021d3 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -499,7 +499,7 @@ async def process_mm_data_async( **kwargs, ): entry_time = time.perf_counter() - base_output = self.load_mm_data( + base_output = self.legacy_load_mm_data( prompt=input_text, image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d91ced805f5d..756bdf618464 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -610,6 +610,7 @@ class ServerArgs: cuda_graph_max_bs: Optional[int] = None cuda_graph_bs: Optional[List[int]] = None disable_cuda_graph: bool = False + disable_draft_cuda_graph: bool = False disable_cuda_graph_padding: bool = False enable_profile_cuda_graph: bool = False enable_cudagraph_gc: bool = False @@ -5354,6 +5355,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable cuda graph.", ) + parser.add_argument( + "--disable-draft-cuda-graph", + action="store_true", + help="Disable cuda graph for draft model in speculative decoding.", + ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 40e859b2d6d6..2604ae037c55 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -377,6 +377,10 @@ def replay(self, forward_batch: ForwardBatch): buffers.seq_lens.fill_(self.seq_len_fill_value) buffers.out_cache_loc.zero_() buffers.positions.zero_() + buffers.topk_p.zero_() + buffers.topk_index.zero_() + buffers.hidden_states.zero_() + buffers.req_pool_indices.zero_() num_tokens = bs * self.num_tokens_per_bs @@ -386,8 +390,12 @@ def replay(self, forward_batch: ForwardBatch): forward_batch.out_cache_loc ) buffers.positions[:raw_num_token].copy_(forward_batch.positions) - buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) - buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) + buffers.topk_index[:raw_bs].copy_( + forward_batch.spec_info.topk_index.clamp( + 0, self.model_runner.model_config.vocab_size - 1 + ) + ) buffers.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index dbb91f555ecf..8747368331a1 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -776,6 +776,10 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] + if self.accept_length is not None: + self.accept_length = self.accept_length[: len(new_indices)] + if self.accept_length_cpu is not None: + self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] else: # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` self.topk_p = self.topk_p[new_indices] @@ -807,6 +811,27 @@ def merge_batch(self, spec_info: "EagleDraftInput"): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + if self.accept_length is not None and spec_info.accept_length is not None: + self.accept_length = torch.cat( + [self.accept_length, spec_info.accept_length] + ) + self.accept_length_cpu = self.accept_length.tolist() + elif self.accept_length is not None: + zeros = torch.zeros( + [spec_info.verified_id.shape[0]], + dtype=self.accept_length.dtype, + device=self.accept_length.device, + ) + self.accept_length = torch.cat([self.accept_length, zeros]) + self.accept_length_cpu = self.accept_length.tolist() + elif spec_info.accept_length is not None: + zeros = torch.zeros( + [self.verified_id.shape[0]], + dtype=spec_info.accept_length.dtype, + device=spec_info.accept_length.device, + ) + self.accept_length = torch.cat([zeros, spec_info.accept_length]) + self.accept_length_cpu = self.accept_length.tolist() @dataclass diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 59c63c17ca5e..0e966a05cb7f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -235,7 +235,10 @@ def init_cuda_graphs(self): self.cuda_graph_runner = None self.cuda_graph_runner_for_draft_extend = None - if self.server_args.disable_cuda_graph: + if ( + self.server_args.disable_cuda_graph + or self.server_args.disable_draft_cuda_graph + ): return Device2DraftCudaGraphRunner = { diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 3be16446e0b5..ce209c1a0da2 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -38,6 +38,8 @@ def _snapshot(self): def _reset_tensors(self): for name, param in self._model_state(): + if "cos_sin_cache" in name or "freqs_cis" in name: + continue param.copy_(_random_like(param)) def _compare(self): @@ -123,6 +125,21 @@ def _postprocess_tensors( skip_compare_names = [] + # Skip non-persistent buffers like cos_sin_cache + # These buffers are registered with persistent=False and are not saved in checkpoints + # They should be recomputed after loading weights, so we don't compare them here + non_persistent_buffer_patterns = [ + "cos_sin_cache", # RoPE cache + "inv_freq", # RoPE inverse frequency (if it exists as buffer) + ] + + for name in raw: + for pattern in non_persistent_buffer_patterns: + if pattern in name: + skip_compare_names.append(name) + logger.info(f"[check_tensors] Skipping non-persistent buffer: {name}") + break + # dequant fp8 quant_names = [ name @@ -131,18 +148,20 @@ def _postprocess_tensors( if name.endswith("weight") and name.replace("weight", "weight_scale_inv") in raw ] skip_compare_names += quant_names + skip_compare_names += [ + name.replace("weight", "weight_scale_inv") for name in quant_names + ] for name in quant_names: w_q = raw[name] w_s = raw[name.replace("weight", "weight_scale_inv")] try: - # TODO this is only needed for Blackwell - w_s_inverse_transformed = inverse_transform_scale_ue8m0( - w_s, mn=w_q.shape[-2] - ) + if w_s.dtype == torch.int32: + # UE8M0 packed format (Blackwell DeepGEMM) + w_s = inverse_transform_scale_ue8m0(w_s, mn=w_q.shape[-2]) w_dequant = block_quant_dequant( w_q, - w_s_inverse_transformed, + w_s, # TODO do not hardcode block_size=[128, 128], dtype=torch.bfloat16, diff --git a/scripts/ci/utils/diffusion/comparison_configs.json b/scripts/ci/utils/diffusion/comparison_configs.json index a6f1be874fc8..b1b766591c4f 100644 --- a/scripts/ci/utils/diffusion/comparison_configs.json +++ b/scripts/ci/utils/diffusion/comparison_configs.json @@ -1,5 +1,5 @@ { - "_comment": "Per-model comparison config. Only frameworks listed under each case are tested. vLLM-Omni disabled until dep install issues resolved.", + "_comment": "Per-model comparison config. Sampling params omitted where model defaults are correct — only override resolution, seed, and params that differ from defaults.", "test_image_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png", "cases": [ { @@ -9,8 +9,6 @@ "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "width": 1024, "height": 1024, - "num_inference_steps": 50, - "guidance_scale": 4.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -27,8 +25,6 @@ "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "width": 1024, "height": 1024, - "num_inference_steps": 50, - "guidance_scale": 4.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -45,8 +41,6 @@ "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "width": 1024, "height": 1024, - "num_inference_steps": 50, - "guidance_scale": 4.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -64,8 +58,6 @@ "reference_image": true, "width": 1024, "height": 1024, - "num_inference_steps": 50, - "guidance_scale": 4.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -82,8 +74,6 @@ "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "width": 1024, "height": 1024, - "num_inference_steps": 9, - "guidance_scale": 4.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -101,8 +91,6 @@ "width": 1280, "height": 720, "num_frames": 81, - "num_inference_steps": 2, - "guidance_scale": 5.0, "seed": 42, "num_gpus": 4, "frameworks": { @@ -121,8 +109,6 @@ "width": 1280, "height": 720, "num_frames": 81, - "num_inference_steps": 50, - "guidance_scale": 5.0, "seed": 42, "num_gpus": 1, "frameworks": { @@ -132,6 +118,23 @@ } } }, + { + "id": "ltx2_twostage_t2v", + "model": "Lightricks/LTX-2", + "task": "text-to-video", + "prompt": "A cat and a dog baking a cake together in a kitchen.", + "width": 768, + "height": 512, + "num_frames": 121, + "seed": 42, + "num_gpus": 2, + "frameworks": { + "sglang": { + "serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --pipeline-class-name LTX2TwoStagePipeline", + "extra_env": {} + } + } + }, { "id": "wan22_i2v_a14b_720p", "model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers", @@ -141,8 +144,6 @@ "width": 1280, "height": 720, "num_frames": 81, - "num_inference_steps": 2, - "guidance_scale": 5.0, "seed": 42, "num_gpus": 4, "frameworks": { diff --git a/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py b/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py index bb223fbe6dde..bce9f31be340 100644 --- a/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py +++ b/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py @@ -239,9 +239,12 @@ def generate_dashboard( current: dict, history: list[dict], charts_dir: str | None = None, -) -> str: +) -> tuple[str, list[str]]: """Generate full markdown dashboard. + Returns (markdown_string, alert_reasons) where alert_reasons is a list of + human-readable strings for cases that need attention (empty if all is well). + If charts_dir is provided, saves chart PNGs as files to that directory and references them via raw.githubusercontent URLs. Otherwise, charts are omitted. @@ -342,45 +345,7 @@ def generate_dashboard( row += f" {_fmt_speedup(sg_lat, case_fws.get(ofw))} |" lines.append(row) - # ---- Section 2: SGLang Performance Trend ---- - if history: - lines.append(f"\n## SGLang Performance Trend (Last {len(history) + 1} Runs)\n") - - # Build header - header = "| Date | Commit |" - sep = "|------|--------|" - for cid in case_ids: - header += f" {cid} (s) |" - sep += "---------|" - header += " Trend |" - sep += "-------|" - lines.append(header) - lines.append(sep) - - # Current run first - all_runs = [current] + history - for i, run in enumerate(all_runs): - run_cases = _extract_case_results(run) - date = _short_date(run.get("timestamp", "")) - sha_s = _short_sha(run.get("commit_sha", "")) - row = f"| {date} | `{sha_s}` |" - for cid in case_ids: - lat = run_cases.get(cid, {}).get("sglang") - row += f" {_fmt_latency(lat)} |" - # Trend vs next (older) run - if i + 1 < len(all_runs): - prev_cases = _extract_case_results(all_runs[i + 1]) - emojis = [] - for cid in case_ids: - cur = run_cases.get(cid, {}).get("sglang") - prev = prev_cases.get(cid, {}).get("sglang") - emojis.append(_trend_emoji(cur, prev)) - row += " ".join(emojis) + " |" - else: - row += " -- |" - lines.append(row) - - # ---- Section 3: Cross-Framework Speedup Trend (only if multiple frameworks) ---- + # ---- Section 2: Cross-Framework Speedup Trend (only if multiple frameworks) ---- if history and other_frameworks: lines.append("\n## SGLang vs vLLM-Omni Speedup Over Time\n") @@ -562,6 +527,41 @@ def _chart_label(run: dict) -> str: except ImportError: lines.append("\n*Charts unavailable (matplotlib not installed)*\n") + # ---- SGLang Performance Trend (raw data table, at the end) ---- + if history: + lines.append(f"\n## SGLang Performance Trend (Last {len(history) + 1} Runs)\n") + + header = "| Date | Commit |" + sep = "|------|--------|" + for cid in case_ids: + header += f" {cid} (s) |" + sep += "---------|" + header += " Trend |" + sep += "-------|" + lines.append(header) + lines.append(sep) + + all_runs = [current] + history + for i, run in enumerate(all_runs): + run_cases = _extract_case_results(run) + date = _short_date(run.get("timestamp", "")) + sha_s = _short_sha(run.get("commit_sha", "")) + row = f"| {date} | `{sha_s}` |" + for cid in case_ids: + lat = run_cases.get(cid, {}).get("sglang") + row += f" {_fmt_latency(lat)} |" + if i + 1 < len(all_runs): + prev_cases = _extract_case_results(all_runs[i + 1]) + emojis = [] + for cid in case_ids: + cur = run_cases.get(cid, {}).get("sglang") + prev = prev_cases.get(cid, {}).get("sglang") + emojis.append(_trend_emoji(cur, prev)) + row += " ".join(emojis) + " |" + else: + row += " -- |" + lines.append(row) + # ---- Risk Notification ---- alert_cases = [ (cid, emoji, reason) @@ -575,8 +575,7 @@ def _chart_label(run: dict) -> str: lines.append("> The following cases need attention:") for _cid, _emoji, reason in alert_cases: lines.append(f"> - {reason}") - lines.append(">") - lines.append("> cc @mickqian @bbuf @yhyang201\n") + lines.append("") # Footer lines.append("\n---") @@ -584,7 +583,164 @@ def _chart_label(run: dict) -> str: "*Generated by `generate_diffusion_dashboard.py` in SGLang nightly CI.*" ) - return "\n".join(lines) + "\n" + alert_reasons = [reason for _, _, reason in alert_cases] + return "\n".join(lines) + "\n", alert_reasons + + +ALERT_ASSIGNEES = ["mickqian", "bbuf", "yhyang201"] +ALERT_LABEL = "perf-regression" + + +ALERT_ISSUE_TITLE = "[Diffusion CI] Performance regression tracker" + + +def _find_alert_issue(repo: str) -> tuple[str | None, bool]: + """Find the perf-regression tracker issue (open OR closed). + + Returns (issue_number, is_open). Prefers an open issue; if none, + returns the most recent closed one so it can be reopened. + """ + import subprocess + + for state in ("open", "closed"): + result = subprocess.run( + [ + "gh", + "issue", + "list", + "--repo", + repo, + "--label", + ALERT_LABEL, + "--state", + state, + "--json", + "number", + "--limit", + "1", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0 or not result.stdout.strip(): + continue + issues = json.loads(result.stdout) + if issues: + return str(issues[0]["number"]), state == "open" + return None, False + + +def _create_alert_issue(alert_reasons: list[str]) -> None: + """Create or update the single perf-regression tracker issue. + + Logic: + - If an open issue exists → add a comment with the new alert. + - If a closed issue exists → reopen it, then add a comment. + - If no issue exists → create one. + + This guarantees at most one tracker issue ever exists. + + Uses `gh` (GitHub CLI) which is available in all GitHub Actions runners. + Falls back silently outside CI. + """ + import subprocess + + run_url = "" + run_id = os.environ.get("GITHUB_RUN_ID", "") + repo = os.environ.get("GITHUB_REPOSITORY", "sgl-project/sglang") + server_url = os.environ.get("GITHUB_SERVER_URL", "https://github.com") + if run_id: + run_url = f"{server_url}/{repo}/actions/runs/{run_id}" + + date = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + body_lines = [ + f"## Performance Alert — {date}", + "", + "The nightly diffusion benchmark detected the following issue(s):", + "", + ] + for reason in alert_reasons: + body_lines.append(f"- {reason}") + if run_url: + body_lines += ["", f"**CI Run:** {run_url}"] + body = "\n".join(body_lines) + + try: + existing, is_open = _find_alert_issue(repo) + + if existing: + # Reopen if closed + if not is_open: + subprocess.run( + [ + "gh", + "issue", + "reopen", + existing, + "--repo", + repo, + ], + capture_output=True, + text=True, + timeout=30, + ) + print(f"Reopened alert issue #{existing}") + + # Add comment + result = subprocess.run( + [ + "gh", + "issue", + "comment", + existing, + "--repo", + repo, + "--body", + body, + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + print(f"Commented on alert issue #{existing}") + else: + print( + f"Warning: failed to comment on issue #{existing} " + f"(rc={result.returncode}): {result.stderr.strip()}" + ) + else: + # Create a new issue + cmd = [ + "gh", + "issue", + "create", + "--repo", + repo, + "--title", + ALERT_ISSUE_TITLE, + "--body", + body, + "--label", + ALERT_LABEL, + ] + for user in ALERT_ASSIGNEES: + cmd += ["--assignee", user] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode == 0: + print(f"Created alert issue: {result.stdout.strip()}") + else: + print( + f"Warning: failed to create alert issue " + f"(rc={result.returncode}): {result.stderr.strip()}" + ) + except FileNotFoundError: + print("Warning: `gh` CLI not found — skipping alert issue creation") + except Exception as e: + print(f"Warning: failed to create/update alert issue: {e}") # --------------------------------------------------------------------------- @@ -649,7 +805,9 @@ def main(): print(f"Loaded {len(history)} historical run(s) from {args.history_dir}") # Generate dashboard - markdown = generate_dashboard(current, history, charts_dir=args.charts_dir) + markdown, alert_reasons = generate_dashboard( + current, history, charts_dir=args.charts_dir + ) # Write output os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) @@ -667,6 +825,12 @@ def main(): else: print("Warning: $GITHUB_STEP_SUMMARY not set, skipping") + # Create GitHub Issue for performance alerts (so assignees get notified) + if alert_reasons: + _create_alert_issue(alert_reasons) + else: + print("No performance alerts — skipping issue creation.") + if __name__ == "__main__": main() diff --git a/test/registered/8-gpu-models/test_deepseek_v32_indexcache.py b/test/registered/8-gpu-models/test_deepseek_v32_indexcache.py new file mode 100644 index 000000000000..4b769ec570a5 --- /dev/null +++ b/test/registered/8-gpu-models/test_deepseek_v32_indexcache.py @@ -0,0 +1,117 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + write_github_step_summary, +) + +register_cuda_ci(est_time=360, suite="stage-c-test-8-gpu-h200") + +DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2" + + +class TestDeepseekV32IndexTopkPattern(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "8", + "--enable-dp-attention", + "--model-loader-extra-config", + '{"enable_multithread_load": true, "num_threads": 64}', + "--json-model-override-args", + '{"index_topk_pattern": "FFSFSSSFSSFFFSSSFFFSFSSSSSSFFSFFSFFSSFFFFFFSFFFFFSFFSSSSSSFSF"}', + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + args = SimpleNamespace( + num_shots=20, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' + ) + self.assertGreater(metrics["accuracy"], 0.935) + + +class TestDeepseekV32IndexFreq(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "8", + "--model-loader-extra-config", + '{"enable_multithread_load": true, "num_threads": 64}', + "--json-model-override-args", + '{"index_topk_freq": 4}', + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + args = SimpleNamespace( + num_shots=20, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' + ) + self.assertGreater(metrics["accuracy"], 0.935) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/8-gpu-models/test_nvidia_nemotron_3_super_nightly.py b/test/registered/8-gpu-models/test_nvidia_nemotron_3_super_nightly.py index 5ae30c3f62bb..608dbbe6c5ec 100644 --- a/test/registered/8-gpu-models/test_nvidia_nemotron_3_super_nightly.py +++ b/test/registered/8-gpu-models/test_nvidia_nemotron_3_super_nightly.py @@ -42,7 +42,7 @@ ] # Accuracy threshold -GSM8K_BASELINE = 0.96 +GSM8K_BASELINE = 0.935 class TestNvidiaNemotron3SuperNightly(unittest.TestCase): diff --git a/test/registered/8-gpu-models/test_qwen3_235b.py b/test/registered/8-gpu-models/test_qwen3_235b.py index f72c51bf4f90..70420bbed64a 100644 --- a/test/registered/8-gpu-models/test_qwen3_235b.py +++ b/test/registered/8-gpu-models/test_qwen3_235b.py @@ -4,7 +4,7 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests -from sglang.test.test_utils import ModelLaunchSettings +from sglang.test.test_utils import ModelLaunchSettings, is_blackwell_system # Runs on both H200 and B200 via nightly-8-gpu-common suite register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) @@ -70,6 +70,7 @@ def test_qwen3_235b_fp8_all_variants(self): ), ) + @unittest.skipIf(is_blackwell_system(), "Requires H200 system") def test_qwen3_235b_fp8_cp(self): """Run performance and accuracy for Qwen3-235B-FP8 with context parallelism.""" diff --git a/test/registered/distributed/test_parallelism_context_integration.py b/test/registered/distributed/test_parallelism_context_integration.py new file mode 100644 index 000000000000..fad3a10a004b --- /dev/null +++ b/test/registered/distributed/test_parallelism_context_integration.py @@ -0,0 +1,271 @@ +""" +Integration tests for ParallelismContext with real sglang servers. + +Tests that ParallelismContext can instantiate models with correct tensor parallel +sharding by comparing parameter names and sizes against a running sglang server. + +Run with: + pytest test/registered/distributed/test_parallelism_context_integration.py -v + +Full test suite (non-CI): + - TP=2 small model (Qwem2.5-1.5B-Instruct) + - EP=2 small MOE model (DeepSeek-Coder-V2-Lite-Instruct) + - MLA model with hybrid dp attention (DeepSeek-Coder-V2-Lite-Instruct) + +CI test (reduced): + - TP=2 small model only +""" + +import dataclasses +import gc +from typing import Dict, List, Tuple + +import pytest +import requests +import torch + +from sglang.srt.distributed.parallel_state import RankParallelismConfig +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, +) +from sglang.utils import terminate_process + + +def get_transfer_engine_info(url: str, rank: int) -> Dict: + """Get transfer engine info (parameter names and sizes) for a rank.""" + response = requests.get( + f"{url}/remote_instance_transfer_engine_info", + params={"rank": rank}, + ) + response.raise_for_status() + return response.json() + + +def get_parallelism_config(url: str, rank: int) -> Dict: + """Get parallelism config for a rank.""" + response = requests.get(f"{url}/parallelism_config", params={"rank": rank}) + response.raise_for_status() + return response.json() + + +def get_server_info(url: str) -> Dict: + """Get server info.""" + response = requests.get(f"{url}/server_info") + response.raise_for_status() + return response.json() + + +def verify_model_params_match_for_rank( + url: str, + rank: int, + server_info: Dict, + test_gpu_id: int, +): + """Verify model parameters match for a specific rank by recreating a model shard.""" + transfer_info = get_transfer_engine_info(url, rank) + server_weights_info = transfer_info["remote_instance_transfer_engine_info"][1] + + # Get parallelism config from running server + parallelism_config_data = get_parallelism_config(url, rank) + parallelism_config = RankParallelismConfig.from_dict(parallelism_config_data) + # Get server args from server info + from sglang.srt.server_args import ServerArgs + + valid_fields = {f.name for f in dataclasses.fields(ServerArgs)} + filtered_info = {k: v for k, v in server_info.items() if k in valid_fields} + filtered_info.pop("model_config", None) + server_args = ServerArgs(**filtered_info) + + from sglang.srt import server_args as server_args_module + from sglang.srt.distributed.parallel_state import ParallelismContext + + original_global_server_args = server_args_module._global_server_args + + try: + # In a Mock ParallelismContext, instantiate the model for this rank. + # Use a separate GPU (test_gpu_id) to avoid memory conflicts with the running server. + server_args_module._global_server_args = server_args + with ParallelismContext(parallelism_config): + from sglang.srt.configs.device_config import DeviceConfig + from sglang.srt.configs.load_config import LoadConfig + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.model_loader import get_model + + model_config = ModelConfig.from_server_args(server_args) + load_config = LoadConfig(load_format="dummy") + device_config = DeviceConfig(device="cuda", gpu_id=test_gpu_id) + + torch.cuda.set_device(test_gpu_id) + model = get_model( + model_config=model_config, + load_config=load_config, + device_config=device_config, + ) + model_params = {} + for name, param in model.named_parameters(): + model_params[name] = param.numel() * param.element_size() + + # Verify all server parameters exist in model with same size + mismatches = [] + missing = [] + for param_name, (ptr, numel, elem_size) in server_weights_info.items(): + expected_size = numel * elem_size + if param_name not in model_params: + missing.append(param_name) + elif model_params[param_name] != expected_size: + mismatches.append( + f"{param_name}: model={model_params[param_name]}, server={expected_size}" + ) + + assert not missing, f"Rank {rank}: Missing parameters: {missing}" + assert not mismatches, f"Rank {rank}: Size mismatches: {mismatches}" + del model + torch.cuda.empty_cache() + + finally: + server_args_module._global_server_args = original_global_server_args + + +TEST_CONFIGS: List[Tuple[str, str, int, List[str], int]] = [ + # Basic TP=2 test (CI only) + ( + "tp2_small", + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + 2, + [], + 2, + ), + # EP=2: MoE experts split across 2 groups, moe_tp=1 per group + ( + "mla_ep2", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 2, + ["--ep-size", "2"], + 2, + ), + ( + "mla_dp2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2"], + 4, + ), + ( + "mla_dp2_ep2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2", "--ep-size", "2"], + 4, + ), + ( + "mla_dp2_ep4_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "2", "--ep-size", "4"], + 4, + ), + ( + "mla_dp4_ep2_tp4", + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + 4, + ["--enable-dp-attention", "--dp", "4", "--ep-size", "2"], + 4, + ), +] + + +def get_test_configs(): + if is_in_ci(): + return [TEST_CONFIGS[0]] + else: + return TEST_CONFIGS + + +def _get_test_params(): + """Generate pytest parameters based on test configs.""" + configs = get_test_configs() + params = [] + ids = [] + for ( + test_id, + model_name, + tp_size, + extra_args, + min_gpus, + ) in configs: + params.append( + pytest.param( + (model_name, tp_size, extra_args, min_gpus), + id=test_id, + ) + ) + return params + + +class TestParallelismContextIntegration: + """ + Test that ParallelismContext can instantiate models with the same + parameter names and sizes as the sglang server engine. + """ + + @pytest.mark.parametrize("config", _get_test_params()) + def test_model_instantiation_matches_server(self, config): + """ + Test that a model instantiated with ParallelismContext has the same + parameter names and sizes as the model in the sglang server. + + This test: + 1. Starts a server with specified parallelism config + 2. Gets transfer_engine_info for all ranks (contains param names and sizes) + 3. Gets parallelism_config and server_info + 4. Uses ParallelismContext to instantiate a model for each rank + 5. Compares the parameter names and sizes + """ + model_name, tp_size, extra_args, min_gpus = config + url = DEFAULT_URL_FOR_TEST + + # Need min_gpus for server + 1 extra GPU for test model instantiation + required_gpus = min_gpus + 1 + if torch.cuda.device_count() < required_gpus: + pytest.skip( + f"Need at least {required_gpus} GPUs (server={min_gpus} + test=1), have {torch.cuda.device_count()}" + ) + test_gpu_id = min_gpus # e.g., if server uses 0-1, test uses 2 + + # Build server args + other_args = [ + "--tp-size", + str(tp_size), + "--remote-instance-weight-loader-start-seed-via-transfer-engine", + "--trust-remote-code", + ] + other_args.extend(extra_args) + + process = None + try: + process = popen_launch_server( + model_name, + url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + server_info = get_server_info(url) + + for rank in range(tp_size): + verify_model_params_match_for_rank(url, rank, server_info, test_gpu_id) + + finally: + if process is not None: + terminate_process(process) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/registered/lora/test_lora_qwen3.py b/test/registered/lora/test_lora_qwen3.py index 115597fcfe3f..f88babe3facf 100644 --- a/test/registered/lora/test_lora_qwen3.py +++ b/test/registered/lora/test_lora_qwen3.py @@ -15,7 +15,7 @@ import multiprocessing as mp import unittest -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.ci.ci_register import register_amd_ci from sglang.test.lora_utils import ( LORA_MODELS_QWEN3, run_lora_multiple_batch_on_model_cases, @@ -27,7 +27,6 @@ suite="stage-b-test-1-gpu-small-amd", disabled="see https://github.com/sgl-project/sglang/issues/13107", ) -register_cuda_ci(est_time=97, suite="nightly-1-gpu", nightly=True) class TestLoRAQwen3(CustomTestCase): diff --git a/test/registered/openai_server/basic/test_protocol.py b/test/registered/openai_server/basic/test_protocol.py index bdbdff6a1cea..e1bad744e4b5 100644 --- a/test/registered/openai_server/basic/test_protocol.py +++ b/test/registered/openai_server/basic/test_protocol.py @@ -337,6 +337,74 @@ def test_hidden_states_included_when_not_none(self): self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3]) +class TestCompletionTokenIds(unittest.TestCase): + """Cover the token-in-token-out (TITO) response-side additions: + + * `return_completion_token_ids` on `ChatCompletionRequest` (flag default False). + * `completion_token_ids` on `ChatCompletionResponseChoice` (Optional[List[int]]). + * `_serialize` drops the field when None, keeps it when populated. + """ + + def test_request_accepts_return_completion_token_ids_flag(self): + req = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + return_completion_token_ids=True, + ) + self.assertTrue(req.return_completion_token_ids) + + def test_request_flag_defaults_false(self): + req = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + self.assertFalse(req.return_completion_token_ids) + + def test_response_choice_accepts_completion_token_ids(self): + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content="ok"), + finish_reason="stop", + completion_token_ids=[10, 20, 30], + ) + self.assertEqual(choice.completion_token_ids, [10, 20, 30]) + + def test_completion_token_ids_dropped_when_none(self): + """_serialize pops completion_token_ids when None, mirroring the + existing behavior for hidden_states and prompt_token_ids.""" + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content="ok"), + finish_reason="stop", + completion_token_ids=None, + ) + response = ChatCompletionResponse( + id="test-id", + model="test-model", + choices=[choice], + usage=UsageInfo(prompt_tokens=3, completion_tokens=1, total_tokens=4), + ) + data = response.model_dump() + self.assertNotIn("completion_token_ids", data["choices"][0]) + + def test_completion_token_ids_kept_when_set(self): + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content="ok"), + finish_reason="stop", + completion_token_ids=[10, 20, 30], + ) + response = ChatCompletionResponse( + id="test-id", + model="test-model", + choices=[choice], + usage=UsageInfo(prompt_tokens=3, completion_tokens=3, total_tokens=6), + ) + data = response.model_dump() + self.assertIn("completion_token_ids", data["choices"][0]) + self.assertEqual(data["choices"][0]["completion_token_ids"], [10, 20, 30]) + + class TestValidationEdgeCases(unittest.TestCase): """Test edge cases and validation scenarios""" diff --git a/test/registered/perf/test_dpsk_r1_fp4_4gpu_perf.py b/test/registered/perf/test_dpsk_v3_fp4_4gpu_perf.py similarity index 83% rename from test/registered/perf/test_dpsk_r1_fp4_4gpu_perf.py rename to test/registered/perf/test_dpsk_v3_fp4_4gpu_perf.py index b03c34337d26..e2e9ddd5eee5 100644 --- a/test/registered/perf/test_dpsk_r1_fp4_4gpu_perf.py +++ b/test/registered/perf/test_dpsk_v3_fp4_4gpu_perf.py @@ -9,11 +9,11 @@ # Runs on B200 via nightly-4-gpu-b200 suite register_cuda_ci(est_time=2000, suite="nightly-4-gpu-b200", nightly=True) -DEEPSEEK_R1_FP4_MODEL_PATH = "nvidia/DeepSeek-R1-0528-NVFP4-v2" +FULL_DEEPSEEK_V3_FP4_MODEL_PATH = "nvidia/DeepSeek-V3-0324-FP4" class TestDeepseekR1FP4Unified(unittest.TestCase): - """Unified test class for DeepSeek-R1-0528-NVFP4-v2 performance and accuracy. + """Unified test class for DeepSeek-V3-0324-FP4 performance and accuracy. Two variants: - basic: Standard TP=4 @@ -44,28 +44,29 @@ def test_deepseek_r1_fp4_all_variants(self): variants = [ # Variant: "basic" - Standard TP=4 ModelLaunchSettings( - DEEPSEEK_R1_FP4_MODEL_PATH, + FULL_DEEPSEEK_V3_FP4_MODEL_PATH, tp_size=4, extra_args=base_args, variant="TP4", ), # Variant: "mtp" - TP=4 + EAGLE speculative decoding ModelLaunchSettings( - DEEPSEEK_R1_FP4_MODEL_PATH, + FULL_DEEPSEEK_V3_FP4_MODEL_PATH, tp_size=4, extra_args=base_args + mtp_args, variant="TP4+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, ), ] run_combined_tests( models=variants, - test_name="DeepSeek-R1-0528-NVFP4-v2 Unified", + test_name="DeepSeek-V3-0324-FP4 Unified", accuracy_params=AccuracyTestParams( dataset="gsm8k", baseline_accuracy=0.935 ), performance_params=PerformanceTestParams( - profile_dir="performance_profiles_deepseek_r1_fp4", + profile_dir="performance_profiles_deepseek_v3_fp4", ), ) diff --git a/test/srt/models/test_params_mapping.py b/test/srt/models/test_params_mapping.py new file mode 100644 index 000000000000..e064266cab6a --- /dev/null +++ b/test/srt/models/test_params_mapping.py @@ -0,0 +1,292 @@ +"""Unit tests for ParameterMapper.""" + +from types import SimpleNamespace + +import pytest + +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.model_loader.parameter_mapper import ParameterMapper + +_DEEPSEEK_N_ROUTED = 4 +_DEEPSEEK_N_LOCAL = _DEEPSEEK_N_ROUTED + 1 # +1 fused shared expert +_QWEN3MOE_N = 4 +_GLM4LITE_N_ROUTED = 4 +_GLM4LITE_N_LOCAL = _GLM4LITE_N_ROUTED + 1 # +1 fused shared expert + + +def _make_model(**kwargs): + """Create a stub model object for ParameterMapper.from_model().""" + return SimpleNamespace(**kwargs) + + +def _deepseek_mutate(name): + if "mlp.shared_experts" in name: + return name.replace("mlp.shared_experts", f"mlp.experts.{_DEEPSEEK_N_ROUTED}") + return name + + +def _deepseek_scale_remap(name): + for s in ["k_scale", "v_scale"]: + if s in name: + return name.replace(f"{s[0]}_proj", "attn_mqa") + return name + + +def _glm4lite_mutate(name): + if "mlp.shared_experts" in name: + return name.replace("mlp.shared_experts", f"mlp.experts.{_GLM4LITE_N_ROUTED}") + return name + + +_LLAMA_SCALE_PATTERNS = [ + (".activation_scale", ".activation_scale", ".input_scale"), + (".weight_scale_inv", ".weight_scale_inv", ".weight_scale"), +] + + +def _llama_scale_remap(name): + for suffix, pattern, replacement in _LLAMA_SCALE_PATTERNS: + if name.endswith(suffix) and pattern in name: + return name.replace(pattern, replacement) + return name + + +@pytest.fixture +def qwen_mapper(): + """Qwen2/Qwen3 (dense): QKV fusion, gate/up fusion, no experts.""" + return ParameterMapper.from_model( + _make_model( + stacked_params_mapping=[ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ], + ) + ) + + +@pytest.fixture +def llama_mapper(): + """Llama/GLM4 (dense): dot-prefixed stacked params, custom scale remap.""" + return ParameterMapper.from_model( + _make_model( + stacked_params_mapping=[ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ], + custom_scale_remap=_llama_scale_remap, + ) + ) + + +@pytest.fixture +def qwen3moe_mapper(): + """Qwen3-MoE: QKV fusion + experts, no shared expert fusion.""" + return ParameterMapper.from_model( + _make_model( + stacked_params_mapping=[ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ], + expert_params_mapping=FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=_QWEN3MOE_N, + ), + ) + ) + + +@pytest.fixture +def deepseek_mapper(): + """DeepSeek V2/V3: MLA A-proj fusion, shared expert fusion, custom scale remap.""" + return ParameterMapper.from_model( + _make_model( + stacked_params_mapping=[ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj_with_mqa", "q_a_proj", 0), + ("fused_qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", 1), + ], + expert_params_mapping=FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=_DEEPSEEK_N_LOCAL, + ), + mutate_weight_preload=_deepseek_mutate, + custom_scale_remap=_deepseek_scale_remap, + ) + ) + + +@pytest.fixture +def glm4lite_mapper(): + """GLM4-MoE-Lite (GLM-4.7): QKV fusion, MLA A-proj fusion, shared expert fusion, custom scale remap.""" + return ParameterMapper.from_model( + _make_model( + stacked_params_mapping=[ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj_with_mqa", "q_a_proj", 0), + ("fused_qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", 1), + ], + expert_params_mapping=FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=_GLM4LITE_N_LOCAL, + ), + mutate_weight_preload=_glm4lite_mutate, + custom_scale_remap=_deepseek_scale_remap, + ) + ) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def to_expect(name, shard=None, n=1, expert=None, n_exp=None): + """Shorthand for expected MappingResult fields.""" + return (name, shard, n, expert, n_exp) + + +def _assert(mapper, ckpt, expected): + r = mapper.map(ckpt) + name, shard, n, expert, n_exp = expected + assert ( + r.sglang_name, + r.shard_id, + r.num_shards, + r.expert_id, + r.num_local_experts, + ) == (name, shard, n, expert, n_exp), f"map({ckpt!r}) = {r}" + + +# ── Tests ──────────────────────────────────────────────────────────────────── + +# fmt: off +_QWEN_CASES = [ + # QKV fusion (Qwen2, Qwen3, GLM4-MoE) + ("layers.0.attn.q_proj.weight", to_expect("layers.0.attn.qkv_proj.weight", "q", 3)), + ("layers.0.attn.k_proj.weight", to_expect("layers.0.attn.qkv_proj.weight", "k", 3)), + ("layers.0.attn.v_proj.weight", to_expect("layers.0.attn.qkv_proj.weight", "v", 3)), + # Gate/Up fusion + ("layers.0.mlp.gate_proj.weight", to_expect("layers.0.mlp.gate_up_proj.weight", 0, 2)), + ("layers.0.mlp.up_proj.weight", to_expect("layers.0.mlp.gate_up_proj.weight", 1, 2)), + # Pass-through + ("layers.0.mlp.down_proj.weight", to_expect("layers.0.mlp.down_proj.weight")), + ("embed_tokens.weight", to_expect("embed_tokens.weight")), + # Standard scale remap (no custom_scale_remap) + ("model.layers.0.self_attn.k_scale", to_expect("model.layers.0.self_attn.attn.k_scale")), + ("model.layers.0.self_attn.v_scale", to_expect("model.layers.0.self_attn.attn.v_scale")), +] + +_LLAMA_CASES = [ + # Dot-prefixed QKV fusion (Llama, GLM4) + ("model.layers.0.self_attn.q_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "q", 3)), + ("model.layers.0.self_attn.k_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "k", 3)), + # Dot-prefixed gate/up + ("model.layers.0.mlp.gate_proj.weight", to_expect("model.layers.0.mlp.gate_up_proj.weight", 0, 2)), + # Llama-specific scale remap + stacked (scales follow their weights) + ("model.layers.0.mlp.gate_proj.activation_scale", to_expect("model.layers.0.mlp.gate_up_proj.input_scale", 0, 2)), + ("model.layers.0.mlp.gate_proj.weight_scale_inv", to_expect("model.layers.0.mlp.gate_up_proj.weight_scale", 0, 2)), + # Pass-through + ("model.layers.0.mlp.down_proj.weight", to_expect("model.layers.0.mlp.down_proj.weight")), +] + +_QWEN3MOE_CASES = [ + # QKV fusion + ("model.layers.0.self_attn.q_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "q", 3)), + # Expert mapping (no shared expert fusion) + ("model.layers.0.mlp.experts.0.gate_proj.weight", to_expect("model.layers.0.mlp.experts.w13_weight", "w1", 2, 0, _QWEN3MOE_N)), + ("model.layers.0.mlp.experts.3.down_proj.weight", to_expect("model.layers.0.mlp.experts.w2_weight", "w2", 1, 3, _QWEN3MOE_N)), + # shared_experts falls through to stacked mapping (no mutate_weight_preload) + ("model.layers.0.mlp.shared_experts.gate_proj.weight", to_expect("model.layers.0.mlp.shared_experts.gate_up_proj.weight", 0, 2)), +] + +_DEEPSEEK_CASES = [ + # MLA A-proj fusion + ("model.layers.0.self_attn.q_a_proj.weight", to_expect("model.layers.0.self_attn.fused_qkv_a_proj_with_mqa.weight", 0, 2)), + ("model.layers.0.self_attn.kv_a_proj_with_mqa.weight", to_expect("model.layers.0.self_attn.fused_qkv_a_proj_with_mqa.weight", 1, 2)), + # Shared expert fusion via mutate_weight_preload + ("model.layers.0.mlp.shared_experts.gate_proj.weight", to_expect("model.layers.0.mlp.experts.w13_weight", "w1", 2, _DEEPSEEK_N_ROUTED, _DEEPSEEK_N_LOCAL)), + ("model.layers.0.mlp.shared_experts.down_proj.weight", to_expect("model.layers.0.mlp.experts.w2_weight", "w2", 1, _DEEPSEEK_N_ROUTED, _DEEPSEEK_N_LOCAL)), + # Custom scale remap (k_proj/v_proj -> attn_mqa, NOT double-remapped) + ("model.layers.0.self_attn.k_proj.k_scale", to_expect("model.layers.0.self_attn.attn_mqa.k_scale")), + ("model.layers.0.self_attn.v_proj.v_scale", to_expect("model.layers.0.self_attn.attn_mqa.v_scale")), + # kv_b_proj pass-through (decomposed in post_load_weights) + ("model.layers.0.self_attn.kv_b_proj.weight", to_expect("model.layers.0.self_attn.kv_b_proj.weight")), +] + +_GLM4LITE_CASES = [ + # QKV fusion (GLM-4.7 uses standard QKV unlike DeepSeek which uses MLA-only) + ("model.layers.0.self_attn.q_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "q", 3)), + ("model.layers.0.self_attn.k_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "k", 3)), + ("model.layers.0.self_attn.v_proj.weight", to_expect("model.layers.0.self_attn.qkv_proj.weight", "v", 3)), + # MLA A-proj fusion (GLM-4.7 also uses MLA with q_lora_rank) + ("model.layers.0.self_attn.q_a_proj.weight", to_expect("model.layers.0.self_attn.fused_qkv_a_proj_with_mqa.weight", 0, 2)), + ("model.layers.0.self_attn.kv_a_proj_with_mqa.weight", to_expect("model.layers.0.self_attn.fused_qkv_a_proj_with_mqa.weight", 1, 2)), + # Gate/Up fusion (non-expert layers) + ("model.layers.0.mlp.gate_proj.weight", to_expect("model.layers.0.mlp.gate_up_proj.weight", 0, 2)), + ("model.layers.0.mlp.up_proj.weight", to_expect("model.layers.0.mlp.gate_up_proj.weight", 1, 2)), + # Expert mapping + ("model.layers.0.mlp.experts.0.gate_proj.weight", to_expect("model.layers.0.mlp.experts.w13_weight", "w1", 2, 0, _GLM4LITE_N_LOCAL)), + ("model.layers.0.mlp.experts.0.up_proj.weight", to_expect("model.layers.0.mlp.experts.w13_weight", "w3", 2, 0, _GLM4LITE_N_LOCAL)), + ("model.layers.0.mlp.experts.3.down_proj.weight", to_expect("model.layers.0.mlp.experts.w2_weight", "w2", 1, 3, _GLM4LITE_N_LOCAL)), + # Shared expert fusion via mutate_weight_preload + ("model.layers.0.mlp.shared_experts.gate_proj.weight", to_expect("model.layers.0.mlp.experts.w13_weight", "w1", 2, _GLM4LITE_N_ROUTED, _GLM4LITE_N_LOCAL)), + ("model.layers.0.mlp.shared_experts.down_proj.weight", to_expect("model.layers.0.mlp.experts.w2_weight", "w2", 1, _GLM4LITE_N_ROUTED, _GLM4LITE_N_LOCAL)), + # Custom scale remap (same as DeepSeek: k_proj/v_proj -> attn_mqa) + ("model.layers.0.self_attn.k_proj.k_scale", to_expect("model.layers.0.self_attn.attn_mqa.k_scale")), + ("model.layers.0.self_attn.v_proj.v_scale", to_expect("model.layers.0.self_attn.attn_mqa.v_scale")), + # Pass-through + ("model.layers.0.mlp.down_proj.weight", to_expect("model.layers.0.mlp.down_proj.weight")), + ("model.layers.0.self_attn.kv_b_proj.weight", to_expect("model.layers.0.self_attn.kv_b_proj.weight")), +] +# fmt: on + + +@pytest.mark.parametrize("ckpt,expected", _QWEN_CASES, ids=[c[0] for c in _QWEN_CASES]) +def test_qwen(qwen_mapper, ckpt, expected): + _assert(qwen_mapper, ckpt, expected) + + +@pytest.mark.parametrize( + "ckpt,expected", _LLAMA_CASES, ids=[c[0] for c in _LLAMA_CASES] +) +def test_llama(llama_mapper, ckpt, expected): + _assert(llama_mapper, ckpt, expected) + + +@pytest.mark.parametrize( + "ckpt,expected", _QWEN3MOE_CASES, ids=[c[0] for c in _QWEN3MOE_CASES] +) +def test_qwen3moe(qwen3moe_mapper, ckpt, expected): + _assert(qwen3moe_mapper, ckpt, expected) + + +@pytest.mark.parametrize( + "ckpt,expected", _DEEPSEEK_CASES, ids=[c[0] for c in _DEEPSEEK_CASES] +) +def test_deepseek(deepseek_mapper, ckpt, expected): + _assert(deepseek_mapper, ckpt, expected) + + +@pytest.mark.parametrize( + "ckpt,expected", _GLM4LITE_CASES, ids=[c[0] for c in _GLM4LITE_CASES] +) +def test_glm4lite(glm4lite_mapper, ckpt, expected): + _assert(glm4lite_mapper, ckpt, expected)