Skip to content

Commit 0cbb564

Browse files
committed
switch to prompt
1 parent 0ca752b commit 0cbb564

File tree

3 files changed

+22
-56
lines changed

3 files changed

+22
-56
lines changed

.github/workflows/rollout.yml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,10 @@ on:
1111
description: 'Rollout ID for tracking'
1212
required: true
1313
type: string
14-
messages_b64:
15-
description: 'Base64 encoded JSON messages array'
14+
prompt:
15+
description: 'User prompt for the rollout'
1616
required: true
1717
type: string
18-
tools_b64:
19-
description: 'Base64 encoded JSON tools array (optional)'
20-
required: false
21-
type: string
2218

2319
jobs:
2420
rollout:
@@ -46,8 +42,7 @@ jobs:
4642
python tests/github_actions/rollout_worker.py \
4743
--model "${{ inputs.model }}" \
4844
--rollout-id "${{ inputs.rollout_id }}" \
49-
--messages-b64 "${{ inputs.messages_b64 }}" \
50-
${{ inputs.tools_b64 && format('--tools-b64 "{0}"', inputs.tools_b64) || '' }}
45+
--prompt "${{ inputs.prompt }}"
5146
5247
- name: Upload rollout trace
5348
uses: actions/upload-artifact@v4

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import base64
32
import json
43
import os
54
import tempfile
@@ -63,32 +62,27 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
6362
if model is None:
6463
raise ValueError("Model must be provided")
6564

66-
# Clean and encode messages
67-
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
68-
clean_messages = []
69-
for m in row.messages:
70-
if hasattr(m, "model_dump"):
71-
md = m.model_dump()
72-
elif isinstance(m, dict):
73-
md = m
74-
else:
75-
md = {
76-
"role": getattr(m, "role", None),
77-
"content": getattr(m, "content", None),
78-
"tool_calls": getattr(m, "tool_calls", None),
79-
"tool_call_id": getattr(m, "tool_call_id", None),
80-
"name": getattr(m, "name", None),
81-
}
82-
clean_messages.append({k: v for k, v in md.items() if k in allowed_fields and v is not None})
65+
# Extract user prompt (first user message)
66+
user_prompt = None
67+
for msg in row.messages:
68+
if hasattr(msg, "role"):
69+
if msg.role == "user":
70+
user_prompt = msg.content
71+
break
72+
elif isinstance(msg, dict):
73+
if msg.get("role") == "user":
74+
user_prompt = msg.get("content")
75+
break
76+
77+
if not user_prompt:
78+
raise ValueError("At least one user message is required")
8379

8480
# Prepare workflow inputs
8581
inputs = {
8682
"model": model,
8783
"rollout_id": row.execution_metadata.rollout_id,
88-
"messages_b64": base64.b64encode(json.dumps(clean_messages).encode()).decode(),
84+
"prompt": user_prompt,
8985
}
90-
if row.tools:
91-
inputs["tools_b64"] = base64.b64encode(json.dumps(row.tools).encode()).decode()
9286

9387
# Dispatch workflow
9488
def _dispatch():

tests/github_actions/rollout_worker.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88

99
import argparse
10-
import base64
1110
import json
1211
import os
1312

@@ -20,44 +19,24 @@ def main():
2019
# Required arguments from workflow inputs
2120
parser.add_argument("--model", required=True, help="Model to use")
2221
parser.add_argument("--rollout-id", required=True, help="Rollout ID for tracking")
23-
parser.add_argument("--messages-b64", required=True, help="Base64 encoded JSON messages")
24-
parser.add_argument("--tools-b64", required=False, help="Base64 encoded JSON tools (optional)")
22+
parser.add_argument("--prompt", required=True, help="User prompt for the rollout")
2523

2624
args = parser.parse_args()
2725

2826
print(f"🚀 Starting rollout {args.rollout_id}")
2927
print(f" Model: {args.model}")
28+
print(f" Prompt: {args.prompt}")
3029

31-
# Decode messages and tools
32-
try:
33-
messages = json.loads(base64.b64decode(args.messages_b64).decode("utf-8"))
34-
tools = None
35-
if args.tools_b64:
36-
tools = json.loads(base64.b64decode(args.tools_b64).decode("utf-8"))
37-
except Exception as e:
38-
print(f"❌ Failed to decode inputs: {e}")
39-
# Save error trace
40-
error_data = {
41-
"status": "error",
42-
"rollout_id": args.rollout_id,
43-
"model": args.model,
44-
"messages": [],
45-
"error": f"Failed to decode inputs: {e}",
46-
}
47-
with open(f"rollout_trace_{args.rollout_id}.json", "w") as f:
48-
json.dump(error_data, f, indent=2)
49-
exit(1)
30+
# Build messages array
31+
messages = [{"role": "user", "content": args.prompt}]
5032

5133
print(f" Messages: {len(messages)} messages")
52-
print(f" Tools: {len(tools) if tools else 0} tools")
5334

5435
# Perform the rollout
5536
conversation = messages.copy()
5637

5738
try:
5839
completion_kwargs = {"model": args.model, "messages": messages}
59-
if tools:
60-
completion_kwargs["tools"] = tools
6140

6241
client = OpenAI(api_key=os.environ.get("FIREWORKS_API_KEY"))
6342

@@ -76,7 +55,6 @@ def main():
7655
"rollout_id": args.rollout_id,
7756
"model": args.model,
7857
"messages": conversation,
79-
"tools": tools,
8058
"usage": completion.usage.model_dump() if completion.usage else None,
8159
}
8260

@@ -91,7 +69,6 @@ def main():
9169
"rollout_id": args.rollout_id,
9270
"model": args.model,
9371
"messages": conversation,
94-
"tools": tools,
9572
"error": str(e),
9673
}
9774

0 commit comments

Comments
 (0)