Skip to content

Commit 903584b

Browse files
committed
pipelined
1 parent 0a9a9a4 commit 903584b

File tree

5 files changed

+75
-84
lines changed

5 files changed

+75
-84
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -402,33 +402,15 @@ async def _execute_groupwise_eval_with_semaphore(
402402
return results
403403

404404
if mode == "pointwise":
405+
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
405406
pointwise_tasks: list[asyncio.Task[EvaluationRow]] = []
406-
407-
if rollout_processor.supports_pipelining:
408-
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
409-
# Use wrapper that handles retry logic internally
410-
async for row in rollout_processor_with_retry(
411-
rollout_processor, fresh_dataset, config, run_idx
412-
):
413-
pointwise_tasks.append(
414-
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
415-
)
416-
else:
417-
# Non-pipelined mode: collect all rollout results first, then postprocess, then evaluate
418-
collected_rollout_rows: list[EvaluationRow] = []
419-
async for row in rollout_processor_with_retry(
420-
rollout_processor, fresh_dataset, config, run_idx
421-
):
422-
collected_rollout_rows.append(row)
423-
424-
# Post-process rollout results to get evaluation inputs
425-
eval_input_rows = rollout_processor.postprocess(collected_rollout_rows)
426-
427-
# Now evaluate all the post-processed rows
428-
for row in eval_input_rows:
429-
pointwise_tasks.append(
430-
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
431-
)
407+
# Use wrapper that handles retry logic internally
408+
async for row in rollout_processor_with_retry(
409+
rollout_processor, fresh_dataset, config, run_idx
410+
):
411+
pointwise_tasks.append(
412+
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
413+
)
432414

433415
# Run evaluation tasks with progress bar
434416
results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx)
@@ -471,13 +453,9 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
471453
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
472454
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
473455
rollout_results = await asyncio.gather(*tasks)
474-
475-
# Flatten and postprocess all rollout results
476-
all_rollout_rows = [row for result in rollout_results for row in result]
477-
processed_rows = rollout_processor.postprocess(all_rollout_rows)
478-
479-
for row in processed_rows:
480-
row_groups[row.input_metadata.row_id].append(row)
456+
for result in rollout_results:
457+
for row in result:
458+
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]
481459
tasks = []
482460
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
483461
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
@@ -494,8 +472,6 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
494472
):
495473
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
496474

497-
input_dataset = rollout_processor.postprocess(input_dataset)
498-
499475
# NOTE: we will still evaluate errored rows (give users control over this)
500476
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
501477
results = await execute_pytest(

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class RemoteRolloutProcessor(RolloutProcessor):
3636
Returns: {"terminated": bool, "info": {...}?}
3737
"""
3838

39-
supports_pipelining: bool = False # Remote rollout processor cannot pipeline - must wait for all rollouts to complete before fetching results.
40-
4139
def __init__(
4240
self,
4341
*,
@@ -156,27 +154,30 @@ def _get_status() -> Dict[str, Any]:
156154

157155
# Update duration, regardless of termination
158156
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
159-
return row
160157

161-
for r in rows:
162-
tasks.append(asyncio.create_task(_process_row(r)))
158+
if row.execution_metadata.rollout_id is None:
159+
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
163160

164-
return tasks
161+
data_loader = self._output_data_loader(row.execution_metadata.rollout_id)
162+
163+
def _load_data():
164+
return data_loader.load()
165+
166+
results = await asyncio.to_thread(_load_data)
165167

166-
def postprocess(self, finished_rollout_rows: List[EvaluationRow]) -> List[EvaluationRow]:
167-
"""Fetch actual evaluation rows from Langfuse using the output_data_loader."""
168-
invocation_id = finished_rollout_rows[0].execution_metadata.invocation_id
169-
if not invocation_id:
170-
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
168+
output_rows: List[EvaluationRow] = [row for result in results for row in result.rows]
171169

172-
data_loader = self._output_data_loader(invocation_id)
170+
assert len(output_rows) == 1, "Dataloader used for RemoteRolloutProcessor should have exactly one row"
173171

174-
results = data_loader.load()
175-
output_rows: List[EvaluationRow] = []
176-
for result in results:
177-
output_rows.extend(result.rows)
172+
langfuse_row = output_rows[0]
173+
langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params
178174

179-
return output_rows
175+
return langfuse_row
176+
177+
for r in rows:
178+
tasks.append(asyncio.create_task(_process_row(r)))
179+
180+
return tasks
180181

181182
def cleanup(self) -> None:
182183
return None

eval_protocol/pytest/rollout_processor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,11 @@ class RolloutProcessor(ABC):
1010
Abstract base class for all rollout processor strategies.
1111
"""
1212

13-
supports_pipelining: bool = (
14-
True # Whether this processor supports pipelined evaluation (evaluate rows as rollouts complete)
15-
)
16-
1713
@abstractmethod
1814
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
1915
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
2016
pass
2117

22-
def postprocess(self, finished_rollout_rows: list[EvaluationRow]) -> list[EvaluationRow]:
23-
"""Post-process rollout results to produce evaluation inputs. Only available for processors that return False from supports_pipelining."""
24-
return finished_rollout_rows
25-
2618
def cleanup(self) -> None:
2719
"""Cleanup resources. Override in subclasses if cleanup is needed."""
2820
pass

eval_protocol/quickstart/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,28 @@ def assistant_to_ground_truth(data: list[EvaluationRow]) -> list[EvaluationRow]:
186186
return processed_rows
187187

188188

189+
def filter_longest_conversation(data: list[EvaluationRow]) -> list[EvaluationRow]:
190+
"""
191+
Filter out the longest conversation from a list of evaluation rows that share the same rollout_id.
192+
193+
Args:
194+
data: List of EvaluationRow objects that share the same rollout_id
195+
196+
Returns:
197+
List containing only the EvaluationRow with the most messages (longest conversation)
198+
"""
199+
if not data:
200+
return data
201+
202+
if len(data) == 1:
203+
return data
204+
205+
# Find the row with the most messages (longest conversation)
206+
longest_row = max(data, key=lambda row: len(row.messages))
207+
208+
return [longest_row]
209+
210+
189211
async def run_single_judgment(
190212
question_text: str, answer_a: str, answer_b: str, tools, judge_config, client
191213
) -> Optional[Dict[str, Any]]:

tests/chinook/langfuse/test_remote_langfuse_chinook.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,37 @@
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
1515
from eval_protocol.adapters.langfuse import create_langfuse_adapter
16+
from eval_protocol.quickstart.utils import filter_longest_conversation
1617

17-
INVOCATION_ID = ""
18-
ASSERTION_EXECUTED = False
18+
ROLLOUT_IDS = set()
1919

2020

2121
@pytest.fixture(autouse=True)
22-
def check_assertion_executed():
23-
"""Ensure the test actually executed the Langfuse validation"""
24-
global ASSERTION_EXECUTED
25-
ASSERTION_EXECUTED = False # Reset before test
22+
def check_rollout_coverage():
23+
"""Ensure we processed all expected rollout_ids"""
24+
global ROLLOUT_IDS
25+
ROLLOUT_IDS.clear()
2626
yield
27-
# After test completes, verify the assertion was executed
28-
assert ASSERTION_EXECUTED, (
29-
"Test passed but never validated Langfuse data - check if output_data_loader returned empty results"
27+
28+
# Verify we've seen the expected number of rollout_ids after test is done
29+
expected_rollout_count = 3
30+
assert len(ROLLOUT_IDS) == expected_rollout_count, (
31+
f"Expected to see {expected_rollout_count} rollout_ids, but only saw {len(ROLLOUT_IDS)}: {ROLLOUT_IDS}"
3032
)
3133

3234

33-
def fetch_trajectories(invocation_id: str) -> List[EvaluationRow]:
34-
global INVOCATION_ID # This is just to verify the invocation_id is set correctly in the test
35-
INVOCATION_ID = invocation_id
35+
def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]:
36+
global ROLLOUT_IDS # Track all rollout_ids we've seen
37+
ROLLOUT_IDS.add(rollout_id)
3638

3739
adapter = create_langfuse_adapter()
38-
return adapter.get_evaluation_rows(tags=[f"invocation_id:{invocation_id}"])
40+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"])
3941

4042

41-
def create_output_data_loader(invocation_id: str) -> DynamicDataLoader:
42-
return DynamicDataLoader(generators=[lambda: fetch_trajectories(invocation_id)])
43+
def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader:
44+
return DynamicDataLoader(
45+
generators=[lambda: fetch_langfuse_traces(rollout_id)], preprocess_fn=filter_longest_conversation
46+
)
4347

4448

4549
def _start_remote_server():
@@ -87,7 +91,7 @@ def remote_langfuse_data_generator() -> List[EvaluationRow]:
8791

8892
# Minimal single-user-turn message to trigger a response
8993
row = EvaluationRow(messages=[Message(role="user", content="Hello there! Please say hi back.")])
90-
return [row]
94+
return [row, row, row]
9195

9296

9397
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
@@ -100,7 +104,7 @@ def remote_langfuse_data_generator() -> List[EvaluationRow]:
100104
remote_base_url="http://127.0.0.1:7077",
101105
num_turns=2,
102106
timeout_seconds=30,
103-
output_data_loader=create_output_data_loader,
107+
output_data_loader=langfuse_output_data_loader,
104108
),
105109
)
106110
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
@@ -110,13 +114,9 @@ async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> Evaluati
110114
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
111115
- fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found
112116
"""
113-
global ASSERTION_EXECUTED
114-
115-
# Sanity check: row should have an invocation_id since it came from Langfuse via output_data_loader
116117
assert row.messages[0].content == "Hello there! Please say hi back.", "Row should have correct message content"
117-
assert row.execution_metadata.invocation_id == INVOCATION_ID, "Row should have correct invocation_id set"
118-
119-
ASSERTION_EXECUTED = True
120-
print(f"✅ Successfully received row from Langfuse with invocation_id: {row.execution_metadata.invocation_id}")
118+
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
119+
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
120+
)
121121

122122
return row

0 commit comments

Comments
 (0)