Skip to content

Commit 1221842

Browse files
committed
enable other stream
1 parent b4d149c commit 1221842

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

eval_protocol/mcp/execution/policy.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Rewritten to use LiteLLM for unified retry logic, caching, and provider support.
66
"""
77

8-
import asyncio
9-
import json
108
import logging
119
import os
12-
from abc import ABC, abstractmethod
13-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Literal, Optional
1411

1512
import litellm
16-
from litellm import acompletion, completion
13+
from litellm import acompletion
14+
from litellm.types.utils import ModelResponse
15+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
1716
from litellm.caching.caching import Cache
1817
from litellm.caching.dual_cache import DualCache
1918
from litellm.caching.in_memory_cache import InMemoryCache
@@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
194193
request_params["tools"] = tools
195194

196195
try:
197-
response = await acompletion(model=self.model_id, **request_params)
196+
if request_params.get("stream") is True:
197+
chunks = []
198+
stream = await acompletion(model=self.model_id, **request_params)
199+
200+
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
201+
202+
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
203+
chunks.append(chunk)
204+
response = litellm.stream_chunk_builder(chunks, messages)
205+
else:
206+
response = await acompletion(model=self.model_id, **request_params)
207+
208+
assert response is not None, "Response is None"
209+
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
198210

199211
# Log cache hit/miss for monitoring
200212
hidden = getattr(response, "_hidden_params", {})

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,26 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
223223
try:
224224
self.server.start()
225225

226+
model_id = str(
227+
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
228+
)
229+
temperature = config.completion_params.get("temperature", 0.0)
230+
max_tokens = config.completion_params.get("max_tokens", 4096)
231+
232+
# Pass all other completion_params (e.g. stream=True) via kwargs
233+
other_params = {
234+
k: v
235+
for k, v in (config.completion_params or {}).items()
236+
if k not in ["model", "temperature", "max_tokens", "extra_body"]
237+
}
238+
extra_body = config.completion_params.get("extra_body", {}) or {}
239+
226240
self.policy = ep.LiteLLMPolicy(
227-
model_id=str(
228-
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
229-
),
230-
temperature=config.completion_params.get("temperature", 0.0),
231-
max_tokens=config.completion_params.get("max_tokens", 4096),
232-
**(config.completion_params.get("extra_body", {}) or {}),
241+
model_id=model_id,
242+
temperature=temperature,
243+
max_tokens=max_tokens,
244+
**extra_body,
245+
**other_params,
233246
)
234247

235248
except Exception as e:

0 commit comments

Comments
 (0)