-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathremote_server.py
More file actions
73 lines (52 loc) · 1.98 KB
/
remote_server.py
File metadata and controls
73 lines (52 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import threading
from typing import Any, Dict, List
import uvicorn
from fastapi import FastAPI, HTTPException
from langfuse.openai import openai # pyright: ignore[reportPrivateImportUsage]
from eval_protocol.types.remote_rollout_processor import (
InitRequest,
StatusResponse,
create_langfuse_config_tags,
)
from eval_protocol.models import Message
app = FastAPI()
_STATE: Dict[str, Dict[str, Any]] = {}
@app.post("/init")
def init(req: InitRequest):
# Persist state
_STATE[req.metadata.rollout_id] = {"terminated": False}
# Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration
def _worker():
try:
metadata = {"langfuse_tags": create_langfuse_config_tags(req)}
if not req.messages:
raise ValueError("messages is required")
completion_kwargs = {
"model": req.model,
"messages": req.messages,
"metadata": metadata,
}
if req.tools:
completion_kwargs["tools"] = req.tools
completion = openai.chat.completions.create(**completion_kwargs)
except Exception as e:
# Best-effort; mark as done even on error to unblock polling
print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
_STATE[req.metadata.rollout_id]["terminated"] = True
t = threading.Thread(target=_worker, daemon=True)
t.start()
@app.get("/status", response_model=StatusResponse)
def status(rollout_id: str) -> StatusResponse:
st = _STATE.get(rollout_id)
if not st:
raise HTTPException(status_code=404, detail="unknown rollout_id")
return StatusResponse(terminated=bool(st.get("terminated", False)))
def main():
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
port = int(os.getenv("REMOTE_SERVER_PORT", "3000"))
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()