Skip to content

Commit 9be5530

Browse files
committed
clean up
1 parent 5657586 commit 9be5530

File tree

3 files changed

+50
-82
lines changed

3 files changed

+50
-82
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,34 @@ class RemoteRolloutProcessor(RolloutProcessor):
3535
Returns: {"terminated": bool, "info": {...}?}
3636
"""
3737

38-
def __init__(self):
39-
pass
38+
def __init__(
39+
self,
40+
*,
41+
remote_base_url: Optional[str] = None,
42+
num_turns: int = 2,
43+
poll_interval: float = 1.0,
44+
timeout_seconds: float = 120.0,
45+
):
46+
# Prefer constructor-provided configuration. These can be overridden via
47+
# config.kwargs at call time for backward compatibility.
48+
self._remote_base_url = remote_base_url
49+
self._num_turns = num_turns
50+
self._poll_interval = poll_interval
51+
self._timeout_seconds = timeout_seconds
4052

4153
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
4254
tasks: List[asyncio.Task[EvaluationRow]] = []
4355

44-
remote_base_url: Optional[str] = None
45-
num_turns: int = 2
46-
poll_interval: float = 1.0
47-
timeout_seconds: float = 120.0
56+
# Start with constructor values
57+
remote_base_url: Optional[str] = self._remote_base_url
58+
num_turns: int = self._num_turns
59+
poll_interval: float = self._poll_interval
60+
timeout_seconds: float = self._timeout_seconds
4861

62+
# Backward compatibility: allow overrides via config.kwargs
4963
if config.kwargs:
50-
remote_base_url = config.kwargs.get("remote_base_url")
64+
if remote_base_url is None:
65+
remote_base_url = config.kwargs.get("remote_base_url", remote_base_url)
5166
num_turns = int(config.kwargs.get("num_turns", num_turns))
5267
poll_interval = float(config.kwargs.get("poll_interval", poll_interval))
5368
timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds))

tests/chinook/langfuse/remote_server.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -45,37 +45,6 @@ def init(req: InitRequest):
4545
# Kick off worker thread that runs multi-turn chat via LiteLLM proxy
4646
def _worker():
4747
try:
48-
# Try to set up Langfuse trace to guarantee observability, independent of proxy wiring
49-
langfuse = None
50-
trace = None
51-
try:
52-
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
53-
54-
langfuse = get_client()
55-
id_tags = []
56-
try:
57-
id_tags = [
58-
f"inv:{req.metadata.get('invocation_id')}",
59-
f"exp:{req.metadata.get('experiment_id')}",
60-
f"rollout:{req.metadata.get('rollout_id')}",
61-
]
62-
except Exception:
63-
id_tags = []
64-
trace = langfuse.api.trace.create(
65-
name="remote_chinook_rollout",
66-
metadata=req.metadata,
67-
requester_metadata=req.metadata,
68-
tags=["chinook_remote", "chinook_sql", *[t for t in id_tags if t]],
69-
input={
70-
"messages": _clean_messages_for_api(req.messages),
71-
"tools": req.tools,
72-
"metadata": req.metadata,
73-
},
74-
)
75-
except Exception:
76-
langfuse = None
77-
trace = None
78-
7948
base_url = os.getenv(
8049
"LITELLM_BASE_URL",
8150
"https://litellm-cloud-proxy-prod-644257448872.us-central1.run.app",
@@ -110,48 +79,13 @@ def _worker():
11079
r.raise_for_status()
11180
data = r.json()
11281
assistant = data.get("choices", [{}])[0].get("message", {})
113-
# Optionally record a generation on Langfuse
114-
try:
115-
if langfuse and trace and getattr(langfuse.api, "generation", None):
116-
langfuse.api.generation.create(
117-
trace_id=trace.id,
118-
name="assistant",
119-
input={"messages": _clean_messages_for_api(messages)},
120-
output=assistant,
121-
)
122-
except Exception:
123-
pass
12482
# Append assistant for next turn
12583
messages = messages + [assistant]
12684

127-
# Update final trace output for easier adapter extraction
128-
try:
129-
if langfuse and trace:
130-
langfuse.api.trace.update(
131-
id=trace.id,
132-
output={
133-
"messages": _clean_messages_for_api(messages),
134-
"metadata": req.metadata,
135-
},
136-
)
137-
except Exception:
138-
pass
139-
14085
except Exception:
14186
# Best-effort; mark as done even on error to unblock polling
14287
pass
14388
finally:
144-
try:
145-
if "langfuse" in locals() and langfuse is not None:
146-
# Ensure buffered telemetry is sent
147-
flush = getattr(langfuse, "flush", None)
148-
if callable(flush):
149-
flush()
150-
shutdown = getattr(langfuse, "shutdown", None)
151-
if callable(shutdown):
152-
shutdown()
153-
except Exception:
154-
pass
15589
_STATE[req.rollout_id]["terminated"] = True
15690

15791
t = threading.Thread(target=_worker, daemon=True)

tests/chinook/langfuse/test_remote_langfuse_chinook.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import atexit
77

88
import pytest
9+
import requests
910

1011
from eval_protocol.models import EvaluationRow, Message
1112
from eval_protocol.pytest import evaluation_test
@@ -23,17 +24,36 @@ def _start_remote_server():
2324

2425

2526
def _ensure_server_running():
27+
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
28+
port = int(os.getenv("REMOTE_SERVER_PORT", "7077"))
29+
base_url = f"http://{host}:{port}"
30+
31+
def _is_up() -> bool:
32+
try:
33+
r = requests.get(f"{base_url}/status", params={"rollout_id": "ping"}, timeout=1.0)
34+
return r.status_code in (200, 404)
35+
except Exception:
36+
return False
37+
38+
if _is_up():
39+
return None
40+
2641
# Launch in a background process
2742
proc = multiprocessing.Process(target=_start_remote_server, daemon=True)
2843
proc.start()
29-
# Give it a moment to boot
30-
time.sleep(1.5)
44+
45+
# Poll for readiness up to 10s
46+
deadline = time.time() + 10
47+
while time.time() < deadline:
48+
if _is_up():
49+
break
50+
time.sleep(0.5)
3151
return proc
3252

3353

3454
# Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body)
3555
_SERVER_PROC = _ensure_server_running()
36-
atexit.register(lambda: (_SERVER_PROC.terminate() if _SERVER_PROC.is_alive() else None))
56+
atexit.register(lambda: (_SERVER_PROC and _SERVER_PROC.is_alive() and _SERVER_PROC.terminate()))
3757

3858

3959
def _make_input_rows() -> List[EvaluationRow]:
@@ -47,12 +67,11 @@ def _make_input_rows() -> List[EvaluationRow]:
4767
@evaluation_test(
4868
input_rows=[_make_input_rows()],
4969
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}],
50-
rollout_processor=RemoteRolloutProcessor(),
51-
rollout_processor_kwargs={
52-
"remote_base_url": "http://127.0.0.1:7077",
53-
"num_turns": 2,
54-
"timeout_seconds": 30,
55-
},
70+
rollout_processor=RemoteRolloutProcessor(
71+
remote_base_url="http://127.0.0.1:7077",
72+
num_turns=2,
73+
timeout_seconds=30,
74+
),
5675
mode="pointwise",
5776
)
5877
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)