-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_single_turn_rollout_process.py
More file actions
197 lines (164 loc) · 8.98 KB
/
default_single_turn_rollout_process.py
File metadata and controls
197 lines (164 loc) · 8.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import asyncio
import json
import logging
import os
import time
from typing import List
import litellm
from litellm import acompletion
from litellm.types.utils import ModelResponse, Choices
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
logger = logging.getLogger(__name__)
class SingleTurnRolloutProcessor(RolloutProcessor):
"""Single turn rollout processor for direct LLM calls."""
def __init__(self, *, drop_trailing_assistant_messages: bool = True) -> None:
"""
Args:
drop_trailing_assistant_messages: When True (default), strip any trailing
assistant messages from the input conversation before calling the model.
This helps when datasets include previous assistant turns and you want
the model to answer the latest user query.
"""
self.drop_trailing_assistant_messages = drop_trailing_assistant_messages
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate single turn rollout tasks and return them for external handling."""
# Do not modify global LiteLLM cache. Disable caching per-request instead.
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
start_time = time.perf_counter()
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")
# Optionally drop trailing assistant messages for single-turn prompts
messages_for_request: List[Message] = list(row.messages)
if self.drop_trailing_assistant_messages:
while messages_for_request and messages_for_request[-1].role == "assistant":
messages_for_request.pop()
# Filter out fields that are not supported by OpenAI/LiteLLM APIs (e.g., weight, control_plane_step, reasoning_content)
# Use the Message class method that excludes unsupported fields
messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request]
request_params = {"messages": messages_payload, **config.completion_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}
api_base = os.getenv("EP_LLM_API_BASE") or os.getenv("EP_LLM_BASE_URL")
if api_base and "api_base" not in request_params:
request_params["api_base"] = api_base
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None
if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
if effort_val:
# Always under extra_body so LiteLLM forwards to provider-specific param set
request_params.setdefault("extra_body", {})
request_params["extra_body"]["reasoning_effort"] = effort_val
# Ensure unsupported top-level keys are not present
if "reasoning_effort" in request_params:
request_params.pop("reasoning_effort", None)
if row.tools is not None:
request_params["tools"] = row.tools
if request_params.get("stream") is True:
chunks = []
stream = await acompletion(**request_params)
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
chunks.append(chunk)
response = litellm.stream_chunk_builder(chunks, messages_payload)
else:
response = await acompletion(**request_params)
assert response is not None, "Response is None"
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
assert isinstance(response.choices[0], Choices), "Response choice should be a Choices"
assistant_message = response.choices[0].message
finish_reason = getattr(response.choices[0], "finish_reason", None)
# Extract content
assistant_content = assistant_message.content or ""
# Extract reasoning content (if present)
reasoning_content = getattr(assistant_message, "reasoning_content", None)
if reasoning_content is None:
reasoning_content = getattr(assistant_message, "reasoning", None)
if reasoning_content is not None and not isinstance(reasoning_content, str):
try:
reasoning_content = json.dumps(reasoning_content)
except Exception:
reasoning_content = str(reasoning_content)
# Extract tool calls
tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None
converted_tool_calls = None
if tool_calls:
converted_tool_calls = []
for tool_call in tool_calls:
try:
converted_tool_calls.append(
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
)
except Exception:
# best-effort: fallback to dict form
try:
converted_tool_calls.append(
{
"id": getattr(tool_call, "id", "toolcall_0"),
"type": getattr(tool_call, "type", "function"),
"function": {
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
},
}
)
except Exception:
pass
messages = list(messages_for_request) + [
Message(
role="assistant",
content=assistant_content,
reasoning_content=reasoning_content,
tool_calls=converted_tool_calls,
)
]
row.execution_metadata.finish_reason = str(finish_reason) if finish_reason is not None else None
row.execution_metadata.tool_call_count = (
len(converted_tool_calls) if converted_tool_calls is not None else 0
)
usage = getattr(response, "usage", None)
if usage:
row.execution_metadata.usage = (
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
prompt_tokens=getattr(usage, "prompt_tokens", 0),
completion_tokens=getattr(usage, "completion_tokens", 0),
total_tokens=getattr(usage, "total_tokens", 0),
)
)
else:
row.execution_metadata.usage = None
row.messages = messages
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
default_logger.log(row)
return row
semaphore = config.semaphore
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks