|
16 | 16 |
|
17 | 17 | logger = logging.getLogger(__name__) |
18 | 18 |
|
19 | | -litellm._turn_on_debug() # pyright: ignore[reportPrivateImportUsage] |
20 | | - |
21 | 19 |
|
22 | 20 | class SingleTurnRolloutProcessor(RolloutProcessor): |
23 | 21 | """Single turn rollout processor for direct LLM calls.""" |
@@ -66,26 +64,19 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: |
66 | 64 | if row.tools is not None: |
67 | 65 | request_params["tools"] = row.tools |
68 | 66 |
|
69 | | - # _litellm = importlib.import_module("litellm") |
70 | | - # acompletion = getattr(_litellm, "acompletion") |
71 | | - |
72 | | - # Handle streaming response |
73 | | - assistant_content = "" |
74 | | - tool_calls = None |
75 | | - usage_info = None |
| 67 | + chunks = [] |
76 | 68 |
|
77 | 69 | stream = await acompletion(**request_params) |
78 | | - async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues] |
79 | | - if chunk.choices and len(chunk.choices) > 0: |
80 | | - delta = chunk.choices[0].delta |
81 | | - if hasattr(delta, "content") and delta.content: |
82 | | - assistant_content += delta.content |
83 | | - if hasattr(delta, "tool_calls") and delta.tool_calls: |
84 | | - tool_calls = delta.tool_calls |
85 | | - |
86 | | - # Capture usage info from the final chunk |
87 | | - if hasattr(chunk, "usage") and chunk.usage: |
88 | | - usage_info = chunk.usage |
| 70 | + async for chunk in stream: |
| 71 | + chunks.append(chunk) |
| 72 | + |
| 73 | + response = litellm.stream_chunk_builder(chunks, messages_payload) |
| 74 | + |
| 75 | + if response is None: |
| 76 | + raise ValueError("Response is None") |
| 77 | + |
| 78 | + assistant_content = response.choices[0].message.content or "" |
| 79 | + tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None |
89 | 80 |
|
90 | 81 | converted_tool_calls = None |
91 | 82 | if tool_calls: |
@@ -125,20 +116,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: |
125 | 116 | tool_calls=converted_tool_calls, |
126 | 117 | ) |
127 | 118 | ] |
128 | | - |
129 | | - if usage_info: |
130 | | - row.execution_metadata.usage = CompletionUsage( |
131 | | - prompt_tokens=usage_info.prompt_tokens, |
132 | | - completion_tokens=usage_info.completion_tokens, |
133 | | - total_tokens=usage_info.total_tokens, |
134 | | - ) |
135 | | - else: |
136 | | - # Fallback if usage info not available from streaming |
137 | | - row.execution_metadata.usage = CompletionUsage( |
138 | | - prompt_tokens=0, |
139 | | - completion_tokens=0, |
140 | | - total_tokens=0, |
141 | | - ) |
| 119 | + row.execution_metadata.usage = CompletionUsage( |
| 120 | + prompt_tokens=response.usage.prompt_tokens, |
| 121 | + completion_tokens=response.usage.completion_tokens, |
| 122 | + total_tokens=response.usage.total_tokens, |
| 123 | + ) |
142 | 124 |
|
143 | 125 | row.messages = messages |
144 | 126 |
|
|
0 commit comments