Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
--ignore=tests/test_tau_bench_airline_smoke.py \
--ignore=tests/pytest/test_svgbench.py \
--ignore=tests/pytest/test_livesvgbench.py \
--ignore=tests/remote_server/test_remote_fireworks.py \
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
--ignore=tests/logging/test_elasticsearch_direct_http_handler.py \
--ignore=eval_protocol/benchmarks/ \
Expand Down
25 changes: 21 additions & 4 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,28 @@
WeaveAdapter = None

try:
from .proxy import create_app, AuthProvider, AccountInfo
from .proxy import create_app, AuthProvider, AccountInfo # pyright: ignore[reportAssignmentType]
except ImportError:
create_app = None
AuthProvider = None
AccountInfo = None

def create_app(*args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)

class AuthProvider:
def __init__(self, *args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)

class AccountInfo:
def __init__(self, *args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)


warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ langgraph_tools = [
"langchain-fireworks>=0.3.0",
]

proxy = [
"redis>=5.0.0",
"langfuse>=2.0.0",
"uuid6>=2025.0.0",
]

[project.scripts]
fireworks-reward = "eval_protocol.cli:main"
eval-protocol = "eval_protocol.cli:main"
Expand Down
1 change: 0 additions & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def rows() -> List[EvaluationRow]:
return [row, row, row]


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
Expand Down
16 changes: 16 additions & 0 deletions tests/remote_server/test_remote_fireworks_propagate_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from eval_protocol.models import EvaluationRow, Message, Status
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig


def find_available_port() -> int:
Expand Down Expand Up @@ -75,6 +78,18 @@ def setup_remote_server():
process.wait()


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
return [row]
Expand All @@ -88,6 +103,7 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=120,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
Expand Down
31 changes: 30 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading