-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathremote_server.py
More file actions
122 lines (95 loc) · 4.33 KB
/
remote_server.py
File metadata and controls
122 lines (95 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import random
import threading
import argparse
import uvicorn
from fastapi import FastAPI
from openai import OpenAI
import logging
from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter
app = FastAPI()
# Attach Fireworks tracing handler to root logger
fireworks_handler = FireworksTracingHttpHandler()
logging.getLogger().addHandler(fireworks_handler)
force_early_error_message = None
@app.post("/init")
def init(req: InitRequest):
# Attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
# Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration
def _worker():
try:
if not req.messages:
raise ValueError("messages is required")
model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")
# Convert Eval Protocol Message objects into OpenAI-compatible dicts,
# excluding any None fields (Fireworks rejects extra keys even when null).
messages_payload = []
for m in req.messages:
if hasattr(m, "dump_mdoel_for_chat_completion_request"):
md = m.dump_mdoel_for_chat_completion_request() # type: ignore[attr-defined]
elif hasattr(m, "model_dump"):
md = m.model_dump(exclude_none=True) # type: ignore[call-arg]
elif isinstance(m, dict):
md = {k: v for k, v in m.items() if v is not None}
else:
md = {"role": getattr(m, "role", None), "content": getattr(m, "content", None)}
md = {k: v for k, v in md.items() if v is not None}
messages_payload.append(md)
# Spread all completion_params (model, temperature, max_tokens, etc.)
completion_kwargs = {"messages": messages_payload, **req.completion_params}
if req.tools:
completion_kwargs["tools"] = req.tools
logger.info(f"Final completion_kwargs: {completion_kwargs}")
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
logger.info(f"Sending completion request to model {model}")
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")
# If force_early_error is set via command-line arg, log the error and return early
if force_early_error_message:
logger.error(
force_early_error_message,
extra={"status": Status.rollout_error(force_early_error_message)},
)
raise RuntimeError(force_early_error_message)
except Exception as e:
# Best-effort; mark as done even on error to unblock polling
logger.error(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
if not force_early_error_message:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)
t = threading.Thread(target=_worker, daemon=True)
t.start()
def main():
global force_early_error_message
parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol")
parser.add_argument(
"--host",
type=str,
default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"),
help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("REMOTE_SERVER_PORT", "3000")),
help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)",
)
parser.add_argument(
"--force-early-error",
type=str,
default=None,
help="If set, /init will immediately return after logging a rollout_error with this message",
)
args = parser.parse_args()
force_early_error_message = args.force_early_error
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()