-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_single_turn_rollout_process.py
More file actions
109 lines (88 loc) · 4.6 KB
/
default_single_turn_rollout_process.py
File metadata and controls
109 lines (88 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import asyncio
import logging
import os
import time
from typing import List
from litellm import acompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
logger = logging.getLogger(__name__)
class SingleTurnRolloutProcessor(RolloutProcessor):
"""Single turn rollout processor for direct LLM calls."""
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate single turn rollout tasks and return them for external handling."""
# Do not modify global LiteLLM cache. Disable caching per-request instead.
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
request_params = {"messages": messages_payload, **config.completion_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None
if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
if effort_val:
# Always under extra_body so LiteLLM forwards to provider-specific param set
request_params.setdefault("extra_body", {})
request_params["extra_body"]["reasoning_effort"] = effort_val
# Ensure unsupported top-level keys are not present
if "reasoning_effort" in request_params:
request_params.pop("reasoning_effort", None)
if row.tools is not None:
request_params["tools"] = row.tools
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
import importlib
_litellm = importlib.import_module("litellm")
acompletion = getattr(_litellm, "acompletion")
response = await acompletion(**request_params)
assistant_content = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
converted_tool_calls = None
if tool_calls:
converted_tool_calls = [
ChatCompletionMessageToolCall(
id=tool_call.id,
type=tool_call.type,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
)
for tool_call in tool_calls
]
messages = list(row.messages) + [
Message(
role="assistant",
content=assistant_content,
tool_calls=converted_tool_calls,
)
]
row.messages = messages
default_logger.log(row)
return row
# Process rows with bounded concurrency
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks