Skip to content

Commit 2f044dd

Browse files
committed
feat(server): add chat_template_kwargs model setting
1 parent a83926a commit 2f044dd

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed

llama_cpp/server/cli.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import argparse
4+
import json
45

5-
from typing import List, Literal, Union, Any, Type, TypeVar
6+
from typing import List, Literal, Union, Any, Type, TypeVar, Dict
67

78
from pydantic import BaseModel
89

@@ -40,6 +41,17 @@ def _contains_list_type(annotation: Type[Any] | None) -> bool:
4041
return False
4142

4243

44+
def _contains_dict_type(annotation: Type[Any] | None) -> bool:
45+
origin = getattr(annotation, "__origin__", None)
46+
47+
if origin is dict or origin is Dict:
48+
return True
49+
elif origin in (Literal, Union):
50+
return any(_contains_dict_type(arg) for arg in annotation.__args__) # type: ignore
51+
else:
52+
return False
53+
54+
4355
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
4456
if isinstance(arg, bytes):
4557
arg = arg.decode("utf-8")
@@ -57,6 +69,16 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
5769
raise ValueError(f"Invalid boolean argument: {arg}")
5870

5971

72+
def _parse_json_object_arg(arg: str | bytes) -> dict[str, Any]:
73+
if isinstance(arg, bytes):
74+
arg = arg.decode("utf-8")
75+
76+
value = json.loads(arg)
77+
if not isinstance(value, dict):
78+
raise ValueError(f"Invalid JSON object argument: {arg}")
79+
return value
80+
81+
6082
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
6183
"""Add arguments from a pydantic model to an argparse parser."""
6284

@@ -68,7 +90,15 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel])
6890
_get_base_type(field.annotation) if field.annotation is not None else str
6991
)
7092
list_type = _contains_list_type(field.annotation)
71-
if base_type is not bool:
93+
dict_type = _contains_dict_type(field.annotation)
94+
if dict_type:
95+
parser.add_argument(
96+
f"--{name}",
97+
dest=name,
98+
type=_parse_json_object_arg,
99+
help=description,
100+
)
101+
elif base_type is not bool:
72102
parser.add_argument(
73103
f"--{name}",
74104
dest=name,

llama_cpp/server/model.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44

5-
from typing import Dict, Optional, Union, List
5+
from typing import Any, Dict, Optional, Union, List
66

77
import llama_cpp
88
import llama_cpp.llama_speculative as llama_speculative
@@ -11,6 +11,29 @@
1111
from llama_cpp.server.settings import ModelSettings
1212

1313

14+
def _resolve_chat_handler(
15+
model: llama_cpp.Llama,
16+
) -> llama_cpp.llama_chat_format.LlamaChatCompletionHandler:
17+
chat_handler = (
18+
model.chat_handler
19+
or model._chat_handlers.get(model.chat_format)
20+
or llama_cpp.llama_chat_format.get_chat_completion_handler(model.chat_format)
21+
)
22+
return chat_handler
23+
24+
25+
def _chat_handler_with_kwargs(
26+
chat_handler: llama_cpp.llama_chat_format.LlamaChatCompletionHandler,
27+
chat_template_kwargs: Dict[str, Any],
28+
) -> llama_cpp.llama_chat_format.LlamaChatCompletionHandler:
29+
def handler(*args: Any, **kwargs: Any):
30+
merged_kwargs = dict(chat_template_kwargs)
31+
merged_kwargs.update(kwargs)
32+
return chat_handler(*args, **merged_kwargs)
33+
34+
return handler
35+
36+
1437
class LlamaProxy:
1538
def __init__(self, models: List[ModelSettings]) -> None:
1639
assert len(models) > 0, "No models provided!"
@@ -299,6 +322,10 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
299322
# Misc
300323
verbose=settings.verbose,
301324
)
325+
if settings.chat_template_kwargs:
326+
_model.chat_handler = _chat_handler_with_kwargs(
327+
_resolve_chat_handler(_model), settings.chat_template_kwargs
328+
)
302329
if settings.cache:
303330
if settings.cache_type == "disk":
304331
if settings.verbose:

llama_cpp/server/settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import multiprocessing
44

5-
from typing import Optional, List, Literal, Union, Dict, cast
5+
from typing import Any, Optional, List, Literal, Union, Dict, cast
66
from typing_extensions import Self
77

88
from pydantic import Field, model_validator
@@ -131,6 +131,10 @@ class ModelSettings(BaseSettings):
131131
default=None,
132132
description="Chat format to use.",
133133
)
134+
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
135+
default=None,
136+
description="Extra keyword arguments forwarded to chat templates at model load time. Matches llama.cpp server `chat_template_kwargs`.",
137+
)
134138
clip_model_path: Optional[str] = Field(
135139
default=None,
136140
description="Path to a CLIP model to use for multi-modal chat completion.",

tests/test_server_model.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import argparse
2+
3+
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
4+
from llama_cpp.server.model import _chat_handler_with_kwargs
5+
from llama_cpp.server.settings import ModelSettings
6+
7+
8+
def test_model_settings_accepts_chat_template_kwargs():
9+
settings = ModelSettings(
10+
model="test.gguf",
11+
chat_template_kwargs={
12+
"enable_thinking": True,
13+
"reasoning_effort": "low",
14+
},
15+
)
16+
17+
assert settings.chat_template_kwargs == {
18+
"enable_thinking": True,
19+
"reasoning_effort": "low",
20+
}
21+
22+
23+
def test_cli_parses_chat_template_kwargs_json():
24+
parser = argparse.ArgumentParser()
25+
add_args_from_model(parser, ModelSettings)
26+
27+
args = parser.parse_args(
28+
[
29+
"--model",
30+
"test.gguf",
31+
"--chat_template_kwargs",
32+
'{"enable_thinking": true, "reasoning_effort": "low"}',
33+
]
34+
)
35+
settings = parse_model_from_args(ModelSettings, args)
36+
37+
assert settings.chat_template_kwargs == {
38+
"enable_thinking": True,
39+
"reasoning_effort": "low",
40+
}
41+
42+
43+
def test_chat_handler_with_kwargs_merges_defaults_and_request_kwargs():
44+
captured = {}
45+
46+
def base_handler(*args, **kwargs):
47+
captured["args"] = args
48+
captured["kwargs"] = kwargs
49+
return "ok"
50+
51+
wrapped = _chat_handler_with_kwargs(
52+
base_handler,
53+
{
54+
"enable_thinking": True,
55+
"reasoning_effort": "medium",
56+
},
57+
)
58+
59+
result = wrapped(reasoning_effort="high", extra_flag="x")
60+
61+
assert result == "ok"
62+
assert captured["kwargs"] == {
63+
"enable_thinking": True,
64+
"reasoning_effort": "high",
65+
"extra_flag": "x",
66+
}

0 commit comments

Comments
 (0)