Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
Rewritten to use LiteLLM for unified retry logic, caching, and provider support.
"""

import asyncio
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional

import litellm
from litellm import acompletion, completion
from litellm import acompletion
from litellm.types.utils import ModelResponse
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.caching.caching import Cache
from litellm.caching.dual_cache import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
Expand Down Expand Up @@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
request_params["tools"] = tools

try:
response = await acompletion(model=self.model_id, **request_params)
if request_params.get("stream") is True:
chunks = []
stream = await acompletion(model=self.model_id, **request_params)

assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"

async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
chunks.append(chunk)
response = litellm.stream_chunk_builder(chunks, messages)
else:
response = await acompletion(model=self.model_id, **request_params)

assert response is not None, "Response is None"
assert isinstance(response, ModelResponse), "Response should be ModelResponse"

# Log cache hit/miss for monitoring
hidden = getattr(response, "_hidden_params", {})
Expand Down
25 changes: 19 additions & 6 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,26 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
try:
self.server.start()

model_id = str(
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
)
temperature = config.completion_params.get("temperature", 0.0)
max_tokens = config.completion_params.get("max_tokens", 4096)

# Pass all other completion_params (e.g. stream=True) via kwargs
other_params = {
k: v
for k, v in (config.completion_params or {}).items()
if k not in ["model", "temperature", "max_tokens", "extra_body"]
}
extra_body = config.completion_params.get("extra_body", {}) or {}

self.policy = ep.LiteLLMPolicy(
model_id=str(
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
),
temperature=config.completion_params.get("temperature", 0.0),
max_tokens=config.completion_params.get("max_tokens", 4096),
**(config.completion_params.get("extra_body", {}) or {}),
model_id=model_id,
temperature=temperature,
max_tokens=max_tokens,
**extra_body,
**other_params,
)

except Exception as e:
Expand Down
34 changes: 23 additions & 11 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import time
from typing import List

import litellm
from litellm import acompletion
from typing import Dict
from litellm.types.utils import ModelResponse, Choices
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
Expand Down Expand Up @@ -62,12 +64,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
if row.tools is not None:
request_params["tools"] = row.tools

# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
import importlib
if request_params.get("stream") is True:
chunks = []
stream = await acompletion(**request_params)

_litellm = importlib.import_module("litellm")
acompletion = getattr(_litellm, "acompletion")
response = await acompletion(**request_params)
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"

async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
chunks.append(chunk)
response = litellm.stream_chunk_builder(chunks, messages_payload)
else:
response = await acompletion(**request_params)

assert response is not None, "Response is None"
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
assert isinstance(response.choices[0], Choices), "Response choice should be a Choices"

assistant_content = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
Expand Down Expand Up @@ -110,11 +121,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
tool_calls=converted_tool_calls,
)
]

row.execution_metadata.usage = CompletionUsage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
row.execution_metadata.usage = (
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue]
completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue]
total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue]
)
)

row.messages = messages
Expand Down
Loading