diff --git a/eval_protocol/log_utils/elasticsearch_direct_http_handler.py b/eval_protocol/log_utils/elasticsearch_direct_http_handler.py index 735869b0..af14268b 100644 --- a/eval_protocol/log_utils/elasticsearch_direct_http_handler.py +++ b/eval_protocol/log_utils/elasticsearch_direct_http_handler.py @@ -60,6 +60,17 @@ def emit(self, record: logging.LogRecord) -> None: if status_info: data.update(status_info) + # Optional correlation enrichment + experiment_id = getattr(record, "experiment_id", None) + if experiment_id is not None: + data["experiment_id"] = experiment_id + run_id = getattr(record, "run_id", None) + if run_id is not None: + data["run_id"] = run_id + rollout_ids = getattr(record, "rollout_ids", None) + if rollout_ids is not None: + data["rollout_ids"] = rollout_ids + # Schedule the HTTP request to run asynchronously self._schedule_async_send(data, record) except Exception as e: diff --git a/eval_protocol/log_utils/fireworks_tracing_http_handler.py b/eval_protocol/log_utils/fireworks_tracing_http_handler.py index 0e8dfdf5..df53a921 100644 --- a/eval_protocol/log_utils/fireworks_tracing_http_handler.py +++ b/eval_protocol/log_utils/fireworks_tracing_http_handler.py @@ -46,15 +46,36 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}") if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")): tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}") + # Groupwise list of rollout_ids + if hasattr(record, "rollout_ids") and cast(Any, getattr(record, "rollout_ids")): + try: + for rid in cast(List[str], getattr(record, "rollout_ids")): + tags.append(f"rollout_id:{rid}") + except Exception: + pass program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol" status_val = cast(Any, getattr(record, "status", None)) status = status_val if isinstance(status_val, str) else None + # Capture optional structured status fields if present + metadata: Dict[str, Any] = {} + status_code = cast(Any, getattr(record, "status_code", None)) + if isinstance(status_code, int): + metadata["status_code"] = status_code + status_message = cast(Any, getattr(record, "status_message", None)) + if isinstance(status_message, str): + metadata["status_message"] = status_message + status_details = getattr(record, "status_details", None) + if status_details is not None: + metadata["status_details"] = status_details + extra_metadata = cast(Any, getattr(record, "metadata", None)) + if isinstance(extra_metadata, dict): + metadata.update(extra_metadata) return { "program": program, "status": status, "message": message, "tags": tags, - "metadata": cast(Any, getattr(record, "metadata", None)), + "metadata": metadata or None, "extras": { "logger_name": record.name, "level": record.levelname, diff --git a/eval_protocol/log_utils/init.py b/eval_protocol/log_utils/init.py new file mode 100644 index 00000000..1699b744 --- /dev/null +++ b/eval_protocol/log_utils/init.py @@ -0,0 +1,61 @@ +import logging +import os +from typing import Optional + +from eval_protocol.log_utils.fireworks_tracing_http_handler import ( + FireworksTracingHttpHandler, +) +from eval_protocol.log_utils.elasticsearch_direct_http_handler import ( + ElasticsearchDirectHttpHandler, +) +from eval_protocol.log_utils.rollout_context import ContextRolloutIdFilter +from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig + + +_INITIALIZED = False + + +def _get_env(name: str) -> Optional[str]: + val = os.getenv(name) + return val if val and val.strip() else None + + +def init_external_logging_from_env() -> None: + """ + Initialize external logging sinks (Fireworks tracing, optional Elasticsearch) from env vars. + + Idempotent: safe to call multiple times. + + Environment variables: + - FW_TRACING_GATEWAY_BASE_URL: enable Fireworks tracing handler when set + - EP_ELASTICSEARCH_URL, EP_ELASTICSEARCH_API_KEY, EP_ELASTICSEARCH_INDEX: enable ES when all set + """ + global _INITIALIZED + if _INITIALIZED: + return + + root_logger = logging.getLogger() + + # Ensure we do not add duplicate handlers if already present + existing_handler_types = {type(h).__name__ for h in root_logger.handlers} + + # Fireworks tracing + fw_url = _get_env("FW_TRACING_GATEWAY_BASE_URL") + if fw_url and "FireworksTracingHttpHandler" not in existing_handler_types: + fw_handler = FireworksTracingHttpHandler(gateway_base_url=fw_url) + fw_handler.setLevel(logging.INFO) + fw_handler.addFilter(ContextRolloutIdFilter()) + root_logger.addHandler(fw_handler) + + # Elasticsearch + es_url = _get_env("EP_ELASTICSEARCH_URL") + es_api_key = _get_env("EP_ELASTICSEARCH_API_KEY") + es_index = _get_env("EP_ELASTICSEARCH_INDEX") + if es_url and es_api_key and es_index and "ElasticsearchDirectHttpHandler" not in existing_handler_types: + es_config = ElasticsearchConfig(url=es_url, api_key=es_api_key, index_name=es_index) + es_handler = ElasticsearchDirectHttpHandler(elasticsearch_config=es_config) + es_handler.setLevel(logging.INFO) + es_handler.addFilter(ContextRolloutIdFilter()) + root_logger.addHandler(es_handler) + + _INITIALIZED = True diff --git a/eval_protocol/log_utils/rollout_context.py b/eval_protocol/log_utils/rollout_context.py new file mode 100644 index 00000000..a5e062bf --- /dev/null +++ b/eval_protocol/log_utils/rollout_context.py @@ -0,0 +1,84 @@ +import logging +import os +from contextlib import asynccontextmanager +from typing import List, Optional + +import contextvars + + +# Context variables used to correlate logs with rollouts under concurrency +current_rollout_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_rollout_id", default=None) +current_rollout_ids: contextvars.ContextVar[Optional[List[str]]] = contextvars.ContextVar( + "ep_rollout_ids", default=None +) +current_experiment_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_experiment_id", default=None) +current_run_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_run_id", default=None) + + +class ContextRolloutIdFilter(logging.Filter): + """ + Logging filter that injects correlation fields into a LogRecord from ContextVars. + + The filter is intended to be attached ONLY to external sink handlers (e.g., + Fireworks or Elasticsearch). If there is no active rollout context, it drops + the record for that handler to avoid shipping uncorrelated logs. + """ + + def filter(self, record: logging.LogRecord) -> bool: # type: ignore[override] + rollout_id = current_rollout_id.get() + if not rollout_id: + # Allow explicit rollout IDs on the record or via environment fallback. + rollout_id = getattr(record, "rollout_id", None) or os.getenv("EP_ROLLOUT_ID") + if not rollout_id: + # No correlation context → do not emit to external sink + return False + + # Inject primary correlation fields + setattr(record, "rollout_id", rollout_id) + + rollout_ids = current_rollout_ids.get() + if rollout_ids: + setattr(record, "rollout_ids", rollout_ids) + + experiment_id = current_experiment_id.get() + if experiment_id: + setattr(record, "experiment_id", experiment_id) + + run_id = current_run_id.get() + if run_id: + setattr(record, "run_id", run_id) + + return True + + +@asynccontextmanager +async def rollout_logging_context( + rollout_id: str, + *, + experiment_id: Optional[str] = None, + run_id: Optional[str] = None, + rollout_ids: Optional[List[str]] = None, +): + """ + Async context manager to set correlation ContextVars for the current task. + + Args: + rollout_id: Primary rollout identifier for correlation. + experiment_id: Optional experiment ID for tagging. + run_id: Optional run ID for tagging. + rollout_ids: Optional list of related rollout IDs (e.g., groupwise mode). + """ + t_rollout = current_rollout_id.set(rollout_id) + t_rollouts = current_rollout_ids.set(rollout_ids) if rollout_ids is not None else None + t_experiment = current_experiment_id.set(experiment_id) if experiment_id is not None else None + t_run = current_run_id.set(run_id) if run_id is not None else None + try: + yield + finally: + current_rollout_id.reset(t_rollout) + if t_rollouts is not None: + current_rollout_ids.reset(t_rollouts) + if t_experiment is not None: + current_experiment_id.reset(t_experiment) + if t_run is not None: + current_run_id.reset(t_run) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 57f36a9f..26239013 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -63,6 +63,8 @@ run_tasks_with_run_progress, ) from eval_protocol.utils.show_results_url import store_local_ui_results_url, generate_invocation_filter_url +from eval_protocol.log_utils.init import init_external_logging_from_env +from eval_protocol.log_utils.rollout_context import rollout_logging_context from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab from ..common_utils import load_jsonl @@ -254,6 +256,9 @@ def create_wrapper_with_signature() -> Callable[[], None]: async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None: nonlocal browser_opened_for_invocation + # Initialize external logging sinks (Fireworks/ES) from env (idempotent) + init_external_logging_from_env() + # Store URL for viewing results (after all postprocessing is complete) store_local_ui_results_url(invocation_id) @@ -419,11 +424,16 @@ async def _execute_pointwise_eval_with_semaphore( ) -> EvaluationRow: async with semaphore: evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - result = await execute_pytest( - test_func, - processed_row=row, - evaluation_test_kwargs=evaluation_test_kwargs, - ) + async with rollout_logging_context( + row.execution_metadata.rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + ): + result = await execute_pytest( + test_func, + processed_row=row, + evaluation_test_kwargs=evaluation_test_kwargs, + ) if not isinstance(result, EvaluationRow): raise ValueError( f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." @@ -435,11 +445,21 @@ async def _execute_groupwise_eval_with_semaphore( ) -> list[EvaluationRow]: async with semaphore: evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - results = await execute_pytest( - test_func, - processed_dataset=rows, - evaluation_test_kwargs=evaluation_test_kwargs, - ) + primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None + group_rollout_ids = [ + r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + ] + async with rollout_logging_context( + primary_rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + rollout_ids=group_rollout_ids or None, + ): + results = await execute_pytest( + test_func, + processed_dataset=rows, + evaluation_test_kwargs=evaluation_test_kwargs, + ) if not isinstance(results, list): raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." @@ -516,11 +536,25 @@ async def _collect_result(config, lst): input_dataset.append(row) # NOTE: we will still evaluate errored rows (give users control over this) # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func - results = await execute_pytest( - test_func, - processed_dataset=input_dataset, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + primary_rollout_id = ( + input_dataset[0].execution_metadata.rollout_id if input_dataset else None ) + group_rollout_ids = [ + r.execution_metadata.rollout_id + for r in input_dataset + if r.execution_metadata.rollout_id + ] + async with rollout_logging_context( + primary_rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + rollout_ids=group_rollout_ids or None, + ): + results = await execute_pytest( + test_func, + processed_dataset=input_dataset, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) if ( results is None or not isinstance(results, list) diff --git a/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py b/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py index 5ed1e377..2bb4b622 100644 --- a/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py +++ b/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py @@ -1,4 +1,5 @@ from collections.abc import Awaitable, Callable +import logging import os from typing_extensions import cast from pydantic_ai import Agent @@ -36,7 +37,7 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent: input_rows=[collect_dataset()], completion_params=[ { - "model": "gpt-5", + "model": "gpt-5-nano", }, ], rollout_processor=PydanticAgentRolloutProcessor(agent_factory), @@ -45,6 +46,19 @@ async def test_pydantic_complex_queries_responses(row: EvaluationRow) -> Evaluat """ Evaluation of complex queries for the Chinook database using PydanticAI """ + + logger = logging.getLogger("tests.chinook.pydantic.complex_queries_responses") + logger.info( + "Starting chinook responses evaluation", + extra={"status": {"code": 101, "message": "RUNNING"}}, + ) + casted_evaluation_test = cast(Callable[[EvaluationRow], Awaitable[EvaluationRow]], eval) evaluated_row = await casted_evaluation_test(row) + + logger.info( + "Finished chinook responses evaluation", + extra={"status": {"code": 100, "message": "FINISHED"}}, + ) + return evaluated_row diff --git a/tests/logging/test_rollout_context_logging.py b/tests/logging/test_rollout_context_logging.py new file mode 100644 index 00000000..4be0d4b6 --- /dev/null +++ b/tests/logging/test_rollout_context_logging.py @@ -0,0 +1,146 @@ +import asyncio +import logging +import sys +from typing import Any, Dict, List + +import importlib.util +from pathlib import Path + +import pytest + + +def _load_module(module_name: str, relative_path: str): + root = Path(__file__).resolve().parents[2] + spec = importlib.util.spec_from_file_location(module_name, root / relative_path) + if spec is None or spec.loader is None: # pragma: no cover - defensive + raise ImportError(f"Unable to load module {module_name} from {relative_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +FireworksTracingHttpHandler = _load_module( + "eval_protocol.log_utils.fireworks_tracing_http_handler", + "eval_protocol/log_utils/fireworks_tracing_http_handler.py", +).FireworksTracingHttpHandler + +_rollout_context_module = _load_module( + "eval_protocol.log_utils.rollout_context", "eval_protocol/log_utils/rollout_context.py" +) +ContextRolloutIdFilter = _rollout_context_module.ContextRolloutIdFilter +rollout_logging_context = _rollout_context_module.rollout_logging_context + + +def _make_record(message: str = "msg") -> logging.LogRecord: + return logging.LogRecord( + name="test", level=logging.INFO, pathname=__file__, lineno=0, msg=message, args=(), exc_info=None + ) + + +def test_context_filter_respects_explicit_rollout_id() -> None: + record = _make_record() + record.rollout_id = "explicit-rid" + + filt = ContextRolloutIdFilter() + + assert filt.filter(record) + assert record.rollout_id == "explicit-rid" + + +def test_context_filter_respects_environment_rollout_id(monkeypatch) -> None: + monkeypatch.setenv("EP_ROLLOUT_ID", "env-rid") + record = _make_record() + + filt = ContextRolloutIdFilter() + + try: + assert filt.filter(record) + assert record.rollout_id == "env-rid" + finally: + monkeypatch.delenv("EP_ROLLOUT_ID", raising=False) + + +@pytest.mark.asyncio +async def test_context_filter_correlates_concurrent_logs() -> None: + logger = logging.getLogger("ep.test.fireworks") + logger.setLevel(logging.INFO) + + handler = FireworksTracingHttpHandler(gateway_base_url="http://localhost:1") + captured: List[Dict[str, Any]] = [] + + # Monkeypatch the requests call used by the handler + def fake_post(url: str, json: Dict[str, Any], timeout: int) -> Any: # type: ignore[override] + captured.append(json) + + class _Resp: + status_code = 200 + + return _Resp() + + handler._session.post = fake_post # type: ignore[attr-defined] + handler.addFilter(ContextRolloutIdFilter()) + logger.addHandler(handler) + + try: + + async def _emit(rollout_id: str, message_prefix: str) -> None: + async with rollout_logging_context(rollout_id, experiment_id="exp", run_id="run"): + logger.info(f"{message_prefix}-1") + await asyncio.sleep(0) + logger.info(f"{message_prefix}-2") + await asyncio.sleep(0) + logger.info(f"{message_prefix}-3") + + await asyncio.gather( + _emit("rid-A", "A"), + _emit("rid-B", "B"), + ) + + # We expect 6 captured payloads + assert len(captured) == 6 + + # Ensure each payload includes the correct rollout tag and message + tags_sets = [set(entry.get("tags", [])) for entry in captured] + messages = [entry.get("message", "") for entry in captured] + + assert any("rollout_id:rid-A" in tags for tags in tags_sets) + assert any("rollout_id:rid-B" in tags for tags in tags_sets) + assert any(msg.startswith("A-") for msg in messages) + assert any(msg.startswith("B-") for msg in messages) + finally: + logger.removeHandler(handler) + + +@pytest.mark.asyncio +async def test_context_filter_groupwise_rollout_ids_tagged() -> None: + logger = logging.getLogger("ep.test.fireworks.group") + logger.setLevel(logging.INFO) + + handler = FireworksTracingHttpHandler(gateway_base_url="http://localhost:1") + captured: List[Dict[str, Any]] = [] + + def fake_post(url: str, json: Dict[str, Any], timeout: int) -> Any: # type: ignore[override] + captured.append(json) + + class _Resp: + status_code = 200 + + return _Resp() + + handler._session.post = fake_post # type: ignore[attr-defined] + handler.addFilter(ContextRolloutIdFilter()) + logger.addHandler(handler) + + try: + group_ids = ["rid-1", "rid-2", "rid-3"] + async with rollout_logging_context(group_ids[0], experiment_id="exp2", run_id="run2", rollout_ids=group_ids): + logger.info("group-message") + + assert len(captured) == 1 + tags = set(captured[0].get("tags", [])) + # All rollout_ids should be present as tags + for rid in group_ids: + assert f"rollout_id:{rid}" in tags + finally: + logger.removeHandler(handler)