Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions eval_protocol/adapters/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING, cast

from langfuse.api.resources.commons.types.observations_view import ObservationsView
from eval_protocol.models import EvaluationRow, InputMetadata, Message
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from .base import BaseAdapter
from .utils import extract_messages_from_data

Expand Down Expand Up @@ -82,14 +82,53 @@ def convert_trace_to_evaluation_row(
if not messages:
return None

execution_metadata = ExecutionMetadata()
row_id = None

if trace.observations:
for obs in trace.observations:
if obs.metadata and "requester_metadata" in obs.metadata:
req_meta = obs.metadata["requester_metadata"]
if isinstance(req_meta, dict):
execution_metadata.invocation_id = req_meta.get("invocation_id")
execution_metadata.experiment_id = req_meta.get("experiment_id")
execution_metadata.rollout_id = req_meta.get("rollout_id")
execution_metadata.run_id = req_meta.get("run_id")
row_id = req_meta.get("row_id")
break # Only need to get first observation

if trace.tags:
for tag in trace.tags:
if tag.startswith("invocation_id:") and not execution_metadata.invocation_id:
execution_metadata.invocation_id = tag.split(":", 1)[1]
elif tag.startswith("experiment_id:") and not execution_metadata.experiment_id:
execution_metadata.experiment_id = tag.split(":", 1)[1]
elif tag.startswith("rollout_id:") and not execution_metadata.rollout_id:
execution_metadata.rollout_id = tag.split(":", 1)[1]
elif tag.startswith("run_id:") and not execution_metadata.run_id:
execution_metadata.run_id = tag.split(":", 1)[1]
elif tag.startswith("row_id:") and not row_id:
row_id = tag.split(":", 1)[1]

if (
execution_metadata.invocation_id
and execution_metadata.experiment_id
and execution_metadata.rollout_id
and execution_metadata.run_id
and row_id
):
break # Break early if we've found all the metadata we need

return EvaluationRow(
messages=messages,
tools=tools,
input_metadata=InputMetadata(
row_id=row_id,
session_data={
"langfuse_trace_id": trace.id, # Store the trace ID here
}
},
),
execution_metadata=execution_metadata,
)

except (AttributeError, ValueError, KeyError) as e:
Expand Down Expand Up @@ -332,16 +371,18 @@ def get_evaluation_rows(
to_timestamp=to_timestamp,
order_by="timestamp.desc",
)

# If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse
if traces and traces.meta and traces.meta.total_items == 0 and page == 1:
raise Exception("Empty results - indexing delay")

break
except Exception as e:
list_retries += 1
if "429" in str(e) and list_retries < max_retries:
if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)):
sleep_time = 2**list_retries # Exponential backoff
logger.warning(
"Rate limit hit on trace.list(), retrying in %ds (attempt %d/%d)",
sleep_time,
list_retries,
max_retries,
"Retrying in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e)
)
time.sleep(sleep_time)
else:
Expand Down
47 changes: 36 additions & 11 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,33 @@ async def _execute_groupwise_eval_with_semaphore(
return results

if mode == "pointwise":
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
pointwise_tasks: list[asyncio.Task[EvaluationRow]] = []
# Use wrapper that handles retry logic internally
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
pointwise_tasks.append(
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
)

if rollout_processor.supports_pipelining:
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
# Use wrapper that handles retry logic internally
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
pointwise_tasks.append(
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
)
else:
# Non-pipelined mode: collect all rollout results first, then postprocess, then evaluate
collected_rollout_rows: list[EvaluationRow] = []
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
collected_rollout_rows.append(row)

# Post-process rollout results to get evaluation inputs
eval_input_rows = rollout_processor.postprocess(collected_rollout_rows)

# Now evaluate all the post-processed rows
for row in eval_input_rows:
pointwise_tasks.append(
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
)

# Run evaluation tasks with progress bar
results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx)
Expand Down Expand Up @@ -453,9 +471,13 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
rollout_results = await asyncio.gather(*tasks)
for result in rollout_results:
for row in result:
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]

# Flatten and postprocess all rollout results
all_rollout_rows = [row for result in rollout_results for row in result]
processed_rows = rollout_processor.postprocess(all_rollout_rows)

for row in processed_rows:
row_groups[row.input_metadata.row_id].append(row)
tasks = []
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
Expand All @@ -471,6 +493,9 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
rollout_processor, fresh_dataset, config, run_idx
):
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]

input_dataset = rollout_processor.postprocess(input_dataset)

# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
results = await execute_pytest(
Expand Down
22 changes: 21 additions & 1 deletion eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Callable

import requests

from eval_protocol.models import EvaluationRow
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

Expand Down Expand Up @@ -35,20 +36,24 @@ class RemoteRolloutProcessor(RolloutProcessor):
Returns: {"terminated": bool, "info": {...}?}
"""

supports_pipelining: bool = False # Remote rollout processor cannot pipeline - must wait for all rollouts to complete before fetching results.

def __init__(
self,
*,
remote_base_url: Optional[str] = None,
num_turns: int = 2,
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Callable[[str], DynamicDataLoader],
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
self._remote_base_url = remote_base_url
self._num_turns = num_turns
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
Expand Down Expand Up @@ -158,5 +163,20 @@ def _get_status() -> Dict[str, Any]:

return tasks

def postprocess(self, finished_rollout_rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Fetch actual evaluation rows from Langfuse using the output_data_loader."""
invocation_id = finished_rollout_rows[0].execution_metadata.invocation_id
if not invocation_id:
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")

data_loader = self._output_data_loader(invocation_id)

results = data_loader.load()
output_rows: List[EvaluationRow] = []
for result in results:
output_rows.extend(result.rows)

return output_rows

def cleanup(self) -> None:
return None
8 changes: 8 additions & 0 deletions eval_protocol/pytest/rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ class RolloutProcessor(ABC):
Abstract base class for all rollout processor strategies.
"""

supports_pipelining: bool = (
True # Whether this processor supports pipelined evaluation (evaluate rows as rollouts complete)
)

@abstractmethod
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
pass

def postprocess(self, finished_rollout_rows: list[EvaluationRow]) -> list[EvaluationRow]:
"""Post-process rollout results to produce evaluation inputs. Only available for processors that return False from supports_pipelining."""
return finished_rollout_rows

def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses if cleanup is needed."""
pass
28 changes: 17 additions & 11 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import sys
from dataclasses import replace
from typing import Any, Literal
from typing import Any, Literal, Callable, AsyncGenerator

from litellm.cost_calculator import cost_per_token
from tqdm import tqdm
Expand Down Expand Up @@ -33,7 +33,9 @@
AggregationMethod = Literal["mean", "max", "min", "bootstrap"]


async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
async def run_tasks_with_eval_progress(
pointwise_tasks: list[asyncio.Task[EvaluationRow]], run_idx: int
) -> list[EvaluationRow]:
"""
Run evaluation tasks with a progress bar and proper cancellation handling.

Expand All @@ -58,7 +60,7 @@ async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
) as eval_pbar:

async def task_with_progress(task):
async def task_with_progress(task: asyncio.Task[EvaluationRow]) -> EvaluationRow:
try:
result = await task
return result
Expand All @@ -77,7 +79,9 @@ async def task_with_progress(task):
raise


async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
async def run_tasks_with_run_progress(
execute_run_func: Callable[[int, RolloutProcessorConfig], Any], num_runs: int, config: RolloutProcessorConfig
) -> None:
"""
Run tasks with a parallel runs progress bar, preserving original logic.

Expand All @@ -98,12 +102,12 @@ async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
) as run_pbar:

async def execute_run_with_progress(run_idx: int, config):
async def execute_run_with_progress(run_idx: int, config: RolloutProcessorConfig) -> Any:
result = await execute_run_func(run_idx, config)
run_pbar.update(1)
return result

tasks = []
tasks: list[asyncio.Task[Any]] = []
for run_idx in range(num_runs):
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
try:
Expand Down Expand Up @@ -274,7 +278,7 @@ async def rollout_processor_with_retry(
fresh_dataset: list[EvaluationRow],
config: RolloutProcessorConfig,
run_idx: int = 0,
):
) -> AsyncGenerator[EvaluationRow, None]:
"""
Wrapper around rollout_processor that handles retry logic using the Python backoff library.

Expand Down Expand Up @@ -304,13 +308,13 @@ async def rollout_processor_with_retry(

# Create a single backoff-decorated retry function that can be reused
@exception_config.get_backoff_decorator() # pyright: ignore[reportUntypedFunctionDecorator]
async def execute_row_with_backoff_retry(row: EvaluationRow):
async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
"""Execute rollout for a single row with backoff retry."""
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
retry_tasks = rollout_processor([row], retry_config)
return await retry_tasks[0]

async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
"""Execute a single row task with backoff retry."""

try:
Expand Down Expand Up @@ -344,7 +348,9 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
row.rollout_status = Status.rollout_error(repr(e))
return row

async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
async def execute_row_with_backoff_and_log(
task: asyncio.Task[EvaluationRow], row: EvaluationRow
) -> EvaluationRow:
"""Execute a single row task with backoff retry and logging."""
result = await execute_row_with_backoff(task, row)
# Log the row after execution completes (success or failure)
Expand Down Expand Up @@ -386,7 +392,7 @@ def sanitize_filename(text: str) -> str:
return safe[:120]


def extract_effort_tag(params: dict) -> str | None: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
def extract_effort_tag(params: dict[str, Any]) -> str | None:
"""
Extract effort tag from completion parameters for use in file naming.

Expand Down
Loading
Loading