-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_single_turn_rollout_process.py
More file actions
131 lines (109 loc) · 5.65 KB
/
default_single_turn_rollout_process.py
File metadata and controls
131 lines (109 loc) · 5.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import asyncio
import logging
import os
import time
from typing import List
from litellm import acompletion
from typing import Dict
from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
logger = logging.getLogger(__name__)
class SingleTurnRolloutProcessor(RolloutProcessor):
"""Single turn rollout processor for direct LLM calls."""
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate single turn rollout tasks and return them for external handling."""
# Do not modify global LiteLLM cache. Disable caching per-request instead.
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
request_params = {"messages": messages_payload, **config.completion_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None
if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
if effort_val:
# Always under extra_body so LiteLLM forwards to provider-specific param set
request_params.setdefault("extra_body", {})
request_params["extra_body"]["reasoning_effort"] = effort_val
# Ensure unsupported top-level keys are not present
if "reasoning_effort" in request_params:
request_params.pop("reasoning_effort", None)
if row.tools is not None:
request_params["tools"] = row.tools
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
import importlib
_litellm = importlib.import_module("litellm")
acompletion = getattr(_litellm, "acompletion")
response = await acompletion(**request_params)
assistant_content = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
converted_tool_calls = None
if tool_calls:
converted_tool_calls = []
for tool_call in tool_calls:
try:
converted_tool_calls.append(
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
)
except Exception:
# best-effort: fallback to dict form
try:
converted_tool_calls.append(
{
"id": getattr(tool_call, "id", "toolcall_0"),
"type": getattr(tool_call, "type", "function"),
"function": {
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
},
}
)
except Exception:
pass
messages = list(row.messages) + [
Message(
role="assistant",
content=assistant_content,
tool_calls=converted_tool_calls,
)
]
row.execution_metadata.usage = CompletionUsage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
row.messages = messages
default_logger.log(row)
return row
semaphore = config.semaphore
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks