Skip to content

Commit 7ffa8b4

Browse files
authored
fix: Qwen3.5 tool-call chat-template tokenization (#634)
1 parent 1905677 commit 7ffa8b4

File tree

2 files changed

+136
-3
lines changed

2 files changed

+136
-3
lines changed

src/art/preprocessing/tokenize.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass, field
33
from functools import cached_property
44
from itertools import takewhile
5+
import json
56
import math
67
import random
78
from typing import Any, Generator, cast
@@ -31,6 +32,40 @@ def _normalize_tools_for_chat_template(tools: Any) -> list[ChatTemplateTool] | N
3132
return normalized_tools
3233

3334

35+
def _normalize_tool_call_arguments_for_chat_template(
36+
tokenizer: PreTrainedTokenizerBase,
37+
messages: list[dict[str, Any]],
38+
) -> list[dict[str, Any]]:
39+
chat_template = tokenizer.chat_template
40+
assert isinstance(chat_template, str)
41+
if "tool_call.arguments|items" not in chat_template:
42+
return messages
43+
44+
normalized_messages: list[dict[str, Any]] = []
45+
for message in messages:
46+
tool_calls = message.get("tool_calls")
47+
if tool_calls is None:
48+
normalized_messages.append(message)
49+
continue
50+
51+
assert isinstance(tool_calls, list)
52+
normalized_tool_calls = []
53+
for tool_call in tool_calls:
54+
assert isinstance(tool_call, dict)
55+
function = tool_call["function"]
56+
assert isinstance(function, dict)
57+
arguments_json = function["arguments"]
58+
assert isinstance(arguments_json, str)
59+
arguments = json.loads(arguments_json)
60+
assert isinstance(arguments, dict)
61+
normalized_tool_calls.append(
62+
{**tool_call, "function": {**function, "arguments": arguments}}
63+
)
64+
normalized_messages.append({**message, "tool_calls": normalized_tool_calls})
65+
66+
return normalized_messages
67+
68+
3469
@dataclass
3570
class TokenizedResult:
3671
advantage: float
@@ -223,20 +258,23 @@ def tokenize_trajectory(
223258
if last_assistant_index == -1:
224259
return None
225260
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
226-
messages = get_messages(messages_and_choices)
261+
messages = cast(list[dict[str, Any]], get_messages(messages_and_choices))
262+
# Qwen3.5's chat template uses `tool_call.arguments|items`, so it needs a
263+
# mapping here instead of the OpenAI JSON string.
264+
messages = _normalize_tool_call_arguments_for_chat_template(tokenizer, messages)
227265
tools = _normalize_tools_for_chat_template(history.tools)
228266
chat = cast(
229267
str,
230268
tokenizer.apply_chat_template(
231-
cast(list[dict], messages),
269+
messages,
232270
tools=tools,
233271
continue_final_message=True,
234272
tokenize=False,
235273
),
236274
)
237275
original_token_ids = _apply_chat_template_token_ids(
238276
tokenizer,
239-
cast(list[dict[str, Any]], messages),
277+
messages,
240278
tools=tools,
241279
continue_final_message=True,
242280
)

tests/unit/test_preprocessing_tokenize.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import types
33
from typing import cast
44

5+
from openai.types.chat.chat_completion import Choice
56
import pytest
67
from transformers.tokenization_utils_base import BatchEncoding
78

@@ -15,6 +16,7 @@
1516

1617

1718
class _FakeTokenizer:
19+
chat_template = ""
1820
vocab_size = 256
1921
eos_token = "\x00"
2022
eos_token_id = 0
@@ -60,6 +62,38 @@ def convert_tokens_to_ids(self, tokens):
6062
return self.eos_token_id
6163

6264

65+
class _Qwen3_5FakeTokenizer(_FakeTokenizer):
66+
chat_template = (
67+
"{% for args_name, args_value in tool_call.arguments|items %}{% endfor %}"
68+
)
69+
70+
def apply_chat_template(
71+
self,
72+
messages,
73+
tools=None,
74+
tokenize=True,
75+
return_dict=None,
76+
**kwargs,
77+
):
78+
del kwargs
79+
for message in messages:
80+
tool_calls = message.get("tool_calls")
81+
if tool_calls is None:
82+
continue
83+
assert isinstance(tool_calls, list)
84+
for tool_call in tool_calls:
85+
assert isinstance(tool_call, dict)
86+
function = tool_call["function"]
87+
assert isinstance(function, dict)
88+
assert isinstance(function["arguments"], dict)
89+
return super().apply_chat_template(
90+
messages,
91+
tools=tools,
92+
tokenize=tokenize,
93+
return_dict=return_dict,
94+
)
95+
96+
6397
def test_tokenize_trajectory_accepts_batchencoding_chat_template_output() -> None:
6498
tokenizer = _FakeTokenizer()
6599
messages = cast(
@@ -143,3 +177,64 @@ def _labels_fn(batch):
143177
[1] * len(expected_ids)
144178
]
145179
assert batch.num_trainable_tokens == len(expected_ids)
180+
181+
182+
def test_tokenize_trajectory_normalizes_mapping_tool_arguments_for_chat_template() -> (
183+
None
184+
):
185+
tokenizer = _Qwen3_5FakeTokenizer()
186+
choice = Choice.model_validate(
187+
{
188+
"finish_reason": "stop",
189+
"index": 0,
190+
"logprobs": {
191+
"content": [
192+
{
193+
"token": "token_id:65",
194+
"bytes": [65],
195+
"logprob": -0.1,
196+
"top_logprobs": [],
197+
}
198+
],
199+
"refusal": None,
200+
},
201+
"message": {
202+
"content": "",
203+
"refusal": None,
204+
"role": "assistant",
205+
"annotations": None,
206+
"audio": None,
207+
"function_call": None,
208+
"tool_calls": [
209+
{
210+
"id": "call_1",
211+
"function": {
212+
"arguments": '{"city": "San Francisco", "days": 3}',
213+
"name": "lookup_weather",
214+
},
215+
"type": "function",
216+
}
217+
],
218+
},
219+
}
220+
)
221+
messages = cast(
222+
MessagesAndChoices,
223+
[
224+
{"role": "user", "content": "Weather?"},
225+
choice,
226+
],
227+
)
228+
history = History(messages_and_choices=messages)
229+
trajectory = Trajectory(messages_and_choices=messages, reward=1.0)
230+
231+
result = tokenize_trajectory(
232+
tokenizer=tokenizer, # type: ignore[arg-type]
233+
image_processor=None,
234+
history=history,
235+
advantage=1.0,
236+
allow_training_without_logprobs=False,
237+
trajectory=trajectory,
238+
)
239+
240+
assert result is not None

0 commit comments

Comments
 (0)