-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_single_turn_rollout_process.py
More file actions
248 lines (207 loc) · 11.2 KB
/
default_single_turn_rollout_process.py
File metadata and controls
248 lines (207 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import asyncio
import json
import logging
import os
import time
from datetime import datetime, timezone
from dataclasses import asdict, is_dataclass
from types import SimpleNamespace
from typing import Any, List
import litellm
from litellm import acompletion
from litellm.types.utils import ModelResponse, Choices
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
logger = logging.getLogger(__name__)
def _serialize_logprobs(logprobs: Any) -> Any:
"""Best-effort conversion of provider logprobs into JSON-serializable data."""
if logprobs is None:
return None
if hasattr(logprobs, "model_dump"):
try:
return logprobs.model_dump()
except Exception:
pass
if is_dataclass(logprobs) and not isinstance(logprobs, type):
return asdict(logprobs)
if isinstance(logprobs, SimpleNamespace):
return vars(logprobs)
if isinstance(logprobs, dict):
return logprobs
try:
return json.loads(json.dumps(logprobs, default=lambda o: getattr(o, "__dict__", str(o))))
except Exception:
return logprobs
class SingleTurnRolloutProcessor(RolloutProcessor):
"""Single turn rollout processor for direct LLM calls."""
def __init__(self, *, drop_trailing_assistant_messages: bool = True) -> None:
"""
Args:
drop_trailing_assistant_messages: When True (default), strip any trailing
assistant messages from the input conversation before calling the model.
This helps when datasets include previous assistant turns and you want
the model to answer the latest user query.
"""
self.drop_trailing_assistant_messages = drop_trailing_assistant_messages
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate single turn rollout tasks and return them for external handling."""
# Do not modify global LiteLLM cache. Disable caching per-request instead.
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if row.execution_metadata.rollout_start_time is None:
row.execution_metadata.rollout_start_time = datetime.now(timezone.utc)
start_time = time.perf_counter()
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")
# Optionally drop trailing assistant messages for single-turn prompts
messages_for_request: List[Message] = list(row.messages)
if self.drop_trailing_assistant_messages:
while messages_for_request and messages_for_request[-1].role == "assistant":
messages_for_request.pop()
# Filter out fields that are not supported by OpenAI/LiteLLM APIs (e.g., weight, control_plane_step, reasoning_content)
# Use the Message class method that excludes unsupported fields
messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request]
request_params = {"messages": messages_payload, **config.completion_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}
api_base = os.getenv("EP_LLM_API_BASE") or os.getenv("EP_LLM_BASE_URL")
if api_base and "api_base" not in request_params:
request_params["api_base"] = api_base
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None
if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
if effort_val:
# Always under extra_body so LiteLLM forwards to provider-specific param set
request_params.setdefault("extra_body", {})
request_params["extra_body"]["reasoning_effort"] = effort_val
# Ensure unsupported top-level keys are not present
if "reasoning_effort" in request_params:
request_params.pop("reasoning_effort", None)
# Handle raw_output - move to extra_body so LiteLLM forwards it
if "raw_output" in request_params:
request_params.setdefault("extra_body", {})
request_params["extra_body"]["raw_output"] = request_params.pop("raw_output")
if row.tools is not None:
request_params["tools"] = row.tools
if request_params.get("stream") is True:
chunks = []
stream = await acompletion(**request_params)
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
chunks.append(chunk)
response = litellm.stream_chunk_builder(chunks, messages_payload)
else:
response = await acompletion(**request_params)
assert response is not None, "Response is None"
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
assert isinstance(response.choices[0], Choices), "Response choice should be a Choices"
assistant_message = response.choices[0].message
finish_reason = getattr(response.choices[0], "finish_reason", None)
assistant_logprobs = _serialize_logprobs(getattr(response.choices[0], "logprobs", None))
# Extract content
assistant_content = assistant_message.content or ""
# Extract reasoning content (if present)
reasoning_content = getattr(assistant_message, "reasoning_content", None)
if reasoning_content is None:
reasoning_content = getattr(assistant_message, "reasoning", None)
if reasoning_content is not None and not isinstance(reasoning_content, str):
try:
reasoning_content = json.dumps(reasoning_content)
except Exception:
reasoning_content = str(reasoning_content)
# Extract tool calls
tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None
converted_tool_calls = None
if tool_calls:
converted_tool_calls = []
for tool_call in tool_calls:
try:
converted_tool_calls.append(
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
)
except Exception:
# best-effort: fallback to dict form
try:
converted_tool_calls.append(
{
"id": getattr(tool_call, "id", "toolcall_0"),
"type": getattr(tool_call, "type", "function"),
"function": {
"name": getattr(getattr(tool_call, "function", None), "name", "tool"),
"arguments": getattr(getattr(tool_call, "function", None), "arguments", "{}"),
},
}
)
except Exception:
pass
messages = list(messages_for_request) + [
Message(
role="assistant",
content=assistant_content,
reasoning_content=reasoning_content,
tool_calls=converted_tool_calls,
logprobs=assistant_logprobs,
)
]
row.execution_metadata.finish_reason = str(finish_reason) if finish_reason is not None else None
row.execution_metadata.tool_call_count = (
len(converted_tool_calls) if converted_tool_calls is not None else 0
)
# Extract raw_output if present (when raw_output=True was passed to the API)
# Note: raw_output is only captured for non-streaming requests
# LiteLLM stores extra fields in model_extra for non-streaming responses
choice = response.choices[0]
raw_output = None
# Check model_extra (where LiteLLM puts extra fields for non-streaming)
if hasattr(choice, "model_extra") and choice.model_extra:
raw_output = choice.model_extra.get("raw_output")
# Fallback: check as direct attribute
if raw_output is None:
raw_output = getattr(choice, "raw_output", None)
if raw_output is not None and isinstance(raw_output, dict):
row.execution_metadata.raw_output = raw_output
usage = getattr(response, "usage", None)
if usage:
row.execution_metadata.usage = (
CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field
prompt_tokens=getattr(usage, "prompt_tokens", 0),
completion_tokens=getattr(usage, "completion_tokens", 0),
total_tokens=getattr(usage, "total_tokens", 0),
)
)
else:
row.execution_metadata.usage = None
row.messages = messages
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
default_logger.log(row)
return row
semaphore = config.semaphore
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks