@@ -17,13 +17,24 @@ def main():
1717 parser = argparse .ArgumentParser (description = "GitHub Actions rollout worker" )
1818
1919 # Required arguments from workflow inputs
20- parser .add_argument ("--model " , required = True , help = "Model to use " )
20+ parser .add_argument ("--completion-params " , required = True , help = "JSON completion params (includes model) " )
2121 parser .add_argument ("--metadata" , required = True , help = "JSON serialized metadata object" )
2222 parser .add_argument ("--model-base-url" , required = True , help = "Base URL for the model API" )
2323
2424 args = parser .parse_args ()
2525
26- # Parse the metadata
26+ # Parse completion_params
27+ try :
28+ completion_params = json .loads (args .completion_params )
29+ except Exception as e :
30+ print (f"❌ Failed to parse completion_params: { e } " )
31+ exit (1 )
32+
33+ model = completion_params .get ("model" )
34+ if not model :
35+ print ("Error: model is required in completion_params" )
36+ exit (1 )
37+
2738 try :
2839 metadata = json .loads (args .metadata )
2940 except Exception as e :
@@ -34,7 +45,7 @@ def main():
3445 row_id = metadata ["row_id" ]
3546
3647 print (f"🚀 Starting rollout { rollout_id } " )
37- print (f" Model: { args . model } " )
48+ print (f" Model: { model } " )
3849 print (f" Row ID: { row_id } " )
3950
4051 dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
@@ -49,11 +60,13 @@ def main():
4960 print (f" Messages: { len (messages )} messages" )
5061
5162 try :
52- completion_kwargs = {"model" : args .model , "messages" : messages }
63+ # Build completion kwargs from completion_params
64+ completion_kwargs = {"messages" : messages , ** completion_params }
5365
5466 client = OpenAI (base_url = args .model_base_url , api_key = os .environ .get ("FIREWORKS_API_KEY" ))
5567
5668 print ("📡 Calling OpenAI completion..." )
69+ print (f" Completion kwargs: { completion_kwargs } " )
5770 completion = client .chat .completions .create (** completion_kwargs )
5871
5972 print (f"✅ Rollout { rollout_id } completed successfully" )
0 commit comments