Skip to content

Commit 95307e1

Browse files
committed
2 parents 92ccb22 + a3341a2 commit 95307e1

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

eval_protocol/mcp/execution/policy.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Rewritten to use LiteLLM for unified retry logic, caching, and provider support.
66
"""
77

8-
import asyncio
9-
import json
108
import logging
119
import os
12-
from abc import ABC, abstractmethod
13-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Literal, Optional
1411

1512
import litellm
16-
from litellm import acompletion, completion
13+
from litellm import acompletion
14+
from litellm.types.utils import ModelResponse
15+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
1716
from litellm.caching.caching import Cache
1817
from litellm.caching.dual_cache import DualCache
1918
from litellm.caching.in_memory_cache import InMemoryCache
@@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
194193
request_params["tools"] = tools
195194

196195
try:
197-
response = await acompletion(model=self.model_id, **request_params)
196+
if request_params.get("stream") is True:
197+
chunks = []
198+
stream = await acompletion(model=self.model_id, **request_params)
199+
200+
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
201+
202+
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
203+
chunks.append(chunk)
204+
response = litellm.stream_chunk_builder(chunks, messages)
205+
else:
206+
response = await acompletion(model=self.model_id, **request_params)
207+
208+
assert response is not None, "Response is None"
209+
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
198210

199211
# Log cache hit/miss for monitoring
200212
hidden = getattr(response, "_hidden_params", {})

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,26 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
223223
try:
224224
self.server.start()
225225

226+
model_id = str(
227+
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
228+
)
229+
temperature = config.completion_params.get("temperature", 0.0)
230+
max_tokens = config.completion_params.get("max_tokens", 4096)
231+
232+
# Pass all other completion_params (e.g. stream=True) via kwargs
233+
other_params = {
234+
k: v
235+
for k, v in (config.completion_params or {}).items()
236+
if k not in ["model", "temperature", "max_tokens", "extra_body"]
237+
}
238+
extra_body = config.completion_params.get("extra_body", {}) or {}
239+
226240
self.policy = ep.LiteLLMPolicy(
227-
model_id=str(
228-
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
229-
),
230-
temperature=config.completion_params.get("temperature", 0.0),
231-
max_tokens=config.completion_params.get("max_tokens", 4096),
232-
**(config.completion_params.get("extra_body", {}) or {}),
241+
model_id=model_id,
242+
temperature=temperature,
243+
max_tokens=max_tokens,
244+
**extra_body,
245+
**other_params,
233246
)
234247

235248
except Exception as e:

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import time
55
from typing import List
66

7+
import litellm
78
from litellm import acompletion
8-
from typing import Dict
9+
from litellm.types.utils import ModelResponse, Choices
10+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
911

1012
from eval_protocol.dataset_logger import default_logger
1113
from eval_protocol.models import EvaluationRow, Message
@@ -62,12 +64,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
6264
if row.tools is not None:
6365
request_params["tools"] = row.tools
6466

65-
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
66-
import importlib
67+
if request_params.get("stream") is True:
68+
chunks = []
69+
stream = await acompletion(**request_params)
6770

68-
_litellm = importlib.import_module("litellm")
69-
acompletion = getattr(_litellm, "acompletion")
70-
response = await acompletion(**request_params)
71+
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
72+
73+
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
74+
chunks.append(chunk)
75+
response = litellm.stream_chunk_builder(chunks, messages_payload)
76+
else:
77+
response = await acompletion(**request_params)
78+
79+
assert response is not None, "Response is None"
80+
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
81+
assert isinstance(response.choices[0], Choices), "Response choice should be a Choices"
7182

7283
assistant_content = response.choices[0].message.content or ""
7384
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
@@ -110,11 +121,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
110121
tool_calls=converted_tool_calls,
111122
)
112123
]
113-
114-
row.execution_metadata.usage = CompletionUsage(
115-
prompt_tokens=response.usage.prompt_tokens,
116-
completion_tokens=response.usage.completion_tokens,
117-
total_tokens=response.usage.total_tokens,
124+
row.execution_metadata.usage = (
125+
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
126+
prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue]
127+
completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue]
128+
total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue]
129+
)
118130
)
119131

120132
row.messages = messages

0 commit comments

Comments
 (0)