Skip to content

Commit a04f57e

Browse files
committed
Merge branch 'main' into derekx/vercel-server-example
2 parents 60792be + 695632c commit a04f57e

File tree

12 files changed

+517
-48
lines changed

12 files changed

+517
-48
lines changed

.github/workflows/rollout.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
55
on:
66
workflow_dispatch:
77
inputs:
8-
model:
9-
description: 'Model to use'
8+
completion_params:
9+
description: 'JSON completion params (optional, includes model_kwargs)'
1010
required: true
1111
type: string
1212
metadata:
@@ -18,6 +18,7 @@ on:
1818
required: true
1919
type: string
2020

21+
2122
jobs:
2223
rollout:
2324
runs-on: ubuntu-latest
@@ -41,6 +42,6 @@ jobs:
4142
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
4243
run: |
4344
python tests/github_actions/rollout_worker.py \
44-
--model "${{ inputs.model }}" \
45+
--completion-params '${{ inputs.completion_params }}' \
4546
--metadata '${{ inputs.metadata }}' \
4647
--model-base-url "${{ inputs.model_base_url }}"

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from typing import Any, Callable, Dict, List, Optional
5-
5+
import json
66
import requests
77
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
@@ -87,10 +87,14 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8787

8888
def _dispatch_workflow():
8989
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
90+
91+
model = init_request.completion_params.get("model")
92+
if not model:
93+
raise ValueError("model is required in completion_params")
9094
payload = {
9195
"ref": self.ref,
9296
"inputs": {
93-
"model": init_request.model,
97+
"completion_params": json.dumps(init_request.completion_params),
9498
"metadata": init_request.metadata.model_dump_json(),
9599
"model_base_url": init_request.model_base_url,
96100
},

eval_protocol/pytest/tracing_utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,25 @@ def build_init_request(
8080
row_id=row.input_metadata.row_id,
8181
)
8282

83-
# Extract model
84-
model: Optional[str] = None
83+
# Build completion_params from row and config
84+
completion_params_dict: Dict[str, Any] = {}
85+
86+
# Start with config-level completion_params
87+
if config.completion_params and isinstance(config.completion_params, dict):
88+
completion_params_dict.update(config.completion_params)
89+
90+
# Override with row-specific completion_params
8591
if row.input_metadata and row.input_metadata.completion_params:
86-
model = row.input_metadata.completion_params.get("model")
87-
if model is None and config.completion_params:
88-
model = config.completion_params.get("model")
89-
if model is None:
90-
raise ValueError("Model must be provided in row.input_metadata.completion_params or config.completion_params")
92+
row_cp = row.input_metadata.completion_params
93+
if isinstance(row_cp, dict):
94+
completion_params_dict.update(row_cp)
95+
96+
# Validate model is present
97+
if not completion_params_dict.get("model"):
98+
raise ValueError("Model must be provided in completion_params")
9199

92100
# Extract base_url from completion_params
93-
completion_params_base_url: Optional[str] = None
94-
if row.input_metadata and row.input_metadata.completion_params:
95-
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
96-
if completion_params_base_url is None and config.completion_params:
97-
completion_params_base_url = config.completion_params.get("base_url")
101+
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
98102

99103
# Strip non-OpenAI fields from messages
100104
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
@@ -127,7 +131,7 @@ def build_init_request(
127131
api_key = os.environ.get("FIREWORKS_API_KEY")
128132

129133
return InitRequest(
130-
model=model,
134+
completion_params=completion_params_dict,
131135
messages=clean_messages,
132136
tools=row.tools,
133137
metadata=meta,

eval_protocol/types/remote_rollout_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ class DataLoaderConfig(BaseModel):
4444
class InitRequest(BaseModel):
4545
"""Request model for POST /init endpoint."""
4646

47-
model: str
47+
completion_params: Dict[str, Any] = Field(
48+
default_factory=dict,
49+
description="Completion parameters including model and optional model_kwargs, temperature, etc.",
50+
)
4851
messages: Optional[List[Message]] = None
4952
tools: Optional[List[Dict[str, Any]]] = None
5053

tests/github_actions/rollout_worker.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,24 @@ def main():
1717
parser = argparse.ArgumentParser(description="GitHub Actions rollout worker")
1818

1919
# Required arguments from workflow inputs
20-
parser.add_argument("--model", required=True, help="Model to use")
20+
parser.add_argument("--completion-params", required=True, help="JSON completion params (includes model)")
2121
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
2222
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")
2323

2424
args = parser.parse_args()
2525

26-
# Parse the metadata
26+
# Parse completion_params
27+
try:
28+
completion_params = json.loads(args.completion_params)
29+
except Exception as e:
30+
print(f"❌ Failed to parse completion_params: {e}")
31+
exit(1)
32+
33+
model = completion_params.get("model")
34+
if not model:
35+
print("Error: model is required in completion_params")
36+
exit(1)
37+
2738
try:
2839
metadata = json.loads(args.metadata)
2940
except Exception as e:
@@ -34,7 +45,7 @@ def main():
3445
row_id = metadata["row_id"]
3546

3647
print(f"🚀 Starting rollout {rollout_id}")
37-
print(f" Model: {args.model}")
48+
print(f" Model: {model}")
3849
print(f" Row ID: {row_id}")
3950

4051
dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
@@ -49,11 +60,13 @@ def main():
4960
print(f" Messages: {len(messages)} messages")
5061

5162
try:
52-
completion_kwargs = {"model": args.model, "messages": messages}
63+
# Build completion kwargs from completion_params
64+
completion_kwargs = {"messages": messages, **completion_params}
5365

5466
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
5567

5668
print("📡 Calling OpenAI completion...")
69+
print(f" Completion kwargs: {completion_kwargs}")
5770
completion = client.chat.completions.create(**completion_kwargs)
5871

5972
print(f"✅ Rollout {rollout_id} completed successfully")

tests/github_actions/test_github_actions_rollout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def rows() -> List[EvaluationRow]:
5454

5555

5656
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
57-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
57+
@pytest.mark.parametrize(
58+
"completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}]
59+
)
5860
@evaluation_test(
5961
data_loaders=DynamicDataLoader(
6062
generators=[rows],

tests/remote_server/remote_server.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,21 @@ def _worker():
3333
if not req.messages:
3434
raise ValueError("messages is required")
3535

36-
completion_kwargs = {
37-
"model": req.model,
38-
"messages": req.messages,
39-
}
36+
model = req.completion_params.get("model")
37+
if not model:
38+
raise ValueError("model is required in completion_params")
39+
40+
# Spread all completion_params (model, temperature, max_tokens, etc.)
41+
completion_kwargs = {"messages": req.messages, **req.completion_params}
4042

4143
if req.tools:
4244
completion_kwargs["tools"] = req.tools
4345

46+
logger.info(f"Final completion_kwargs: {completion_kwargs}")
47+
4448
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
4549

46-
logger.info(f"Sending completion request to model {req.model}")
50+
logger.info(f"Sending completion request to model {model}")
4751
completion = client.chat.completions.create(**completion_kwargs)
4852
logger.info(f"Completed response: {completion}")
4953

tests/remote_server/remote_server_multi_turn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def _worker():
2929
if not req.messages:
3030
raise ValueError("messages is required")
3131

32+
model = req.completion_params.get("model")
33+
if not model:
34+
raise ValueError("model is required in completion_params")
35+
3236
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
3337

3438
# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
@@ -41,10 +45,10 @@ def _worker():
4145
]
4246

4347
# First completion (turns 1-2: initial user message + assistant response)
44-
logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}")
48+
logger.info(f"Turn 1-2: Sending initial completion request to model {model}")
4549
completion = client.chat.completions.create(
46-
model=req.model,
4750
messages=conversation_history, # type: ignore
51+
**req.completion_params,
4852
)
4953
assistant_message = completion.choices[0].message
5054
assistant_content = assistant_message.content or ""
@@ -55,8 +59,8 @@ def _worker():
5559
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
5660
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
5761
completion = client.chat.completions.create(
58-
model=req.model,
5962
messages=conversation_history, # type: ignore
63+
**req.completion_params,
6064
)
6165
assistant_message = completion.choices[0].message
6266
assistant_content = assistant_message.content or ""
@@ -67,8 +71,8 @@ def _worker():
6771
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
6872
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
6973
completion = client.chat.completions.create(
70-
model=req.model,
7174
messages=conversation_history, # type: ignore
75+
**req.completion_params,
7276
)
7377
assistant_message = completion.choices[0].message
7478
assistant_content = assistant_message.content or ""

tests/remote_server/test_remote_fireworks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def rows() -> List[EvaluationRow]:
9898
return [row, row, row]
9999

100100

101-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
101+
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
102+
@pytest.mark.parametrize(
103+
"completion_params",
104+
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
105+
)
102106
@evaluation_test(
103107
data_loaders=DynamicDataLoader(
104108
generators=[rows],
@@ -122,5 +126,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
122126
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
123127
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
124128
)
129+
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
125130

126131
return row

0 commit comments

Comments
 (0)