Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
--ignore=tests/test_tau_bench_airline_smoke.py \
--ignore=tests/pytest/test_svgbench.py \
--ignore=tests/pytest/test_livesvgbench.py \
--ignore=tests/remote_server/test_remote_fireworks.py \
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
--ignore=tests/logging/test_elasticsearch_direct_http_handler.py \
--ignore=eval_protocol/benchmarks/ \
Expand Down
25 changes: 21 additions & 4 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,28 @@
WeaveAdapter = None

try:
from .proxy import create_app, AuthProvider, AccountInfo
from .proxy import create_app, AuthProvider, AccountInfo # pyright: ignore[reportAssignmentType]
except ImportError:
create_app = None
AuthProvider = None
AccountInfo = None

def create_app(*args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)

class AuthProvider:
def __init__(self, *args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)

class AccountInfo:
def __init__(self, *args, **kwargs):
raise ImportError(
"Proxy functionality requires additional dependencies. "
"Please install with: pip install eval-protocol[proxy]"
)


warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
Expand Down
4 changes: 4 additions & 0 deletions eval_protocol/proxy/proxy_core/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ async def handle_chat_completion(
# Forward to LiteLLM
litellm_url = f"{config.litellm_url}/chat/completions"

print("litellm_url: ", litellm_url)
print("data: ", data)
print("headers: ", headers)
Comment thread
xzrderek marked this conversation as resolved.
Outdated

response = await client.post(
litellm_url,
json=data, # httpx will serialize and set correct Content-Length
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ langgraph_tools = [
"langchain-fireworks>=0.3.0",
]

proxy = [
"redis>=5.0.0",
"langfuse>=2.0.0",
"uuid6>=2025.0.0",
]

[project.scripts]
fireworks-reward = "eval_protocol.cli:main"
eval-protocol = "eval_protocol.cli:main"
Expand Down
1 change: 0 additions & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def rows() -> List[EvaluationRow]:
return [row, row, row]


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
Expand Down
16 changes: 16 additions & 0 deletions tests/remote_server/test_remote_fireworks_propagate_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from eval_protocol.models import EvaluationRow, Message, Status
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig


def find_available_port() -> int:
Expand Down Expand Up @@ -75,6 +78,18 @@ def setup_remote_server():
process.wait()


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)
Comment thread
xzrderek marked this conversation as resolved.


def rows() -> List[EvaluationRow]:
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
return [row]
Expand All @@ -88,6 +103,7 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=120,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
Expand Down
31 changes: 30 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading