Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 155 additions & 17 deletions eval_protocol/adapters/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
to EvaluationRow format for use in evaluation pipelines.
"""

from langfuse.api.resources.commons.types.observations_view import ObservationsView
from __future__ import annotations

import logging
import random
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING

from eval_protocol.models import EvaluationRow, InputMetadata, Message
from .base import BaseAdapter
Expand Down Expand Up @@ -46,14 +47,15 @@ def __call__(

try:
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
from langfuse.api.resources.trace.types.traces import Traces
from langfuse.api.resources.commons.types.trace import Trace
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails

LANGFUSE_AVAILABLE = True
except ImportError:
LANGFUSE_AVAILABLE = False

if TYPE_CHECKING:
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails
from langfuse.api.resources.commons.types.observations_view import ObservationsView


def convert_trace_to_evaluation_row(
trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None
Expand All @@ -64,7 +66,6 @@ def convert_trace_to_evaluation_row(
trace: Langfuse trace object
include_tool_calls: Whether to include tool calling information
span_name: If provided, extract messages from generations within this named span
converter: Optional custom converter implementing TraceConverter protocol

Returns:
EvaluationRow or None if conversion fails
Expand Down Expand Up @@ -97,7 +98,7 @@ def convert_trace_to_evaluation_row(


def extract_messages_from_trace(
trace: TraceWithFullDetails, include_tool_calls: bool = True, span_name: Optional[str] = None
trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None
) -> List[Message]:
"""Extract messages from Langfuse trace input and output.

Expand All @@ -114,7 +115,7 @@ def extract_messages_from_trace(
if span_name: # Look for a generation tied to a span name
try:
# Find the final generation in the named span
gen: ObservationsView | None = get_final_generation_in_span(trace, span_name)
gen: "ObservationsView | None" = get_final_generation_in_span(trace, span_name)
if not gen:
return messages

Expand All @@ -140,10 +141,27 @@ def extract_messages_from_trace(
except (AttributeError, ValueError, KeyError) as e:
logger.warning("Error processing trace %s: %s", trace.id, e)

# Fallback: use the last GENERATION observation which typically contains full chat history
if not messages:
try:
all_observations = getattr(trace, "observations", None) or []
gens: List[ObservationsView] = [
obs for obs in all_observations if getattr(obs, "type", None) == "GENERATION"
]
if gens:
gens.sort(key=lambda x: x.start_time)
last_gen = gens[-1]
if getattr(last_gen, "input", None):
messages.extend(extract_messages_from_data(getattr(last_gen, "input"), include_tool_calls))
if getattr(last_gen, "output", None):
messages.extend(extract_messages_from_data(getattr(last_gen, "output"), include_tool_calls))
except Exception as e:
logger.warning("Failed to extract from last generation for trace %s: %s", trace.id, e)

return messages


def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> ObservationsView | None:
def get_final_generation_in_span(trace: "TraceWithFullDetails", span_name: str) -> "ObservationsView | None":
"""Get the final generation within a named span that contains full message history.

Args:
Expand Down Expand Up @@ -173,7 +191,7 @@ def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) ->
return None

# Find all generations within this span
generations: List[ObservationsView] = []
generations: List["ObservationsView"] = []
for obs in all_observations:
if obs.type == "GENERATION" and obs.parent_observation_id == parent_span.id:
generations.append(obs)
Expand Down Expand Up @@ -241,6 +259,9 @@ def get_evaluation_rows(
max_retries: int = 3,
span_name: Optional[str] = None,
converter: Optional[TraceConverter] = None,
metadata: Optional[Dict[str, Any]] = None,
requester_metadata: Optional[Dict[str, Any]] = None,
requester_metadata_contains: Optional[str] = None,
) -> List[EvaluationRow]:
"""Pull traces from Langfuse and convert to EvaluationRow format.

Expand Down Expand Up @@ -275,6 +296,10 @@ def get_evaluation_rows(
to_timestamp = datetime.now()
from_timestamp = to_timestamp - timedelta(hours=hours_back)

# If filtering by metadata/requester_metadata, prefer fetching metadata fields
if (metadata is not None or requester_metadata is not None or requester_metadata_contains) and not fields:
fields = "core,metadata,observations"

# Collect trace summaries via pagination (up to limit)
all_traces = []
page = 1
Expand Down Expand Up @@ -354,6 +379,74 @@ def get_evaluation_rows(
selected_traces = all_traces
logger.debug("Processing all %d collected traces (no sampling)", len(all_traces))

# Helper to check if a trace matches provided metadata filters. We look in multiple places
# to account for Langfuse moving fields (e.g., metadata vs requester_metadata) and SDK shape.
def _trace_matches_metadata_filters(trace_obj: Any) -> bool:
if metadata is None and requester_metadata is None:
return True

def _as_dict(val: Any) -> Dict[str, Any]:
if val is None:
return {}
if isinstance(val, dict):
return val
# Some SDK objects expose .model_dump() or behave like pydantic models
dump = getattr(val, "model_dump", None)
if callable(dump):
try:
return dump() # type: ignore[no-any-return]
except Exception:
return {}
return {}

# Try common locations for metadata on full trace
trace_meta = _as_dict(getattr(trace_obj, "metadata", None))
trace_req_meta = _as_dict(getattr(trace_obj, "requester_metadata", None))
# Some Langfuse deployments nest requester_metadata inside metadata
nested_req_meta = {}
try:
if isinstance(trace_meta, dict) and isinstance(trace_meta.get("requester_metadata"), dict):
nested_req_meta = _as_dict(trace_meta.get("requester_metadata"))
except Exception:
nested_req_meta = {}

# Fallbacks: sometimes metadata is embedded in input
input_meta = {}
try:
inp = getattr(trace_obj, "input", None)
if isinstance(inp, dict):
input_meta = _as_dict(inp.get("metadata"))
except Exception:
input_meta = {}

# Combine for matching convenience (later keys override earlier for equality check only)
combined_meta = {**trace_meta, **input_meta}
combined_req_meta = {**trace_req_meta}

# Also merge nested requester metadata when present
if nested_req_meta:
combined_req_meta = {**combined_req_meta, **nested_req_meta}

def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
for k, v in needle.items():
if haystack.get(k) != v:
return False
return True

ok_meta = True
ok_req_meta = True

if metadata is not None:
# Accept match if found either in metadata or requester_metadata buckets
ok_meta = _is_subset(metadata, combined_meta) or _is_subset(metadata, combined_req_meta)

if requester_metadata is not None:
ok_req_meta = _is_subset(requester_metadata, combined_req_meta) or _is_subset(
requester_metadata, combined_meta
)

return ok_meta and ok_req_meta

# Process each selected trace with sleep and retry logic
for trace_info in selected_traces:
# Sleep between gets to avoid rate limits
Expand All @@ -365,6 +458,7 @@ def get_evaluation_rows(
detail_retries = 0
while detail_retries < max_retries:
try:
# Some SDKs don't support fields= on get; call without it
trace_full = self.client.api.trace.get(trace_info.id)
break
except Exception as e:
Expand All @@ -379,11 +473,49 @@ def get_evaluation_rows(
max_retries,
)
time.sleep(sleep_time)
elif "Not Found" in str(e) or "404" in str(e):
# Skip missing traces quickly
logger.debug("Trace %s not found, skipping", trace_info.id)
trace_full = None
break
else:
logger.warning("Failed to fetch trace %s after %d retries: %s", trace_info.id, max_retries, e)
break # Skip this trace

if trace_full:
# If metadata filters are provided, skip non-matching traces early
try:
if not _trace_matches_metadata_filters(trace_full):
continue
except Exception:
# Be permissive on filter errors; treat as non-match
continue

# If observations carry requester_metadata, allow substring filtering
if requester_metadata_contains:
contains_val = requester_metadata_contains
found_match = False
try:
for obs in getattr(trace_full, "observations", []) or []:
obs_rmd = getattr(obs, "requester_metadata", None)
if isinstance(obs_rmd, dict) and any(
(isinstance(v, str) and contains_val in v) for v in obs_rmd.values()
):
found_match = True
break
obs_md = getattr(obs, "metadata", None)
if isinstance(obs_md, dict):
nested = obs_md.get("requester_metadata")
if isinstance(nested, dict) and any(
(isinstance(v, str) and contains_val in v) for v in nested.values()
):
found_match = True
break
except Exception:
found_match = False
if not found_match:
continue

try:
if converter:
eval_row = converter(trace_full, include_tool_calls, span_name)
Expand Down Expand Up @@ -451,16 +583,22 @@ def upload_scores(self, rows: List[EvaluationRow], model_name: str, mean_score:
"""
try:
for trace_id in set(
row.input_metadata.session_data["langfuse_trace_id"]
(row.input_metadata.session_data or {}).get("langfuse_trace_id")
for row in rows
if row.evaluation_result and row.input_metadata and row.input_metadata.session_data
if row.input_metadata and row.input_metadata.session_data
):
if trace_id:
self.client.create_score(
trace_id=trace_id,
name=model_name,
value=mean_score,
)
try:
self.client.api.score.create(
trace_id=trace_id,
name=model_name,
value=mean_score,
)
except Exception:
# Fallback to legacy client if available in some environments
create_score = getattr(self.client, "create_score", None)
if callable(create_score):
create_score(trace_id=trace_id, name=model_name, value=mean_score)
except Exception as e:
logger.warning("Failed to push scores to Langfuse: %s", e)

Expand Down
6 changes: 4 additions & 2 deletions eval_protocol/adapters/langsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class LangSmithAdapter(BaseAdapter):
- outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict]
"""

def __init__(self, client: Optional[Client] = None) -> None:
def __init__(self, client: Optional[Any] = None) -> None:
if not LANGSMITH_AVAILABLE:
raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'")
self.client = client or Client()
# Client is provided by langsmith package; typing is relaxed to Any to avoid
# static analysis issues when stubs aren't available.
self.client = client or Client() # type: ignore[reportCallIssue]

def get_evaluation_rows(
self,
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .remote_rollout_processor import RemoteRolloutProcessor
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
Expand Down Expand Up @@ -31,6 +32,7 @@
"MCPGymRolloutProcessor",
"RolloutProcessor",
"SingleTurnRolloutProcessor",
"RemoteRolloutProcessor",
"NoOpRolloutProcessor",
"default_dataset_adapter",
"RolloutProcessorConfig",
Expand Down
Loading
Loading