Skip to content

Commit e1d7512

Browse files
committed
fix test
1 parent c9462e9 commit e1d7512

File tree

3 files changed

+142
-33
lines changed

3 files changed

+142
-33
lines changed

tests/github_actions/test_github_actions_rollout.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,33 @@
1212
from eval_protocol.models import EvaluationRow, InputMetadata
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
15+
import eval_protocol.pytest.github_action_rollout_processor as github_action_rollout_processor_module
16+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
17+
18+
19+
ROLLOUT_IDS = set()
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def check_rollout_coverage(monkeypatch):
24+
"""
25+
Ensure we attempted to fetch remote traces for each rollout.
26+
27+
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
28+
and tracks rollout_ids passed through its DataLoaderConfig.
29+
"""
30+
global ROLLOUT_IDS
31+
ROLLOUT_IDS.clear()
32+
33+
original_loader = github_action_rollout_processor_module.default_fireworks_output_data_loader
34+
35+
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
36+
ROLLOUT_IDS.add(config.rollout_id)
37+
return original_loader(config)
38+
39+
monkeypatch.setattr(github_action_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
40+
yield
41+
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
1542

1643

1744
def rows() -> List[EvaluationRow]:

tests/remote_server/test_remote_fireworks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,33 @@
1212
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
15+
import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module
16+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
17+
18+
19+
ROLLOUT_IDS = set()
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def check_rollout_coverage(monkeypatch):
24+
"""
25+
Ensure we attempted to fetch remote traces for each rollout.
26+
27+
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
28+
and tracks rollout_ids passed through its DataLoaderConfig.
29+
"""
30+
global ROLLOUT_IDS
31+
ROLLOUT_IDS.clear()
32+
33+
original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader
34+
35+
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
36+
ROLLOUT_IDS.add(config.rollout_id)
37+
return original_loader(config)
38+
39+
monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
40+
yield
41+
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
1542

1643

1744
def find_available_port() -> int:
Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,88 @@
1-
import os
2-
from typing import List
3-
4-
import pytest
5-
6-
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
7-
from eval_protocol.models import EvaluationRow, Message
8-
from eval_protocol.pytest import evaluation_test
9-
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
10-
11-
12-
def rows() -> List[EvaluationRow]:
13-
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
14-
return [row, row, row]
15-
16-
17-
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
18-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
19-
@evaluation_test(
20-
data_loaders=DynamicDataLoader(
21-
generators=[rows],
22-
),
23-
rollout_processor=RemoteRolloutProcessor(remote_base_url="http://127.0.0.1:3000", timeout_seconds=30),
24-
)
25-
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
26-
"""
27-
End-to-end test:
28-
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
29-
"""
30-
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
31-
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
32-
33-
return row
1+
# NOTE: This test is deprecated. We no longer support custom output data loaders, including pulling from Langfuse. We can revisit this in the future.
2+
3+
# # MANUAL SERVER STARTUP REQUIRED:
4+
# #
5+
# # For Python server testing, start:
6+
# # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
7+
# #
8+
# # For TypeScript server testing, start:
9+
# # cd tests/remote_server/typescript-server
10+
# # npm install
11+
# # npm start
12+
# #
13+
# # The TypeScript server should be running on http://127.0.0.1:3000
14+
# # You only need to start one of the servers!
15+
16+
# import os
17+
# from typing import List
18+
19+
# import pytest
20+
21+
# from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
22+
# from eval_protocol.models import EvaluationRow, Message
23+
# from eval_protocol.pytest import evaluation_test
24+
# from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
25+
# from eval_protocol.adapters.langfuse import create_langfuse_adapter
26+
# from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
27+
# from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
28+
29+
# ROLLOUT_IDS = set()
30+
31+
32+
# @pytest.fixture(autouse=True)
33+
# def check_rollout_coverage():
34+
# """Ensure we processed all expected rollout_ids"""
35+
# global ROLLOUT_IDS
36+
# ROLLOUT_IDS.clear()
37+
# yield
38+
39+
# assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}"
40+
41+
42+
# def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
43+
# global ROLLOUT_IDS # Track all rollout_ids we've seen
44+
# ROLLOUT_IDS.add(config.rollout_id)
45+
46+
# adapter = create_langfuse_adapter()
47+
# return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
48+
49+
50+
# def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
51+
# return DynamicDataLoader(
52+
# generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation
53+
# )
54+
55+
56+
# def rows() -> List[EvaluationRow]:
57+
# row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
58+
# return [row, row, row]
59+
60+
61+
# @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
62+
# @pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
63+
# @evaluation_test(
64+
# data_loaders=DynamicDataLoader(
65+
# generators=[rows],
66+
# ),
67+
# rollout_processor=RemoteRolloutProcessor(
68+
# remote_base_url="http://127.0.0.1:3000",
69+
# timeout_seconds=30,
70+
# output_data_loader=langfuse_output_data_loader,
71+
# model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk",
72+
# ),
73+
# )
74+
# async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
75+
# """
76+
# End-to-end test:
77+
# - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server
78+
# - trigger remote rollout via RemoteRolloutProcessor (calls init/status)
79+
# - fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found
80+
# """
81+
# assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
82+
# assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
83+
84+
# assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
85+
# f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
86+
# )
87+
88+
# return row

0 commit comments

Comments
 (0)