diff --git a/eval_protocol/mcp/execution/base_policy.py b/eval_protocol/mcp/execution/base_policy.py index bdced48a..9d44d02c 100644 --- a/eval_protocol/mcp/execution/base_policy.py +++ b/eval_protocol/mcp/execution/base_policy.py @@ -199,6 +199,13 @@ async def _generate_live_tool_calls( if message.get("tool_calls"): assistant_message_for_history["tool_calls"] = message["tool_calls"] + # Preserve specific fields from provider_specific_fields if present + if message.get("provider_specific_fields"): + if message["provider_specific_fields"].get("reasoning_details"): + assistant_message_for_history["reasoning_details"] = message["provider_specific_fields"][ + "reasoning_details" + ] + # Add to actual conversation history conversation_history.append(assistant_message_for_history) diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index 777c4f7e..0b4aac4e 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -146,7 +146,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]: Clean messages with only OpenAI API compatible fields """ # Standard OpenAI message fields - allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"} + allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_details"} clean_messages = [] for msg in messages: @@ -217,12 +217,15 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[ logger.debug(f"🔄 API call for model: {self.model_id}") # LiteLLM already returns OpenAI-compatible format + message_obj = getattr(response.choices[0], "message", object()) + return { "choices": [ { "message": { - "role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"), - "content": getattr(getattr(response.choices[0], "message", object()), "content", None), + "role": getattr(message_obj, "role", "assistant"), + "content": getattr(message_obj, "content", None), + "provider_specific_fields": getattr(message_obj, "provider_specific_fields", None), "tool_calls": ( [ { @@ -233,12 +236,9 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[ "arguments": getattr(getattr(tc, "function", None), "arguments", "{}"), }, } - for tc in ( - getattr(getattr(response.choices[0], "message", object()), "tool_calls", []) - or [] - ) + for tc in (getattr(message_obj, "tool_calls", []) or []) ] - if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None) + if getattr(message_obj, "tool_calls", None) else [] ), }, diff --git a/tests/test_litellm_policy_provider_fields.py b/tests/test_litellm_policy_provider_fields.py new file mode 100644 index 00000000..6812ef9b --- /dev/null +++ b/tests/test_litellm_policy_provider_fields.py @@ -0,0 +1,95 @@ +import types + +import pytest + +import eval_protocol.mcp.execution.policy as policy_mod +from eval_protocol.mcp.execution.policy import LiteLLMPolicy + + +@pytest.mark.asyncio +async def test_litellm_policy_surfaces_provider_specific_reasoning_details(monkeypatch): + """ + Ensure that provider_specific_fields from the LiteLLM message object are + preserved on the returned message dict from LiteLLMPolicy._make_llm_call. + """ + + # Define a fake ModelResponse base class and patch the module's ModelResponse + class FakeModelResponseBase: ... + + policy_mod.ModelResponse = FakeModelResponseBase + + async def fake_acompletion(*args, **kwargs): + # This mimics the LiteLLM Message object shape we rely on in policy._make_llm_call + message_obj = types.SimpleNamespace( + role="assistant", + content="", + tool_calls=[ + types.SimpleNamespace( + id="tool_get_reservation_details_123", + type="function", + function=types.SimpleNamespace( + name="get_reservation_details", + arguments='{"reservation_id":"EHGLP3"}', + ), + ) + ], + provider_specific_fields={ + "reasoning_details": [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}], + "custom_field": "keep_me", + }, + ) + + class FakeModelResponse(FakeModelResponseBase): + def __init__(self) -> None: + self.choices = [ + types.SimpleNamespace( + finish_reason="tool_calls", + index=0, + message=message_obj, + ) + ] + self.usage = types.SimpleNamespace( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ) + + return FakeModelResponse() + + # Patch acompletion so we don't hit the network + monkeypatch.setattr(policy_mod, "acompletion", fake_acompletion) + + # Use a concrete policy instance; base_url/model_id values don't matter for this unit test + policy = LiteLLMPolicy(model_id="openrouter/google/gemini-3-pro-preview", use_caching=False) + + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "tool_get_reservation_details_123", + "type": "function", + "function": {"name": "get_reservation_details", "arguments": '{"reservation_id":"EHGLP3"}'}, + } + ], + } + ] + + # No tools are needed for this test – we only care about the returned message shape + result = await policy._make_llm_call(messages, tools=[]) + + assert "choices" in result + assert len(result["choices"]) == 1 + msg = result["choices"][0]["message"] + + # Core fields should be present + assert msg["role"] == "assistant" + assert isinstance(msg.get("tool_calls"), list) + + # provider_specific_fields should be preserved on the message + ps = msg.get("provider_specific_fields") + assert isinstance(ps, dict) + assert ps["reasoning_details"] == [{"id": "tool_get_reservation_details_123", "type": "reasoning.encrypted"}] + # Non-core provider_specific_fields should also be preserved + assert ps.get("custom_field") == "keep_me"