Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions extension/llm/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
7 changes: 7 additions & 0 deletions extension/llm/server/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""OpenAI-compatible server for ExecuTorch LLMs (Python implementation)."""
144 changes: 144 additions & 0 deletions extension/llm/server/python/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Render OpenAI chat messages into a single prompt string.

The ExecuTorch runner tokenizes a plain prompt; chat formatting is the server's
job (control plane). We require the model's own Hugging Face ``chat_template``
(via ``--hf-tokenizer``) for correct, tool-aware, reasoning-aware formatting.
The generic ChatML fallback is opt-in only (``allow_fallback``): it is
approximate and cannot reproduce model-specific controls (e.g. enable_thinking),
so it must be a deliberate choice rather than a silent default.
"""

import json
import logging
from typing import Any, Optional

from .protocol import ChatMessage

logger = logging.getLogger(__name__)


_DEFAULT_SPECIAL_TOKENS = ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "<|end|>"]


def _decode_tool_call_arguments(messages: list[dict[str, Any]]) -> None:
"""In-place: parse each tool call's ``function.arguments`` from a JSON string
into an object.

OpenAI sends assistant tool-call arguments as a JSON-encoded string, but HF
chat templates expect a mapping (e.g. Qwen renders ``arguments|items`` into
``<parameter=…>`` tags). Without this, a multi-turn tool conversation makes
the template raise "Can only get item pairs from a mapping". Left as-is if
the value isn't valid JSON, so a template that wants the raw string still works.
"""
for m in messages:
for tc in m.get("tool_calls") or []:
fn = tc.get("function")
if not isinstance(fn, dict):
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
fn["arguments"] = json.loads(args)
except (ValueError, TypeError):
pass


class ChatTemplate:
def __init__(
self,
hf_tokenizer_path: Optional[str] = None,
default_template_kwargs: Optional[dict[str, Any]] = None,
allow_fallback: bool = False,
):
# Server-level defaults (e.g. {"enable_thinking": False}); per-request
# chat_template_kwargs override these.
self._defaults = default_template_kwargs or {}
self._hf = None
if hf_tokenizer_path:
from transformers import AutoTokenizer

self._hf = AutoTokenizer.from_pretrained(hf_tokenizer_path)
if self._hf.chat_template is None:
self._hf = None
if not allow_fallback:
raise ValueError(
f"HF tokenizer at {hf_tokenizer_path} has no chat_template; "
"pass an explicit fallback flag to use approximate ChatML."
)
logger.warning(
"No chat_template at %s; using approximate ChatML.",
hf_tokenizer_path,
)
elif not allow_fallback:
raise ValueError(
"A chat template is required: pass --hf-tokenizer for the model's own "
"template, or opt into approximate ChatML with --allow-chatml-fallback."
)
else:
logger.warning(
"No --hf-tokenizer; using approximate ChatML (no thinking control)."
)

def render(
self,
messages: list[ChatMessage],
tools: Optional[list[dict[str, Any]]] = None,
template_kwargs: Optional[dict[str, Any]] = None,
) -> str:
kwargs = {**self._defaults, **(template_kwargs or {})}
if self._hf is not None:
dumped = [m.model_dump(exclude_none=True) for m in messages]
_decode_tool_call_arguments(dumped)
return self._hf.apply_chat_template(
dumped,
tools=tools,
add_generation_prompt=True,
tokenize=False,
**kwargs,
)
return self._fallback(messages)

def chat_template_str(self) -> Optional[str]:
"""Raw chat-template string (for tool-format auto-detection), if available."""
return (
getattr(self._hf, "chat_template", None) if self._hf is not None else None
)

def count_tokens(self, prompt: str) -> Optional[int]:
"""Token count for the rendered prompt, or None if no tokenizer is available."""
if self._hf is not None:
# The prompt is already rendered (apply_chat_template includes the
# control tokens), so encode without re-adding BOS/EOS — matching the
# session/prefix-cache paths, so the count isn't inflated and
# near-limit requests aren't falsely rejected under --max-context.
return len(self._hf.encode(prompt, add_special_tokens=False))
return None

def special_tokens(self) -> list[str]:
"""Special-token strings whose appearance ends the visible content.

From the HF tokenizer when available (model-accurate), else a default set
covering common chat models.
"""
if self._hf is not None:
toks = list(getattr(self._hf, "all_special_tokens", []) or [])
return [t for t in toks if isinstance(t, str) and t]
return list(_DEFAULT_SPECIAL_TOKENS)

@staticmethod
def _fallback(messages: list[ChatMessage]) -> str:
# Approximate ChatML. Provide --hf-tokenizer for model-correct formatting
# (including reasoning controls like enable_thinking, which the fallback
# cannot reproduce).
parts = []
for m in messages:
content = m.content if isinstance(m.content, str) else str(m.content or "")
parts.append(f"<|im_start|>{m.role}\n{content}<|im_end|>")
parts.append("<|im_start|>assistant\n")
return "\n".join(parts)
62 changes: 62 additions & 0 deletions extension/llm/server/python/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""OpenAI-shaped API errors.

Raising these lets the server return a structured `{"error": {...}}` body with
the right HTTP status instead of dropping the connection.
"""

from typing import Optional


class APIError(Exception):
def __init__(
self, status: int, message: str, err_type: str, code: Optional[str] = None
):
super().__init__(message)
self.status = status
self.message = message
self.err_type = err_type
self.code = code

def body(self) -> dict:
return {
"error": {"message": self.message, "type": self.err_type, "code": self.code}
}


class ContextLengthExceeded(APIError):
def __init__(self, num_tokens: int, max_context: int, completion_tokens: int = 0):
# completion_tokens > 0: the prompt fits but prompt + requested
# max_tokens would run past the window — reject up front rather than
# fail (or truncate) mid-generation.
if completion_tokens > 0:
message = (
f"This model's maximum context length is {max_context} tokens. "
f"However, you requested {num_tokens + completion_tokens} tokens "
f"({num_tokens} in the messages, {completion_tokens} in the "
f"completion). Please reduce the length of the messages or "
f"completion."
)
else:
message = (
f"This model's maximum context length is {max_context} tokens, "
f"but the request has {num_tokens} prompt tokens."
)
super().__init__(
status=400,
message=message,
err_type="invalid_request_error",
code="context_length_exceeded",
)


class GenerationError(APIError):
def __init__(self, detail: str):
super().__init__(
status=500, message=f"Generation failed: {detail}", err_type="server_error"
)
148 changes: 148 additions & 0 deletions extension/llm/server/python/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""OpenAI-compatible request/response schemas for the ExecuTorch LLM server.

This is the Python view of the contract defined in ``extension/llm/server/spec``.
Any language server must serialize to the same shapes; the conformance suite in
``extension/llm/server/conformance`` validates them.
"""

import time
import uuid
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, Field


def _new_id(prefix: str) -> str:
return f"{prefix}-{uuid.uuid4().hex}"


class FunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None


class ToolCall(BaseModel):
index: Optional[int] = None
id: Optional[str] = None
type: Literal["function"] = "function"
function: FunctionCall


class ChatMessage(BaseModel):
role: str
content: Optional[Union[str, list[dict[str, Any]]]] = None
name: Optional[str] = None
tool_calls: Optional[list[ToolCall]] = None
tool_call_id: Optional[str] = None


class StreamOptions(BaseModel):
include_usage: bool = False


class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: list[ChatMessage]
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
n: int = 1
seed: Optional[int] = None
# Sampling knobs that change generation output. We don't plumb these, so they
# are modeled (not dropped) in order to be rejected with a clear error rather
# than silently ignored — see serving_chat's unsupported-parameter check.
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
top_k: Optional[int] = None
logit_bias: Optional[dict[str, float]] = None
# Output-contract fields: modeled (not dropped) so we reject the ones we
# can't honor rather than returning an output that violates what was asked.
response_format: Optional[dict[str, Any]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
parallel_tool_calls: Optional[bool] = None
# Per-request chat-template controls, e.g. {"enable_thinking": false} for Qwen3.
chat_template_kwargs: Optional[dict[str, Any]] = None
# Accepted now so the contract is stable; parsing/enforcement land in M2/M5.
tools: Optional[list[dict[str, Any]]] = None
tool_choice: Optional[Union[str, dict[str, Any]]] = None
reasoning_effort: Optional[str] = None

def resolved_max_tokens(self) -> int:
# `is not None` (not `or`): an explicit 0 must not be treated as unset.
# Callers validate positivity; -1 means "unset / auto".
if self.max_completion_tokens is not None:
return self.max_completion_tokens
if self.max_tokens is not None:
return self.max_tokens
return -1


class Usage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0


class ResponseMessage(BaseModel):
role: str = "assistant"
content: Optional[str] = None
tool_calls: Optional[list[ToolCall]] = None


class Choice(BaseModel):
index: int = 0
message: ResponseMessage
finish_reason: Optional[str] = None


class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: _new_id("chatcmpl"))
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[Choice]
usage: Usage = Field(default_factory=Usage)


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[list[ToolCall]] = None


class ChunkChoice(BaseModel):
index: int = 0
delta: DeltaMessage
finish_reason: Optional[str] = None


class ChatCompletionChunk(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChunkChoice]
usage: Optional[Usage] = None


class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "executorch"


class ModelList(BaseModel):
object: Literal["list"] = "list"
data: list[ModelCard]
5 changes: 5 additions & 0 deletions extension/llm/server/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi>=0.110
uvicorn[standard]>=0.27
pydantic>=2.0
# Optional but recommended for model-correct chat templating (--hf-tokenizer):
# transformers>=4.40
Loading
Loading