Skip to content

Commit ada2ef2

Browse files
committed
fix single turn rollout acompletion
1 parent d5f3b81 commit ada2ef2

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import List
66

77
from litellm import acompletion
8-
from typing import Dict
98

109
from eval_protocol.dataset_logger import default_logger
1110
from eval_protocol.models import EvaluationRow, Message
@@ -62,15 +61,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
6261
if row.tools is not None:
6362
request_params["tools"] = row.tools
6463

65-
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
66-
import importlib
67-
68-
_litellm = importlib.import_module("litellm")
69-
acompletion = getattr(_litellm, "acompletion")
7064
response = await acompletion(**request_params)
7165

72-
assistant_content = response.choices[0].message.content or ""
73-
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
66+
assistant_content = response.choices[0].message.content or "" # pyright: ignore[reportAttributeAccessIssue]
67+
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None # pyright: ignore[reportAttributeAccessIssue]
7468

7569
converted_tool_calls = None
7670
if tool_calls:
@@ -112,9 +106,9 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
112106
]
113107

114108
row.execution_metadata.usage = CompletionUsage(
115-
prompt_tokens=response.usage.prompt_tokens,
116-
completion_tokens=response.usage.completion_tokens,
117-
total_tokens=response.usage.total_tokens,
109+
prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue]
110+
completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue]
111+
total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue]
118112
)
119113

120114
row.messages = messages

0 commit comments

Comments
 (0)