Skip to content

Commit 4cbfa15

Browse files
committed
chore: Refactor message handling and enhance tokenization logic
- Updated imports for clarity and consistency in `types.py`. - Added `drop_zero_advantage_trajectories` parameter to `tokenize_trajectory_groups` function in `tokenize.py` to control trajectory filtering. - Introduced `_normalize_message_or_choice` function in `server.py` to standardize message validation and conversion. - Enhanced `_message_or_choice_to_dict` function in `client.py` to utilize the new message adapter for improved validation.
1 parent dec6b3a commit 4cbfa15

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

src/art/preprocessing/tokenize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def tokenize_trajectory_groups(
155155
allow_training_without_logprobs: bool,
156156
scale_rewards: bool,
157157
shuffle_group_trajectories: bool = True,
158+
drop_zero_advantage_trajectories: bool = True,
158159
image_processor: BaseImageProcessor | None = None,
159160
) -> Generator["TokenizedResult", None, None]:
160161
for group in trajectory_groups:
@@ -172,8 +173,7 @@ def tokenize_trajectory_groups(
172173
advantage = trajectory.reward - reward_mean
173174
if scale_rewards:
174175
advantage /= reward_std + 1e-6
175-
# Skip trajectories with no advantage
176-
if advantage == 0:
176+
if advantage == 0 and drop_zero_advantage_trajectories:
177177
continue
178178
trajectory_results: list[TokenizedResult] = []
179179
for history in [

src/art/tinker/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,23 @@
1212
from openai.resources.models import AsyncModels
1313
from openai.types import Model
1414
from openai.types.chat.chat_completion import Choice
15+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1516
from openai.types.completion_usage import CompletionUsage
17+
from pydantic import TypeAdapter
1618

1719
from art.types import Message, MessageOrChoice, MessagesAndChoices, Tools
1820

1921
ParsedMessageOrChoice = Choice | Message
2022
ParsedMessagesAndChoices = list[ParsedMessageOrChoice]
23+
_MESSAGE_ADAPTER = TypeAdapter(ChatCompletionMessageParam)
2124

2225

2326
def _message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, Any]:
2427
if isinstance(message_or_choice, dict):
25-
return cast(dict[str, Any], message_or_choice)
28+
validated = _MESSAGE_ADAPTER.validate_python(message_or_choice)
29+
return cast(
30+
dict[str, Any], _MESSAGE_ADAPTER.dump_python(validated, mode="json")
31+
)
2632
if isinstance(message_or_choice, BaseModel):
2733
return cast(dict[str, Any], message_or_choice.to_dict())
2834
to_dict = getattr(message_or_choice, "to_dict", None)

src/art/tinker/server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from openai.types.chat.completion_create_params import CompletionCreateParams
2828
from openai.types.completion_usage import CompletionUsage
29-
from pydantic import BaseModel, Field, SkipValidation
29+
from pydantic import BaseModel, Field, SkipValidation, TypeAdapter
3030
import tinker
3131
from transformers.tokenization_utils_base import BatchEncoding
3232
import uvicorn
@@ -49,6 +49,7 @@ class ModelUpsert(BaseModel):
4949

5050

5151
WireMessagesAndChoices = list[Choice | Message]
52+
_MESSAGE_ADAPTER = TypeAdapter(ChatCompletionMessageParam)
5253

5354

5455
class MessagesAndChoicesWithLogprobsArgs(BaseModel):
@@ -63,6 +64,14 @@ class MessagesAndChoicesWithLogprobs(BaseModel):
6364
usages: list[CompletionUsage]
6465

6566

67+
def _normalize_message_or_choice(
68+
message_or_choice: Choice | Message,
69+
) -> Choice | Message:
70+
if isinstance(message_or_choice, Choice):
71+
return message_or_choice
72+
return cast(Message, _MESSAGE_ADAPTER.validate_python(message_or_choice))
73+
74+
6675
def _normalize_qwen3_5_messages(
6776
base_model: str, messages: list[ChatCompletionMessageParam]
6877
) -> list[dict[str, Any]]:
@@ -264,7 +273,10 @@ async def add_logprobs(model: str, alias: str | None) -> CompletionUsage:
264273
]
265274
)
266275
return MessagesAndChoicesWithLogprobs(
267-
messages_and_choices=args.messages_and_choices,
276+
messages_and_choices=[
277+
_normalize_message_or_choice(item)
278+
for item in args.messages_and_choices
279+
],
268280
usages=usages,
269281
)
270282

src/art/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import Annotated, Literal
33

4-
from openai.types.chat.chat_completion import Choice
4+
from openai.types.chat.chat_completion import Choice as Choice
55
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
66
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
77
import pydantic

0 commit comments

Comments
 (0)