diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index c18ee329..0ca8c749 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -35,6 +35,8 @@ from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler from .log_utils.rollout_id_filter import RolloutIdFilter from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler +from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler + from .types.remote_rollout_processor import ( InitRequest, @@ -89,6 +91,7 @@ "BraintrustAdapter", "create_braintrust_adapter", "LangSmithAdapter", + "FireworksTracingHttpHandler", # Core interfaces "Message", "MetricResult", diff --git a/eval_protocol/adapters/openai_responses.py b/eval_protocol/adapters/openai_responses.py index e24489f5..657f66a5 100644 --- a/eval_protocol/adapters/openai_responses.py +++ b/eval_protocol/adapters/openai_responses.py @@ -169,7 +169,9 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl raise NotImplementedError(f"Unsupported content type: {content_item.type}") elif item.type == "function_call_output": # Collect tool call outputs to add before assistant message - tool_call_outputs.append(Message(role="tool", content=item.output, tool_call_id=item.call_id)) + tool_call_outputs.append( + Message(role="tool", content=self._coerce_tool_output(item.output), tool_call_id=item.call_id) + ) elif item.type == "function_call": tool_call = ChatCompletionMessageToolCall( id=item.call_id, type="function", function=Function(name=item.name, arguments=item.arguments) @@ -186,3 +188,29 @@ def _create_messages(self, input_items: SyncCursorPage[ResponseItem]) -> Iterabl messages.append(Message(role="assistant", tool_calls=current_tool_calls)) return reversed(messages) + + def _coerce_tool_output(self, output: Any) -> str: + """Coerce OpenAI Responses tool output into a string for Message.content. + + The Responses API may return structured content lists. For our purposes, + we stringify non-string outputs to satisfy the Message.content type. + """ + if isinstance(output, str): + return output + try: + # Attempt to join list of objects with any 'text' fields + if isinstance(output, list): + parts: list[str] = [] + for part in output: + text = None + if isinstance(part, dict): + text = part.get("text") + if text: + parts.append(str(text)) + else: + parts.append(str(part)) + return "\n".join(parts) + # Fallback to string conversion + return str(output) + except Exception: + return str(output) diff --git a/eval_protocol/log_utils/fireworks_tracing_http_handler.py b/eval_protocol/log_utils/fireworks_tracing_http_handler.py new file mode 100644 index 00000000..0e8dfdf5 --- /dev/null +++ b/eval_protocol/log_utils/fireworks_tracing_http_handler.py @@ -0,0 +1,63 @@ +import logging +import os +import threading +from datetime import datetime, timezone +from typing import Optional, Any, Dict, List, cast + +import requests + + +class FireworksTracingHttpHandler(logging.Handler): + """Logging handler that posts structured logs to tracing.fireworks gateway /logs endpoint.""" + + def __init__(self, gateway_base_url: Optional[str] = None, rollout_id_env: str = "EP_ROLLOUT_ID") -> None: + super().__init__() + self.gateway_base_url = gateway_base_url or os.getenv("FW_TRACING_GATEWAY_BASE_URL") + self.rollout_id_env = rollout_id_env + self._session = requests.Session() + self._lock = threading.Lock() + + def emit(self, record: logging.LogRecord) -> None: + try: + if not self.gateway_base_url: + return + rollout_id = self._get_rollout_id(record) + if not rollout_id: + return + payload = self._build_payload(record, rollout_id) + url = f"{self.gateway_base_url.rstrip('/')}/logs" + with self._lock: + self._session.post(url, json=payload, timeout=5) + except Exception: + # Avoid raising exceptions from logging + self.handleError(record) + + def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]: + if hasattr(record, "rollout_id") and cast(Any, getattr(record, "rollout_id")) is not None: + return str(cast(Any, getattr(record, "rollout_id"))) + return os.getenv(self.rollout_id_env) + + def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]: + timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + message = record.getMessage() + tags: List[str] = [f"rollout_id:{rollout_id}"] + # Optional additional tags + if hasattr(record, "experiment_id") and cast(Any, getattr(record, "experiment_id")): + tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}") + if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")): + tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}") + program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol" + status_val = cast(Any, getattr(record, "status", None)) + status = status_val if isinstance(status_val, str) else None + return { + "program": program, + "status": status, + "message": message, + "tags": tags, + "metadata": cast(Any, getattr(record, "metadata", None)), + "extras": { + "logger_name": record.name, + "level": record.levelname, + "timestamp": timestamp, + }, + } diff --git a/scripts/validate_dev_tracing.py b/scripts/validate_dev_tracing.py new file mode 100644 index 00000000..cb8ab3c0 --- /dev/null +++ b/scripts/validate_dev_tracing.py @@ -0,0 +1,58 @@ +import os +import time +import logging +import requests + +from eval_protocol import FireworksTracingHttpHandler + + +def main(): + gateway = os.getenv("FW_TRACING_GATEWAY_BASE_URL") + if not gateway: + # default to deployed dev gateway + gateway = "https://metadata-gateway-dev-644257448872.us-central1.run.app" + rollout_id = os.getenv("EP_ROLLOUT_ID", f"sdk-dev-{int(time.time())}") + + root = logging.getLogger() + root.setLevel(logging.INFO) + root.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway)) + + logger = logging.getLogger("eval_protocol.sdk.validate") + + logger.info( + "SDK sending structured log to dev gateway", + extra={ + "rollout_id": rollout_id, + "program": "eval_protocol", + "status": "completed", + "experiment_id": "dev-exp", + "run_id": "dev-run", + "metadata": {"source": "sdk-validate"}, + }, + ) + + # Poll fetch with retries for indexing + params = { + "tags": [f"rollout_id:{rollout_id}"], + "program": "eval_protocol", + "limit": 10, + "hours_back": 1, + } + total = 0 + for _ in range(20): + r = requests.get(f"{gateway}/logs", params=params, timeout=30) + r.raise_for_status() + data = r.json() + total = int(data.get("total_entries") or 0) + if total > 0: + print("Fetched entries:", total) + for e in data.get("entries", []): + print({k: e.get(k) for k in ["timestamp", "severity", "program", "status", "message", "tags"]}) + break + time.sleep(3) + if total == 0: + print("Fetched entries: 0 (after retries)") + + +if __name__ == "__main__": + main()