diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index 1adb9b95..777c4f7e 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -5,15 +5,14 @@ Rewritten to use LiteLLM for unified retry logic, caching, and provider support. """ -import asyncio -import json import logging import os -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional import litellm -from litellm import acompletion, completion +from litellm import acompletion +from litellm.types.utils import ModelResponse +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from litellm.caching.caching import Cache from litellm.caching.dual_cache import DualCache from litellm.caching.in_memory_cache import InMemoryCache @@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[ request_params["tools"] = tools try: - response = await acompletion(model=self.model_id, **request_params) + if request_params.get("stream") is True: + chunks = [] + stream = await acompletion(model=self.model_id, **request_params) + + assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper" + + async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues] + chunks.append(chunk) + response = litellm.stream_chunk_builder(chunks, messages) + else: + response = await acompletion(model=self.model_id, **request_params) + + assert response is not None, "Response is None" + assert isinstance(response, ModelResponse), "Response should be ModelResponse" # Log cache hit/miss for monitoring hidden = getattr(response, "_hidden_params", {}) diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 9173e6f9..2d01b6c1 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -223,13 +223,26 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> try: self.server.start() + model_id = str( + (config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini" + ) + temperature = config.completion_params.get("temperature", 0.0) + max_tokens = config.completion_params.get("max_tokens", 4096) + + # Pass all other completion_params (e.g. stream=True) via kwargs + other_params = { + k: v + for k, v in (config.completion_params or {}).items() + if k not in ["model", "temperature", "max_tokens", "extra_body"] + } + extra_body = config.completion_params.get("extra_body", {}) or {} + self.policy = ep.LiteLLMPolicy( - model_id=str( - (config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini" - ), - temperature=config.completion_params.get("temperature", 0.0), - max_tokens=config.completion_params.get("max_tokens", 4096), - **(config.completion_params.get("extra_body", {}) or {}), + model_id=model_id, + temperature=temperature, + max_tokens=max_tokens, + **extra_body, + **other_params, ) except Exception as e: diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 2b4bf893..d98ab042 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -4,8 +4,10 @@ import time from typing import List +import litellm from litellm import acompletion -from typing import Dict +from litellm.types.utils import ModelResponse, Choices +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from eval_protocol.dataset_logger import default_logger from eval_protocol.models import EvaluationRow, Message @@ -62,12 +64,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if row.tools is not None: request_params["tools"] = row.tools - # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet - import importlib + if request_params.get("stream") is True: + chunks = [] + stream = await acompletion(**request_params) - _litellm = importlib.import_module("litellm") - acompletion = getattr(_litellm, "acompletion") - response = await acompletion(**request_params) + assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper" + + async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues] + chunks.append(chunk) + response = litellm.stream_chunk_builder(chunks, messages_payload) + else: + response = await acompletion(**request_params) + + assert response is not None, "Response is None" + assert isinstance(response, ModelResponse), "Response should be ModelResponse" + assert isinstance(response.choices[0], Choices), "Response choice should be a Choices" assistant_content = response.choices[0].message.content or "" tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None @@ -110,11 +121,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: tool_calls=converted_tool_calls, ) ] - - row.execution_metadata.usage = CompletionUsage( - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, + row.execution_metadata.usage = ( + CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field + prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue] + completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue] + total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue] + ) ) row.messages = messages