11import os
22import threading
3- from typing import Any , Dict
3+ from typing import Any , Dict , List
44
55import uvicorn
66from fastapi import FastAPI , HTTPException
7- from pydantic import BaseModel
87from langfuse .openai import openai # pyright: ignore[reportPrivateImportUsage]
98
10-
11- app = FastAPI ()
9+ from eval_protocol .types .remote_rollout_processor import (
10+ InitRequest ,
11+ StatusResponse ,
12+ create_langfuse_config_tags ,
13+ )
14+ from eval_protocol .models import Message
1215
1316
14- class InitRequest (BaseModel ):
15- rollout_id : str
16- model : str
17- messages : list [dict ]
18- tools : list [dict ] | None = None
19- metadata : dict
20- num_turns : int = 2
17+ app = FastAPI ()
2118
2219
2320_STATE : Dict [str , Dict [str , Any ]] = {}
2421
25-
2622ALLOWED_MESSAGE_FIELDS = {"role" , "content" , "tool_calls" , "tool_call_id" , "name" }
2723
2824
29- def _clean_messages_for_api (messages : list [ dict ]) -> list [dict ]:
25+ def _clean_messages_for_api (messages : List [ Message ]) -> list [dict ]:
3026 cleaned : list [dict ] = []
3127 for msg in messages :
32- if not isinstance (msg , dict ):
33- continue
34- cm = {k : v for k , v in msg .items () if k in ALLOWED_MESSAGE_FIELDS and v is not None }
28+ msg_dict = msg .model_dump ()
29+ cm = {k : v for k , v in msg_dict .items () if k in ALLOWED_MESSAGE_FIELDS and v is not None }
3530 # Some providers dislike empty content on assistant messages; keep if present
3631 cleaned .append (cm )
3732 return cleaned
@@ -42,53 +37,25 @@ def init(req: InitRequest):
4237 # Persist state
4338 _STATE [req .rollout_id ] = {"terminated" : False }
4439
45- # Kick off worker thread that runs multi -turn chat via Langfuse OpenAI integration
40+ # Kick off worker thread that does a single -turn chat via Langfuse OpenAI integration
4641 def _worker ():
4742 try :
48- # Prepare tags for Langfuse filtering
49- metadata = {
50- "langfuse_tags" : [
51- f"invocation_id:{ req .metadata .get ('invocation_id' )} " ,
52- f"experiment_id:{ req .metadata .get ('experiment_id' )} " ,
53- f"rollout_id:{ req .metadata .get ('rollout_id' )} " ,
54- f"run_id:{ req .metadata .get ('run_id' )} " ,
55- f"row_id:{ req .metadata .get ('row_id' )} " ,
56- ]
43+ metadata = {"langfuse_tags" : create_langfuse_config_tags (req )}
44+
45+ completion_kwargs = {
46+ "model" : req .model ,
47+ "messages" : _clean_messages_for_api (req .messages ),
48+ "metadata" : metadata ,
5749 }
5850
59- messages = req .messages
60-
61- # Simulate N-1 assistant turns (single-shot or simple echo)
62- for _ in range (max (1 , req .num_turns )):
63- completion_kwargs = {
64- "model" : req .model ,
65- "messages" : _clean_messages_for_api (messages ),
66- "metadata" : metadata ,
67- }
68-
69- if req .tools :
70- completion_kwargs ["tools" ] = req .tools
71-
72- completion = openai .chat .completions .create (** completion_kwargs )
73- assistant_message = completion .choices [0 ].message
74-
75- # Convert to dict format for next turn
76- assistant_dict = {"role" : "assistant" , "content" : assistant_message .content }
77- if assistant_message .tool_calls :
78- assistant_dict ["tool_calls" ] = [
79- {
80- "id" : tc .id ,
81- "type" : tc .type ,
82- "function" : {"name" : tc .function .name , "arguments" : tc .function .arguments },
83- }
84- for tc in assistant_message .tool_calls
85- ]
86-
87- # Append assistant for next turn
88- messages = messages + [assistant_dict ]
89-
90- except Exception :
51+ if req .tools :
52+ completion_kwargs ["tools" ] = req .tools
53+
54+ completion = openai .chat .completions .create (** completion_kwargs )
55+
56+ except Exception as e :
9157 # Best-effort; mark as done even on error to unblock polling
58+ print (f"❌ Error in rollout { req .rollout_id } : { e } " )
9259 pass
9360 finally :
9461 _STATE [req .rollout_id ]["terminated" ] = True
@@ -98,12 +65,12 @@ def _worker():
9865 return {"ok" : True }
9966
10067
101- @app .get ("/status" )
102- def status (rollout_id : str ):
68+ @app .get ("/status" , response_model = StatusResponse )
69+ def status (rollout_id : str ) -> StatusResponse :
10370 st = _STATE .get (rollout_id )
10471 if not st :
10572 raise HTTPException (status_code = 404 , detail = "unknown rollout_id" )
106- return { " terminated" : bool (st .get ("terminated" , False ))}
73+ return StatusResponse ( terminated = bool (st .get ("terminated" , False )))
10774
10875
10976def main ():
0 commit comments