Skip to content

Commit 8ec64ce

Browse files
committed
fix test
1 parent 015e856 commit 8ec64ce

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tests/remote_server/test_remote_fireworks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def rows() -> List[EvaluationRow]:
9898
return [row, row, row]
9999

100100

101-
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
102101
@pytest.mark.parametrize(
103102
"completion_params",
104103
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],

tests/remote_server/test_remote_fireworks_propagate_status.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from eval_protocol.models import EvaluationRow, Message, Status
2424
from eval_protocol.pytest import evaluation_test
2525
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
26+
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
27+
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
28+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
2629

2730

2831
def find_available_port() -> int:
@@ -75,6 +78,18 @@ def setup_remote_server():
7578
process.wait()
7679

7780

81+
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
82+
base_url = config.model_base_url or "https://tracing.fireworks.ai"
83+
adapter = FireworksTracingAdapter(base_url=base_url)
84+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)
85+
86+
87+
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
88+
return DynamicDataLoader(
89+
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
90+
)
91+
92+
7893
def rows() -> List[EvaluationRow]:
7994
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
8095
return [row]
@@ -88,6 +103,7 @@ def rows() -> List[EvaluationRow]:
88103
rollout_processor=RemoteRolloutProcessor(
89104
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
90105
timeout_seconds=120,
106+
output_data_loader=fireworks_output_data_loader,
91107
),
92108
)
93109
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)