Skip to content

Commit c31771c

Browse files
committed
update test
1 parent f7755d5 commit c31771c

File tree

5 files changed

+42
-34
lines changed

5 files changed

+42
-34
lines changed

.github/workflows/rollout.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@ on:
1313
description: 'JSON serialized metadata object'
1414
required: true
1515
type: string
16-
messages:
17-
description: 'JSON serialized messages array'
18-
required: true
19-
type: string
20-
tools:
21-
description: 'JSON serialized tools array'
22-
required: false
23-
type: string
2416
model_base_url:
2517
description: 'Base URL for the model API'
2618
required: true
@@ -51,6 +43,4 @@ jobs:
5143
python tests/github_actions/rollout_worker.py \
5244
--model "${{ inputs.model }}" \
5345
--metadata '${{ inputs.metadata }}' \
54-
--messages '${{ inputs.messages }}' \
55-
--tools '${{ inputs.tools }}' \
5646
--model-base-url "${{ inputs.model_base_url }}"

eval_protocol/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,9 @@ class EvaluationRow(BaseModel):
598598
model_config = ConfigDict(extra="allow")
599599

600600
# Core OpenAI ChatCompletion compatible conversation data
601-
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
601+
messages: List[Message] = Field(
602+
default_factory=list, description="List of messages in the conversation. Also known as a trajectory."
603+
)
602604

603605
# Tool and function call information
604606
tools: Optional[List[Dict[str, Any]]] = Field(

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ def _dispatch_workflow():
162162
"inputs": {
163163
"model": init_request.model,
164164
"metadata": init_request.metadata.model_dump_json(),
165-
"messages": json.dumps(init_request.messages),
166-
"tools": json.dumps(init_request.tools),
167165
"model_base_url": init_request.model_base_url,
168166
},
169167
}

tests/github_actions/rollout_worker.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import os
1212

1313
from openai import OpenAI
14-
from eval_protocol.types.remote_rollout_processor import InitRequest
1514

1615

1716
def main():
@@ -20,44 +19,42 @@ def main():
2019
# Required arguments from workflow inputs
2120
parser.add_argument("--model", required=True, help="Model to use")
2221
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
23-
parser.add_argument("--messages", required=True, help="JSON serialized messages array")
24-
parser.add_argument("--tools", required=False, help="JSON serialized tools array")
2522
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")
2623

2724
args = parser.parse_args()
2825

29-
# Parse the JSON inputs
26+
# Parse the metadata
3027
try:
3128
metadata = json.loads(args.metadata)
32-
messages = json.loads(args.messages)
33-
tools = json.loads(args.tools) if args.tools else None
3429
except Exception as e:
35-
print(f"❌ Failed to parse JSON inputs: {e}")
30+
print(f"❌ Failed to parse metadata: {e}")
3631
exit(1)
3732

3833
rollout_id = metadata["rollout_id"]
34+
row_id = metadata["row_id"]
35+
3936
print(f"🚀 Starting rollout {rollout_id}")
4037
print(f" Model: {args.model}")
41-
print(f" Messages: {len(messages)} messages")
38+
print(f" Row ID: {row_id}")
39+
40+
dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
41+
"What is the capital of France?",
42+
"What is the capital of Germany?",
43+
"What is the capital of Italy?",
44+
]
4245

43-
# Perform the rollout
44-
conversation = messages.copy()
46+
user_content = dataset[int(row_id)]
47+
messages = [{"role": "user", "content": user_content}]
48+
49+
print(f" Messages: {len(messages)} messages")
4550

4651
try:
4752
completion_kwargs = {"model": args.model, "messages": messages}
48-
if tools:
49-
completion_kwargs["tools"] = tools
5053

5154
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
5255

5356
print("📡 Calling OpenAI completion...")
5457
completion = client.chat.completions.create(**completion_kwargs)
55-
print("✅ Received response")
56-
57-
# Add assistant response to conversation
58-
if completion.choices and completion.choices[0].message:
59-
assistant_message = completion.choices[0].message.model_dump()
60-
conversation.append(assistant_message)
6158

6259
print(f"✅ Rollout {rollout_id} completed successfully")
6360

tests/github_actions/test_github_actions_rollout.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
import pytest
1919

2020
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
21-
from eval_protocol.models import EvaluationRow, Message
21+
from eval_protocol.models import EvaluationRow, InputMetadata
2222
from eval_protocol.pytest import evaluation_test
2323
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
24+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
25+
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
26+
from eval_protocol.quickstart.utils import filter_longest_conversation
2427

2528
ROLLOUT_IDS = set()
2629

@@ -35,9 +38,26 @@ def check_rollout_coverage():
3538
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
3639

3740

41+
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
42+
global ROLLOUT_IDS # Track all rollout_ids we've seen
43+
ROLLOUT_IDS.add(config.rollout_id)
44+
45+
base_url = config.model_base_url or "https://tracing.fireworks.ai"
46+
adapter = FireworksTracingAdapter(base_url=base_url)
47+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
48+
49+
50+
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
51+
return DynamicDataLoader(
52+
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
53+
)
54+
55+
3856
def rows() -> List[EvaluationRow]:
39-
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
40-
return [row, row, row]
57+
return [
58+
EvaluationRow(input_metadata=InputMetadata(row_id=str(i)))
59+
for i in range(3) # In this example we use index to associate rows.
60+
]
4161

4262

4363
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@@ -52,6 +72,7 @@ def rows() -> List[EvaluationRow]:
5272
workflow_id="rollout.yml", # or you can use numeric ID like "12345678"
5373
ref=os.getenv("GITHUB_REF", "main"),
5474
timeout_seconds=300,
75+
output_data_loader=fireworks_output_data_loader,
5576
),
5677
)
5778
async def test_github_actions_rollout_direct_artifacts(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)