|
1 | 1 | import asyncio |
2 | 2 | import logging |
| 3 | +import sys |
3 | 4 | from typing import Any, Dict, List |
4 | 5 |
|
| 6 | +import importlib.util |
| 7 | +from pathlib import Path |
| 8 | + |
5 | 9 | import pytest |
6 | 10 |
|
7 | | -from eval_protocol.log_utils.fireworks_tracing_http_handler import ( |
8 | | - FireworksTracingHttpHandler, |
9 | | -) |
10 | | -from eval_protocol.log_utils.rollout_context import ( |
11 | | - ContextRolloutIdFilter, |
12 | | - rollout_logging_context, |
| 11 | + |
| 12 | +def _load_module(module_name: str, relative_path: str): |
| 13 | + root = Path(__file__).resolve().parents[2] |
| 14 | + spec = importlib.util.spec_from_file_location(module_name, root / relative_path) |
| 15 | + if spec is None or spec.loader is None: # pragma: no cover - defensive |
| 16 | + raise ImportError(f"Unable to load module {module_name} from {relative_path}") |
| 17 | + module = importlib.util.module_from_spec(spec) |
| 18 | + sys.modules[module_name] = module |
| 19 | + spec.loader.exec_module(module) |
| 20 | + return module |
| 21 | + |
| 22 | + |
| 23 | +FireworksTracingHttpHandler = _load_module( |
| 24 | + "eval_protocol.log_utils.fireworks_tracing_http_handler", |
| 25 | + "eval_protocol/log_utils/fireworks_tracing_http_handler.py", |
| 26 | +).FireworksTracingHttpHandler |
| 27 | + |
| 28 | +_rollout_context_module = _load_module( |
| 29 | + "eval_protocol.log_utils.rollout_context", "eval_protocol/log_utils/rollout_context.py" |
13 | 30 | ) |
| 31 | +ContextRolloutIdFilter = _rollout_context_module.ContextRolloutIdFilter |
| 32 | +rollout_logging_context = _rollout_context_module.rollout_logging_context |
| 33 | + |
| 34 | + |
| 35 | +def _make_record(message: str = "msg") -> logging.LogRecord: |
| 36 | + return logging.LogRecord( |
| 37 | + name="test", level=logging.INFO, pathname=__file__, lineno=0, msg=message, args=(), exc_info=None |
| 38 | + ) |
| 39 | + |
| 40 | + |
| 41 | +def test_context_filter_respects_explicit_rollout_id() -> None: |
| 42 | + record = _make_record() |
| 43 | + record.rollout_id = "explicit-rid" |
| 44 | + |
| 45 | + filt = ContextRolloutIdFilter() |
| 46 | + |
| 47 | + assert filt.filter(record) |
| 48 | + assert record.rollout_id == "explicit-rid" |
| 49 | + |
| 50 | + |
| 51 | +def test_context_filter_respects_environment_rollout_id(monkeypatch) -> None: |
| 52 | + monkeypatch.setenv("EP_ROLLOUT_ID", "env-rid") |
| 53 | + record = _make_record() |
| 54 | + |
| 55 | + filt = ContextRolloutIdFilter() |
| 56 | + |
| 57 | + try: |
| 58 | + assert filt.filter(record) |
| 59 | + assert record.rollout_id == "env-rid" |
| 60 | + finally: |
| 61 | + monkeypatch.delenv("EP_ROLLOUT_ID", raising=False) |
14 | 62 |
|
15 | 63 |
|
16 | 64 | @pytest.mark.asyncio |
|
0 commit comments