Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
from .log_utils.rollout_id_filter import RolloutIdFilter
from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler
from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler


from .types.remote_rollout_processor import (
InitRequest,
Expand Down Expand Up @@ -89,6 +91,7 @@
"BraintrustAdapter",
"create_braintrust_adapter",
"LangSmithAdapter",
"FireworksTracingHttpHandler",
# Core interfaces
"Message",
"MetricResult",
Expand Down
30 changes: 29 additions & 1 deletion eval_protocol/adapters/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl
raise NotImplementedError(f"Unsupported content type: {content_item.type}")
elif item.type == "function_call_output":
# Collect tool call outputs to add before assistant message
tool_call_outputs.append(Message(role="tool", content=item.output, tool_call_id=item.call_id))
tool_call_outputs.append(
Message(role="tool", content=self._coerce_tool_output(item.output), tool_call_id=item.call_id)
)
elif item.type == "function_call":
tool_call = ChatCompletionMessageToolCall(
id=item.call_id, type="function", function=Function(name=item.name, arguments=item.arguments)
Expand All @@ -186,3 +188,29 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl
messages.append(Message(role="assistant", tool_calls=current_tool_calls))

return reversed(messages)

def _coerce_tool_output(self, output: Any) -> str:
"""Coerce OpenAI Responses tool output into a string for Message.content.

The Responses API may return structured content lists. For our purposes,
we stringify non-string outputs to satisfy the Message.content type.
"""
if isinstance(output, str):
return output
try:
# Attempt to join list of objects with any 'text' fields
if isinstance(output, list):
parts: list[str] = []
for part in output:
text = None
if isinstance(part, dict):
text = part.get("text")
if text:
parts.append(str(text))
else:
parts.append(str(part))
return "\n".join(parts)
# Fallback to string conversion
return str(output)
except Exception:
return str(output)
63 changes: 63 additions & 0 deletions eval_protocol/log_utils/fireworks_tracing_http_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
import os
import threading
from datetime import datetime, timezone
from typing import Optional, Any, Dict, List, cast

import requests


class FireworksTracingHttpHandler(logging.Handler):
"""Logging handler that posts structured logs to tracing.fireworks gateway /logs endpoint."""

def __init__(self, gateway_base_url: Optional[str] = None, rollout_id_env: str = "EP_ROLLOUT_ID") -> None:
super().__init__()
self.gateway_base_url = gateway_base_url or os.getenv("FW_TRACING_GATEWAY_BASE_URL")
self.rollout_id_env = rollout_id_env
self._session = requests.Session()
self._lock = threading.Lock()

def emit(self, record: logging.LogRecord) -> None:
try:
if not self.gateway_base_url:
return
rollout_id = self._get_rollout_id(record)
if not rollout_id:
return
payload = self._build_payload(record, rollout_id)
url = f"{self.gateway_base_url.rstrip('/')}/logs"
with self._lock:
self._session.post(url, json=payload, timeout=5)
except Exception:
# Avoid raising exceptions from logging
self.handleError(record)

def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]:
if hasattr(record, "rollout_id") and cast(Any, getattr(record, "rollout_id")) is not None:
return str(cast(Any, getattr(record, "rollout_id")))
return os.getenv(self.rollout_id_env)

def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]:
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
message = record.getMessage()
tags: List[str] = [f"rollout_id:{rollout_id}"]
# Optional additional tags
if hasattr(record, "experiment_id") and cast(Any, getattr(record, "experiment_id")):
tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}")
if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")):
tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}")
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
status_val = cast(Any, getattr(record, "status", None))
status = status_val if isinstance(status_val, str) else None
return {
"program": program,
"status": status,
"message": message,
"tags": tags,
"metadata": cast(Any, getattr(record, "metadata", None)),
"extras": {
"logger_name": record.name,
"level": record.levelname,
"timestamp": timestamp,
},
}
58 changes: 58 additions & 0 deletions scripts/validate_dev_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import time
import logging
import requests

from eval_protocol import FireworksTracingHttpHandler


def main():
gateway = os.getenv("FW_TRACING_GATEWAY_BASE_URL")
if not gateway:
# default to deployed dev gateway
gateway = "https://metadata-gateway-dev-644257448872.us-central1.run.app"
rollout_id = os.getenv("EP_ROLLOUT_ID", f"sdk-dev-{int(time.time())}")

root = logging.getLogger()
root.setLevel(logging.INFO)
root.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway))

logger = logging.getLogger("eval_protocol.sdk.validate")

logger.info(
"SDK sending structured log to dev gateway",
extra={
"rollout_id": rollout_id,
"program": "eval_protocol",
"status": "completed",
"experiment_id": "dev-exp",
"run_id": "dev-run",
"metadata": {"source": "sdk-validate"},
},
)

# Poll fetch with retries for indexing
params = {
"tags": [f"rollout_id:{rollout_id}"],
"program": "eval_protocol",
"limit": 10,
"hours_back": 1,
}
total = 0
for _ in range(20):
r = requests.get(f"{gateway}/logs", params=params, timeout=30)
r.raise_for_status()
data = r.json()
total = int(data.get("total_entries") or 0)
if total > 0:
print("Fetched entries:", total)
for e in data.get("entries", []):
print({k: e.get(k) for k in ["timestamp", "severity", "program", "status", "message", "tags"]})
break
time.sleep(3)
if total == 0:
print("Fetched entries: 0 (after retries)")


if __name__ == "__main__":
main()
Loading