Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@
except ImportError:
LangSmithAdapter = None

# Remote server types
from .types.remote_rollout_processor import (
InitRequest,
RolloutMetadata,
StatusResponse,
create_langfuse_config_tags,
)

warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

Expand Down Expand Up @@ -110,6 +117,11 @@
# Submodules
"rewards",
"mcp",
# Remote server types
"InitRequest",
"RolloutMetadata",
"StatusResponse",
"create_langfuse_config_tags",
]

from . import _version
Expand Down
151 changes: 36 additions & 115 deletions eval_protocol/adapters/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING, cast

from langfuse.api.resources.commons.types.observations_view import ObservationsView
from eval_protocol.models import EvaluationRow, InputMetadata, Message
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from .base import BaseAdapter
from .utils import extract_messages_from_data

Expand Down Expand Up @@ -82,14 +82,41 @@ def convert_trace_to_evaluation_row(
if not messages:
return None

execution_metadata = ExecutionMetadata()
row_id = None

if trace.tags:
for tag in trace.tags:
if tag.startswith("invocation_id:"):
execution_metadata.invocation_id = tag.split(":", 1)[1]
elif tag.startswith("experiment_id:"):
execution_metadata.experiment_id = tag.split(":", 1)[1]
elif tag.startswith("rollout_id:"):
execution_metadata.rollout_id = tag.split(":", 1)[1]
elif tag.startswith("run_id:"):
execution_metadata.run_id = tag.split(":", 1)[1]
elif tag.startswith("row_id:"):
row_id = tag.split(":", 1)[1]

if (
execution_metadata.invocation_id
and execution_metadata.experiment_id
and execution_metadata.rollout_id
and execution_metadata.run_id
and row_id
):
break # Break early if we've found all the metadata we need

return EvaluationRow(
messages=messages,
tools=tools,
input_metadata=InputMetadata(
row_id=row_id,
session_data={
"langfuse_trace_id": trace.id, # Store the trace ID here
}
},
),
execution_metadata=execution_metadata,
)

except (AttributeError, ValueError, KeyError) as e:
Expand Down Expand Up @@ -259,9 +286,6 @@ def get_evaluation_rows(
max_retries: int = 3,
span_name: Optional[str] = None,
converter: Optional[TraceConverter] = None,
metadata: Optional[Dict[str, Any]] = None,
requester_metadata: Optional[Dict[str, Any]] = None,
requester_metadata_contains: Optional[str] = None,
) -> List[EvaluationRow]:
"""Pull traces from Langfuse and convert to EvaluationRow format.

Expand Down Expand Up @@ -296,10 +320,6 @@ def get_evaluation_rows(
to_timestamp = datetime.now()
from_timestamp = to_timestamp - timedelta(hours=hours_back)

# If filtering by metadata/requester_metadata, prefer fetching metadata fields
if (metadata is not None or requester_metadata is not None or requester_metadata_contains) and not fields:
fields = "core,metadata,observations"

# Collect trace summaries via pagination (up to limit)
all_traces = []
page = 1
Expand Down Expand Up @@ -332,16 +352,18 @@ def get_evaluation_rows(
to_timestamp=to_timestamp,
order_by="timestamp.desc",
)

# If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse
if traces and traces.meta and traces.meta.total_items == 0 and page == 1:
raise Exception("Empty results - indexing delay")

break
except Exception as e:
list_retries += 1
if "429" in str(e) and list_retries < max_retries:
if list_retries < max_retries and ("429" in str(e) or "Empty results" in str(e)):
sleep_time = 2**list_retries # Exponential backoff
logger.warning(
"Rate limit hit on trace.list(), retrying in %ds (attempt %d/%d)",
sleep_time,
list_retries,
max_retries,
"Retrying in %ds (attempt %d/%d): %s", sleep_time, list_retries, max_retries, str(e)
)
time.sleep(sleep_time)
else:
Expand Down Expand Up @@ -379,74 +401,6 @@ def get_evaluation_rows(
selected_traces = all_traces
logger.debug("Processing all %d collected traces (no sampling)", len(all_traces))

# Helper to check if a trace matches provided metadata filters. We look in multiple places
# to account for Langfuse moving fields (e.g., metadata vs requester_metadata) and SDK shape.
def _trace_matches_metadata_filters(trace_obj: Any) -> bool:
if metadata is None and requester_metadata is None:
return True

def _as_dict(val: Any) -> Dict[str, Any]:
if val is None:
return {}
if isinstance(val, dict):
return val
# Some SDK objects expose .model_dump() or behave like pydantic models
dump = getattr(val, "model_dump", None)
if callable(dump):
try:
return dump() # type: ignore[no-any-return]
except Exception:
return {}
return {}

# Try common locations for metadata on full trace
trace_meta = _as_dict(getattr(trace_obj, "metadata", None))
trace_req_meta = _as_dict(getattr(trace_obj, "requester_metadata", None))
# Some Langfuse deployments nest requester_metadata inside metadata
nested_req_meta = {}
try:
if isinstance(trace_meta, dict) and isinstance(trace_meta.get("requester_metadata"), dict):
nested_req_meta = _as_dict(trace_meta.get("requester_metadata"))
except Exception:
nested_req_meta = {}

# Fallbacks: sometimes metadata is embedded in input
input_meta = {}
try:
inp = getattr(trace_obj, "input", None)
if isinstance(inp, dict):
input_meta = _as_dict(inp.get("metadata"))
except Exception:
input_meta = {}

# Combine for matching convenience (later keys override earlier for equality check only)
combined_meta = {**trace_meta, **input_meta}
combined_req_meta = {**trace_req_meta}

# Also merge nested requester metadata when present
if nested_req_meta:
combined_req_meta = {**combined_req_meta, **nested_req_meta}

def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
for k, v in needle.items():
if haystack.get(k) != v:
return False
return True

ok_meta = True
ok_req_meta = True

if metadata is not None:
# Accept match if found either in metadata or requester_metadata buckets
ok_meta = _is_subset(metadata, combined_meta) or _is_subset(metadata, combined_req_meta)

if requester_metadata is not None:
ok_req_meta = _is_subset(requester_metadata, combined_req_meta) or _is_subset(
requester_metadata, combined_meta
)

return ok_meta and ok_req_meta

# Process each selected trace with sleep and retry logic
for trace_info in selected_traces:
# Sleep between gets to avoid rate limits
Expand Down Expand Up @@ -483,39 +437,6 @@ def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
break # Skip this trace

if trace_full:
# If metadata filters are provided, skip non-matching traces early
try:
if not _trace_matches_metadata_filters(trace_full):
continue
except Exception:
# Be permissive on filter errors; treat as non-match
continue

# If observations carry requester_metadata, allow substring filtering
if requester_metadata_contains:
contains_val = requester_metadata_contains
found_match = False
try:
for obs in getattr(trace_full, "observations", []) or []:
obs_rmd = getattr(obs, "requester_metadata", None)
if isinstance(obs_rmd, dict) and any(
(isinstance(v, str) and contains_val in v) for v in obs_rmd.values()
):
found_match = True
break
obs_md = getattr(obs, "metadata", None)
if isinstance(obs_md, dict):
nested = obs_md.get("requester_metadata")
if isinstance(nested, dict) and any(
(isinstance(v, str) and contains_val in v) for v in nested.values()
):
found_match = True
break
except Exception:
found_match = False
if not found_match:
continue

try:
if converter:
eval_row = converter(trace_full, include_tool_calls, span_name)
Expand Down
38 changes: 18 additions & 20 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def evaluation_test(
input_dataset: Sequence[DatasetPathParam] | None = None,
input_rows: Sequence[list[EvaluationRow]] | None = None,
data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter,
rollout_processor: RolloutProcessor | None = None,
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None,
rollout_processor_kwargs: RolloutProcessorInputParam | None = None,
Expand Down Expand Up @@ -418,9 +418,7 @@ async def _execute_groupwise_eval_with_semaphore(
all_results[run_idx] = results
elif mode == "groupwise":
# rollout all the completion_params for the same row at once, and then send the output to the test_func
row_groups = defaultdict( # pyright: ignore[reportUnknownVariableType]
list
) # key: row_id, value: list of rollout_result
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
tasks: list[asyncio.Task[list[EvaluationRow]]] = []
# completion_groups = []
for idx, cp in enumerate(original_completion_params):
Expand All @@ -435,13 +433,13 @@ async def _execute_groupwise_eval_with_semaphore(
)
lst = []

async def _collect_result(config, lst): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
async def _collect_result(config, lst):
result = []
async for row in rollout_processor_with_retry(
rollout_processor, lst, config, run_idx
): # pyright: ignore[reportUnknownArgumentType]
result.append(row) # pyright: ignore[reportUnknownMemberType]
return result # pyright: ignore[reportUnknownVariableType]
result.append(row)
return result

for ori_row in fresh_dataset:
copied_row = ori_row.model_copy(deep=True)
Expand All @@ -450,32 +448,32 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
)
copied_row.input_metadata.completion_params = cp if cp is not None else {}
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
lst.append(copied_row)
tasks.append(asyncio.create_task(_collect_result(config, lst)))
rollout_results = await asyncio.gather(*tasks)
for result in rollout_results:
for row in result:
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]
row_groups[row.input_metadata.row_id].append(row)
tasks = []
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
for _, rows in row_groups.items():
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
results = []
for task in tasks:
res = await task
results.extend(res) # pyright: ignore[reportUnknownMemberType]
results.extend(res)
all_results[run_idx] = results
else:
# Batch mode: collect all results first, then evaluate (no pipelining)
input_dataset = []
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
input_dataset.append(row)
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
results = await execute_pytest(
test_func,
processed_dataset=input_dataset, # pyright: ignore[reportUnknownArgumentType]
processed_dataset=input_dataset,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if (
Expand Down Expand Up @@ -538,16 +536,16 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
# rollout_id is used to differentiate the result from different completion_params
if mode == "groupwise":
results_by_group = [ # pyright: ignore[reportUnknownVariableType]
results_by_group = [
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params))
]
for i_run, result in enumerate(all_results):
for r in result:
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess]
results_by_group[completion_param_idx][i_run].append(r) # pyright: ignore[reportUnknownMemberType]
for rollout_id, result in enumerate(results_by_group): # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
results_by_group[completion_param_idx][i_run].append(r)
for rollout_id, result in enumerate(results_by_group):
postprocess(
result, # pyright: ignore[reportUnknownArgumentType]
result,
aggregation_method,
passed_threshold,
active_logger,
Expand Down Expand Up @@ -599,7 +597,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

# Create the dual mode wrapper
dual_mode_wrapper = create_dual_mode_wrapper( # pyright: ignore[reportUnknownVariableType]
dual_mode_wrapper = create_dual_mode_wrapper(
test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper
)

Expand Down
Loading
Loading