Skip to content

Commit 4759b68

Browse files
committed
new LiteLLMPolicy
1 parent 62824b1 commit 4759b68

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -222,29 +222,6 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
222222
try:
223223
self.server.start()
224224

225-
model_id = str(
226-
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
227-
)
228-
print("model_id from eval_protocol: ", model_id)
229-
temperature = config.completion_params.get("temperature", 0.0)
230-
max_tokens = config.completion_params.get("max_tokens", 4096)
231-
232-
# Pass all other completion_params (e.g. stream=True) via kwargs
233-
other_params = {
234-
k: v
235-
for k, v in (config.completion_params or {}).items()
236-
if k not in ["model", "temperature", "max_tokens", "extra_body"]
237-
}
238-
extra_body = config.completion_params.get("extra_body", {}) or {}
239-
240-
self.policy = ep.LiteLLMPolicy(
241-
model_id=model_id,
242-
temperature=temperature,
243-
max_tokens=max_tokens,
244-
**extra_body,
245-
**other_params,
246-
)
247-
248225
except Exception as e:
249226
if self.server:
250227
self.server.stop()
@@ -254,13 +231,35 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
254231

255232
else:
256233
# Reuse existing MCP environments for retry
257-
if not self.server or not self.policy:
234+
if not self.server:
258235
raise RuntimeError(
259236
"Cannot retry without existing server/environments. Call with start_server=True first."
260237
)
261238

239+
240+
model_id = str(
241+
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
242+
)
243+
print("model_id from eval_protocol: ", model_id)
244+
temperature = config.completion_params.get("temperature", 0.0)
245+
max_tokens = config.completion_params.get("max_tokens", 4096)
246+
247+
# Pass all other completion_params (e.g. stream=True) via kwargs
248+
other_params = {
249+
k: v
250+
for k, v in (config.completion_params or {}).items()
251+
if k not in ["model", "temperature", "max_tokens", "extra_body"]
252+
}
253+
extra_body = config.completion_params.get("extra_body", {}) or {}
254+
255+
self.policy = ep.LiteLLMPolicy(
256+
model_id=model_id,
257+
temperature=temperature,
258+
max_tokens=max_tokens,
259+
**extra_body,
260+
**other_params,
261+
)
262262
# Create MCP environments directly from evaluation_rows
263-
assert self.policy is not None, "Policy must be initialized before rollout"
264263
envs = ep.make(
265264
"http://localhost:9700/mcp/",
266265
evaluation_rows=rows,

0 commit comments

Comments
 (0)