Skip to content

Commit e0923cf

Browse files
authored
allow reasoning details to pass through (#355)
* allow reasoning details to pass through * update * address comments * make generic * update * more generic storing of provider_specific_fields * add a test * put message back in dict
1 parent 01bc8e9 commit e0923cf

File tree

3 files changed

+110
-8
lines changed

3 files changed

+110
-8
lines changed

eval_protocol/mcp/execution/base_policy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ async def _generate_live_tool_calls(
199199
if message.get("tool_calls"):
200200
assistant_message_for_history["tool_calls"] = message["tool_calls"]
201201

202+
# Preserve specific fields from provider_specific_fields if present
203+
if message.get("provider_specific_fields"):
204+
if message["provider_specific_fields"].get("reasoning_details"):
205+
assistant_message_for_history["reasoning_details"] = message["provider_specific_fields"][
206+
"reasoning_details"
207+
]
208+
202209
# Add to actual conversation history
203210
conversation_history.append(assistant_message_for_history)
204211

eval_protocol/mcp/execution/policy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
146146
Clean messages with only OpenAI API compatible fields
147147
"""
148148
# Standard OpenAI message fields
149-
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
149+
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_details"}
150150

151151
clean_messages = []
152152
for msg in messages:
@@ -217,12 +217,15 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
217217
logger.debug(f"🔄 API call for model: {self.model_id}")
218218

219219
# LiteLLM already returns OpenAI-compatible format
220+
message_obj = getattr(response.choices[0], "message", object())
221+
220222
return {
221223
"choices": [
222224
{
223225
"message": {
224-
"role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"),
225-
"content": getattr(getattr(response.choices[0], "message", object()), "content", None),
226+
"role": getattr(message_obj, "role", "assistant"),
227+
"content": getattr(message_obj, "content", None),
228+
"provider_specific_fields": getattr(message_obj, "provider_specific_fields", None),
226229
"tool_calls": (
227230
[
228231
{
@@ -233,12 +236,9 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
233236
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
234237
},
235238
}
236-
for tc in (
237-
getattr(getattr(response.choices[0], "message", object()), "tool_calls", [])
238-
or []
239-
)
239+
for tc in (getattr(message_obj, "tool_calls", []) or [])
240240
]
241-
if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None)
241+
if getattr(message_obj, "tool_calls", None)
242242
else []
243243
),
244244
},
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import types
2+
3+
import pytest
4+
5+
import eval_protocol.mcp.execution.policy as policy_mod
6+
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_litellm_policy_surfaces_provider_specific_reasoning_details(monkeypatch):
11+
"""
12+
Ensure that provider_specific_fields from the LiteLLM message object are
13+
preserved on the returned message dict from LiteLLMPolicy._make_llm_call.
14+
"""
15+
16+
# Define a fake ModelResponse base class and patch the module's ModelResponse
17+
class FakeModelResponseBase: ...
18+
19+
policy_mod.ModelResponse = FakeModelResponseBase
20+
21+
async def fake_acompletion(*args, **kwargs):
22+
# This mimics the LiteLLM Message object shape we rely on in policy._make_llm_call
23+
message_obj = types.SimpleNamespace(
24+
role="assistant",
25+
content="",
26+
tool_calls=[
27+
types.SimpleNamespace(
28+
id="tool_get_reservation_details_123",
29+
type="function",
30+
function=types.SimpleNamespace(
31+
name="get_reservation_details",
32+
arguments='{"reservation_id":"EHGLP3"}',
33+
),
34+
)
35+
],
36+
provider_specific_fields={
37+
"reasoning_details": [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}],
38+
"custom_field": "keep_me",
39+
},
40+
)
41+
42+
class FakeModelResponse(FakeModelResponseBase):
43+
def __init__(self) -> None:
44+
self.choices = [
45+
types.SimpleNamespace(
46+
finish_reason="tool_calls",
47+
index=0,
48+
message=message_obj,
49+
)
50+
]
51+
self.usage = types.SimpleNamespace(
52+
prompt_tokens=10,
53+
completion_tokens=5,
54+
total_tokens=15,
55+
)
56+
57+
return FakeModelResponse()
58+
59+
# Patch acompletion so we don't hit the network
60+
monkeypatch.setattr(policy_mod, "acompletion", fake_acompletion)
61+
62+
# Use a concrete policy instance; base_url/model_id values don't matter for this unit test
63+
policy = LiteLLMPolicy(model_id="openrouter/google/gemini-3-pro-preview", use_caching=False)
64+
65+
messages = [
66+
{
67+
"role": "assistant",
68+
"content": "",
69+
"tool_calls": [
70+
{
71+
"id": "tool_get_reservation_details_123",
72+
"type": "function",
73+
"function": {"name": "get_reservation_details", "arguments": '{"reservation_id":"EHGLP3"}'},
74+
}
75+
],
76+
}
77+
]
78+
79+
# No tools are needed for this test – we only care about the returned message shape
80+
result = await policy._make_llm_call(messages, tools=[])
81+
82+
assert "choices" in result
83+
assert len(result["choices"]) == 1
84+
msg = result["choices"][0]["message"]
85+
86+
# Core fields should be present
87+
assert msg["role"] == "assistant"
88+
assert isinstance(msg.get("tool_calls"), list)
89+
90+
# provider_specific_fields should be preserved on the message
91+
ps = msg.get("provider_specific_fields")
92+
assert isinstance(ps, dict)
93+
assert ps["reasoning_details"] == [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}]
94+
# Non-core provider_specific_fields should also be preserved
95+
assert ps.get("custom_field") == "keep_me"

0 commit comments

Comments
 (0)