Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions eval_protocol/log_utils/elasticsearch_direct_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def emit(self, record: logging.LogRecord) -> None:
if status_info:
data.update(status_info)

# Optional correlation enrichment
experiment_id = getattr(record, "experiment_id", None)
if experiment_id is not None:
data["experiment_id"] = experiment_id
run_id = getattr(record, "run_id", None)
if run_id is not None:
data["run_id"] = run_id
rollout_ids = getattr(record, "rollout_ids", None)
if rollout_ids is not None:
data["rollout_ids"] = rollout_ids

# Schedule the HTTP request to run asynchronously
self._schedule_async_send(data, record)
except Exception as e:
Expand Down
23 changes: 22 additions & 1 deletion eval_protocol/log_utils/fireworks_tracing_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,36 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str
tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}")
if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")):
tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}")
# Groupwise list of rollout_ids
if hasattr(record, "rollout_ids") and cast(Any, getattr(record, "rollout_ids")):
try:
for rid in cast(List[str], getattr(record, "rollout_ids")):
tags.append(f"rollout_id:{rid}")
except Exception:
pass
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
status_val = cast(Any, getattr(record, "status", None))
status = status_val if isinstance(status_val, str) else None
# Capture optional structured status fields if present
metadata: Dict[str, Any] = {}
status_code = cast(Any, getattr(record, "status_code", None))
if isinstance(status_code, int):
metadata["status_code"] = status_code
status_message = cast(Any, getattr(record, "status_message", None))
if isinstance(status_message, str):
metadata["status_message"] = status_message
status_details = getattr(record, "status_details", None)
if status_details is not None:
metadata["status_details"] = status_details
extra_metadata = cast(Any, getattr(record, "metadata", None))
if isinstance(extra_metadata, dict):
metadata.update(extra_metadata)
return {
"program": program,
"status": status,
"message": message,
"tags": tags,
"metadata": cast(Any, getattr(record, "metadata", None)),
"metadata": metadata or None,
"extras": {
"logger_name": record.name,
"level": record.levelname,
Expand Down
61 changes: 61 additions & 0 deletions eval_protocol/log_utils/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging
import os
from typing import Optional

from eval_protocol.log_utils.fireworks_tracing_http_handler import (
FireworksTracingHttpHandler,
)
from eval_protocol.log_utils.elasticsearch_direct_http_handler import (
ElasticsearchDirectHttpHandler,
)
from eval_protocol.log_utils.rollout_context import ContextRolloutIdFilter
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig


_INITIALIZED = False


def _get_env(name: str) -> Optional[str]:
val = os.getenv(name)
return val if val and val.strip() else None


def init_external_logging_from_env() -> None:
"""
Initialize external logging sinks (Fireworks tracing, optional Elasticsearch) from env vars.

Idempotent: safe to call multiple times.

Environment variables:
- FW_TRACING_GATEWAY_BASE_URL: enable Fireworks tracing handler when set
- EP_ELASTICSEARCH_URL, EP_ELASTICSEARCH_API_KEY, EP_ELASTICSEARCH_INDEX: enable ES when all set
"""
global _INITIALIZED
if _INITIALIZED:
return

root_logger = logging.getLogger()

# Ensure we do not add duplicate handlers if already present
existing_handler_types = {type(h).__name__ for h in root_logger.handlers}

# Fireworks tracing
fw_url = _get_env("FW_TRACING_GATEWAY_BASE_URL")
if fw_url and "FireworksTracingHttpHandler" not in existing_handler_types:
fw_handler = FireworksTracingHttpHandler(gateway_base_url=fw_url)
fw_handler.setLevel(logging.INFO)
fw_handler.addFilter(ContextRolloutIdFilter())
root_logger.addHandler(fw_handler)

# Elasticsearch
es_url = _get_env("EP_ELASTICSEARCH_URL")
es_api_key = _get_env("EP_ELASTICSEARCH_API_KEY")
es_index = _get_env("EP_ELASTICSEARCH_INDEX")
if es_url and es_api_key and es_index and "ElasticsearchDirectHttpHandler" not in existing_handler_types:
es_config = ElasticsearchConfig(url=es_url, api_key=es_api_key, index_name=es_index)
es_handler = ElasticsearchDirectHttpHandler(elasticsearch_config=es_config)
es_handler.setLevel(logging.INFO)
es_handler.addFilter(ContextRolloutIdFilter())
root_logger.addHandler(es_handler)

_INITIALIZED = True
84 changes: 84 additions & 0 deletions eval_protocol/log_utils/rollout_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import os
from contextlib import asynccontextmanager
from typing import List, Optional

import contextvars


# Context variables used to correlate logs with rollouts under concurrency
current_rollout_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_rollout_id", default=None)
current_rollout_ids: contextvars.ContextVar[Optional[List[str]]] = contextvars.ContextVar(
"ep_rollout_ids", default=None
)
current_experiment_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_experiment_id", default=None)
current_run_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_run_id", default=None)


class ContextRolloutIdFilter(logging.Filter):
"""
Logging filter that injects correlation fields into a LogRecord from ContextVars.

The filter is intended to be attached ONLY to external sink handlers (e.g.,
Fireworks or Elasticsearch). If there is no active rollout context, it drops
the record for that handler to avoid shipping uncorrelated logs.
"""

def filter(self, record: logging.LogRecord) -> bool: # type: ignore[override]
rollout_id = current_rollout_id.get()
if not rollout_id:
# Allow explicit rollout IDs on the record or via environment fallback.
rollout_id = getattr(record, "rollout_id", None) or os.getenv("EP_ROLLOUT_ID")
if not rollout_id:
# No correlation context → do not emit to external sink
return False

# Inject primary correlation fields
setattr(record, "rollout_id", rollout_id)

rollout_ids = current_rollout_ids.get()
if rollout_ids:
setattr(record, "rollout_ids", rollout_ids)

experiment_id = current_experiment_id.get()
if experiment_id:
setattr(record, "experiment_id", experiment_id)

run_id = current_run_id.get()
if run_id:
setattr(record, "run_id", run_id)

return True


@asynccontextmanager
async def rollout_logging_context(
rollout_id: str,
*,
experiment_id: Optional[str] = None,
run_id: Optional[str] = None,
rollout_ids: Optional[List[str]] = None,
):
"""
Async context manager to set correlation ContextVars for the current task.

Args:
rollout_id: Primary rollout identifier for correlation.
experiment_id: Optional experiment ID for tagging.
run_id: Optional run ID for tagging.
rollout_ids: Optional list of related rollout IDs (e.g., groupwise mode).
"""
t_rollout = current_rollout_id.set(rollout_id)
t_rollouts = current_rollout_ids.set(rollout_ids) if rollout_ids is not None else None
t_experiment = current_experiment_id.set(experiment_id) if experiment_id is not None else None
t_run = current_run_id.set(run_id) if run_id is not None else None
try:
yield
finally:
current_rollout_id.reset(t_rollout)
if t_rollouts is not None:
current_rollout_ids.reset(t_rollouts)
if t_experiment is not None:
current_experiment_id.reset(t_experiment)
if t_run is not None:
current_run_id.reset(t_run)
62 changes: 48 additions & 14 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
run_tasks_with_run_progress,
)
from eval_protocol.utils.show_results_url import store_local_ui_results_url, generate_invocation_filter_url
from eval_protocol.log_utils.init import init_external_logging_from_env
from eval_protocol.log_utils.rollout_context import rollout_logging_context
from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab

from ..common_utils import load_jsonl
Expand Down Expand Up @@ -254,6 +256,9 @@ def create_wrapper_with_signature() -> Callable[[], None]:
async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
nonlocal browser_opened_for_invocation

# Initialize external logging sinks (Fireworks/ES) from env (idempotent)
init_external_logging_from_env()

# Store URL for viewing results (after all postprocessing is complete)
store_local_ui_results_url(invocation_id)

Expand Down Expand Up @@ -419,11 +424,16 @@ async def _execute_pointwise_eval_with_semaphore(
) -> EvaluationRow:
async with semaphore:
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
result = await execute_pytest(
test_func,
processed_row=row,
evaluation_test_kwargs=evaluation_test_kwargs,
)
async with rollout_logging_context(
row.execution_metadata.rollout_id or "",
experiment_id=experiment_id,
run_id=run_id,
):
result = await execute_pytest(
test_func,
processed_row=row,
evaluation_test_kwargs=evaluation_test_kwargs,
)
if not isinstance(result, EvaluationRow):
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
Expand All @@ -435,11 +445,21 @@ async def _execute_groupwise_eval_with_semaphore(
) -> list[EvaluationRow]:
async with semaphore:
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
results = await execute_pytest(
test_func,
processed_dataset=rows,
evaluation_test_kwargs=evaluation_test_kwargs,
)
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
group_rollout_ids = [
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
]
async with rollout_logging_context(
primary_rollout_id or "",
experiment_id=experiment_id,
run_id=run_id,
rollout_ids=group_rollout_ids or None,
):
results = await execute_pytest(
test_func,
processed_dataset=rows,
evaluation_test_kwargs=evaluation_test_kwargs,
)
if not isinstance(results, list):
raise ValueError(
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
Expand Down Expand Up @@ -516,11 +536,25 @@ async def _collect_result(config, lst):
input_dataset.append(row)
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
results = await execute_pytest(
test_func,
processed_dataset=input_dataset,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
primary_rollout_id = (
input_dataset[0].execution_metadata.rollout_id if input_dataset else None
)
group_rollout_ids = [
r.execution_metadata.rollout_id
for r in input_dataset
if r.execution_metadata.rollout_id
]
async with rollout_logging_context(
primary_rollout_id or "",
experiment_id=experiment_id,
run_id=run_id,
rollout_ids=group_rollout_ids or None,
):
results = await execute_pytest(
test_func,
processed_dataset=input_dataset,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if (
results is None
or not isinstance(results, list)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Awaitable, Callable
import logging
import os
from typing_extensions import cast
from pydantic_ai import Agent
Expand Down Expand Up @@ -36,7 +37,7 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent:
input_rows=[collect_dataset()],
completion_params=[
{
"model": "gpt-5",
"model": "gpt-5-nano",
},
],
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
Expand All @@ -45,6 +46,19 @@ async def test_pydantic_complex_queries_responses(row: EvaluationRow) -> Evaluat
"""
Evaluation of complex queries for the Chinook database using PydanticAI
"""

logger = logging.getLogger("tests.chinook.pydantic.complex_queries_responses")
logger.info(
"Starting chinook responses evaluation",
extra={"status": {"code": 101, "message": "RUNNING"}},
)

casted_evaluation_test = cast(Callable[[EvaluationRow], Awaitable[EvaluationRow]], eval)
evaluated_row = await casted_evaluation_test(row)

logger.info(
"Finished chinook responses evaluation",
extra={"status": {"code": 100, "message": "FINISHED"}},
)

return evaluated_row
Loading
Loading