Skip to content

Commit 21db77c

Browse files
committed
add types
1 parent 5125e5a commit 21db77c

File tree

5 files changed

+96
-120
lines changed

5 files changed

+96
-120
lines changed

eval_protocol/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
except ImportError:
6363
LangSmithAdapter = None
6464

65+
# Remote server types
66+
from .types.remote_rollout_processor import (
67+
InitRequest,
68+
RolloutMetadata,
69+
StatusResponse,
70+
create_langfuse_config_tags,
71+
)
6572

6673
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6774

@@ -110,6 +117,11 @@
110117
# Submodules
111118
"rewards",
112119
"mcp",
120+
# Remote server types
121+
"InitRequest",
122+
"RolloutMetadata",
123+
"StatusResponse",
124+
"create_langfuse_config_tags",
113125
]
114126

115127
from . import _version

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class RemoteRolloutProcessor(RolloutProcessor):
2828
"run_id": str | null,
2929
"row_id": str | null
3030
},
31-
"num_turns": int
3231
}
3332
Returns: {"ok": true}
3433
@@ -40,15 +39,13 @@ def __init__(
4039
self,
4140
*,
4241
remote_base_url: Optional[str] = None,
43-
num_turns: int = 2,
4442
poll_interval: float = 1.0,
4543
timeout_seconds: float = 120.0,
4644
output_data_loader: Callable[[str], DynamicDataLoader],
4745
):
4846
# Prefer constructor-provided configuration. These can be overridden via
4947
# config.kwargs at call time for backward compatibility.
5048
self._remote_base_url = remote_base_url
51-
self._num_turns = num_turns
5249
self._poll_interval = poll_interval
5350
self._timeout_seconds = timeout_seconds
5451
self._output_data_loader = output_data_loader
@@ -58,15 +55,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
5855

5956
# Start with constructor values
6057
remote_base_url: Optional[str] = self._remote_base_url
61-
num_turns: int = self._num_turns
6258
poll_interval: float = self._poll_interval
6359
timeout_seconds: float = self._timeout_seconds
6460

6561
# Backward compatibility: allow overrides via config.kwargs
6662
if config.kwargs:
6763
if remote_base_url is None:
6864
remote_base_url = config.kwargs.get("remote_base_url", remote_base_url)
69-
num_turns = int(config.kwargs.get("num_turns", num_turns))
7065
poll_interval = float(config.kwargs.get("poll_interval", poll_interval))
7166
timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds))
7267

@@ -121,7 +116,6 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
121116
"messages": clean_messages,
122117
"tools": row.tools,
123118
"metadata": meta,
124-
"num_turns": num_turns,
125119
}
126120

127121
# Fire-and-poll
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Request and response models for remote rollout processor servers.
3+
"""
4+
5+
from typing import Any, Dict, List, Optional
6+
from pydantic import BaseModel, Field
7+
from eval_protocol.models import Message
8+
9+
10+
class RolloutMetadata(BaseModel):
11+
"""Metadata for rollout execution."""
12+
13+
invocation_id: str
14+
experiment_id: str
15+
rollout_id: str
16+
run_id: str
17+
row_id: str
18+
19+
20+
class InitRequest(BaseModel):
21+
"""Request model for POST /init endpoint."""
22+
23+
rollout_id: str
24+
model: str
25+
messages: List[Message] = Field(min_length=1)
26+
tools: Optional[List[Dict[str, Any]]] = None
27+
metadata: RolloutMetadata
28+
29+
30+
class StatusResponse(BaseModel):
31+
"""Response model for GET /status endpoint."""
32+
33+
terminated: bool
34+
35+
36+
def create_langfuse_config_tags(init_request: InitRequest) -> List[str]:
37+
"""Create Langfuse tags from InitRequest metadata."""
38+
metadata = init_request.metadata
39+
return [
40+
f"invocation_id:{metadata.invocation_id}",
41+
f"experiment_id:{metadata.experiment_id}",
42+
f"rollout_id:{metadata.rollout_id}",
43+
f"run_id:{metadata.run_id}",
44+
f"row_id:{metadata.row_id}",
45+
]

tests/chinook/langfuse/remote_server.py

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,32 @@
11
import os
22
import threading
3-
from typing import Any, Dict
3+
from typing import Any, Dict, List
44

55
import uvicorn
66
from fastapi import FastAPI, HTTPException
7-
from pydantic import BaseModel
87
from langfuse.openai import openai # pyright: ignore[reportPrivateImportUsage]
98

10-
11-
app = FastAPI()
9+
from eval_protocol.types.remote_rollout_processor import (
10+
InitRequest,
11+
StatusResponse,
12+
create_langfuse_config_tags,
13+
)
14+
from eval_protocol.models import Message
1215

1316

14-
class InitRequest(BaseModel):
15-
rollout_id: str
16-
model: str
17-
messages: list[dict]
18-
tools: list[dict] | None = None
19-
metadata: dict
20-
num_turns: int = 2
17+
app = FastAPI()
2118

2219

2320
_STATE: Dict[str, Dict[str, Any]] = {}
2421

25-
2622
ALLOWED_MESSAGE_FIELDS = {"role", "content", "tool_calls", "tool_call_id", "name"}
2723

2824

29-
def _clean_messages_for_api(messages: list[dict]) -> list[dict]:
25+
def _clean_messages_for_api(messages: List[Message]) -> list[dict]:
3026
cleaned: list[dict] = []
3127
for msg in messages:
32-
if not isinstance(msg, dict):
33-
continue
34-
cm = {k: v for k, v in msg.items() if k in ALLOWED_MESSAGE_FIELDS and v is not None}
28+
msg_dict = msg.model_dump()
29+
cm = {k: v for k, v in msg_dict.items() if k in ALLOWED_MESSAGE_FIELDS and v is not None}
3530
# Some providers dislike empty content on assistant messages; keep if present
3631
cleaned.append(cm)
3732
return cleaned
@@ -42,53 +37,25 @@ def init(req: InitRequest):
4237
# Persist state
4338
_STATE[req.rollout_id] = {"terminated": False}
4439

45-
# Kick off worker thread that runs multi-turn chat via Langfuse OpenAI integration
40+
# Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration
4641
def _worker():
4742
try:
48-
# Prepare tags for Langfuse filtering
49-
metadata = {
50-
"langfuse_tags": [
51-
f"invocation_id:{req.metadata.get('invocation_id')}",
52-
f"experiment_id:{req.metadata.get('experiment_id')}",
53-
f"rollout_id:{req.metadata.get('rollout_id')}",
54-
f"run_id:{req.metadata.get('run_id')}",
55-
f"row_id:{req.metadata.get('row_id')}",
56-
]
43+
metadata = {"langfuse_tags": create_langfuse_config_tags(req)}
44+
45+
completion_kwargs = {
46+
"model": req.model,
47+
"messages": _clean_messages_for_api(req.messages),
48+
"metadata": metadata,
5749
}
5850

59-
messages = req.messages
60-
61-
# Simulate N-1 assistant turns (single-shot or simple echo)
62-
for _ in range(max(1, req.num_turns)):
63-
completion_kwargs = {
64-
"model": req.model,
65-
"messages": _clean_messages_for_api(messages),
66-
"metadata": metadata,
67-
}
68-
69-
if req.tools:
70-
completion_kwargs["tools"] = req.tools
71-
72-
completion = openai.chat.completions.create(**completion_kwargs)
73-
assistant_message = completion.choices[0].message
74-
75-
# Convert to dict format for next turn
76-
assistant_dict = {"role": "assistant", "content": assistant_message.content}
77-
if assistant_message.tool_calls:
78-
assistant_dict["tool_calls"] = [
79-
{
80-
"id": tc.id,
81-
"type": tc.type,
82-
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
83-
}
84-
for tc in assistant_message.tool_calls
85-
]
86-
87-
# Append assistant for next turn
88-
messages = messages + [assistant_dict]
89-
90-
except Exception:
51+
if req.tools:
52+
completion_kwargs["tools"] = req.tools
53+
54+
completion = openai.chat.completions.create(**completion_kwargs)
55+
56+
except Exception as e:
9157
# Best-effort; mark as done even on error to unblock polling
58+
print(f"❌ Error in rollout {req.rollout_id}: {e}")
9259
pass
9360
finally:
9461
_STATE[req.rollout_id]["terminated"] = True
@@ -98,12 +65,12 @@ def _worker():
9865
return {"ok": True}
9966

10067

101-
@app.get("/status")
102-
def status(rollout_id: str):
68+
@app.get("/status", response_model=StatusResponse)
69+
def status(rollout_id: str) -> StatusResponse:
10370
st = _STATE.get(rollout_id)
10471
if not st:
10572
raise HTTPException(status_code=404, detail="unknown rollout_id")
106-
return {"terminated": bool(st.get("terminated", False))}
73+
return StatusResponse(terminated=bool(st.get("terminated", False)))
10774

10875

10976
def main():

tests/chinook/langfuse/test_remote_langfuse_chinook.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
# MANUAL SERVER STARTUP REQUIRED:
2+
# Before running this test, start the remote server manually:
3+
# cd /Users/derekxu/Documents/code/python-sdk
4+
# python -m tests.chinook.langfuse.remote_server
5+
#
6+
# The server should be running on http://127.0.0.1:7077
7+
18
import os
2-
import multiprocessing
3-
import time
4-
from datetime import datetime, timedelta
59
from typing import List
6-
import atexit
710

811
import pytest
9-
import requests
1012

1113
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
1214
from eval_protocol.models import EvaluationRow, Message
@@ -33,7 +35,7 @@ def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]:
3335
ROLLOUT_IDS.add(rollout_id)
3436

3537
adapter = create_langfuse_adapter()
36-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"])
38+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
3739

3840

3941
def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader:
@@ -42,51 +44,8 @@ def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader:
4244
)
4345

4446

45-
def _start_remote_server():
46-
# Starts FastAPI server defined in remote_server.py using absolute import
47-
import importlib
48-
49-
os.environ.setdefault("REMOTE_SERVER_HOST", "127.0.0.1")
50-
os.environ.setdefault("REMOTE_SERVER_PORT", "7077")
51-
mod = importlib.import_module("tests.chinook.langfuse.remote_server")
52-
mod.main()
53-
54-
55-
def _ensure_server_running():
56-
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
57-
port = int(os.getenv("REMOTE_SERVER_PORT", "7077"))
58-
base_url = f"http://{host}:{port}"
59-
60-
def _is_up() -> bool:
61-
try:
62-
r = requests.get(f"{base_url}/status", params={"rollout_id": "ping"}, timeout=1.0)
63-
return r.status_code in (200, 404)
64-
except Exception:
65-
return False
66-
67-
if _is_up():
68-
return None
69-
70-
# Launch in a background process
71-
proc = multiprocessing.Process(target=_start_remote_server, daemon=True)
72-
proc.start()
73-
74-
# Poll for readiness up to 10s
75-
deadline = time.time() + 10
76-
while time.time() < deadline:
77-
if _is_up():
78-
break
79-
time.sleep(0.5)
80-
return proc
81-
82-
8347
def remote_langfuse_data_generator() -> List[EvaluationRow]:
84-
# Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body)
85-
_SERVER_PROC = _ensure_server_running()
86-
atexit.register(lambda: (_SERVER_PROC and _SERVER_PROC.is_alive() and _SERVER_PROC.terminate()))
87-
88-
# Minimal single-user-turn message to trigger a response
89-
row = EvaluationRow(messages=[Message(role="user", content="Hello there! Please say hi back.")])
48+
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
9049
return [row, row, row]
9150

9251

@@ -98,19 +57,18 @@ def remote_langfuse_data_generator() -> List[EvaluationRow]:
9857
),
9958
rollout_processor=RemoteRolloutProcessor(
10059
remote_base_url="http://127.0.0.1:7077",
101-
num_turns=2,
10260
timeout_seconds=30,
10361
output_data_loader=langfuse_output_data_loader,
10462
),
10563
)
10664
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
10765
"""
10866
End-to-end test:
109-
- remote server started at import time
67+
- REQUIRES MANUAL SERVER STARTUP: python -m tests.chinook.langfuse.remote_server
11068
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
11169
- fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found
11270
"""
113-
assert row.messages[0].content == "Hello there! Please say hi back.", "Row should have correct message content"
71+
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
11472
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
11573

11674
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (

0 commit comments

Comments
 (0)