Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions eval_protocol/mcp/execution/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ async def _generate_live_tool_calls(
if message.get("tool_calls"):
assistant_message_for_history["tool_calls"] = message["tool_calls"]

# Preserve specific fields from provider_specific_fields if present
if message.get("provider_specific_fields"):
if message["provider_specific_fields"].get("reasoning_details"):
assistant_message_for_history["reasoning_details"] = message["provider_specific_fields"][
"reasoning_details"
]

# Add to actual conversation history
conversation_history.append(assistant_message_for_history)

Expand Down
16 changes: 8 additions & 8 deletions eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
Clean messages with only OpenAI API compatible fields
"""
# Standard OpenAI message fields
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_details"}

clean_messages = []
for msg in messages:
Expand Down Expand Up @@ -217,12 +217,15 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
logger.debug(f"🔄 API call for model: {self.model_id}")

# LiteLLM already returns OpenAI-compatible format
message_obj = getattr(response.choices[0], "message", object())

return {
"choices": [
{
"message": {
"role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"),
"content": getattr(getattr(response.choices[0], "message", object()), "content", None),
"role": getattr(message_obj, "role", "assistant"),
"content": getattr(message_obj, "content", None),
"provider_specific_fields": getattr(message_obj, "provider_specific_fields", None),
"tool_calls": (
[
{
Expand All @@ -233,12 +236,9 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
},
}
for tc in (
getattr(getattr(response.choices[0], "message", object()), "tool_calls", [])
or []
)
for tc in (getattr(message_obj, "tool_calls", []) or [])
]
if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None)
if getattr(message_obj, "tool_calls", None)
else []
),
},
Expand Down
95 changes: 95 additions & 0 deletions tests/test_litellm_policy_provider_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import types

import pytest

import eval_protocol.mcp.execution.policy as policy_mod
from eval_protocol.mcp.execution.policy import LiteLLMPolicy


@pytest.mark.asyncio
async def test_litellm_policy_surfaces_provider_specific_reasoning_details(monkeypatch):
"""
Ensure that provider_specific_fields from the LiteLLM message object are
preserved on the returned message dict from LiteLLMPolicy._make_llm_call.
"""

# Define a fake ModelResponse base class and patch the module's ModelResponse
class FakeModelResponseBase: ...

policy_mod.ModelResponse = FakeModelResponseBase

async def fake_acompletion(*args, **kwargs):
# This mimics the LiteLLM Message object shape we rely on in policy._make_llm_call
message_obj = types.SimpleNamespace(
role="assistant",
content="",
tool_calls=[
types.SimpleNamespace(
id="tool_get_reservation_details_123",
type="function",
function=types.SimpleNamespace(
name="get_reservation_details",
arguments='{"reservation_id":"EHGLP3"}',
),
)
],
provider_specific_fields={
"reasoning_details": [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}],
"custom_field": "keep_me",
},
)

class FakeModelResponse(FakeModelResponseBase):
def __init__(self) -> None:
self.choices = [
types.SimpleNamespace(
finish_reason="tool_calls",
index=0,
message=message_obj,
)
]
self.usage = types.SimpleNamespace(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
)

return FakeModelResponse()

# Patch acompletion so we don't hit the network
monkeypatch.setattr(policy_mod, "acompletion", fake_acompletion)

# Use a concrete policy instance; base_url/model_id values don't matter for this unit test
policy = LiteLLMPolicy(model_id="openrouter/google/gemini-3-pro-preview", use_caching=False)

messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "tool_get_reservation_details_123",
"type": "function",
"function": {"name": "get_reservation_details", "arguments": '{"reservation_id":"EHGLP3"}'},
}
],
}
]

# No tools are needed for this test – we only care about the returned message shape
result = await policy._make_llm_call(messages, tools=[])

assert "choices" in result
assert len(result["choices"]) == 1
msg = result["choices"][0]["message"]

# Core fields should be present
assert msg["role"] == "assistant"
assert isinstance(msg.get("tool_calls"), list)

# provider_specific_fields should be preserved on the message
ps = msg.get("provider_specific_fields")
assert isinstance(ps, dict)
assert ps["reasoning_details"] == [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}]
# Non-core provider_specific_fields should also be preserved
assert ps.get("custom_field") == "keep_me"
Loading