Skip to content

Commit 5adf5f2

Browse files
authored
filter fields in rollout (#309)
1 parent e15d855 commit 5adf5f2

File tree

4 files changed

+78
-4
lines changed

4 files changed

+78
-4
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPart
133133

134134
async def _call_model(self, messages: list[Message], tools: Optional[List[dict[str, Any]]]) -> Message:
135135
# Convert Message models to plain dicts for LLM call
136+
# Filter out fields that are not supported by OpenAI/LiteLLM APIs (e.g., weight, control_plane_step, reasoning_content)
136137
messages_payload: List[Dict[str, Any]] = [
137-
message.model_dump() if hasattr(message, "model_dump") else message # type: ignore[misc]
138+
message.dump_mdoel_for_chat_completion_request()
139+
if hasattr(message, "dump_mdoel_for_chat_completion_request")
140+
else (message.model_dump() if hasattr(message, "model_dump") else message) # type: ignore[misc]
138141
for message in messages
139142
]
140143
# Normalize tool definitions into OpenAI-compatible dicts

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4848
while messages_for_request and messages_for_request[-1].role == "assistant":
4949
messages_for_request.pop()
5050

51-
messages_payload = [message.model_dump() for message in messages_for_request]
51+
# Filter out fields that are not supported by OpenAI/LiteLLM APIs (e.g., weight, control_plane_step, reasoning_content)
52+
# Use the Message class method that excludes unsupported fields
53+
messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request]
5254

5355
request_params = {"messages": messages_payload, **config.completion_params}
5456
# Ensure caching is disabled only for this request (review feedback)

eval_protocol/pytest/tracing_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,14 @@ def build_init_request(
101101
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
102102

103103
# Strip non-OpenAI fields from messages
104-
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
104+
# Use dump_mdoel_for_chat_completion_request() to automatically exclude unsupported fields (weight, control_plane_step, reasoning_content)
105105
clean_messages = []
106106
for m in row.messages:
107107
md: Dict[str, Any]
108-
if hasattr(m, "model_dump"):
108+
if hasattr(m, "dump_mdoel_for_chat_completion_request"):
109+
# Use the Message method that automatically filters unsupported fields
110+
md = m.dump_mdoel_for_chat_completion_request()
111+
elif hasattr(m, "model_dump"):
109112
md = m.model_dump()
110113
elif isinstance(m, dict):
111114
md = m
@@ -118,6 +121,8 @@ def build_init_request(
118121
"tool_call_id": getattr(m, "tool_call_id", None),
119122
"name": getattr(m, "name", None),
120123
}
124+
# Additional filtering to ensure only allowed fields are kept (already handled by dump_mdoel_for_chat_completion_request for Message objects)
125+
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
121126
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
122127

123128
# Build final model base URL with tracing metadata
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Test to verify that message fields are properly filtered before sending to API.
3+
4+
This test verifies that unsupported fields like 'weight', 'control_plane_step',
5+
and 'reasoning_content' are excluded from messages when preparing API requests.
6+
"""
7+
8+
from eval_protocol.models import Message
9+
10+
11+
def test_dump_model_excludes_unsupported_fields():
12+
"""Test that dump_mdoel_for_chat_completion_request excludes unsupported fields."""
13+
# Create a message with all possible fields including unsupported ones
14+
message = Message(
15+
role="user",
16+
content="Hello",
17+
weight=0,
18+
control_plane_step={"step": 1},
19+
reasoning_content="Some reasoning",
20+
name="test_user",
21+
)
22+
23+
# Get the filtered dictionary
24+
filtered = message.dump_mdoel_for_chat_completion_request()
25+
26+
# Verify unsupported fields are excluded
27+
assert "weight" not in filtered, "weight field should be excluded"
28+
assert "control_plane_step" not in filtered, "control_plane_step field should be excluded"
29+
assert "reasoning_content" not in filtered, "reasoning_content field should be excluded"
30+
31+
# Verify supported fields are included
32+
assert "role" in filtered, "role field should be included"
33+
assert "content" in filtered, "content field should be included"
34+
assert filtered["role"] == "user"
35+
assert filtered["content"] == "Hello"
36+
37+
# Verify name is included (it's a supported field for tool calls)
38+
assert "name" in filtered
39+
assert filtered["name"] == "test_user"
40+
41+
42+
def test_dump_model_with_only_supported_fields():
43+
"""Test that supported fields are preserved."""
44+
message = Message(
45+
role="assistant",
46+
content="I can help you",
47+
tool_calls=None,
48+
tool_call_id=None,
49+
)
50+
51+
filtered = message.dump_mdoel_for_chat_completion_request()
52+
53+
# Should only contain supported fields
54+
assert filtered["role"] == "assistant"
55+
assert filtered["content"] == "I can help you"
56+
57+
# Should not contain unsupported fields even if None
58+
assert "weight" not in filtered
59+
60+
61+
if __name__ == "__main__":
62+
import pytest
63+
64+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)