Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import concurrent
import copy
import functools
import inspect
from collections import defaultdict
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union

Expand Down Expand Up @@ -35,7 +34,13 @@
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.transformer.utils import set_model_to_sequence_parallel
from megatron.core.utils import get_asyncio_loop, get_model_config, get_pg_size, unwrap_model
from megatron.core.utils import (
accepts_parameter,
get_asyncio_loop,
get_model_config,
get_pg_size,
unwrap_model,
)

try:
import transformer_engine as te # pylint: disable=unused-import
Expand Down Expand Up @@ -209,12 +214,7 @@ def detokenize(
while tokens and tokens[-1] == tokenizer.eod:
tokens = tokens[:-1]

sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
detok_accepts_skip = any(
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig_params
)
if detok_accepts_skip:
if accepts_parameter(tokenizer.detokenize, "skip_special_tokens"):
return tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
else:
return tokenizer.detokenize(tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,6 @@ def _get_field(obj, key, default=None):
return getattr(obj, key, default)


_TRANSFER_TOOL_NAME = "transfer_to_human_agents"
_TRANSFER_HOLD_MESSAGE = "YOU ARE BEING TRANSFERRED TO A HUMAN AGENT. PLEASE HOLD ON."
_RESERVATION_UPDATE_TOOLS = {
"update_reservation_flights",
"update_reservation_passengers",
"update_reservation_baggages",
}
_RESERVATION_DESTRUCTIVE_TOOLS = {"cancel_reservation", "book_reservation"}


def _try_parse_jsonish(value):
if not isinstance(value, str):
return value
Expand Down Expand Up @@ -199,48 +189,21 @@ def _normalize_tool_calls(tool_calls, tools=None):
"function": {"name": str(fn_name), "arguments": fn_args},
}
)
return _apply_tool_call_guardrails(normalized)
return normalized


def _apply_tool_call_guardrails(tool_calls):
"""Apply conservative post-parse guardrails to tool call lists.
def _maybe_filter_parallel_tool_calls(tool_calls, parallel_tool_calls):
"""Filter to first tool call only when parallel_tool_calls is False.

If update-style reservation tools are already present in the same response,
suppress cancel+book style calls to avoid destructive replanning patterns.
Matches vLLM's maybe_filter_parallel_tool_calls behavior.
"""
if not isinstance(tool_calls, list):
if parallel_tool_calls:
return tool_calls

call_names = {
_get_field(_get_field(call, "function", {}), "name")
for call in tool_calls
if isinstance(call, dict)
}
if call_names & _RESERVATION_UPDATE_TOOLS:
return [
call
for call in tool_calls
if _get_field(_get_field(call, "function", {}), "name")
not in _RESERVATION_DESTRUCTIVE_TOOLS
]
if tool_calls:
return tool_calls[:1]
return tool_calls


def _normalize_assistant_content(message_text, tool_calls):
"""Normalize assistant content for policy-sensitive tool transitions."""
if not isinstance(message_text, str):
message_text = "" if message_text is None else str(message_text)

tool_names = {
_get_field(_get_field(call, "function", {}), "name")
for call in (tool_calls or [])
if isinstance(call, dict)
}
if _TRANSFER_TOOL_NAME in tool_names:
return _TRANSFER_HOLD_MESSAGE
return message_text


def _coerce_arguments_mapping(arguments):
"""Coerce function.arguments to a mapping for HF/Jinja chat templates.

Expand Down Expand Up @@ -446,7 +409,9 @@ async def chat_completions():

req = await request.get_json()
tools = req.get("tools", None)
tools_requested = bool(tools)
tool_choice = req.get("tool_choice", None)
parallel_tool_calls = req.get("parallel_tool_calls", True)
Comment thread
santhnm2 marked this conversation as resolved.
tools_requested = bool(tools) and tool_choice != "none"
messages = req.get("messages")
chat_template_kwargs = req.get("chat_template_kwargs", {})
if not isinstance(chat_template_kwargs, dict):
Expand Down Expand Up @@ -699,10 +664,22 @@ async def chat_completions():
)

normalized_tool_calls = metadata.get("tool_calls", [])
message = {
"role": "assistant",
"content": _normalize_assistant_content(message_text, normalized_tool_calls),
}

# Apply parallel_tool_calls filtering (matches vLLM behavior)
normalized_tool_calls = _maybe_filter_parallel_tool_calls(
normalized_tool_calls, parallel_tool_calls
)

# Determine content based on tool_choice (matches vLLM behavior):
# - Named tool choice or "required": content is empty string
# - Otherwise: content is the parsed message text
is_named_tool_choice = isinstance(tool_choice, dict) and "function" in tool_choice
if normalized_tool_calls and (is_named_tool_choice or tool_choice == "required"):
content = ""
else:
content = message_text if message_text is not None else ""

message = {"role": "assistant", "content": content}
if normalized_tool_calls:
message["tool_calls"] = normalized_tool_calls
if "reasoning" in metadata:
Expand All @@ -714,12 +691,19 @@ async def chat_completions():
message["generation_log_probs"] = result.get("generated_log_probs", [])
return_log_probs = sampling_params.return_log_probs

finish_reason = "tool_calls" if metadata.get("tool_calls", []) else "stop"
# Determine finish_reason following vLLM conventions:
# - "tool_calls" for auto or required tool choice when tools are called
# - "stop" for named tool choice (even when tools are called)
# - "length" when max tokens is reached
if (
len(result["generated_tokens"])
>= result["sampling_params"]["num_tokens_to_generate"]
):
finish_reason = "length"
elif normalized_tool_calls and not is_named_tool_choice:
finish_reason = "tool_calls"
else:
finish_reason = "stop"

choice_data = {
"index": request_idx,
Expand Down Expand Up @@ -759,7 +743,7 @@ async def chat_completions():

prompt_token_count = max(prompt_tokens_counts) if prompt_tokens_counts else 0
response = {
"id": str(uuid.uuid4()),
"id": f"chatcmpl-{uuid.uuid4().hex}",
"created": int(time.time()),
"model": "EMPTY",
"object": "chat.completion",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import inspect

from megatron.core import mpu
from megatron.core.inference.communication_utils import broadcast_float_list
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_server.tokenization import tokenize_prompts
from megatron.core.utils import accepts_parameter


def run_mcore_engine(
Expand Down Expand Up @@ -60,18 +59,11 @@ def run_mcore_engine(
for p, l in zip(context_tokens_tensor, context_length_tensor):
tokenized_prompts.append(p[:l].cpu().numpy().tolist())

# detect if detokenize supports skip_special_tokens or **kwargs
sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
accepts_skip = any(
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig_params
)

# Detokenize prompts into strings to pass through the engine
detokenized_prompts = [
(
tokenizer.detokenize(p, skip_special_tokens=True)
if accepts_skip
if accepts_parameter(tokenizer.detokenize, "skip_special_tokens")
else tokenizer.detokenize(p)
)
for p in tokenized_prompts
Expand Down
36 changes: 20 additions & 16 deletions megatron/core/tokenizers/text/parsers/qwen3_coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import re
import uuid
from types import SimpleNamespace
from typing import Any

from megatron.core.tokenizers.text.parsers.base_parser import BaseParser
Expand Down Expand Up @@ -48,21 +47,20 @@ def _get_arguments_config(
if tools is None:
return {}
for config in tools:
config = SimpleNamespace(**config) # Convert to SimpleNamespace for ease of access
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
if not isinstance(config, dict):
continue
fn = config.get("function", {})
if not isinstance(fn, dict):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
if config.get("type") != "function" or fn.get("name") != func_name:
continue
params = fn.get("parameters", {})
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
return {}

Expand All @@ -87,6 +85,9 @@ def _convert_param_value(

if isinstance(param_config[param_name], dict) and "type" in param_config[param_name]:
param_type = str(param_config[param_name]["type"]).strip().lower()
elif isinstance(param_config[param_name], dict) and "anyOf" in param_config[param_name]:
# anyOf has no top-level "type"; treat as object to trigger json.loads.
param_type = "object"
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
Expand Down Expand Up @@ -173,7 +174,9 @@ def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
end_index = function_call_str.find(">")
if end_index == -1:
return None
function_name = function_call_str[:end_index]
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
Expand Down Expand Up @@ -236,6 +239,7 @@ def extract_tool_calls(
self._parse_xml_function_call(function_call_str, tools)
for function_call_str in function_calls
]
tool_calls = [tc for tc in tool_calls if tc is not None]

# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
Expand Down
11 changes: 9 additions & 2 deletions megatron/core/tokenizers/text/text_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from megatron.core.tokenizers.base_tokenizer import MegatronTokenizerBase
from megatron.core.tokenizers.text.libraries.abstract_tokenizer import MegatronTokenizerTextAbstract
from megatron.core.utils import accepts_parameter

TOKENIZER_MAPPING_LIBRARIES = OrderedDict(
[
Expand Down Expand Up @@ -75,17 +76,23 @@ def tokenize(self, text: str) -> List[int]:

return self._tokenizer.text_to_ids(text)

def detokenize(self, ids: List[int]) -> str:
def detokenize(self, ids: List[int], skip_special_tokens: Optional[bool] = None) -> str:
"""
Text detokenization.

Args:
ids (list): text to be tokenized.
ids (list): token IDs to be detokenized.
skip_special_tokens (bool): Whether to strip special tokens
(e.g. <|im_end|>) from the output. Defaults to True.

Returns:
text: detokenized text.
"""

if skip_special_tokens is not None and accepts_parameter(
self._tokenizer.ids_to_text, "remove_special_tokens"
):
return self._tokenizer.ids_to_text(ids, remove_special_tokens=skip_special_tokens)
return self._tokenizer.ids_to_text(ids)

def apply_chat_template(
Expand Down
6 changes: 6 additions & 0 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,12 @@ def is_flashinfer_min_version(version, check_equality=True):
return flashinver_version > PkgVersion(version)


def accepts_parameter(func: Callable, name: str) -> bool:
"""Check if a callable accepts a parameter with the given name or **kwargs."""
params = inspect.signature(func).parameters.values()
return any(p.name == name or p.kind == inspect.Parameter.VAR_KEYWORD for p in params)


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
Expand Down
Loading
Loading