Skip to content

Commit cb3bb41

Browse files
authored
Align chat completions endpoint with vLLM (#4063)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
1 parent 150e37a commit cb3bb41

File tree

7 files changed

+141
-87
lines changed

7 files changed

+141
-87
lines changed

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import concurrent
55
import copy
66
import functools
7-
import inspect
87
from collections import defaultdict
98
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
109

@@ -35,7 +34,13 @@
3534
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
3635
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
3736
from megatron.core.transformer.utils import set_model_to_sequence_parallel
38-
from megatron.core.utils import get_asyncio_loop, get_model_config, get_pg_size, unwrap_model
37+
from megatron.core.utils import (
38+
accepts_parameter,
39+
get_asyncio_loop,
40+
get_model_config,
41+
get_pg_size,
42+
unwrap_model,
43+
)
3944

4045
try:
4146
import transformer_engine as te # pylint: disable=unused-import
@@ -209,12 +214,7 @@ def detokenize(
209214
while tokens and tokens[-1] == tokenizer.eod:
210215
tokens = tokens[:-1]
211216

212-
sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
213-
detok_accepts_skip = any(
214-
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
215-
for p in sig_params
216-
)
217-
if detok_accepts_skip:
217+
if accepts_parameter(tokenizer.detokenize, "skip_special_tokens"):
218218
return tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
219219
else:
220220
return tokenizer.detokenize(tokens)

megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,6 @@ def _get_field(obj, key, default=None):
7676
return getattr(obj, key, default)
7777

7878

79-
_TRANSFER_TOOL_NAME = "transfer_to_human_agents"
80-
_TRANSFER_HOLD_MESSAGE = "YOU ARE BEING TRANSFERRED TO A HUMAN AGENT. PLEASE HOLD ON."
81-
_RESERVATION_UPDATE_TOOLS = {
82-
"update_reservation_flights",
83-
"update_reservation_passengers",
84-
"update_reservation_baggages",
85-
}
86-
_RESERVATION_DESTRUCTIVE_TOOLS = {"cancel_reservation", "book_reservation"}
87-
88-
8979
def _try_parse_jsonish(value):
9080
if not isinstance(value, str):
9181
return value
@@ -199,48 +189,21 @@ def _normalize_tool_calls(tool_calls, tools=None):
199189
"function": {"name": str(fn_name), "arguments": fn_args},
200190
}
201191
)
202-
return _apply_tool_call_guardrails(normalized)
192+
return normalized
203193

204194

205-
def _apply_tool_call_guardrails(tool_calls):
206-
"""Apply conservative post-parse guardrails to tool call lists.
195+
def _maybe_filter_parallel_tool_calls(tool_calls, parallel_tool_calls):
196+
"""Filter to first tool call only when parallel_tool_calls is False.
207197
208-
If update-style reservation tools are already present in the same response,
209-
suppress cancel+book style calls to avoid destructive replanning patterns.
198+
Matches vLLM's maybe_filter_parallel_tool_calls behavior.
210199
"""
211-
if not isinstance(tool_calls, list):
200+
if parallel_tool_calls:
212201
return tool_calls
213-
214-
call_names = {
215-
_get_field(_get_field(call, "function", {}), "name")
216-
for call in tool_calls
217-
if isinstance(call, dict)
218-
}
219-
if call_names & _RESERVATION_UPDATE_TOOLS:
220-
return [
221-
call
222-
for call in tool_calls
223-
if _get_field(_get_field(call, "function", {}), "name")
224-
not in _RESERVATION_DESTRUCTIVE_TOOLS
225-
]
202+
if tool_calls:
203+
return tool_calls[:1]
226204
return tool_calls
227205

228206

229-
def _normalize_assistant_content(message_text, tool_calls):
230-
"""Normalize assistant content for policy-sensitive tool transitions."""
231-
if not isinstance(message_text, str):
232-
message_text = "" if message_text is None else str(message_text)
233-
234-
tool_names = {
235-
_get_field(_get_field(call, "function", {}), "name")
236-
for call in (tool_calls or [])
237-
if isinstance(call, dict)
238-
}
239-
if _TRANSFER_TOOL_NAME in tool_names:
240-
return _TRANSFER_HOLD_MESSAGE
241-
return message_text
242-
243-
244207
def _coerce_arguments_mapping(arguments):
245208
"""Coerce function.arguments to a mapping for HF/Jinja chat templates.
246209
@@ -446,7 +409,9 @@ async def chat_completions():
446409

447410
req = await request.get_json()
448411
tools = req.get("tools", None)
449-
tools_requested = bool(tools)
412+
tool_choice = req.get("tool_choice", None)
413+
parallel_tool_calls = req.get("parallel_tool_calls", True)
414+
tools_requested = bool(tools) and tool_choice != "none"
450415
messages = req.get("messages")
451416
chat_template_kwargs = req.get("chat_template_kwargs", {})
452417
if not isinstance(chat_template_kwargs, dict):
@@ -699,10 +664,22 @@ async def chat_completions():
699664
)
700665

701666
normalized_tool_calls = metadata.get("tool_calls", [])
702-
message = {
703-
"role": "assistant",
704-
"content": _normalize_assistant_content(message_text, normalized_tool_calls),
705-
}
667+
668+
# Apply parallel_tool_calls filtering (matches vLLM behavior)
669+
normalized_tool_calls = _maybe_filter_parallel_tool_calls(
670+
normalized_tool_calls, parallel_tool_calls
671+
)
672+
673+
# Determine content based on tool_choice (matches vLLM behavior):
674+
# - Named tool choice or "required": content is empty string
675+
# - Otherwise: content is the parsed message text
676+
is_named_tool_choice = isinstance(tool_choice, dict) and "function" in tool_choice
677+
if normalized_tool_calls and (is_named_tool_choice or tool_choice == "required"):
678+
content = ""
679+
else:
680+
content = message_text if message_text is not None else ""
681+
682+
message = {"role": "assistant", "content": content}
706683
if normalized_tool_calls:
707684
message["tool_calls"] = normalized_tool_calls
708685
if "reasoning" in metadata:
@@ -714,12 +691,19 @@ async def chat_completions():
714691
message["generation_log_probs"] = result.get("generated_log_probs", [])
715692
return_log_probs = sampling_params.return_log_probs
716693

717-
finish_reason = "tool_calls" if metadata.get("tool_calls", []) else "stop"
694+
# Determine finish_reason following vLLM conventions:
695+
# - "tool_calls" for auto or required tool choice when tools are called
696+
# - "stop" for named tool choice (even when tools are called)
697+
# - "length" when max tokens is reached
718698
if (
719699
len(result["generated_tokens"])
720700
>= result["sampling_params"]["num_tokens_to_generate"]
721701
):
722702
finish_reason = "length"
703+
elif normalized_tool_calls and not is_named_tool_choice:
704+
finish_reason = "tool_calls"
705+
else:
706+
finish_reason = "stop"
723707

724708
choice_data = {
725709
"index": request_idx,
@@ -759,7 +743,7 @@ async def chat_completions():
759743

760744
prompt_token_count = max(prompt_tokens_counts) if prompt_tokens_counts else 0
761745
response = {
762-
"id": str(uuid.uuid4()),
746+
"id": f"chatcmpl-{uuid.uuid4().hex}",
763747
"created": int(time.time()),
764748
"model": "EMPTY",
765749
"object": "chat.completion",

megatron/core/inference/text_generation_server/run_mcore_engine.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

3-
import inspect
4-
53
from megatron.core import mpu
64
from megatron.core.inference.communication_utils import broadcast_float_list
75
from megatron.core.inference.inference_request import InferenceRequest
86
from megatron.core.inference.sampling_params import SamplingParams
97
from megatron.core.inference.text_generation_server.tokenization import tokenize_prompts
8+
from megatron.core.utils import accepts_parameter
109

1110

1211
def run_mcore_engine(
@@ -60,18 +59,11 @@ def run_mcore_engine(
6059
for p, l in zip(context_tokens_tensor, context_length_tensor):
6160
tokenized_prompts.append(p[:l].cpu().numpy().tolist())
6261

63-
# detect if detokenize supports skip_special_tokens or **kwargs
64-
sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
65-
accepts_skip = any(
66-
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
67-
for p in sig_params
68-
)
69-
7062
# Detokenize prompts into strings to pass through the engine
7163
detokenized_prompts = [
7264
(
7365
tokenizer.detokenize(p, skip_special_tokens=True)
74-
if accepts_skip
66+
if accepts_parameter(tokenizer.detokenize, "skip_special_tokens")
7567
else tokenizer.detokenize(p)
7668
)
7769
for p in tokenized_prompts

megatron/core/tokenizers/text/parsers/qwen3_coder_tool_parser.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import re
99
import uuid
10-
from types import SimpleNamespace
1110
from typing import Any
1211

1312
from megatron.core.tokenizers.text.parsers.base_parser import BaseParser
@@ -48,21 +47,20 @@ def _get_arguments_config(
4847
if tools is None:
4948
return {}
5049
for config in tools:
51-
config = SimpleNamespace(**config) # Convert to SimpleNamespace for ease of access
52-
if not hasattr(config, "type") or not (
53-
hasattr(config, "function") and hasattr(config.function, "name")
54-
):
50+
if not isinstance(config, dict):
51+
continue
52+
fn = config.get("function", {})
53+
if not isinstance(fn, dict):
5554
continue
56-
if config.type == "function" and config.function.name == func_name:
57-
if not hasattr(config.function, "parameters"):
58-
return {}
59-
params = config.function.parameters
60-
if isinstance(params, dict) and "properties" in params:
61-
return params["properties"]
62-
elif isinstance(params, dict):
63-
return params
64-
else:
65-
return {}
55+
if config.get("type") != "function" or fn.get("name") != func_name:
56+
continue
57+
params = fn.get("parameters", {})
58+
if isinstance(params, dict) and "properties" in params:
59+
return params["properties"]
60+
elif isinstance(params, dict):
61+
return params
62+
else:
63+
return {}
6664
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
6765
return {}
6866

@@ -87,6 +85,9 @@ def _convert_param_value(
8785

8886
if isinstance(param_config[param_name], dict) and "type" in param_config[param_name]:
8987
param_type = str(param_config[param_name]["type"]).strip().lower()
88+
elif isinstance(param_config[param_name], dict) and "anyOf" in param_config[param_name]:
89+
# anyOf has no top-level "type"; treat as object to trigger json.loads.
90+
param_type = "object"
9091
else:
9192
param_type = "string"
9293
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
@@ -173,7 +174,9 @@ def _parse_xml_function_call(
173174
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
174175
) -> ToolCall | None:
175176
# Extract function name
176-
end_index = function_call_str.index(">")
177+
end_index = function_call_str.find(">")
178+
if end_index == -1:
179+
return None
177180
function_name = function_call_str[:end_index]
178181
param_config = self._get_arguments_config(function_name, tools)
179182
parameters = function_call_str[end_index + 1 :]
@@ -236,6 +239,7 @@ def extract_tool_calls(
236239
self._parse_xml_function_call(function_call_str, tools)
237240
for function_call_str in function_calls
238241
]
242+
tool_calls = [tc for tc in tool_calls if tc is not None]
239243

240244
# Extract content before tool calls
241245
content_index = model_output.find(self.tool_call_start_token)

megatron/core/tokenizers/text/text_tokenizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from megatron.core.tokenizers.base_tokenizer import MegatronTokenizerBase
77
from megatron.core.tokenizers.text.libraries.abstract_tokenizer import MegatronTokenizerTextAbstract
8+
from megatron.core.utils import accepts_parameter
89

910
TOKENIZER_MAPPING_LIBRARIES = OrderedDict(
1011
[
@@ -75,17 +76,23 @@ def tokenize(self, text: str) -> List[int]:
7576

7677
return self._tokenizer.text_to_ids(text)
7778

78-
def detokenize(self, ids: List[int]) -> str:
79+
def detokenize(self, ids: List[int], skip_special_tokens: Optional[bool] = None) -> str:
7980
"""
8081
Text detokenization.
8182
8283
Args:
83-
ids (list): text to be tokenized.
84+
ids (list): token IDs to be detokenized.
85+
skip_special_tokens (bool): Whether to strip special tokens
86+
(e.g. <|im_end|>) from the output. Defaults to True.
8487
8588
Returns:
8689
text: detokenized text.
8790
"""
8891

92+
if skip_special_tokens is not None and accepts_parameter(
93+
self._tokenizer.ids_to_text, "remove_special_tokens"
94+
):
95+
return self._tokenizer.ids_to_text(ids, remove_special_tokens=skip_special_tokens)
8996
return self._tokenizer.ids_to_text(ids)
9097

9198
def apply_chat_template(

megatron/core/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ def is_flashinfer_min_version(version, check_equality=True):
491491
return flashinver_version > PkgVersion(version)
492492

493493

494+
def accepts_parameter(func: Callable, name: str) -> bool:
495+
"""Check if a callable accepts a parameter with the given name or **kwargs."""
496+
params = inspect.signature(func).parameters.values()
497+
return any(p.name == name or p.kind == inspect.Parameter.VAR_KEYWORD for p in params)
498+
499+
494500
def ensure_divisibility(numerator, denominator):
495501
"""Ensure that numerator is divisible by the denominator."""
496502
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

0 commit comments

Comments
 (0)