Skip to content

Commit e446e98

Browse files
authored
Merge branch 'main' into pytest-plugin-replacement
2 parents 6295fd9 + 5d7e5cb commit e446e98

File tree

14 files changed

+574
-55
lines changed

14 files changed

+574
-55
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def evaluation_test(
8181
aggregation_method: AggregationMethod = "mean",
8282
passed_threshold: EvaluationThreshold | float | EvaluationThresholdDict | None = None,
8383
num_runs: int = 1,
84+
filtered_row_ids: Sequence[str] | None = None,
8485
max_dataset_rows: int | None = None,
8586
mcp_config_path: str | None = None,
8687
max_concurrent_rollouts: int = 8,
@@ -148,6 +149,7 @@ def evaluation_test(
148149
Success rate must be above success, and if set, standard error must be below standard_error.
149150
Success rate +/- one standard_error is equivalent to 68% confidence interval.
150151
num_runs: Number of times to repeat the rollout and evaluations.
152+
filtered_row_ids: List of row_ids to filter for the evaluation. If provided, only the rows with the given row_ids will be evaluated.
151153
max_dataset_rows: Limit dataset to the first N rows.
152154
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
153155
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
@@ -272,6 +274,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
272274
results = data_loader.load()
273275
for result in results:
274276
data.extend(result.rows)
277+
# Apply max_dataset_rows limit to data from data loaders
278+
if max_dataset_rows is not None:
279+
data = data[:max_dataset_rows]
275280
elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
276281
ds_arg: list[str] = kwargs["dataset_path"]
277282
# Support either a single path or a list of paths; if a list is provided,
@@ -293,6 +298,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
293298
else:
294299
raise ValueError("No input dataset, input messages, or input rows provided")
295300

301+
if filtered_row_ids is not None:
302+
data = [row for row in data if row.input_metadata.row_id in filtered_row_ids]
303+
296304
"""
297305
data_loaders handles preprocess_fn internally so we want
298306
to specially handle data_loaders here so we don't double

eval_protocol/pytest/handle_persist_flow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
from typing import Any
99

10+
from eval_protocol.directory_utils import find_eval_protocol_dir
1011
from eval_protocol.models import EvaluationRow
1112
from eval_protocol.pytest.store_experiment_link import store_experiment_link
1213
import requests
@@ -25,7 +26,8 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
2526
if row.execution_metadata and row.execution_metadata.experiment_id:
2627
experiments[row.execution_metadata.experiment_id].append(row)
2728

28-
exp_dir = pathlib.Path("experiment_results")
29+
eval_protocol_dir = find_eval_protocol_dir()
30+
exp_dir = pathlib.Path(eval_protocol_dir) / "experiment_results"
2931
exp_dir.mkdir(parents=True, exist_ok=True)
3032

3133
# Create one JSONL file per experiment_id

eval_protocol/pytest/parameterize.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
7373
and decorator.func.value.attr == "mark"
7474
and decorator.func.attr == "parametrize"
7575
):
76+
# Validate argvalues if present
77+
_validate_parametrize_argvalues(decorator)
78+
7679
# Check positional arguments first (argnames is typically the first positional arg)
7780
if len(decorator.args) > 0:
7881
argnames_arg = decorator.args[0]
@@ -88,6 +91,90 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
8891
return False
8992

9093

94+
def _ast_dict_to_string(dict_node: ast.Dict) -> str:
95+
"""
96+
Convert an AST Dict node to its string representation.
97+
98+
Args:
99+
dict_node: AST node representing a dictionary
100+
101+
Returns:
102+
String representation of the dictionary
103+
"""
104+
if not dict_node.keys:
105+
return "{}"
106+
107+
pairs = []
108+
for key, value in zip(dict_node.keys, dict_node.values):
109+
if key is not None:
110+
key_str = _ast_node_to_string(key)
111+
value_str = _ast_node_to_string(value)
112+
pairs.append(f"{key_str}: {value_str}")
113+
114+
return "{" + ", ".join(pairs) + "}"
115+
116+
117+
def _ast_node_to_string(node: ast.expr) -> str:
118+
"""
119+
Convert an AST node to its string representation.
120+
121+
Args:
122+
node: AST node to convert
123+
124+
Returns:
125+
String representation of the node
126+
"""
127+
if isinstance(node, ast.Constant):
128+
if isinstance(node.value, str):
129+
return repr(node.value)
130+
else:
131+
return str(node.value)
132+
elif isinstance(node, ast.Name):
133+
return node.id
134+
elif isinstance(node, ast.Dict):
135+
return _ast_dict_to_string(node)
136+
elif isinstance(node, ast.List):
137+
elements = [_ast_node_to_string(elt) for elt in node.elts]
138+
return "[" + ", ".join(elements) + "]"
139+
elif isinstance(node, ast.Tuple):
140+
elements = [_ast_node_to_string(elt) for elt in node.elts]
141+
return "(" + ", ".join(elements) + ")"
142+
else:
143+
# For complex expressions, return a simplified representation
144+
return "<complex expression>"
145+
146+
147+
def _validate_parametrize_argvalues(decorator: ast.Call) -> None:
148+
"""
149+
Validate that pytest.mark.parametrize argvalues is a list/tuple, not a dict.
150+
151+
Args:
152+
decorator: AST node representing the pytest.mark.parametrize decorator call
153+
154+
Raises:
155+
ValueError: If argvalues is a dict instead of a list/tuple
156+
"""
157+
# Check positional arguments (argvalues is typically the second positional arg)
158+
if len(decorator.args) > 1:
159+
argvalues_arg = decorator.args[1]
160+
if isinstance(argvalues_arg, ast.Dict):
161+
dict_repr = _ast_dict_to_string(argvalues_arg)
162+
raise ValueError(
163+
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
164+
f"Use [{dict_repr}] instead of {dict_repr}."
165+
)
166+
167+
# Check keyword arguments for argvalues
168+
for keyword in decorator.keywords:
169+
if keyword.arg == "argvalues":
170+
if isinstance(keyword.value, ast.Dict):
171+
dict_repr = _ast_dict_to_string(keyword.value)
172+
raise ValueError(
173+
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
174+
f"Use [{dict_repr}] instead of {dict_repr}."
175+
)
176+
177+
91178
def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
92179
"""
93180
Check if an argnames AST node contains "completion_params".

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from eval_protocol.models import EvaluationRow, Status
88
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9+
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
910
from .rollout_processor import RolloutProcessor
1011
from .types import RolloutProcessorConfig
1112
import os
@@ -15,31 +16,14 @@ class RemoteRolloutProcessor(RolloutProcessor):
1516
"""
1617
Rollout processor that triggers a remote HTTP server to perform the rollout.
1718
18-
Expected remote API:
19-
- POST {remote_base_url}/init
20-
Body: {
21-
"rollout_id": str,
22-
"model": str,
23-
"messages": list[dict],
24-
"tools": list[dict] | null,
25-
"metadata": {
26-
"invocation_id": str,
27-
"experiment_id": str,
28-
"rollout_id": str,
29-
"run_id": str | null,
30-
"row_id": str | null
31-
},
32-
}
33-
Returns: {"ok": true}
34-
35-
- GET {remote_base_url}/status?rollout_id=...
36-
Returns: {"terminated": bool, "info": {...}?}
19+
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
3720
"""
3821

3922
def __init__(
4023
self,
4124
*,
4225
remote_base_url: Optional[str] = None,
26+
model_base_url: Optional[str] = None,
4327
poll_interval: float = 1.0,
4428
timeout_seconds: float = 120.0,
4529
output_data_loader: Callable[[str], DynamicDataLoader],
@@ -58,6 +42,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
5842

5943
# Start with constructor values
6044
remote_base_url: Optional[str] = self._remote_base_url
45+
model_base_url: Optional[str] = self._model_base_url
6146
poll_interval: float = self._poll_interval
6247
timeout_seconds: float = self._timeout_seconds
6348

@@ -74,14 +59,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
7459
async def _process_row(row: EvaluationRow) -> EvaluationRow:
7560
start_time = time.perf_counter()
7661

62+
if row.execution_metadata.invocation_id is None:
63+
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
64+
if row.execution_metadata.experiment_id is None:
65+
raise ValueError("Experiment ID is required in RemoteRolloutProcessor")
66+
if row.execution_metadata.rollout_id is None:
67+
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
68+
if row.execution_metadata.run_id is None:
69+
raise ValueError("Run ID is required in RemoteRolloutProcessor")
70+
if row.input_metadata.row_id is None:
71+
raise ValueError("Row ID is required in RemoteRolloutProcessor")
72+
7773
# Build request metadata and payload
78-
meta: Dict[str, Any] = {
79-
"invocation_id": row.execution_metadata.invocation_id,
80-
"experiment_id": row.execution_metadata.experiment_id,
81-
"rollout_id": row.execution_metadata.rollout_id,
82-
"run_id": row.execution_metadata.run_id,
83-
"row_id": row.input_metadata.row_id,
84-
}
74+
meta: RolloutMetadata = RolloutMetadata(
75+
invocation_id=row.execution_metadata.invocation_id,
76+
experiment_id=row.execution_metadata.experiment_id,
77+
rollout_id=row.execution_metadata.rollout_id,
78+
run_id=row.execution_metadata.run_id,
79+
row_id=row.input_metadata.row_id,
80+
)
8581

8682
model: Optional[str] = None
8783
if row.input_metadata and row.input_metadata.completion_params:
@@ -113,19 +109,33 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
113109
}
114110
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
115111

116-
init_payload: Dict[str, Any] = {
117-
"rollout_id": row.execution_metadata.rollout_id,
118-
"model": model,
119-
"messages": clean_messages,
120-
"tools": row.tools,
121-
"metadata": meta,
122-
}
112+
if row.execution_metadata.rollout_id is None:
113+
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
114+
115+
init_payload: InitRequest = InitRequest(
116+
model=model,
117+
messages=clean_messages,
118+
tools=row.tools,
119+
metadata=meta,
120+
model_base_url=model_base_url,
121+
)
123122

124123
# Fire-and-poll
125124
def _post_init() -> None:
126125
url = f"{remote_base_url}/init"
127-
r = requests.post(url, json=init_payload, timeout=30)
128-
r.raise_for_status()
126+
try:
127+
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
128+
r.raise_for_status()
129+
except requests.exceptions.Timeout:
130+
raise TimeoutError(
131+
"The /init endpoint timed out after 30 seconds. "
132+
"CRITICAL: The /init endpoint must return immediately (within 30s) and NOT block on rollout execution. "
133+
"Your remote server should:\n"
134+
"1. Accept the /init request and return a 200 response immediately\n"
135+
"2. Process the actual rollout asynchronously in the background\n"
136+
"3. Use the /status endpoint to report progress\n"
137+
"For Python/Node.js: Start a separate process per rollout to avoid blocking the /init response."
138+
)
129139

130140
await asyncio.to_thread(_post_init)
131141

@@ -147,7 +157,13 @@ def _get_status() -> Dict[str, Any]:
147157
except Exception:
148158
# transient errors; continue polling
149159
pass
160+
150161
await asyncio.sleep(poll_interval)
162+
else:
163+
# Loop completed without breaking, which means we timed out
164+
row.rollout_status = Status.rollout_error(
165+
f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds"
166+
)
151167

152168
# Update duration, regardless of termination
153169
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
@@ -170,14 +186,28 @@ def _load_data():
170186
elif len(output_rows) == 1: # Return the Langfuse row
171187
langfuse_row = output_rows[0]
172188
langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params
189+
# merge dataset_info dicts on input_metadata
190+
if langfuse_row.input_metadata.dataset_info and row.input_metadata.dataset_info:
191+
langfuse_row.input_metadata.dataset_info = {
192+
**row.input_metadata.dataset_info,
193+
**langfuse_row.input_metadata.dataset_info,
194+
}
195+
elif row.input_metadata.dataset_info:
196+
langfuse_row.input_metadata.dataset_info = row.input_metadata.dataset_info
173197
langfuse_row.eval_metadata = row.eval_metadata
198+
langfuse_row.ground_truth = row.ground_truth
174199
return langfuse_row
175200
else:
176201
raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")
177202

178-
for r in rows:
179-
tasks.append(asyncio.create_task(_process_row(r)))
203+
semaphore = config.semaphore
204+
205+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
206+
async with semaphore:
207+
result = await _process_row(r)
208+
return result
180209

210+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
181211
return tasks
182212

183213
def cleanup(self) -> None:

eval_protocol/types/remote_rollout_processor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Any, Dict, List, Optional
66
from pydantic import BaseModel, Field
7-
from eval_protocol.models import Message
7+
from eval_protocol.models import Message, Status
88

99

1010
class RolloutMetadata(BaseModel):
@@ -20,10 +20,17 @@ class RolloutMetadata(BaseModel):
2020
class InitRequest(BaseModel):
2121
"""Request model for POST /init endpoint."""
2222

23-
rollout_id: str
2423
model: str
25-
messages: List[Message] = Field(min_length=1)
24+
messages: Optional[List[Message]] = None
2625
tools: Optional[List[Dict[str, Any]]] = None
26+
27+
model_base_url: Optional[str] = None
28+
"""
29+
A Base URL that the remote server can use to make LLM calls. This is useful
30+
to configure on the eval-protocol side for flexibility in
31+
development/traning.
32+
"""
33+
2734
metadata: RolloutMetadata
2835

2936

@@ -33,6 +40,12 @@ class StatusResponse(BaseModel):
3340
terminated: bool
3441
info: Optional[Dict[str, Any]] = None
3542

43+
status: Optional[Status] = None
44+
"""
45+
Optional status indicator for the rollout to be used by eval-protocol. This
46+
is useful to distinguish between successful and failed rollouts.
47+
"""
48+
3649

3750
def create_langfuse_config_tags(init_request: InitRequest) -> List[str]:
3851
"""Create Langfuse tags from InitRequest metadata."""

tests/chinook/langgraph/test_langgraph_chinook_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def build_graph_kwargs(cp: CompletionParams) -> Dict[str, Any]:
1919

2020

2121
@pytest.mark.asyncio
22+
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally since its not stable")
2223
@pytest.mark.skipif(os.getenv("FIREWORKS_API_KEY") in (None, ""), reason="FIREWORKS_API_KEY not set")
2324
@evaluation_test(
2425
input_messages=[[[Message(role="user", content="Use tools to count total tracks in the database.")]]],

0 commit comments

Comments
 (0)