Skip to content

Commit 114d714

Browse files
authored
Merge branch 'main' into completionparams
2 parents f3e37d5 + e769b6a commit 114d714

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2986
-308
lines changed

.github/workflows/rollout.yml

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
name: Eval Protocol Rollout
22

3-
run-name: rollout:${{ inputs.rollout_id }}
3+
run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
44

55
on:
66
workflow_dispatch:
77
inputs:
88
model:
9-
description: 'Model to use for the rollout'
9+
description: 'Model to use'
1010
required: true
1111
type: string
12-
rollout_id:
13-
description: 'Rollout ID for tracking'
12+
metadata:
13+
description: 'JSON serialized metadata object'
1414
required: true
1515
type: string
16-
prompt:
17-
description: 'User prompt for the rollout'
16+
model_base_url:
17+
description: 'Base URL for the model API'
1818
required: true
1919
type: string
2020

2121
jobs:
2222
rollout:
2323
runs-on: ubuntu-latest
24-
name: rollout-${{ inputs.rollout_id }}
2524

2625
steps:
2726
- name: Checkout code
@@ -43,13 +42,5 @@ jobs:
4342
run: |
4443
python tests/github_actions/rollout_worker.py \
4544
--model "${{ inputs.model }}" \
46-
--rollout-id "${{ inputs.rollout_id }}" \
47-
--prompt "${{ inputs.prompt }}"
48-
49-
- name: Upload rollout trace
50-
uses: actions/upload-artifact@v4
51-
if: always() # Upload even if the rollout failed
52-
with:
53-
name: rollout-trace-${{ inputs.rollout_id }}
54-
path: rollout_trace_${{ inputs.rollout_id }}.json
55-
retention-days: 7
45+
--metadata '${{ inputs.metadata }}' \
46+
--model-base-url "${{ inputs.model_base_url }}"

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ env.bak/
105105
venv.bak/
106106
*.backup
107107

108+
# Secrets
109+
secrets.yaml
110+
108111
# Spyder project settings
109112
.spyderproject
110113
.spyproject

eval_protocol/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@
2929
from .resources import create_llm_resource
3030
from .reward_function import RewardFunction
3131
from .typed_interface import reward_function
32-
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
33-
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
32+
from .quickstart.aha_judge import aha_judge
33+
from .utils.evaluation_row_utils import (
34+
multi_turn_assistant_to_ground_truth,
35+
assistant_to_ground_truth,
36+
filter_longest_conversation,
37+
)
38+
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
3439
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
3540
from .pytest.parameterize import DefaultParameterIdGenerator
3641
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
@@ -74,6 +79,14 @@
7479
except ImportError:
7580
WeaveAdapter = None
7681

82+
try:
83+
from .proxy import create_app, AuthProvider, AccountInfo
84+
except ImportError:
85+
create_app = None
86+
AuthProvider = None
87+
AccountInfo = None
88+
89+
7790
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
7891

7992
__all__ = [
@@ -85,6 +98,7 @@
8598
"DataLoaderConfig",
8699
"Status",
87100
"RemoteRolloutProcessor",
101+
"GithubActionRolloutProcessor",
88102
"InputMetadata",
89103
"EvaluationRow",
90104
"DefaultParameterIdGenerator",
@@ -93,6 +107,7 @@
93107
"aha_judge",
94108
"multi_turn_assistant_to_ground_truth",
95109
"assistant_to_ground_truth",
110+
"filter_longest_conversation",
96111
"evaluation_test",
97112
"SingleTurnRolloutProcessor",
98113
"OpenAIResponsesAdapter",
@@ -137,6 +152,10 @@
137152
"RolloutMetadata",
138153
"StatusResponse",
139154
"create_langfuse_config_tags",
155+
# Proxy
156+
"create_app",
157+
"AuthProvider",
158+
"AccountInfo",
140159
]
141160

142161
from . import _version

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from __future__ import annotations
88
import logging
99
import requests
10-
import time
1110
from datetime import datetime
1211
from typing import Any, Dict, List, Optional, Protocol
12+
import os
1313

1414
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1515
from .base import BaseAdapter
@@ -343,15 +343,17 @@ def get_evaluation_rows(
343343
# Remove None values
344344
params = {k: v for k, v in params.items() if v is not None}
345345

346-
# Make request to proxy
346+
# Make request to proxy (using pointwise for efficiency)
347347
if self.project_id:
348-
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
348+
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces/pointwise"
349349
else:
350-
url = f"{self.base_url}/v1/traces"
350+
url = f"{self.base_url}/v1/traces/pointwise"
351+
352+
headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"}
351353

352354
result = None
353355
try:
354-
response = requests.get(url, params=params, timeout=self.timeout)
356+
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
355357
response.raise_for_status()
356358
result = response.json()
357359
except requests.exceptions.HTTPError as e:
@@ -365,7 +367,7 @@ def get_evaluation_rows(
365367
except Exception: # In case e.response.json() fails
366368
error_msg = f"Proxy error: {e.response.text}"
367369

368-
logger.error("Failed to fetch traces from proxy: %s", error_msg)
370+
logger.error("Failed to fetch traces from proxy (HTTP %s): %s", e.response.status_code, error_msg)
369371
return eval_rows
370372
except requests.exceptions.RequestException as e:
371373
# Non-HTTP errors (network issues, timeouts, etc.)

eval_protocol/mcp/execution/policy.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Rewritten to use LiteLLM for unified retry logic, caching, and provider support.
66
"""
77

8-
import asyncio
9-
import json
108
import logging
119
import os
12-
from abc import ABC, abstractmethod
13-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Literal, Optional
1411

1512
import litellm
16-
from litellm import acompletion, completion
13+
from litellm import acompletion
14+
from litellm.types.utils import ModelResponse
15+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
1716
from litellm.caching.caching import Cache
1817
from litellm.caching.dual_cache import DualCache
1918
from litellm.caching.in_memory_cache import InMemoryCache
@@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[
194193
request_params["tools"] = tools
195194

196195
try:
197-
response = await acompletion(model=self.model_id, **request_params)
196+
if request_params.get("stream") is True:
197+
chunks = []
198+
stream = await acompletion(model=self.model_id, **request_params)
199+
200+
assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper"
201+
202+
async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues]
203+
chunks.append(chunk)
204+
response = litellm.stream_chunk_builder(chunks, messages)
205+
else:
206+
response = await acompletion(model=self.model_id, **request_params)
207+
208+
assert response is not None, "Response is None"
209+
assert isinstance(response, ModelResponse), "Response should be ModelResponse"
198210

199211
# Log cache hit/miss for monitoring
200212
hidden = getattr(response, "_hidden_params", {})

eval_protocol/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,9 @@ class EvaluationRow(BaseModel):
598598
model_config = ConfigDict(extra="allow")
599599

600600
# Core OpenAI ChatCompletion compatible conversation data
601-
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
601+
messages: List[Message] = Field(
602+
default_factory=list, description="List of messages in the conversation. Also known as a trajectory."
603+
)
602604

603605
# Tool and function call information
604606
tools: Optional[List[Dict[str, Any]]] = Field(

eval_protocol/proxy/.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# In order to set other model providers keys for proxy, make a copy, rename to .env, and fill here
2+
OPENAI_API_KEY=sk-proj-xxx
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Metadata Extraction Gateway - Sits in front of LiteLLM
2+
FROM python:3.11-slim
3+
4+
WORKDIR /app
5+
6+
# Prevent Python from buffering stdout/stderr
7+
ENV PYTHONUNBUFFERED=1
8+
9+
# Copy requirements file
10+
COPY ./requirements.txt /app/requirements.txt
11+
12+
# Install dependencies
13+
RUN pip install --no-cache-dir -r requirements.txt
14+
15+
# Copy the proxy package
16+
COPY ./proxy_core /app/proxy_core
17+
18+
# Expose port
19+
EXPOSE 4000
20+
21+
# Run the gateway as a module
22+
# LITELLM_URL will be set by environment (docker-compose or Cloud Run)
23+
CMD ["python", "-m", "proxy_core.main"]

0 commit comments

Comments
 (0)