Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,16 @@ def create_py_executor(
logger.info(
"Tokenizer not provided; loading from checkpoint for guided decoding"
)
from tensorrt_llm.tokenizer import TransformersTokenizer
tokenizer = TransformersTokenizer.from_pretrained(
checkpoint_dir, trust_remote_code=llm_args.trust_remote_code)
if llm_args.custom_tokenizer:
from tensorrt_llm.tokenizer import load_custom_tokenizer
tokenizer = load_custom_tokenizer(
llm_args.custom_tokenizer,
checkpoint_dir,
trust_remote_code=llm_args.trust_remote_code)
else:
from tensorrt_llm.tokenizer import TransformersTokenizer
tokenizer = TransformersTokenizer.from_pretrained(
checkpoint_dir, trust_remote_code=llm_args.trust_remote_code)

guided_decoding_config = get_guided_decoding_config(
llm_args.guided_decoding_backend, tokenizer)
Expand Down
34 changes: 8 additions & 26 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,13 +2488,6 @@ def model_name(self) -> Union[str, Path]:
return self.model if isinstance(self.model, str) else None


# Short aliases for built-in custom tokenizer implementations.
TOKENIZER_ALIASES = {
'deepseek_v32': 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer',
'glm_moe_dsa': 'tensorrt_llm.tokenizer.glm_moe_dsa.GlmMoeDsaTokenizer',
}


class DwdpConfig(StrictBaseModel):
"""Configuration for Distributed Weight Data Parallelism (DWDP).

Expand Down Expand Up @@ -2895,26 +2888,15 @@ def validate_and_init_tokenizer(self):
"Please specify a tokenizer path or leave it as None to load from model path."
)

tokenizer_path = TOKENIZER_ALIASES.get(self.custom_tokenizer,
self.custom_tokenizer)
from tensorrt_llm.tokenizer import load_custom_tokenizer

# Dynamically import and use custom tokenizer
from importlib import import_module
try:
module_path, class_name = tokenizer_path.rsplit('.', 1)
module = import_module(module_path)
tokenizer_class = getattr(module, class_name)
# Use tokenizer path if specified, otherwise use model path
load_path = self.tokenizer if self.tokenizer else self.model
self.tokenizer = tokenizer_class.from_pretrained(
load_path,
trust_remote_code=self.trust_remote_code,
use_fast=self.tokenizer_mode != 'slow')
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(
f"Failed to load custom tokenizer '{self.custom_tokenizer}': {e}. "
"Expected format: 'module.path.ClassName' or a recognized alias."
) from e
# Use tokenizer path if specified, otherwise use model path
load_path = self.tokenizer if self.tokenizer else self.model
self.tokenizer = load_custom_tokenizer(
self.custom_tokenizer,
load_path,
trust_remote_code=self.trust_remote_code,
use_fast=self.tokenizer_mode != 'slow')
else:
self.tokenizer = tokenizer_factory(
self.tokenizer,
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from .tokenizer import (
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND,
TLLM_STREAM_INTERVAL_THRESHOLD,
TOKENIZER_ALIASES,
TokenizerBase,
TransformersTokenizer,
_llguidance_tokenizer_info,
_xgrammar_tokenizer_info,
load_custom_tokenizer,
load_hf_tokenizer,
tokenizer_factory,
)

__all__ = [
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND",
"TLLM_STREAM_INTERVAL_THRESHOLD",
"TOKENIZER_ALIASES",
"TokenizerBase",
"TransformersTokenizer",
"load_custom_tokenizer",
"load_hf_tokenizer",
"tokenizer_factory",
"_xgrammar_tokenizer_info",
"_llguidance_tokenizer_info",
"load_hf_tokenizer",
]
7 changes: 7 additions & 0 deletions tensorrt_llm/tokenizer/glm_moe_dsa/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,11 @@ def from_pretrained(
tokenizer_object=rust_tok,
**init_kwargs,
)

# Load chat template from chat_template.jinja
chat_template_path = path / "chat_template.jinja"
if chat_template_path.exists():
with open(chat_template_path, encoding="utf-8") as f:
hf_tokenizer.chat_template = f.read()

return cls(hf_tokenizer)
47 changes: 47 additions & 0 deletions tensorrt_llm/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import pickle # nosec B403
from pathlib import Path
Expand All @@ -9,6 +10,12 @@
from .._utils import nvtx_range_debug
from ..logger import logger

# Aliases for built-in custom tokenizers.
TOKENIZER_ALIASES = {
"deepseek_v32": "tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer",
"glm_moe_dsa": "tensorrt_llm.tokenizer.glm_moe_dsa.GlmMoeDsaTokenizer",
}

TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = os.environ.get(
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", "HF")
TLLM_STREAM_INTERVAL_THRESHOLD = int(
Expand Down Expand Up @@ -450,3 +457,43 @@ def load_hf_tokenizer(model_dir: str,
f"Failed to load hf tokenizer from {model_dir}, encounter error: {e}"
)
return None


def load_custom_tokenizer(
tokenizer_identifier: str,
model_dir: Union[str, Path],
trust_remote_code: bool = True,
use_fast: bool = True,
) -> TokenizerBase:
"""Load a custom tokenizer class by import path or alias.

Args:
tokenizer_identifier: Either a built-in alias (e.g., 'deepseek_v32')
or a fully-qualified import path (e.g.,
'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').
model_dir: The model directory to load the tokenizer from.
trust_remote_code: Whether to trust remote code.
use_fast: Whether to use the fast tokenizer.

Returns:
An instance of the custom tokenizer class.

Raises:
ValueError: If the tokenizer cannot be loaded due to invalid identifier,
import failure, or missing class.
"""
# Resolve aliases to full import paths
import_path = TOKENIZER_ALIASES.get(tokenizer_identifier,
tokenizer_identifier)

try:
module_path, class_name = import_path.rsplit('.', 1)
module = importlib.import_module(module_path)
tokenizer_class = getattr(module, class_name)
return tokenizer_class.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, use_fast=use_fast)
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(
f"Failed to load custom tokenizer '{tokenizer_identifier}': {e}. "
"Expected format: 'module.path.ClassName' or a recognized alias."
) from e
20 changes: 16 additions & 4 deletions tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
'tiktoken_vocab')


@pytest.fixture(
scope="module",
params=["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
@pytest.fixture(scope="module",
params=[
"meta-llama/Llama-3.1-8B-Instruct",
"openai/gpt-oss-120b",
pytest.param("zai-org/GLM-5-FP8",
marks=pytest.mark.skip_less_device(8)),
])
def model_name(request):
return request.param

Expand All @@ -43,6 +47,8 @@ def temp_extra_llm_api_options_file(model_name: str):
"speculative_model_dir":
get_model_path("gpt_oss/gpt-oss-120b-Eagle3"),
}
elif model_name == "zai-org/GLM-5-FP8":
extra_llm_api_options_dict["custom_tokenizer"] = "glm_moe_dsa"

with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
Expand All @@ -59,11 +65,17 @@ def server(model_name: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path("llama-3.1-model/Llama-3.1-8B-Instruct")
elif model_name == "openai/gpt-oss-120b":
model_path = get_model_path("gpt_oss/gpt-oss-120b")
elif model_name == "zai-org/GLM-5-FP8":
model_path = get_model_path("GLM-5-FP8")

args = [
"--max_batch_size=8", "--max_seq_len=4096", "--max_num_tokens=4096",
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
]

if model_name == "zai-org/GLM-5-FP8":
args.extend(["--tp_size=8", "--ep_size=8"])

with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server

Expand Down Expand Up @@ -282,7 +294,7 @@ def test_ebnf(client: openai.OpenAI, model_name: str):
},
{
"role": "user",
"content": "Give me the information of the capital of France.",
"content": "What's the capital of France?",
},
]
chat_completion = client.chat.completions.create(
Expand Down
Loading