Skip to content

Commit 7e76369

Browse files
committed
Fix rollout logging filter fallbacks
1 parent dd1853f commit 7e76369

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

eval_protocol/log_utils/rollout_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from contextlib import asynccontextmanager
34
from typing import List, Optional
45

@@ -25,6 +26,9 @@ class ContextRolloutIdFilter(logging.Filter):
2526

2627
def filter(self, record: logging.LogRecord) -> bool: # type: ignore[override]
2728
rollout_id = current_rollout_id.get()
29+
if not rollout_id:
30+
# Allow explicit rollout IDs on the record or via environment fallback.
31+
rollout_id = getattr(record, "rollout_id", None) or os.getenv("EP_ROLLOUT_ID")
2832
if not rollout_id:
2933
# No correlation context → do not emit to external sink
3034
return False

tests/logging/test_rollout_context_logging.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,64 @@
11
import asyncio
22
import logging
3+
import sys
34
from typing import Any, Dict, List
45

6+
import importlib.util
7+
from pathlib import Path
8+
59
import pytest
610

7-
from eval_protocol.log_utils.fireworks_tracing_http_handler import (
8-
FireworksTracingHttpHandler,
9-
)
10-
from eval_protocol.log_utils.rollout_context import (
11-
ContextRolloutIdFilter,
12-
rollout_logging_context,
11+
12+
def _load_module(module_name: str, relative_path: str):
13+
root = Path(__file__).resolve().parents[2]
14+
spec = importlib.util.spec_from_file_location(module_name, root / relative_path)
15+
if spec is None or spec.loader is None: # pragma: no cover - defensive
16+
raise ImportError(f"Unable to load module {module_name} from {relative_path}")
17+
module = importlib.util.module_from_spec(spec)
18+
sys.modules[module_name] = module
19+
spec.loader.exec_module(module)
20+
return module
21+
22+
23+
FireworksTracingHttpHandler = _load_module(
24+
"eval_protocol.log_utils.fireworks_tracing_http_handler",
25+
"eval_protocol/log_utils/fireworks_tracing_http_handler.py",
26+
).FireworksTracingHttpHandler
27+
28+
_rollout_context_module = _load_module(
29+
"eval_protocol.log_utils.rollout_context", "eval_protocol/log_utils/rollout_context.py"
1330
)
31+
ContextRolloutIdFilter = _rollout_context_module.ContextRolloutIdFilter
32+
rollout_logging_context = _rollout_context_module.rollout_logging_context
33+
34+
35+
def _make_record(message: str = "msg") -> logging.LogRecord:
36+
return logging.LogRecord(
37+
name="test", level=logging.INFO, pathname=__file__, lineno=0, msg=message, args=(), exc_info=None
38+
)
39+
40+
41+
def test_context_filter_respects_explicit_rollout_id() -> None:
42+
record = _make_record()
43+
record.rollout_id = "explicit-rid"
44+
45+
filt = ContextRolloutIdFilter()
46+
47+
assert filt.filter(record)
48+
assert record.rollout_id == "explicit-rid"
49+
50+
51+
def test_context_filter_respects_environment_rollout_id(monkeypatch) -> None:
52+
monkeypatch.setenv("EP_ROLLOUT_ID", "env-rid")
53+
record = _make_record()
54+
55+
filt = ContextRolloutIdFilter()
56+
57+
try:
58+
assert filt.filter(record)
59+
assert record.rollout_id == "env-rid"
60+
finally:
61+
monkeypatch.delenv("EP_ROLLOUT_ID", raising=False)
1462

1563

1664
@pytest.mark.asyncio

0 commit comments

Comments
 (0)