Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/rollout.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
on:
workflow_dispatch:
inputs:
model:
description: 'Model to use'
completion_params:
description: 'JSON completion params (optional, includes model_kwargs)'
required: true
type: string
metadata:
Expand All @@ -18,6 +18,7 @@ on:
required: true
type: string


jobs:
rollout:
runs-on: ubuntu-latest
Expand All @@ -41,6 +42,6 @@ jobs:
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
run: |
python tests/github_actions/rollout_worker.py \
--model "${{ inputs.model }}" \
--completion-params '${{ inputs.completion_params }}' \
--metadata '${{ inputs.metadata }}' \
--model-base-url "${{ inputs.model_base_url }}"
8 changes: 6 additions & 2 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
from typing import Any, Callable, Dict, List, Optional

import json
import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
Expand Down Expand Up @@ -87,10 +87,14 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:

def _dispatch_workflow():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"

model = init_request.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")
payload = {
"ref": self.ref,
"inputs": {
"model": init_request.model,
"completion_params": json.dumps(init_request.completion_params),
"metadata": init_request.metadata.model_dump_json(),
"model_base_url": init_request.model_base_url,
},
Expand Down
30 changes: 17 additions & 13 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,25 @@ def build_init_request(
row_id=row.input_metadata.row_id,
)

# Extract model
model: Optional[str] = None
# Build completion_params from row and config
completion_params_dict: Dict[str, Any] = {}

# Start with config-level completion_params
if config.completion_params and isinstance(config.completion_params, dict):
completion_params_dict.update(config.completion_params)

# Override with row-specific completion_params
if row.input_metadata and row.input_metadata.completion_params:
model = row.input_metadata.completion_params.get("model")
if model is None and config.completion_params:
model = config.completion_params.get("model")
if model is None:
raise ValueError("Model must be provided in row.input_metadata.completion_params or config.completion_params")
row_cp = row.input_metadata.completion_params
if isinstance(row_cp, dict):
completion_params_dict.update(row_cp)

# Validate model is present
if not completion_params_dict.get("model"):
raise ValueError("Model must be provided in completion_params")

# Extract base_url from completion_params
completion_params_base_url: Optional[str] = None
if row.input_metadata and row.input_metadata.completion_params:
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
if completion_params_base_url is None and config.completion_params:
completion_params_base_url = config.completion_params.get("base_url")
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")

# Strip non-OpenAI fields from messages
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
Expand Down Expand Up @@ -123,7 +127,7 @@ def build_init_request(
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)

return InitRequest(
model=model,
completion_params=completion_params_dict,
messages=clean_messages,
tools=row.tools,
metadata=meta,
Expand Down
5 changes: 4 additions & 1 deletion eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class DataLoaderConfig(BaseModel):
class InitRequest(BaseModel):
"""Request model for POST /init endpoint."""

model: str
completion_params: Dict[str, Any] = Field(
default_factory=dict,
description="Completion parameters including model and optional model_kwargs, temperature, etc.",
)
elastic_search_config: Optional[ElasticsearchConfig] = None
messages: Optional[List[Message]] = None
tools: Optional[List[Dict[str, Any]]] = None
Expand Down
21 changes: 17 additions & 4 deletions tests/github_actions/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@ def main():
parser = argparse.ArgumentParser(description="GitHub Actions rollout worker")

# Required arguments from workflow inputs
parser.add_argument("--model", required=True, help="Model to use")
parser.add_argument("--completion-params", required=True, help="JSON completion params (includes model)")
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")

args = parser.parse_args()

# Parse the metadata
# Parse completion_params
try:
completion_params = json.loads(args.completion_params)
except Exception as e:
print(f"❌ Failed to parse completion_params: {e}")
exit(1)

model = completion_params.get("model")
if not model:
print("Error: model is required in completion_params")
exit(1)

try:
metadata = json.loads(args.metadata)
except Exception as e:
Expand All @@ -34,7 +45,7 @@ def main():
row_id = metadata["row_id"]

print(f"🚀 Starting rollout {rollout_id}")
print(f" Model: {args.model}")
print(f" Model: {model}")
print(f" Row ID: {row_id}")

dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
Expand All @@ -49,11 +60,13 @@ def main():
print(f" Messages: {len(messages)} messages")

try:
completion_kwargs = {"model": args.model, "messages": messages}
# Build completion kwargs from completion_params
completion_kwargs = {"messages": messages, **completion_params}

client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

print("📡 Calling OpenAI completion...")
print(f" Completion kwargs: {completion_kwargs}")
completion = client.chat.completions.create(**completion_kwargs)

print(f"✅ Rollout {rollout_id} completed successfully")
Expand Down
4 changes: 3 additions & 1 deletion tests/github_actions/test_github_actions_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def rows() -> List[EvaluationRow]:


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
@pytest.mark.parametrize(
"completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}]
)
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
Expand Down
14 changes: 9 additions & 5 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,21 @@ def _worker():
if not req.messages:
raise ValueError("messages is required")

completion_kwargs = {
"model": req.model,
"messages": req.messages,
}
model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")

# Spread all completion_params (model, temperature, max_tokens, etc.)
completion_kwargs = {"messages": req.messages, **req.completion_params}

if req.tools:
completion_kwargs["tools"] = req.tools

logger.info(f"Final completion_kwargs: {completion_kwargs}")

client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

logger.info(f"Sending completion request to model {req.model}")
logger.info(f"Sending completion request to model {model}")
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")

Expand Down
12 changes: 8 additions & 4 deletions tests/remote_server/remote_server_multi_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _worker():
if not req.messages:
raise ValueError("messages is required")

model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")

client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
Expand All @@ -41,10 +45,10 @@ def _worker():
]

# First completion (turns 1-2: initial user message + assistant response)
logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}")
logger.info(f"Turn 1-2: Sending initial completion request to model {model}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
**req.completion_params,
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
Expand All @@ -55,8 +59,8 @@ def _worker():
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
**req.completion_params,
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
Expand All @@ -67,8 +71,8 @@ def _worker():
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
**req.completion_params,
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
Expand Down
7 changes: 6 additions & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def rows() -> List[EvaluationRow]:
return [row, row, row]


@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
)
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
Expand All @@ -122,5 +126,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
)
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"

return row
Loading
Loading