From e41ac5829706baff5ee6d3839b95f083c36c19f5 Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Tue, 3 Jun 2025 14:48:03 +0200 Subject: [PATCH 1/6] minor: fix structured output generation --- scratchpad/constrained/base_backend.py | 84 ++++--- .../constrained/triton_ops/bitmask_ops.py | 141 +++++++++++ scratchpad/constrained/xgrammar_backend.py | 222 +++++++++++++----- scratchpad/model_executor/forward_info.py | 2 +- scratchpad/scheduler/schedule_batch.py | 1 + scratchpad/scheduler/scheduler.py | 90 +++++-- scratchpad/utils/utils.py | 7 + tools/benchmark/bench_perf.py | 5 +- tools/benchmark/common.py | 100 +++++++- 9 files changed, 529 insertions(+), 123 deletions(-) create mode 100644 scratchpad/constrained/triton_ops/bitmask_ops.py diff --git a/scratchpad/constrained/base_backend.py b/scratchpad/constrained/base_backend.py index c7d6874..8ac1257 100644 --- a/scratchpad/constrained/base_backend.py +++ b/scratchpad/constrained/base_backend.py @@ -16,46 +16,66 @@ class BaseGrammarObject: pass +INVALID_GRAMMAR_OBJ: BaseGrammarObject = BaseGrammarObject() + + class BaseGrammarBackend: def __init__(self): self.executor = ThreadPoolExecutor() - self.cache = {} - self.cache_lock = Lock() - - def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject: - with self.cache_lock: - if key in self.cache: - cache_hit = True - entry = self.cache[key] - else: - cache_hit = False - entry = CacheEntry(None, Event()) - self.cache[key] = entry - - if cache_hit: - entry.event.wait() + self.cache: Dict[Tuple[str, str], CacheEntry] = {} + + def _not_supported(self, key_type: str, key_string: str) -> None: + logger.warning(f"Skip unsupported {key_type=}, {key_string=}") + + def dispatch_fallback( + self, key_type: str, key_string: str + ) -> Optional[BaseGrammarObject]: + """ + This function should not be reached in any case. + """ + raise ValueError(f"Invalid key_type: {key_type}={key_string}") + + def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("json", key_string) + + def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("regex", key_string) + + def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("ebnf", key_string) + + def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("structural_tag", key_string) + + def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: + key_type, key_string = key + if key_type == "json": + return self.dispatch_json(key_string) + elif key_type == "regex": + return self.dispatch_regex(key_string) + elif key_type == "ebnf": + return self.dispatch_ebnf(key_string) + elif key_type == "structural_tag": + return self.dispatch_structural_tag(key_string) + elif key_type == "structural_pattern": + return self.dispatch_structural_pattern(key_string) else: - entry.value = self.init_value_impl(key) - entry.event.set() - return entry.value.copy() if entry.value else None - - def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject: - raise NotImplementedError() + return self.dispatch_fallback(key_type, key_string) - def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: - with self.cache_lock: - entry = self.cache.get(key) - if not entry or not entry.event.is_set(): - return None - val = self.cache[key].value - return val.copy() if val else None + def get_cached_or_future_value( + self, key: Tuple[str, str] + ) -> Optional[BaseGrammarObject]: + value = self.cache.get(key) + if value: + return value.copy(), True + value = self.executor.submit(self._init_value_dispatch, key) + return value, False - def get_future_value(self, key: Tuple[str, str]) -> Future: - return self.executor.submit(self.init_value, key) + def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject): + self.cache[key] = value def reset(self): - with self.cache_lock: - self.cache.clear() + self.cache.clear() def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): diff --git a/scratchpad/constrained/triton_ops/bitmask_ops.py b/scratchpad/constrained/triton_ops/bitmask_ops.py new file mode 100644 index 0000000..2ad8c89 --- /dev/null +++ b/scratchpad/constrained/triton_ops/bitmask_ops.py @@ -0,0 +1,141 @@ +# Adapt from +# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py + +from typing import List, Optional, Union + +import torch +import triton +import triton.language as tl + +from scratchpad.utils import get_device_core_count + + +@triton.jit +def apply_token_bitmask_inplace_kernel( + logits_ptr, + bitmask_ptr, + indices_ptr, + num_rows, + vocab_size, + logits_strides, + bitmask_strides, + NUM_SMS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor, + where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask, + the masked logits will be set to -inf. + + Parameters + ---------- + logits_ptr : tl.tensor + Pointer to the logits tensor to apply the bitmask to. + + bitmask_ptr : tl.tensor + Pointer to the bitmask tensor to apply. + + indices_ptr : Optional[tl.tensor] + Optional pointer to indices tensor specifying which rows to apply the mask to. + + num_rows : int + Number of rows to process. If indices_ptr is provided, this is the number of unique indices. + + vocab_size : int + Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the + same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary. + + logits_strides : int + Stride between rows in the logits tensor. + + bitmask_strides : int + Stride between rows in the bitmask tensor. + + NUM_SMS : int + Number of streaming multiprocessors to use. + + BLOCK_SIZE : int + Size of processing blocks. + """ + + pid = tl.program_id(0) + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): + row_id = work_id // num_blocks + block_offset = (work_id % num_blocks) * BLOCK_SIZE + batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) + offsets = block_offset + tl.arange(0, BLOCK_SIZE) + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) + vocab_mask = offsets < vocab_size + packed_bitmask_mask = bitmask_offsets < bitmask_strides + packed_bitmask = tl.load( + bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, + packed_bitmask_mask, + ) + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + tl.store( + logits_ptr + batch_id * logits_strides + offsets, + -float("inf"), + vocab_mask & bitmask, + ) + + +def apply_token_bitmask_inplace_triton( + logits: torch.Tensor, + bitmask: torch.Tensor, + indices: Optional[Union[List[int], torch.Tensor]] = None, +): + NUM_SMS = get_device_core_count() + BLOCK_SIZE = 4096 + BITS_PER_BLOCK = 32 + + # Check input dtype + assert bitmask.dtype == torch.int32, "bitmask must be of type int32" + + # Check input tensor shapes. + logits_shape = logits.shape + bitmask_shape = bitmask.shape + if logits.ndim == 1: + logits_shape = (1, logits_shape[0]) + if bitmask.ndim == 1: + bitmask_shape = (1, bitmask_shape[0]) + + required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK + assert required_bitmask_width >= bitmask_shape[1], ( + f"Bitmask width too large: allow at most {required_bitmask_width} int32s for " + f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}" + ) + + vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK) + + num_rows = None + if isinstance(indices, list) or isinstance(indices, torch.Tensor): + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) + num_rows = indices.shape[0] + else: + assert ( + logits_shape[0] == bitmask_shape[0] + ), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}" + num_rows = logits_shape[0] + + if NUM_SMS > 0: + grid = (NUM_SMS,) + else: + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + grid = (num_rows * num_blocks,) + NUM_SMS = triton.next_power_of_2(grid[0]) + + apply_token_bitmask_inplace_kernel[grid]( + logits, + bitmask, + indices, + num_rows, + vocab_size, + logits_shape[1], + bitmask_shape[1], + NUM_SMS, + BLOCK_SIZE, + num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()), + num_stages=3, + ) diff --git a/scratchpad/constrained/xgrammar_backend.py b/scratchpad/constrained/xgrammar_backend.py index b4ddcc7..4317999 100644 --- a/scratchpad/constrained/xgrammar_backend.py +++ b/scratchpad/constrained/xgrammar_backend.py @@ -1,17 +1,42 @@ -from typing import List, Tuple +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constrained decoding with xgrammar backend.""" + +import json +import logging +from typing import List, Optional, Tuple, Union + import torch from xgrammar import ( CompiledGrammar, - Grammar, GrammarCompiler, GrammarMatcher, + StructuralTagItem, TokenizerInfo, allocate_token_bitmask, - apply_token_bitmask_inplace, ) -from .base_backend import BaseGrammarObject, BaseGrammarBackend -from scratchpad.utils import logger +from .base_backend import ( + INVALID_GRAMMAR_OBJ, + BaseGrammarBackend, + BaseGrammarObject, +) +from .triton_ops.bitmask_ops import ( + apply_token_bitmask_inplace_triton, +) + +logger = logging.getLogger(__name__) MAX_ROLLBACK_TOKENS = 200 @@ -19,17 +44,76 @@ class XGrammarGrammar(BaseGrammarObject): def __init__( - self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar + self, + matcher: GrammarMatcher, + vocab_size: int, + ctx: CompiledGrammar, + override_stop_tokens: Optional[Union[List[int], int]], + key_string: Optional[str] = None, # TODO (sk): for debugging, remove later ) -> None: self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.override_stop_tokens = override_stop_tokens self.finished = False + self.accepted_tokens = [] + self.key_string = key_string def accept_token(self, token: int): - assert self.matcher.accept_token(token) + if not self.is_terminated(): + accepted = self.matcher.accept_token(token) + if not accepted: + # log for debugging + raise ValueError( + f"Tokens not accepted: {token}\n" + f"Accepted tokens: {self.accepted_tokens}\n" + f"Key string: {self.key_string}" + ) + else: + self.accepted_tokens.append(token) + + def rollback(self, k: int): + self.matcher.rollback(k) + self.accepted_tokens = self.accepted_tokens[:-k] + + def is_terminated(self): + return self.matcher.is_terminated() + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return allocate_token_bitmask(batch_size, vocab_size) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(vocab_mask, idx) + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + + def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + if logits.device.type == "cuda": + apply_token_bitmask_inplace_triton(logits, vocab_mask) + elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: + self.apply_vocab_mask_cpu(logits, vocab_mask) + else: + raise RuntimeError(f"Unsupported device: {logits.device.type}") + + def copy(self): + matcher = GrammarMatcher( + self.ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, + self.vocab_size, + self.ctx, + self.override_stop_tokens, + self.key_string, + ) - def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: s = self.matcher.find_jump_forward_string() if s: return [], s @@ -56,25 +140,8 @@ def jump_and_retokenize( for i in range(k, len(new_output_ids)): assert self.matcher.accept_token(new_output_ids[i]) - def allocate_vocab_mask( - self, vocab_size: int, batch_size: int, device - ) -> torch.Tensor: - return allocate_token_bitmask(batch_size, vocab_size) - - def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: - self.matcher.fill_next_token_bitmask(vocab_mask, idx) - - @staticmethod - def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: - return vocab_mask.to(device, non_blocking=True) - - @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - apply_token_bitmask_inplace(logits, vocab_mask) - - def copy(self): - matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) - return XGrammarGrammar(matcher, self.vocab_size, self.ctx) + def __repr__(self): + return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})" class XGrammarGrammarBackend(BaseGrammarBackend): @@ -85,46 +152,73 @@ def __init__( ): super().__init__() - tokenizer_info = TokenizerInfo.from_huggingface( - tokenizer, vocab_size=vocab_size - ) + if True: + tokenizer_info = TokenizerInfo.from_huggingface( + tokenizer, vocab_size=vocab_size + ) + override_stop_tokens = None + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size + self.override_stop_tokens = override_stop_tokens - def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: - - key_type, key_string = key - if key_type == "json": - try: - if key_string == "$$ANY$$": - ctx = self.grammar_compiler.compile_builtin_json_grammar() - else: - ctx = self.grammar_compiler.compile_json_schema(schema=key_string) - except RuntimeError as e: - logger.warning( - f"Skip invalid json_schema: json_schema={key_string}, {e=}" - ) - return None - elif key_type == "ebnf": - try: - ctx = self.grammar_compiler.compile_grammar(key_string) - except RuntimeError as e: - logger.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") - return None - elif key_type == "regex": - try: - ctx = self.grammar_compiler.compile_grammar( - Grammar.from_regex(key_string) - ) - except RuntimeError as e: - logger.warning(f"Skip invalid regex: regex={key_string}, {e=}") - return None - else: - raise ValueError(f"Invalid key_type: {key_type}") + def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar: + matcher = GrammarMatcher( + ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string + ) - matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) - return XGrammarGrammar(matcher, self.vocab_size, ctx) + def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + if key_string == "$$ANY$$": + # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root) + ctx = self.grammar_compiler.compile_builtin_json_grammar() + else: + ctx = self.grammar_compiler.compile_json_schema(schema=key_string) + + except (RuntimeError, json.decoder.JSONDecodeError) as e: + logging.error(f"Hit invalid json_schema: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string) + + def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_grammar(key_string) + except RuntimeError as e: + logging.error(f"Hit invalid ebnf: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string) + + def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_regex(key_string) + except RuntimeError as e: + logging.error(f"Hit invalid regex: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string) + + def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + structural_tag = json.loads(key_string) + tags = [ + StructuralTagItem( + begin=structure["begin"], + schema=json.dumps(structure["schema"]), + end=structure["end"], + ) + for structure in structural_tag["structures"] + ] + ctx = self.grammar_compiler.compile_structural_tag( + tags, structural_tag["triggers"] + ) + except (RuntimeError, json.decoder.JSONDecodeError) as e: + logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string) def reset(self): - if self.grammar_compiler: - self.grammar_compiler.clear_cache() + self.grammar_compiler.clear_cache() diff --git a/scratchpad/model_executor/forward_info.py b/scratchpad/model_executor/forward_info.py index 9dd1aae..145d682 100644 --- a/scratchpad/model_executor/forward_info.py +++ b/scratchpad/model_executor/forward_info.py @@ -60,7 +60,7 @@ def is_cuda_graph(self): return self.is_decode() or self.is_target_verify() or self.is_idle() def is_dummy_first(self): - return False + return self == 7 def is_decode_or_idle(self): return self.is_decode() or self.is_idle() diff --git a/scratchpad/scheduler/schedule_batch.py b/scratchpad/scheduler/schedule_batch.py index d903387..777b4ec 100644 --- a/scratchpad/scheduler/schedule_batch.py +++ b/scratchpad/scheduler/schedule_batch.py @@ -426,6 +426,7 @@ def __init__( # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None + self.grammar_wait_ct = 0 # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 diff --git a/scratchpad/scheduler/scheduler.py b/scratchpad/scheduler/scheduler.py index 993234c..667f0af 100644 --- a/scratchpad/scheduler/scheduler.py +++ b/scratchpad/scheduler/scheduler.py @@ -16,7 +16,10 @@ from typing import List, Optional, TYPE_CHECKING, Union from types import SimpleNamespace from scratchpad.config.model_config import ModelConfig -from scratchpad.constrained.base_backend import create_grammar_backend +from scratchpad.constrained.base_backend import ( + create_grammar_backend, + INVALID_GRAMMAR_OBJ, +) from scratchpad.nn.layers.logits_processor import LogitsProcessorOutput from scratchpad.scheduler.schedule_batch import ( FINISH_ABORT, @@ -74,6 +77,7 @@ # Crash on warning if we are running CI tests crash_on_warning = os.getenv("SP_IS_IN_CI", "false") == "true" TEST_RETRACT = os.getenv("SP_TEST_RETRACT", "false") == "true" +GRAMMAR_TIMEOUT = float(os.environ.get("SP_GRAMMAR_TIMEOUT", 300)) @dataclass @@ -383,9 +387,10 @@ def event_loop_normal(self): batch = self.get_next_batch_to_run() self.cur_batch = batch - if batch: - result = self.run_batch(batch) + result: EmbeddingBatchResult | GenerationBatchResult = self.run_batch( + batch + ) self.process_batch_result(batch, result) else: # When the server is idle, do self-check and re-init some states @@ -408,9 +413,10 @@ def event_loop_overlap(self): if batch: batch.launch_done = threading.Event() - result = self.run_batch(batch) + result: EmbeddingBatchResult | GenerationBatchResult = self.run_batch( + batch + ) self.result_queue.append((batch.copy(), result)) - if self.last_batch is None: # Create a dummy first batch to start the pipeline for overlap schedule. # It is now used for triggering the sampling_info_done event. @@ -631,10 +637,15 @@ def handle_generate_request( elif req.sampling_params.structural_tag: key = ("structural_tag", req.sampling_params.structural_tag) - req.grammar = self.grammar_backend.get_cached_value(key) - if not req.grammar: - req.grammar = self.grammar_backend.get_future_value(key) + value, cache_hit = self.grammar_backend.get_cached_or_future_value(key) + req.grammar = value + if not cache_hit: + req.grammar_key = key add_to_grammar_queue = True + else: + if value is INVALID_GRAMMAR_OBJ: + error_msg = f"invalid grammar object with cache hit for key={key}" + req.set_finish_with_abort(error_msg) if add_to_grammar_queue: self.grammar_queue.append(req) @@ -1038,6 +1049,7 @@ def run_batch( extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=bid, ) + else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) @@ -1065,11 +1077,12 @@ def process_batch_result( batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() + else: + logger.error(f"Unexpected forward mode: {batch.forward_mode}") if self.return_health_check_ct: # Return some signal for the health check. @@ -1612,26 +1625,67 @@ def get_idle_batch(self): def move_ready_grammar_requests(self): """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" + num_ready_reqs = 0 + num_timeout_reqs = 0 for req in self.grammar_queue: try: - req.grammar = req.grammar.result(timeout=0.05) + if req.finished(): # It is aborted by AbortReq + num_ready_reqs += 1 + continue + req.grammar = req.grammar.result(timeout=0.03) + self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) + if req.grammar is INVALID_GRAMMAR_OBJ: + req.set_finish_with_abort( + f"Invalid grammar request: {req.grammar_key=}" + ) num_ready_reqs += 1 except futures._base.TimeoutError: + req.grammar_wait_ct += 1 + # NOTE(lianmin): this timeout is the waiting time of the above line. It is + # not the waiting time from it enters the grammar queue. + if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03: + num_timeout_reqs = 1 break - if self.tp_size > 1: + if self.server_args.enable_dp_attention: + tp_size = self.attn_tp_size + tp_group = self.attn_tp_cpu_group + else: + tp_size = self.tp_size + tp_group = self.tp_cpu_group + + if tp_size > 1: # Sync across TP ranks to make sure they have the same number of ready requests - tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) + tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32) torch.distributed.all_reduce( - tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group + tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group ) - num_ready_reqs_max = tensor.item() - for i in range(num_ready_reqs, num_ready_reqs_max): - self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() - num_ready_reqs = num_ready_reqs_max + num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist() - self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) + for i in range(num_ready_reqs, num_ready_reqs_max): + req = self.grammar_queue[i] + if req.finished(): # It is aborted by AbortReq + continue + req.grammar = req.grammar.result() + self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) + if req.grammar is INVALID_GRAMMAR_OBJ: + req.set_finish_with_abort( + f"Invalid grammar request: {req.grammar_key=}" + ) + else: + num_ready_reqs_max = num_ready_reqs + num_timeout_reqs_max = num_timeout_reqs + + for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max): + req = self.grammar_queue[i] + req.grammar.cancel() + error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}" + req.set_finish_with_abort(error_msg) + self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ) + num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max + + self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] def flush_cache_wrapped(self, recv_req: FlushCacheReq): diff --git a/scratchpad/utils/utils.py b/scratchpad/utils/utils.py index 2f7dd11..3655ecc 100644 --- a/scratchpad/utils/utils.py +++ b/scratchpad/utils/utils.py @@ -475,3 +475,10 @@ def flatten_nested_list(nested_list): def get_compiler_backend() -> str: return "inductor" + + +def get_device_core_count(device_id: int = 0) -> int: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return torch.cuda.get_device_properties(device_id).multi_processor_count + + return 0 diff --git a/tools/benchmark/bench_perf.py b/tools/benchmark/bench_perf.py index 52d0974..8e39937 100644 --- a/tools/benchmark/bench_perf.py +++ b/tools/benchmark/bench_perf.py @@ -13,6 +13,7 @@ RequestFuncOutput, RequestFuncInput, calculate_metrics, + async_request_openai_chat_completions, ) from tools.benchmark.report import print_benchmark, write_benchmark @@ -63,7 +64,6 @@ async def run_benchmark( goodput_config_dict: Dict[str, float], max_concurrency: Optional[int] = None, ): - # system_info = await async_request_sp_sysinfo(args.endpoint) pbar = tqdm(total=len(input_requests)) tasks: List[asyncio.Task] = [] semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None @@ -133,7 +133,8 @@ def benchmark(args): ) for req in bench_requests: req.model = args.model - request_func = async_request_openai_completions + # request_func = async_request_openai_completions + request_func = async_request_openai_chat_completions gootput_config_dict = check_goodput_args(args) # check if server is ready server_ready = False diff --git a/tools/benchmark/common.py b/tools/benchmark/common.py index f2d31b8..f35e6de 100644 --- a/tools/benchmark/common.py +++ b/tools/benchmark/common.py @@ -9,10 +9,10 @@ import traceback import warnings from dataclasses import dataclass, field -from typing import List, Optional, Union, Tuple, AsyncGenerator, Dict +from typing import List, Optional, Tuple, AsyncGenerator, Dict from tqdm.asyncio import tqdm from datasets import load_dataset, DatasetDict -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -101,7 +101,7 @@ async def async_request_openai_completions( "include_usage": True, }, } - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = {"Authorization": f"Bearer {os.environ.get('RC_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" @@ -165,6 +165,95 @@ async def async_request_openai_completions( return output +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + ("chat/completions", "profile") + ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def get_request( input_requests: List[RequestFuncInput], request_rate: float, @@ -205,11 +294,10 @@ def construct_dataset( response = conversations[2 * i + 1]["content"] req = RequestFuncInput( prompt=prompt, - api_url=endpoint + "/v1/completions", + api_url=endpoint + "/v1/chat/completions", prompt_len=len(tokenizer(prompt)["input_ids"]), output_len=len(tokenizer(response)["input_ids"]), model="", - ignore_eos=True, ) requests.append(req) return requests @@ -289,7 +377,7 @@ def calculate_metrics( stacklevel=2, ) - metrics = BenchmarkMetrics( + metrics: BenchmarkMetrics = BenchmarkMetrics( completed=completed, total_input=total_input, total_output=sum(actual_output_lens), From 4ae9dfa093e7bb3555f137f3490aaad06806eb83 Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Wed, 4 Jun 2025 18:14:24 +0200 Subject: [PATCH 2/6] minor fix on tool parser --- scratchpad/server/args.py | 20 +++++++++++++------ .../server/openai_api/function_call_parser.py | 1 + 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/scratchpad/server/args.py b/scratchpad/server/args.py index 6142c2f..1c899d0 100644 --- a/scratchpad/server/args.py +++ b/scratchpad/server/args.py @@ -181,15 +181,23 @@ def translate_auto(self): ) if self.cuda_graph_bs is None: self.cuda_graph_max_bs = 160 + if self.tool_call_parser == "auto": - self.tool_call_parser = ( - "llama3" if "llama" in self.served_model_name.lower() else None - ) + if "llama" in self.served_model_name.lower(): + self.tool_call_parser = "llama3" + elif "qwen3" in self.served_model_name.lower(): + self.tool_call_parser = "qwen3" + else: + self.tool_call_parser = None logger.info(f"Using tool_call_parser: {self.tool_call_parser}") + if self.reasoning_parser == "auto": - self.reasoning_parser = ( - "qwen3" if "qwen3" in self.served_model_name.lower() else None - ) + if "qwen3" in self.served_model_name.lower(): + self.reasoning_parser = "qwen3" + elif "llama" in self.served_model_name.lower(): + self.reasoning_parser = "llama3" + else: + self.reasoning_parser = None logger.info(f"Using reasoning_parser: {self.reasoning_parser}") def update(self, args): diff --git a/scratchpad/server/openai_api/function_call_parser.py b/scratchpad/server/openai_api/function_call_parser.py index 027e664..e7ff4a5 100644 --- a/scratchpad/server/openai_api/function_call_parser.py +++ b/scratchpad/server/openai_api/function_call_parser.py @@ -764,6 +764,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "llama3": Llama32Detector, "qwen25": Qwen25Detector, + "qwen3": Qwen25Detector, "mistral": MistralDetector, "deepseekv3": DeepSeekV3Detector, "pythonic": PythonicDetector, From 26e6a6f311e422f8d7d16edc278ddf568841f315 Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Tue, 10 Jun 2025 00:40:44 +0200 Subject: [PATCH 3/6] fix default parser for llama & logprobs --- scratchpad/managers/detokenizer.py | 6 +- scratchpad/managers/tokenizer.py | 155 ++++++++++---- scratchpad/nn/layers/logits_processor.py | 16 +- scratchpad/scheduler/scheduler.py | 257 ++++++++++++++++------- scratchpad/server/args.py | 18 +- scratchpad/server/openai_api/handler.py | 8 +- scratchpad/utils/utils.py | 27 ++- tests/e2e/test_logprob.py | 18 ++ 8 files changed, 361 insertions(+), 144 deletions(-) create mode 100644 tests/e2e/test_logprob.py diff --git a/scratchpad/managers/detokenizer.py b/scratchpad/managers/detokenizer.py index 21b5630..3b58c89 100644 --- a/scratchpad/managers/detokenizer.py +++ b/scratchpad/managers/detokenizer.py @@ -50,10 +50,10 @@ def __init__( # Init inter-process communication context = zmq.Context(2) self.recv_from_scheduler = get_zmq_socket( - context, zmq.PULL, server_args.detokenizer_ipc_name + context, zmq.PULL, server_args.detokenizer_ipc_name, True ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, server_args.tokenizer_ipc_name + context, zmq.PUSH, server_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: @@ -228,6 +228,6 @@ def run_detokenizer_process( manager = DetokenizerManager(server_args) manager.event_loop() except Exception: - msg = get_exception_traceback() + msg: str = get_exception_traceback() logger.error(msg) kill_parent_process() diff --git a/scratchpad/managers/tokenizer.py b/scratchpad/managers/tokenizer.py index c4cddf0..95db3b6 100644 --- a/scratchpad/managers/tokenizer.py +++ b/scratchpad/managers/tokenizer.py @@ -51,17 +51,35 @@ class ReqState: """Store the state a request.""" - out_list: List + out_list: List[Dict[Any, Any]] finished: bool event: asyncio.Event - obj: Any + obj: Union[GenerateReqInput, EmbeddingReqInput] # For metrics created_time: float - first_token_time: Optional[float] = None + finished_time: float = 0.0 + first_token_time: float = 0.0 + last_time: float = 0.0 + last_completion_tokens: int = 1 # For streaming output last_output_offset: int = 0 + # For incremental state update. + text: str = "" + output_ids: List[int] = dataclasses.field(default_factory=list) + input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) class TokenizerManager: @@ -77,10 +95,16 @@ def __init__( # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, server_args.tokenizer_ipc_name + context, + zmq.PULL, + server_args.tokenizer_ipc_name, + True, ) self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, server_args.scheduler_input_ipc_name + context, + zmq.PUSH, + server_args.scheduler_input_ipc_name, + True, ) # Read model args @@ -503,6 +527,7 @@ async def handle_loop(self): recv_obj: Union[ BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput ] = await self.recv_from_detokenizer.recv_pyobj() + if isinstance(recv_obj, UpdateWeightReqOutput): if self.server_args.dp_size == 1: self.model_update_result.set_result(recv_obj) @@ -512,6 +537,7 @@ async def handle_loop(self): if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): if self.server_args.dp_size == 1: self.mem_pool_size.set_result(recv_obj) @@ -525,6 +551,7 @@ async def handle_loop(self): assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) ), f"Unexpected obj received: {type(recv_obj)}" + for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: @@ -537,13 +564,14 @@ async def handle_loop(self): if getattr(state.obj, "return_logprob", False): self.convert_logprob_style( meta_info, + state, state.obj.top_logprobs_num, state.obj.token_ids_logprob, - state.obj.return_text_in_logprobs, + state.obj.return_text_in_logprobs + and not self.server_args.skip_tokenizer_init, recv_obj, i, ) - if not isinstance(recv_obj, BatchEmbeddingOut): meta_info.update( { @@ -551,7 +579,6 @@ async def handle_loop(self): "cached_tokens": recv_obj.cached_tokens[i], } ) - if isinstance(recv_obj, BatchStrOut): out_dict = { "text": recv_obj.output_strs[i], @@ -569,6 +596,7 @@ async def handle_loop(self): "embedding": recv_obj.embeddings[i], "meta_info": meta_info, } + state.out_list.append(out_dict) state.finished = recv_obj.finished_reasons[i] is not None state.event.set() @@ -576,73 +604,124 @@ async def handle_loop(self): def convert_logprob_style( self, meta_info: dict, + state: ReqState, top_logprobs_num: int, token_ids_logprob: List[int], return_text_in_logprobs: bool, recv_obj: BatchStrOut, recv_obj_index: int, - ): + ) -> None: + if len(recv_obj.input_token_logprobs_val) > 0: + state.input_token_logprobs_val.extend( + recv_obj.input_token_logprobs_val[recv_obj_index] + ) + state.input_token_logprobs_idx.extend( + recv_obj.input_token_logprobs_idx[recv_obj_index] + ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) + state.output_token_logprobs_idx.extend( + recv_obj.output_token_logprobs_idx[recv_obj_index] + ) meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.input_token_logprobs_val[recv_obj_index], - recv_obj.input_token_logprobs_idx[recv_obj_index], + state.input_token_logprobs_val, + state.input_token_logprobs_idx, return_text_in_logprobs, ) meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.output_token_logprobs_val[recv_obj_index], - recv_obj.output_token_logprobs_idx[recv_obj_index], + state.output_token_logprobs_val, + state.output_token_logprobs_idx, return_text_in_logprobs, ) if top_logprobs_num > 0: + if len(recv_obj.input_top_logprobs_val) > 0: + state.input_top_logprobs_val.extend( + recv_obj.input_top_logprobs_val[recv_obj_index] + ) + state.input_top_logprobs_idx.extend( + recv_obj.input_top_logprobs_idx[recv_obj_index] + ) + state.output_top_logprobs_val.extend( + recv_obj.output_top_logprobs_val[recv_obj_index] + ) + state.output_top_logprobs_idx.extend( + recv_obj.output_top_logprobs_idx[recv_obj_index] + ) meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.input_top_logprobs_val[recv_obj_index], - recv_obj.input_top_logprobs_idx[recv_obj_index], + state.input_top_logprobs_val, + state.input_top_logprobs_idx, return_text_in_logprobs, ) meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.output_top_logprobs_val[recv_obj_index], - recv_obj.output_top_logprobs_idx[recv_obj_index], + state.output_top_logprobs_val, + state.output_top_logprobs_idx, return_text_in_logprobs, ) if token_ids_logprob is not None: + if len(recv_obj.input_token_ids_logprobs_val) > 0: + state.input_token_ids_logprobs_val.extend( + recv_obj.input_token_ids_logprobs_val[recv_obj_index] + ) + state.input_token_ids_logprobs_idx.extend( + recv_obj.input_token_ids_logprobs_idx[recv_obj_index] + ) + state.output_token_ids_logprobs_val.extend( + recv_obj.output_token_ids_logprobs_val[recv_obj_index] + ) + state.output_token_ids_logprobs_idx.extend( + recv_obj.output_token_ids_logprobs_idx[recv_obj_index] + ) meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.input_token_ids_logprobs_val[recv_obj_index], - recv_obj.input_token_ids_logprobs_idx[recv_obj_index], + state.input_token_ids_logprobs_val, + state.input_token_ids_logprobs_idx, return_text_in_logprobs, ) meta_info[ "output_token_ids_logprobs" ] = self.detokenize_top_logprobs_tokens( - recv_obj.output_token_ids_logprobs_val[recv_obj_index], - recv_obj.output_token_ids_logprobs_idx[recv_obj_index], + state.output_token_ids_logprobs_val, + state.output_token_ids_logprobs_idx, return_text_in_logprobs, ) def detokenize_logprob_tokens( - self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, ): - # TODO(lianmin): This should run on DetokenizerManager if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - assert self.tokenizer is not None - token_ids = [tid for _, tid in token_logprobs] - token_texts = self.tokenizer.batch_decode(token_ids) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text in zip(token_logprobs, token_texts) - ] + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): # TODO: The current implementation only batches the detokenization for top-k tokens per single position. # We should batch all top-k tokens in all positions. - for i, token_top_logprobs in enumerate(top_logprobs): - if token_top_logprobs: - top_logprobs[i] = self.detokenize_logprob_tokens( - token_top_logprobs, decode_to_text + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) ) - return top_logprobs + else: + ret.append(None) + return ret class SignalHandler: diff --git a/scratchpad/nn/layers/logits_processor.py b/scratchpad/nn/layers/logits_processor.py index 149c23f..949e3ed 100644 --- a/scratchpad/nn/layers/logits_processor.py +++ b/scratchpad/nn/layers/logits_processor.py @@ -20,28 +20,32 @@ @dataclass class LogitsProcessorOutput: - ## Part 1: This part will be assigned in nn/layers/logits_processor.py::LogitsProcessor + ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # Used by speculative decoding (EAGLE) # The last hidden layers hidden_states: Optional[torch.Tensor] = None - ## Part 2: This part will be assigned in nn/layers/sampler.py::Sampler + ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler # The logprobs of the next tokens. shape: [#seq] next_token_logprobs: Optional[torch.Tensor] = None # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) + next_token_token_ids_logprobs_val: Optional[List] = None + next_token_token_ids_logprobs_idx: Optional[List] = None - ## Part 3: Prefill-only. This part will be assigned in nn/layers/logits_processor.py::LogitsProcessor - # The normlaized logprobs of prompts. shape: [#seq] - normalized_prompt_logprobs: torch.Tensor = None + ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logprobs of input tokens. shape: [#token] - input_token_logprobs: torch.Tensor = None + input_token_logprobs: Optional[torch.Tensor] = None # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] input_top_logprobs_val: List = None input_top_logprobs_idx: List = None + # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) + input_token_ids_logprobs_val: Optional[List] = None + input_token_ids_logprobs_idx: Optional[List] = None @dataclass diff --git a/scratchpad/scheduler/scheduler.py b/scratchpad/scheduler/scheduler.py index 667f0af..610d674 100644 --- a/scratchpad/scheduler/scheduler.py +++ b/scratchpad/scheduler/scheduler.py @@ -20,7 +20,6 @@ create_grammar_backend, INVALID_GRAMMAR_OBJ, ) -from scratchpad.nn.layers.logits_processor import LogitsProcessorOutput from scratchpad.scheduler.schedule_batch import ( FINISH_ABORT, BaseFinishReason, @@ -73,6 +72,7 @@ if TYPE_CHECKING: from scratchpad.server.metric_types import StatLoggerBase + from scratchpad.nn.layers.logits_processor import LogitsProcessorOutput # Crash on warning if we are running CI tests crash_on_warning = os.getenv("SP_IS_IN_CI", "false") == "true" @@ -82,7 +82,7 @@ @dataclass class GenerationBatchResult: - logits_output: LogitsProcessorOutput + logits_output: "LogitsProcessorOutput" next_token_ids: List[int] extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] @@ -131,18 +131,24 @@ def __init__( if self.tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, server_args.scheduler_input_ipc_name + context, zmq.PULL, server_args.scheduler_input_ipc_name, False ) if server_args.skip_tokenizer_init: # Directly send to the tokenizer/api self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, server_args.tokenizer_ipc_name + context, + zmq.PUSH, + server_args.tokenizer_ipc_name, + False, ) else: # Send to the detokenizer self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, server_args.detokenizer_ipc_name + context, + zmq.PUSH, + server_args.detokenizer_ipc_name, + False, ) else: self.recv_from_tokenizer = None @@ -595,7 +601,6 @@ def handle_generate_request( # Copy more attributes if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: - # By default, only return the logprobs for output tokens req.logprob_start_len = len(req.origin_input_ids) - 1 else: req.logprob_start_len = recv_req.logprob_start_len @@ -1250,6 +1255,171 @@ def process_batch_result_prefill( req.is_chunked -= 1 self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) + def add_input_logprob_return_values( + self: "Scheduler", + i: int, + req: Req, + output: "LogitsProcessorOutput", + logprob_pt: int, + num_input_logprobs: int, + last_prefill_chunk: bool, # If True, it means prefill is finished. + ): + """Incrementally add input logprobs to `req`. + + Args: + i: The request index in a batch. + req: The request. Input logprobs inside req are modified as a + consequence of the API + fill_ids: The prefill ids processed. + output: Logit processor output that's used to compute input logprobs + last_prefill_chunk: True if it is the last prefill (when chunked). + Some of input logprob operation should only happen at the last + prefill (e.g., computing input token logprobs). + """ + assert output.input_token_logprobs is not None + if req.input_token_logprobs is None: + req.input_token_logprobs = [] + if req.temp_input_top_logprobs_val is None: + req.temp_input_top_logprobs_val = [] + if req.temp_input_top_logprobs_idx is None: + req.temp_input_top_logprobs_idx = [] + if req.temp_input_token_ids_logprobs_val is None: + req.temp_input_token_ids_logprobs_val = [] + if req.temp_input_token_ids_logprobs_idx is None: + req.temp_input_token_ids_logprobs_idx = [] + + if req.input_token_logprobs_val is not None: + # The input logprob has been already computed. It only happens + # upon retract. + if req.top_logprobs_num > 0: + assert req.input_token_logprobs_val is not None + return + + # Important for the performance. + assert isinstance(output.input_token_logprobs, tuple) + input_token_logprobs: Tuple[int] = output.input_token_logprobs + input_token_logprobs = input_token_logprobs[ + logprob_pt : logprob_pt + num_input_logprobs + ] + req.input_token_logprobs.extend(input_token_logprobs) + + if req.top_logprobs_num > 0: + req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) + req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) + + if req.token_ids_logprob is not None: + req.temp_input_token_ids_logprobs_val.append( + output.input_token_ids_logprobs_val[i] + ) + req.temp_input_token_ids_logprobs_idx.append( + output.input_token_ids_logprobs_idx[i] + ) + + if last_prefill_chunk: + input_token_logprobs = req.input_token_logprobs + req.input_token_logprobs = None + assert req.input_token_logprobs_val is None + assert req.input_token_logprobs_idx is None + assert req.input_top_logprobs_val is None + assert req.input_top_logprobs_idx is None + + # Compute input_token_logprobs_val + # Always pad the first one with None. + req.input_token_logprobs_val = [None] + req.input_token_logprobs_val.extend(input_token_logprobs) + # The last input logprob is for sampling, so just pop it out. + req.input_token_logprobs_val.pop() + + # Compute input_token_logprobs_idx + input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] + # Clip the padded hash values from image tokens. + # Otherwise, it will lead to detokenization errors. + input_token_logprobs_idx = [ + x if x < self.model_config.vocab_size - 1 else 0 + for x in input_token_logprobs_idx + ] + req.input_token_logprobs_idx = input_token_logprobs_idx + + if req.top_logprobs_num > 0: + req.input_top_logprobs_val = [None] + req.input_top_logprobs_idx = [None] + assert len(req.temp_input_token_ids_logprobs_val) == len( + req.temp_input_token_ids_logprobs_idx + ) + for val, idx in zip( + req.temp_input_top_logprobs_val, + req.temp_input_top_logprobs_idx, + strict=True, + ): + req.input_top_logprobs_val.extend(val) + req.input_top_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_top_logprobs_val.pop() + req.input_top_logprobs_idx.pop() + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None + + if req.token_ids_logprob is not None: + req.input_token_ids_logprobs_val = [None] + req.input_token_ids_logprobs_idx = [None] + + for val, idx in zip( + req.temp_input_token_ids_logprobs_val, + req.temp_input_token_ids_logprobs_idx, + strict=True, + ): + req.input_token_ids_logprobs_val.extend(val) + req.input_token_ids_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_token_ids_logprobs_val.pop() + req.input_token_ids_logprobs_idx.pop() + req.temp_input_token_ids_logprobs_idx = None + req.temp_input_token_ids_logprobs_val = None + + if req.return_logprob: + relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len + assert len(req.input_token_logprobs_val) == relevant_tokens_len + assert len(req.input_token_logprobs_idx) == relevant_tokens_len + if req.top_logprobs_num > 0: + assert len(req.input_top_logprobs_val) == relevant_tokens_len + assert len(req.input_top_logprobs_idx) == relevant_tokens_len + if req.token_ids_logprob is not None: + assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len + assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len + + def add_logprob_return_values( + self: "Scheduler", + i: int, + req: Req, + pt: int, + next_token_ids: List[int], + num_input_logprobs: int, + output: "LogitsProcessorOutput", + ): + """Attach logprobs to the return values.""" + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) + + self.add_input_logprob_return_values( + i, req, output, pt, num_input_logprobs, last_prefill_chunk=True + ) + + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) + + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + output.next_token_token_ids_logprobs_idx[i] + ) + + return num_input_logprobs + def process_batch_result_decode( self, batch: ScheduleBatch, @@ -1349,79 +1519,6 @@ def process_batch_result_decode( ): self.log_decode_stats() - def add_logprob_return_values( - self, - i: int, - req: Req, - pt: int, - next_token_ids: List[int], - output: LogitsProcessorOutput, - ): - """Attach logprobs to the return values.""" - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) - - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len - - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - - if req.input_token_logprobs is None: - input_token_logprobs = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] - input_token_ids = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] - req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) - - if ( - req.logprob_start_len == 0 - ): # The first token does not have logprob, pad it. - req.input_token_logprobs = [ - (None, req.fill_ids[0]) - ] + req.input_token_logprobs - - if req.last_update_decode_tokens != 0: - # Some decode tokens are re-computed in an extend batch - req.output_token_logprobs.extend( - list( - zip( - output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 - ], - req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) - ], - ) - ) - ) - - if req.top_logprobs_num > 0: - if req.input_top_logprobs is None: - req.input_top_logprobs = output.input_top_logprobs[i] - if req.logprob_start_len == 0: - req.input_top_logprobs = [None] + req.input_top_logprobs - - if req.last_update_decode_tokens != 0: - req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens :] - ) - req.output_top_logprobs.append(output.output_top_logprobs[i]) - - return num_input_logprobs - def stream_output( self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None ): @@ -1782,6 +1879,6 @@ def run_scheduler_process( traceback.print_exc() logger.info(f"Scheduler process exited: {e}") except Exception: - msg = get_exception_traceback() + msg: str = get_exception_traceback() logger.error(msg) kill_parent_process() diff --git a/scratchpad/server/args.py b/scratchpad/server/args.py index 1c899d0..6afbd1e 100644 --- a/scratchpad/server/args.py +++ b/scratchpad/server/args.py @@ -157,15 +157,19 @@ def translate_auto(self): if self.sampling_backend is None: self.sampling_backend = "flashinfer" if self.random_seed is None: - self.random_seed = 0 # default seed + self.random_seed = 0 if self.scheduler_input_ipc_name == "auto": - self.scheduler_input_ipc_name = tempfile.NamedTemporaryFile( - delete=False - ).name + self.scheduler_input_ipc_name = ( + f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + ) if self.tokenizer_ipc_name == "auto": - self.tokenizer_ipc_name = tempfile.NamedTemporaryFile(delete=False).name + self.tokenizer_ipc_name = ( + f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + ) if self.detokenizer_ipc_name == "auto": - self.detokenizer_ipc_name = tempfile.NamedTemporaryFile(delete=False).name + self.detokenizer_ipc_name = ( + f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + ) try: self.json_model_override_args = json.loads(self.model_override_args) except Exception as e: @@ -194,8 +198,6 @@ def translate_auto(self): if self.reasoning_parser == "auto": if "qwen3" in self.served_model_name.lower(): self.reasoning_parser = "qwen3" - elif "llama" in self.served_model_name.lower(): - self.reasoning_parser = "llama3" else: self.reasoning_parser = None logger.info(f"Using reasoning_parser: {self.reasoning_parser}") diff --git a/scratchpad/server/openai_api/handler.py b/scratchpad/server/openai_api/handler.py index 0ba5922..4ebe262 100644 --- a/scratchpad/server/openai_api/handler.py +++ b/scratchpad/server/openai_api/handler.py @@ -994,7 +994,6 @@ def v1_chat_generate_response( reasoning_parser=None, ): choices = [] - for idx, ret_item in enumerate(ret): logprobs = False if isinstance(request, list) and request[idx].logprobs: @@ -1004,7 +1003,9 @@ def v1_chat_generate_response( if logprobs: logprobs = to_openai_style_logprobs( output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + output_top_logprobs=ret_item["meta_info"].get( + "output_top_logprobs", None + ), ) token_logprobs = [] for token_idx, (token, logprob) in enumerate( @@ -1032,8 +1033,7 @@ def v1_chat_generate_response( top_logprobs=top_logprobs, ) ) - - choice_logprobs = ChoiceLogprobs(content=token_logprobs) + choice_logprobs: ChoiceLogprobs = ChoiceLogprobs(content=token_logprobs) else: choice_logprobs = None diff --git a/scratchpad/utils/utils.py b/scratchpad/utils/utils.py index 3655ecc..ba4812c 100644 --- a/scratchpad/utils/utils.py +++ b/scratchpad/utils/utils.py @@ -389,7 +389,9 @@ def enable_show_time_cost(): show_time_cost = True -def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): +def get_zmq_socket( + context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool +): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -399,17 +401,32 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: buf_size = -1 socket = context.socket(socket_type) - if socket_type == zmq.PUSH: + if endpoint.find("[") != -1: + socket.setsockopt(zmq.IPV6, 1) + + def set_send_opt(): socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - socket.connect(f"ipc://{endpoint}") - elif socket_type == zmq.PULL: + + def set_recv_opt(): socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) - socket.bind(f"ipc://{endpoint}") + + if socket_type == zmq.PUSH: + set_send_opt() + elif socket_type == zmq.PULL: + set_recv_opt() + elif socket_type == zmq.DEALER: + set_send_opt() + set_recv_opt() else: raise ValueError(f"Unsupported socket type: {socket_type}") + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + return socket diff --git a/tests/e2e/test_logprob.py b/tests/e2e/test_logprob.py new file mode 100644 index 0000000..c1c1a84 --- /dev/null +++ b/tests/e2e/test_logprob.py @@ -0,0 +1,18 @@ +import os +import openai + +client = openai.Client(api_key="test", base_url="http://localhost:8081/v1") +res = client.chat.completions.create( + model="meta-llama/Llama-3.2-1B-Instruct", + messages=[ + { + "content": "Who is Pablo Picasso?", + "role": "user", + } + ], + stream=True, + max_tokens=12, + logprobs=True, +) +for chunk in res: + print(chunk) From abf2fbc19418ddc143a6158918ae6e5986641b5e Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Wed, 11 Jun 2025 17:21:54 +0200 Subject: [PATCH 4/6] increase default watchdog timeout (for jit ops) --- meta/requirements.txt | 2 +- scratchpad/server/args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/meta/requirements.txt b/meta/requirements.txt index eaa6d4c..7906a27 100644 --- a/meta/requirements.txt +++ b/meta/requirements.txt @@ -25,5 +25,5 @@ nvidia-cuda-nvrtc-cu12 cuda-python setproctitle soundfile -triton==3.0.0 partial_json_parser +httpx diff --git a/scratchpad/server/args.py b/scratchpad/server/args.py index 6afbd1e..feab5f0 100644 --- a/scratchpad/server/args.py +++ b/scratchpad/server/args.py @@ -27,7 +27,7 @@ class ServerArgs: schedule_policy: str = "lpm" random_seed: Optional[int] = None stream_interval: int = 1 - watchdog_timeout: float = 20 + watchdog_timeout: float = 120 decode_log_interval: int = 10 # memory and scheduling chunked_prefill_size: int = 8192 From ad11191d2545d59773496e7a7033f3ab34402a5b Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Fri, 27 Jun 2025 12:10:28 -0400 Subject: [PATCH 5/6] Update Dockerfile.x86_64-cuda --- meta/docker/Dockerfile.x86_64-cuda | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/meta/docker/Dockerfile.x86_64-cuda b/meta/docker/Dockerfile.x86_64-cuda index 24aa18c..3683fb7 100644 --- a/meta/docker/Dockerfile.x86_64-cuda +++ b/meta/docker/Dockerfile.x86_64-cuda @@ -10,7 +10,7 @@ RUN apt update && apt upgrade -y WORKDIR /scratchpad COPY . . -COPY --from=ghcr.io/xiaozheyao/sp-builder:v0.1.6-x86 /wheels /wheels +COPY --from=ghcr.io/xiaozheyao/sp-builder:v0.1.6-x86_64 /wheels /wheels RUN pip install --no-cache-dir /wheels/flashinfer_python-0.2.3-cp38-abi3-linux_x86_64.whl && \ pip install --no-cache-dir /wheels/triteia-0.1.0-cp310-cp310-linux_x86_64.whl From 4381a45239a7de0a0ac7140208bd668c58acb8d4 Mon Sep 17 00:00:00 2001 From: dhia680 Date: Tue, 8 Jul 2025 19:57:19 +0200 Subject: [PATCH 6/6] correct import --- scratchpad/nn/models/swissai/config.py | 34 +++++--------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/scratchpad/nn/models/swissai/config.py b/scratchpad/nn/models/swissai/config.py index 0fe5948..c28522c 100644 --- a/scratchpad/nn/models/swissai/config.py +++ b/scratchpad/nn/models/swissai/config.py @@ -1,6 +1,6 @@ # copied from https://github.com/swiss-ai/transformers/blob/1e0881a41d4fda838ece30f730130ddf10ba0913/src/transformers/models/swissai/configuration_swissai.py -from transformers import PretrainedConfig, AutoConfig - +from transformers import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation class SwissAIConfig(PretrainedConfig): r""" @@ -128,34 +128,12 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self._rope_scaling_validation() + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.rms_norm_eps = rms_norm_eps - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if ( - rope_scaling_factor is None - or not isinstance(rope_scaling_factor, float) - or rope_scaling_factor <= 1.0 - ): - raise ValueError( - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" - ) +__all__ = ["SwissAIConfig"]