Skip to content

Commit a2e412b

Browse files
committed
fix types
1 parent 02d4d14 commit a2e412b

File tree

3 files changed

+52
-20
lines changed

3 files changed

+52
-20
lines changed

eval_protocol/adapters/openai_responses.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl
169169
raise NotImplementedError(f"Unsupported content type: {content_item.type}")
170170
elif item.type == "function_call_output":
171171
# Collect tool call outputs to add before assistant message
172-
tool_call_outputs.append(Message(role="tool", content=item.output, tool_call_id=item.call_id))
172+
tool_call_outputs.append(
173+
Message(role="tool", content=self._coerce_tool_output(item.output), tool_call_id=item.call_id)
174+
)
173175
elif item.type == "function_call":
174176
tool_call = ChatCompletionMessageToolCall(
175177
id=item.call_id, type="function", function=Function(name=item.name, arguments=item.arguments)
@@ -186,3 +188,29 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl
186188
messages.append(Message(role="assistant", tool_calls=current_tool_calls))
187189

188190
return reversed(messages)
191+
192+
def _coerce_tool_output(self, output: Any) -> str:
193+
"""Coerce OpenAI Responses tool output into a string for Message.content.
194+
195+
The Responses API may return structured content lists. For our purposes,
196+
we stringify non-string outputs to satisfy the Message.content type.
197+
"""
198+
if isinstance(output, str):
199+
return output
200+
try:
201+
# Attempt to join list of objects with any 'text' fields
202+
if isinstance(output, list):
203+
parts: list[str] = []
204+
for part in output:
205+
text = None
206+
if isinstance(part, dict):
207+
text = part.get("text")
208+
if text:
209+
parts.append(str(text))
210+
else:
211+
parts.append(str(part))
212+
return "\n".join(parts)
213+
# Fallback to string conversion
214+
return str(output)
215+
except Exception:
216+
return str(output)

eval_protocol/log_utils/fireworks_tracing_http_handler.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import threading
44
from datetime import datetime, timezone
5-
from typing import Optional, Any, Dict, List
5+
from typing import Optional, Any, Dict, List, cast
66

77
import requests
88

@@ -33,27 +33,28 @@ def emit(self, record: logging.LogRecord) -> None:
3333
self.handleError(record)
3434

3535
def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]:
36-
if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore
37-
return str(record.rollout_id) # type: ignore
36+
if hasattr(record, "rollout_id") and cast(Any, getattr(record, "rollout_id")) is not None:
37+
return str(cast(Any, getattr(record, "rollout_id")))
3838
return os.getenv(self.rollout_id_env)
3939

4040
def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]:
4141
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
4242
message = record.getMessage()
4343
tags: List[str] = [f"rollout_id:{rollout_id}"]
4444
# Optional additional tags
45-
if hasattr(record, "experiment_id") and record.experiment_id:
46-
tags.append(f"experiment_id:{record.experiment_id}") # type: ignore
47-
if hasattr(record, "run_id") and record.run_id:
48-
tags.append(f"run_id:{record.run_id}") # type: ignore
49-
program = getattr(record, "program", None) or "eval_protocol"
50-
status = getattr(record, "status", None)
45+
if hasattr(record, "experiment_id") and cast(Any, getattr(record, "experiment_id")):
46+
tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}")
47+
if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")):
48+
tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}")
49+
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
50+
status_val = cast(Any, getattr(record, "status", None))
51+
status = status_val if isinstance(status_val, str) else None
5152
return {
5253
"program": program,
53-
"status": status if isinstance(status, str) else None,
54+
"status": status,
5455
"message": message,
5556
"tags": tags,
56-
"metadata": getattr(record, "metadata", None),
57+
"metadata": cast(Any, getattr(record, "metadata", None)),
5758
"extras": {
5859
"logger_name": record.name,
5960
"level": record.levelname,

scripts/validate_dev_tracing.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ def main():
1919

2020
logger = logging.getLogger("eval_protocol.sdk.validate")
2121

22-
logger.info("SDK sending structured log to dev gateway", extra={
23-
"rollout_id": rollout_id,
24-
"program": "eval_protocol",
25-
"status": "completed",
26-
"experiment_id": "dev-exp",
27-
"run_id": "dev-run",
28-
"metadata": {"source": "sdk-validate"},
29-
})
22+
logger.info(
23+
"SDK sending structured log to dev gateway",
24+
extra={
25+
"rollout_id": rollout_id,
26+
"program": "eval_protocol",
27+
"status": "completed",
28+
"experiment_id": "dev-exp",
29+
"run_id": "dev-run",
30+
"metadata": {"source": "sdk-validate"},
31+
},
32+
)
3033

3134
# Poll fetch with retries for indexing
3235
params = {

0 commit comments

Comments
 (0)