Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/rollout.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ on:
description: 'Base URL for the model API'
required: true
type: string
completion_params:
description: 'JSON completion params (optional, includes model_kwargs)'
required: false
type: string

jobs:
rollout:
Expand All @@ -43,4 +47,5 @@ jobs:
python tests/github_actions/rollout_worker.py \
--model "${{ inputs.model }}" \
--metadata '${{ inputs.metadata }}' \
--model-base-url "${{ inputs.model_base_url }}"
--model-base-url "${{ inputs.model_base_url }}" \
${{ inputs.completion_params && format('--completion-params ''{0}''', inputs.completion_params) || '' }}
9 changes: 7 additions & 2 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
from typing import Any, Callable, Dict, List, Optional

import json
import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
Expand Down Expand Up @@ -87,12 +87,17 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:

def _dispatch_workflow():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"

model = init_request.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")
payload = {
"ref": self.ref,
"inputs": {
"model": init_request.model,
"model": model,
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated
"metadata": init_request.metadata.model_dump_json(),
"model_base_url": init_request.model_base_url,
"completion_params": json.dumps(init_request.completion_params),
},
}
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)
Expand Down
30 changes: 17 additions & 13 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,25 @@ def build_init_request(
row_id=row.input_metadata.row_id,
)

# Extract model
model: Optional[str] = None
# Build completion_params from row and config
completion_params_dict: Dict[str, Any] = {}

# Start with config-level completion_params
if config.completion_params and isinstance(config.completion_params, dict):
completion_params_dict.update(config.completion_params)

# Override with row-specific completion_params
if row.input_metadata and row.input_metadata.completion_params:
model = row.input_metadata.completion_params.get("model")
if model is None and config.completion_params:
model = config.completion_params.get("model")
if model is None:
raise ValueError("Model must be provided in row.input_metadata.completion_params or config.completion_params")
row_cp = row.input_metadata.completion_params
if isinstance(row_cp, dict):
completion_params_dict.update(row_cp)

# Validate model is present
if not completion_params_dict.get("model"):
raise ValueError("Model must be provided in completion_params")

# Extract base_url from completion_params
completion_params_base_url: Optional[str] = None
if row.input_metadata and row.input_metadata.completion_params:
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
if completion_params_base_url is None and config.completion_params:
completion_params_base_url = config.completion_params.get("base_url")
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")

# Strip non-OpenAI fields from messages
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
Expand Down Expand Up @@ -124,7 +128,7 @@ def build_init_request(
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)

return InitRequest(
model=model,
completion_params=completion_params_dict,
messages=clean_messages,
tools=row.tools,
metadata=meta,
Expand Down
5 changes: 4 additions & 1 deletion eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class DataLoaderConfig(BaseModel):
class InitRequest(BaseModel):
"""Request model for POST /init endpoint."""

model: str
completion_params: Dict[str, Any] = Field(
default_factory=dict,
description="Completion parameters including model and optional model_kwargs, temperature, etc.",
)
elastic_search_config: Optional[ElasticsearchConfig] = None
messages: Optional[List[Message]] = None
tools: Optional[List[Dict[str, Any]]] = None
Expand Down
18 changes: 18 additions & 0 deletions tests/github_actions/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@ def main():

# Required arguments from workflow inputs
parser.add_argument("--model", required=True, help="Model to use")
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated
parser.add_argument("--completion-params", required=False, help="JSON completion params (optional)")
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")

args = parser.parse_args()

# Parse the metadata
completion_params = {}
if args.completion_params:
try:
completion_params = json.loads(args.completion_params)
except Exception as e:
print(f"⚠️ Failed to parse completion_params: {e}")
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated

try:
metadata = json.loads(args.metadata)
except Exception as e:
Expand All @@ -50,10 +58,20 @@ def main():

try:
completion_kwargs = {"model": args.model, "messages": messages}
# Parse and apply completion_params if provided
if args.completion_params:
try:
cp = json.loads(args.completion_params)
if cp.get("model_kwargs"):
completion_kwargs.update(cp["model_kwargs"])
print(f" Applied model_kwargs: {cp.get('model_kwargs')}")
except Exception as e:
print(f"⚠️ Failed to parse completion_params: {e}")
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated

client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

print("📡 Calling OpenAI completion...")
print(f" Completion kwargs: {completion_kwargs}")
completion = client.chat.completions.create(**completion_kwargs)

print(f"✅ Rollout {rollout_id} completed successfully")
Expand Down
5 changes: 4 additions & 1 deletion tests/github_actions/test_github_actions_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def rows() -> List[EvaluationRow]:


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated
)
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
Expand Down
16 changes: 14 additions & 2 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,29 @@ def _worker():
if not req.messages:
raise ValueError("messages is required")

model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")

completion_kwargs = {
"model": req.model,
"model": model,
"messages": req.messages,
}

# Apply model_kwargs if present
if req.completion_params.get("model_kwargs"):
model_kwargs = req.completion_params["model_kwargs"]
if isinstance(model_kwargs, dict):
completion_kwargs.update(model_kwargs)

if req.tools:
completion_kwargs["tools"] = req.tools

logger.info(f"Final completion_kwargs: {completion_kwargs}")

client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

logger.info(f"Sending completion request to model {req.model}")
logger.info(f"Sending completion request to model {model}")
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")

Expand Down
20 changes: 15 additions & 5 deletions tests/remote_server/remote_server_multi_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,17 @@ def _worker():
if not req.messages:
raise ValueError("messages is required")

model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")

client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

# Apply model_kwargs if present
if req.completion_params.get("model_kwargs"):
model_kwargs = req.completion_params["model_kwargs"]
if isinstance(model_kwargs, dict):
completion_kwargs.update(model_kwargs)
# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
# Convert Message objects to dicts for OpenAI API
conversation_history = [{"role": m.role, "content": m.content} for m in req.messages]
Expand All @@ -44,10 +53,11 @@ def _worker():
]

# First completion (turns 1-2: initial user message + assistant response)
logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}")
logger.info(f"Turn 1-2: Sending initial completion request to model {model}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
model=model,
messages=conversation_history, # type: ignore,
**completion_kwargs,
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
Expand All @@ -58,7 +68,7 @@ def _worker():
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
completion = client.chat.completions.create(
model=req.model,
model=model,
messages=conversation_history, # type: ignore
)
assistant_message = completion.choices[0].message
Expand All @@ -70,7 +80,7 @@ def _worker():
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
completion = client.chat.completions.create(
model=req.model,
model=model,
messages=conversation_history, # type: ignore
)
assistant_message = completion.choices[0].message
Expand Down
8 changes: 7 additions & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def rows() -> List[EvaluationRow]:


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated
)
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
Expand All @@ -82,5 +85,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
)
assert row.input_metadata.completion_params["model_kwargs"] == {"temperature": 0.5}, (
Comment thread
shreymodi1 marked this conversation as resolved.
Outdated
"Row should have correct model_kwargs"
)

return row
Loading
Loading