Skip to content

Commit 886fa38

Browse files
committed
Add in data loader and pulling from tags
1 parent 671c882 commit 886fa38

File tree

7 files changed

+151
-110
lines changed

7 files changed

+151
-110
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING, cast
1313

1414
from langfuse.api.resources.commons.types.observations_view import ObservationsView
15-
from eval_protocol.models import EvaluationRow, InputMetadata, Message
15+
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1616
from .base import BaseAdapter
1717
from .utils import extract_messages_from_data
1818

@@ -82,14 +82,44 @@ def convert_trace_to_evaluation_row(
8282
if not messages:
8383
return None
8484

85+
execution_metadata = ExecutionMetadata()
86+
row_id = None
87+
88+
if trace.observations:
89+
for obs in trace.observations:
90+
if obs.metadata and "requester_metadata" in obs.metadata:
91+
req_meta = obs.metadata["requester_metadata"]
92+
if isinstance(req_meta, dict):
93+
execution_metadata.invocation_id = req_meta.get("invocation_id")
94+
execution_metadata.experiment_id = req_meta.get("experiment_id")
95+
execution_metadata.rollout_id = req_meta.get("rollout_id")
96+
execution_metadata.run_id = req_meta.get("run_id")
97+
row_id = req_meta.get("row_id")
98+
break # Only need to get first observation
99+
100+
if trace.tags:
101+
for tag in trace.tags:
102+
if tag.startswith("invocation_id:") and not execution_metadata.invocation_id:
103+
execution_metadata.invocation_id = tag.split(":", 1)[1]
104+
elif tag.startswith("experiment_id:") and not execution_metadata.experiment_id:
105+
execution_metadata.experiment_id = tag.split(":", 1)[1]
106+
elif tag.startswith("rollout_id:") and not execution_metadata.rollout_id:
107+
execution_metadata.rollout_id = tag.split(":", 1)[1]
108+
elif tag.startswith("run_id:") and not execution_metadata.run_id:
109+
execution_metadata.run_id = tag.split(":", 1)[1]
110+
elif tag.startswith("row_id:") and not row_id:
111+
row_id = tag.split(":", 1)[1]
112+
85113
return EvaluationRow(
86114
messages=messages,
87115
tools=tools,
88116
input_metadata=InputMetadata(
117+
row_id=row_id,
89118
session_data={
90119
"langfuse_trace_id": trace.id, # Store the trace ID here
91-
}
120+
},
92121
),
122+
execution_metadata=execution_metadata,
93123
)
94124

95125
except (AttributeError, ValueError, KeyError) as e:
@@ -332,16 +362,18 @@ def get_evaluation_rows(
332362
to_timestamp=to_timestamp,
333363
order_by="timestamp.desc",
334364
)
365+
366+
# If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse
367+
if traces and hasattr(traces, "meta") and traces.meta.total_items == 0 and page == 1:
368+
raise Exception("Empty results - indexing delay")
369+
335370
break
336371
except Exception as e:
337372
list_retries += 1
338-
if "429" in str(e) and list_retries < max_retries:
373+
if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)):
339374
sleep_time = 2**list_retries # Exponential backoff
340375
logger.warning(
341-
"Rate limit hit on trace.list(), retrying in %ds (attempt %d/%d)",
342-
sleep_time,
343-
list_retries,
344-
max_retries,
376+
"Retrying in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e)
345377
)
346378
time.sleep(sleep_time)
347379
else:

eval_protocol/pytest/evaluation_test.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,15 +402,33 @@ async def _execute_groupwise_eval_with_semaphore(
402402
return results
403403

404404
if mode == "pointwise":
405-
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
406405
pointwise_tasks: list[asyncio.Task[EvaluationRow]] = []
407-
# Use wrapper that handles retry logic internally
408-
async for row in rollout_processor_with_retry(
409-
rollout_processor, fresh_dataset, config, run_idx
410-
):
411-
pointwise_tasks.append(
412-
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
413-
)
406+
407+
if rollout_processor.supports_pipelining:
408+
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
409+
# Use wrapper that handles retry logic internally
410+
async for row in rollout_processor_with_retry(
411+
rollout_processor, fresh_dataset, config, run_idx
412+
):
413+
pointwise_tasks.append(
414+
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
415+
)
416+
else:
417+
# Non-pipelined mode: collect all rollout results first, then postprocess, then evaluate
418+
collected_rollout_rows: list[EvaluationRow] = []
419+
async for row in rollout_processor_with_retry(
420+
rollout_processor, fresh_dataset, config, run_idx
421+
):
422+
collected_rollout_rows.append(row)
423+
424+
# Post-process rollout results to get evaluation inputs
425+
eval_input_rows = rollout_processor.postprocess(collected_rollout_rows)
426+
427+
# Now evaluate all the post-processed rows
428+
for row in eval_input_rows:
429+
pointwise_tasks.append(
430+
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
431+
)
414432

415433
# Run evaluation tasks with progress bar
416434
results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx)
@@ -453,9 +471,13 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
453471
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
454472
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
455473
rollout_results = await asyncio.gather(*tasks)
456-
for result in rollout_results:
457-
for row in result:
458-
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]
474+
475+
# Flatten and postprocess all rollout results
476+
all_rollout_rows = [row for result in rollout_results for row in result]
477+
processed_rows = rollout_processor.postprocess(all_rollout_rows)
478+
479+
for row in processed_rows:
480+
row_groups[row.input_metadata.row_id].append(row)
459481
tasks = []
460482
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
461483
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
@@ -471,6 +493,9 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
471493
rollout_processor, fresh_dataset, config, run_idx
472494
):
473495
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
496+
497+
input_dataset = rollout_processor.postprocess(input_dataset)
498+
474499
# NOTE: we will still evaluate errored rows (give users control over this)
475500
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
476501
results = await execute_pytest(

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Callable
44

55
import requests
66

77
from eval_protocol.models import EvaluationRow
8+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
89
from .rollout_processor import RolloutProcessor
910
from .types import RolloutProcessorConfig
1011

@@ -35,20 +36,24 @@ class RemoteRolloutProcessor(RolloutProcessor):
3536
Returns: {"terminated": bool, "info": {...}?}
3637
"""
3738

39+
supports_pipelining: bool = False # Remote rollout processor cannot pipeline - must wait for all rollouts to complete before fetching results.
40+
3841
def __init__(
3942
self,
4043
*,
4144
remote_base_url: Optional[str] = None,
4245
num_turns: int = 2,
4346
poll_interval: float = 1.0,
4447
timeout_seconds: float = 120.0,
48+
output_data_loader: Callable[[str], DynamicDataLoader],
4549
):
4650
# Prefer constructor-provided configuration. These can be overridden via
4751
# config.kwargs at call time for backward compatibility.
4852
self._remote_base_url = remote_base_url
4953
self._num_turns = num_turns
5054
self._poll_interval = poll_interval
5155
self._timeout_seconds = timeout_seconds
56+
self._output_data_loader = output_data_loader
5257

5358
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
5459
tasks: List[asyncio.Task[EvaluationRow]] = []
@@ -158,5 +163,20 @@ def _get_status() -> Dict[str, Any]:
158163

159164
return tasks
160165

166+
def postprocess(self, finished_rollout_rows: List[EvaluationRow]) -> List[EvaluationRow]:
167+
"""Fetch actual evaluation rows from Langfuse using the output_data_loader."""
168+
invocation_id = finished_rollout_rows[0].execution_metadata.invocation_id
169+
if not invocation_id:
170+
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
171+
172+
data_loader = self._output_data_loader(invocation_id)
173+
174+
results = data_loader.load()
175+
output_rows: List[EvaluationRow] = []
176+
for result in results:
177+
output_rows.extend(result.rows)
178+
179+
return output_rows
180+
161181
def cleanup(self) -> None:
162182
return None

eval_protocol/pytest/rollout_processor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,19 @@ class RolloutProcessor(ABC):
1010
Abstract base class for all rollout processor strategies.
1111
"""
1212

13+
supports_pipelining: bool = (
14+
True # Whether this processor supports pipelined evaluation (evaluate rows as rollouts complete)
15+
)
16+
1317
@abstractmethod
1418
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
1519
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
1620
pass
1721

22+
def postprocess(self, finished_rollout_rows: list[EvaluationRow]) -> list[EvaluationRow]:
23+
"""Post-process rollout results to produce evaluation inputs. Only available for processors that return False from supports_pipelining."""
24+
return finished_rollout_rows
25+
1826
def cleanup(self) -> None:
1927
"""Cleanup resources. Override in subclasses if cleanup is needed."""
2028
pass

eval_protocol/pytest/utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import sys
66
from dataclasses import replace
7-
from typing import Any, Literal
7+
from typing import Any, Literal, Callable, AsyncGenerator
88

99
from litellm.cost_calculator import cost_per_token
1010
from tqdm import tqdm
@@ -33,7 +33,9 @@
3333
AggregationMethod = Literal["mean", "max", "min", "bootstrap"]
3434

3535

36-
async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
36+
async def run_tasks_with_eval_progress(
37+
pointwise_tasks: list[asyncio.Task[EvaluationRow]], run_idx: int
38+
) -> list[EvaluationRow]:
3739
"""
3840
Run evaluation tasks with a progress bar and proper cancellation handling.
3941
@@ -58,7 +60,7 @@ async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
5860
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
5961
) as eval_pbar:
6062

61-
async def task_with_progress(task):
63+
async def task_with_progress(task: asyncio.Task[EvaluationRow]) -> EvaluationRow:
6264
try:
6365
result = await task
6466
return result
@@ -77,7 +79,9 @@ async def task_with_progress(task):
7779
raise
7880

7981

80-
async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
82+
async def run_tasks_with_run_progress(
83+
execute_run_func: Callable[[int, RolloutProcessorConfig], Any], num_runs: int, config: RolloutProcessorConfig
84+
) -> None:
8185
"""
8286
Run tasks with a parallel runs progress bar, preserving original logic.
8387
@@ -98,12 +102,12 @@ async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
98102
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
99103
) as run_pbar:
100104

101-
async def execute_run_with_progress(run_idx: int, config):
105+
async def execute_run_with_progress(run_idx: int, config: RolloutProcessorConfig) -> Any:
102106
result = await execute_run_func(run_idx, config)
103107
run_pbar.update(1)
104108
return result
105109

106-
tasks = []
110+
tasks: list[asyncio.Task[Any]] = []
107111
for run_idx in range(num_runs):
108112
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
109113
try:
@@ -274,7 +278,7 @@ async def rollout_processor_with_retry(
274278
fresh_dataset: list[EvaluationRow],
275279
config: RolloutProcessorConfig,
276280
run_idx: int = 0,
277-
):
281+
) -> AsyncGenerator[EvaluationRow, None]:
278282
"""
279283
Wrapper around rollout_processor that handles retry logic using the Python backoff library.
280284
@@ -304,13 +308,13 @@ async def rollout_processor_with_retry(
304308

305309
# Create a single backoff-decorated retry function that can be reused
306310
@exception_config.get_backoff_decorator() # pyright: ignore[reportUntypedFunctionDecorator]
307-
async def execute_row_with_backoff_retry(row: EvaluationRow):
311+
async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
308312
"""Execute rollout for a single row with backoff retry."""
309313
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
310314
retry_tasks = rollout_processor([row], retry_config)
311315
return await retry_tasks[0]
312316

313-
async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
317+
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
314318
"""Execute a single row task with backoff retry."""
315319

316320
try:
@@ -344,7 +348,9 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
344348
row.rollout_status = Status.rollout_error(repr(e))
345349
return row
346350

347-
async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
351+
async def execute_row_with_backoff_and_log(
352+
task: asyncio.Task[EvaluationRow], row: EvaluationRow
353+
) -> EvaluationRow:
348354
"""Execute a single row task with backoff retry and logging."""
349355
result = await execute_row_with_backoff(task, row)
350356
# Log the row after execution completes (success or failure)
@@ -386,7 +392,7 @@ def sanitize_filename(text: str) -> str:
386392
return safe[:120]
387393

388394

389-
def extract_effort_tag(params: dict) -> str | None: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
395+
def extract_effort_tag(params: dict[str, Any]) -> str | None:
390396
"""
391397
Extract effort tag from completion parameters for use in file naming.
392398

tests/chinook/langfuse/remote_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def _worker():
5757

5858
# Prepare metadata payload to attach for Langfuse filtering
5959
metadata = {
60+
"tags": [
61+
f"invocation_id:{req.metadata.get('invocation_id')}",
62+
f"experiment_id:{req.metadata.get('experiment_id')}",
63+
f"rollout_id:{req.metadata.get('rollout_id')}",
64+
f"run_id:{req.metadata.get('run_id')}",
65+
f"row_id:{req.metadata.get('row_id')}",
66+
],
6067
"invocation_id": req.metadata.get("invocation_id"),
6168
"experiment_id": req.metadata.get("experiment_id"),
6269
"rollout_id": req.metadata.get("rollout_id"),

0 commit comments

Comments
 (0)