diff --git a/eval_protocol/log_utils/rollout_context.py b/eval_protocol/log_utils/rollout_context.py index a7b5dcc2..a5e062bf 100644 --- a/eval_protocol/log_utils/rollout_context.py +++ b/eval_protocol/log_utils/rollout_context.py @@ -1,4 +1,5 @@ import logging +import os from contextlib import asynccontextmanager from typing import List, Optional @@ -25,6 +26,9 @@ class ContextRolloutIdFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: # type: ignore[override] rollout_id = current_rollout_id.get() + if not rollout_id: + # Allow explicit rollout IDs on the record or via environment fallback. + rollout_id = getattr(record, "rollout_id", None) or os.getenv("EP_ROLLOUT_ID") if not rollout_id: # No correlation context → do not emit to external sink return False diff --git a/tests/logging/test_rollout_context_logging.py b/tests/logging/test_rollout_context_logging.py index 1c8fa37a..4be0d4b6 100644 --- a/tests/logging/test_rollout_context_logging.py +++ b/tests/logging/test_rollout_context_logging.py @@ -1,16 +1,64 @@ import asyncio import logging +import sys from typing import Any, Dict, List +import importlib.util +from pathlib import Path + import pytest -from eval_protocol.log_utils.fireworks_tracing_http_handler import ( - FireworksTracingHttpHandler, -) -from eval_protocol.log_utils.rollout_context import ( - ContextRolloutIdFilter, - rollout_logging_context, + +def _load_module(module_name: str, relative_path: str): + root = Path(__file__).resolve().parents[2] + spec = importlib.util.spec_from_file_location(module_name, root / relative_path) + if spec is None or spec.loader is None: # pragma: no cover - defensive + raise ImportError(f"Unable to load module {module_name} from {relative_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +FireworksTracingHttpHandler = _load_module( + "eval_protocol.log_utils.fireworks_tracing_http_handler", + "eval_protocol/log_utils/fireworks_tracing_http_handler.py", +).FireworksTracingHttpHandler + +_rollout_context_module = _load_module( + "eval_protocol.log_utils.rollout_context", "eval_protocol/log_utils/rollout_context.py" ) +ContextRolloutIdFilter = _rollout_context_module.ContextRolloutIdFilter +rollout_logging_context = _rollout_context_module.rollout_logging_context + + +def _make_record(message: str = "msg") -> logging.LogRecord: + return logging.LogRecord( + name="test", level=logging.INFO, pathname=__file__, lineno=0, msg=message, args=(), exc_info=None + ) + + +def test_context_filter_respects_explicit_rollout_id() -> None: + record = _make_record() + record.rollout_id = "explicit-rid" + + filt = ContextRolloutIdFilter() + + assert filt.filter(record) + assert record.rollout_id == "explicit-rid" + + +def test_context_filter_respects_environment_rollout_id(monkeypatch) -> None: + monkeypatch.setenv("EP_ROLLOUT_ID", "env-rid") + record = _make_record() + + filt = ContextRolloutIdFilter() + + try: + assert filt.filter(record) + assert record.rollout_id == "env-rid" + finally: + monkeypatch.delenv("EP_ROLLOUT_ID", raising=False) @pytest.mark.asyncio