Skip to content

Commit fcadf33

Browse files
committed
HTTP remote rollout server support
1 parent 7e75290 commit fcadf33

File tree

6 files changed

+627
-19
lines changed

6 files changed

+627
-19
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 155 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
to EvaluationRow format for use in evaluation pipelines.
55
"""
66

7-
from langfuse.api.resources.commons.types.observations_view import ObservationsView
7+
from __future__ import annotations
8+
89
import logging
910
import random
1011
import time
1112
from datetime import datetime, timedelta
12-
from typing import Any, Dict, List, Optional, Protocol
13+
from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING
1314

1415
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1516
from .base import BaseAdapter
@@ -46,14 +47,15 @@ def __call__(
4647

4748
try:
4849
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
49-
from langfuse.api.resources.trace.types.traces import Traces
50-
from langfuse.api.resources.commons.types.trace import Trace
51-
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails
5250

5351
LANGFUSE_AVAILABLE = True
5452
except ImportError:
5553
LANGFUSE_AVAILABLE = False
5654

55+
if TYPE_CHECKING:
56+
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails
57+
from langfuse.api.resources.commons.types.observations_view import ObservationsView
58+
5759

5860
def convert_trace_to_evaluation_row(
5961
trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None
@@ -64,7 +66,6 @@ def convert_trace_to_evaluation_row(
6466
trace: Langfuse trace object
6567
include_tool_calls: Whether to include tool calling information
6668
span_name: If provided, extract messages from generations within this named span
67-
converter: Optional custom converter implementing TraceConverter protocol
6869
6970
Returns:
7071
EvaluationRow or None if conversion fails
@@ -97,7 +98,7 @@ def convert_trace_to_evaluation_row(
9798

9899

99100
def extract_messages_from_trace(
100-
trace: TraceWithFullDetails, include_tool_calls: bool = True, span_name: Optional[str] = None
101+
trace: "TraceWithFullDetails", include_tool_calls: bool = True, span_name: Optional[str] = None
101102
) -> List[Message]:
102103
"""Extract messages from Langfuse trace input and output.
103104
@@ -114,7 +115,7 @@ def extract_messages_from_trace(
114115
if span_name: # Look for a generation tied to a span name
115116
try:
116117
# Find the final generation in the named span
117-
gen: ObservationsView | None = get_final_generation_in_span(trace, span_name)
118+
gen: "ObservationsView | None" = get_final_generation_in_span(trace, span_name)
118119
if not gen:
119120
return messages
120121

@@ -140,10 +141,27 @@ def extract_messages_from_trace(
140141
except (AttributeError, ValueError, KeyError) as e:
141142
logger.warning("Error processing trace %s: %s", trace.id, e)
142143

144+
# Fallback: use the last GENERATION observation which typically contains full chat history
145+
if not messages:
146+
try:
147+
all_observations = getattr(trace, "observations", None) or []
148+
gens: List[ObservationsView] = [
149+
obs for obs in all_observations if getattr(obs, "type", None) == "GENERATION"
150+
]
151+
if gens:
152+
gens.sort(key=lambda x: x.start_time)
153+
last_gen = gens[-1]
154+
if getattr(last_gen, "input", None):
155+
messages.extend(extract_messages_from_data(getattr(last_gen, "input"), include_tool_calls))
156+
if getattr(last_gen, "output", None):
157+
messages.extend(extract_messages_from_data(getattr(last_gen, "output"), include_tool_calls))
158+
except Exception as e:
159+
logger.warning("Failed to extract from last generation for trace %s: %s", trace.id, e)
160+
143161
return messages
144162

145163

146-
def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> ObservationsView | None:
164+
def get_final_generation_in_span(trace: "TraceWithFullDetails", span_name: str) -> "ObservationsView | None":
147165
"""Get the final generation within a named span that contains full message history.
148166
149167
Args:
@@ -173,7 +191,7 @@ def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) ->
173191
return None
174192

175193
# Find all generations within this span
176-
generations: List[ObservationsView] = []
194+
generations: List["ObservationsView"] = []
177195
for obs in all_observations:
178196
if obs.type == "GENERATION" and obs.parent_observation_id == parent_span.id:
179197
generations.append(obs)
@@ -241,6 +259,9 @@ def get_evaluation_rows(
241259
max_retries: int = 3,
242260
span_name: Optional[str] = None,
243261
converter: Optional[TraceConverter] = None,
262+
metadata: Optional[Dict[str, Any]] = None,
263+
requester_metadata: Optional[Dict[str, Any]] = None,
264+
requester_metadata_contains: Optional[str] = None,
244265
) -> List[EvaluationRow]:
245266
"""Pull traces from Langfuse and convert to EvaluationRow format.
246267
@@ -275,6 +296,10 @@ def get_evaluation_rows(
275296
to_timestamp = datetime.now()
276297
from_timestamp = to_timestamp - timedelta(hours=hours_back)
277298

299+
# If filtering by metadata/requester_metadata, prefer fetching metadata fields
300+
if (metadata is not None or requester_metadata is not None or requester_metadata_contains) and not fields:
301+
fields = "core,metadata,observations"
302+
278303
# Collect trace summaries via pagination (up to limit)
279304
all_traces = []
280305
page = 1
@@ -354,6 +379,74 @@ def get_evaluation_rows(
354379
selected_traces = all_traces
355380
logger.debug("Processing all %d collected traces (no sampling)", len(all_traces))
356381

382+
# Helper to check if a trace matches provided metadata filters. We look in multiple places
383+
# to account for Langfuse moving fields (e.g., metadata vs requester_metadata) and SDK shape.
384+
def _trace_matches_metadata_filters(trace_obj: Any) -> bool:
385+
if metadata is None and requester_metadata is None:
386+
return True
387+
388+
def _as_dict(val: Any) -> Dict[str, Any]:
389+
if val is None:
390+
return {}
391+
if isinstance(val, dict):
392+
return val
393+
# Some SDK objects expose .model_dump() or behave like pydantic models
394+
dump = getattr(val, "model_dump", None)
395+
if callable(dump):
396+
try:
397+
return dump() # type: ignore[no-any-return]
398+
except Exception:
399+
return {}
400+
return {}
401+
402+
# Try common locations for metadata on full trace
403+
trace_meta = _as_dict(getattr(trace_obj, "metadata", None))
404+
trace_req_meta = _as_dict(getattr(trace_obj, "requester_metadata", None))
405+
# Some Langfuse deployments nest requester_metadata inside metadata
406+
nested_req_meta = {}
407+
try:
408+
if isinstance(trace_meta, dict) and isinstance(trace_meta.get("requester_metadata"), dict):
409+
nested_req_meta = _as_dict(trace_meta.get("requester_metadata"))
410+
except Exception:
411+
nested_req_meta = {}
412+
413+
# Fallbacks: sometimes metadata is embedded in input
414+
input_meta = {}
415+
try:
416+
inp = getattr(trace_obj, "input", None)
417+
if isinstance(inp, dict):
418+
input_meta = _as_dict(inp.get("metadata"))
419+
except Exception:
420+
input_meta = {}
421+
422+
# Combine for matching convenience (later keys override earlier for equality check only)
423+
combined_meta = {**trace_meta, **input_meta}
424+
combined_req_meta = {**trace_req_meta}
425+
426+
# Also merge nested requester metadata when present
427+
if nested_req_meta:
428+
combined_req_meta = {**combined_req_meta, **nested_req_meta}
429+
430+
def _is_subset(needle: Dict[str, Any], haystack: Dict[str, Any]) -> bool:
431+
for k, v in needle.items():
432+
if haystack.get(k) != v:
433+
return False
434+
return True
435+
436+
ok_meta = True
437+
ok_req_meta = True
438+
439+
if metadata is not None:
440+
# Accept match if found either in metadata or requester_metadata buckets
441+
ok_meta = _is_subset(metadata, combined_meta) or _is_subset(metadata, combined_req_meta)
442+
443+
if requester_metadata is not None:
444+
ok_req_meta = _is_subset(requester_metadata, combined_req_meta) or _is_subset(
445+
requester_metadata, combined_meta
446+
)
447+
448+
return ok_meta and ok_req_meta
449+
357450
# Process each selected trace with sleep and retry logic
358451
for trace_info in selected_traces:
359452
# Sleep between gets to avoid rate limits
@@ -365,6 +458,7 @@ def get_evaluation_rows(
365458
detail_retries = 0
366459
while detail_retries < max_retries:
367460
try:
461+
# Some SDKs don't support fields= on get; call without it
368462
trace_full = self.client.api.trace.get(trace_info.id)
369463
break
370464
except Exception as e:
@@ -379,11 +473,49 @@ def get_evaluation_rows(
379473
max_retries,
380474
)
381475
time.sleep(sleep_time)
476+
elif "Not Found" in str(e) or "404" in str(e):
477+
# Skip missing traces quickly
478+
logger.debug("Trace %s not found, skipping", trace_info.id)
479+
trace_full = None
480+
break
382481
else:
383482
logger.warning("Failed to fetch trace %s after %d retries: %s", trace_info.id, max_retries, e)
384483
break # Skip this trace
385484

386485
if trace_full:
486+
# If metadata filters are provided, skip non-matching traces early
487+
try:
488+
if not _trace_matches_metadata_filters(trace_full):
489+
continue
490+
except Exception:
491+
# Be permissive on filter errors; treat as non-match
492+
continue
493+
494+
# If observations carry requester_metadata, allow substring filtering
495+
if requester_metadata_contains:
496+
contains_val = requester_metadata_contains
497+
found_match = False
498+
try:
499+
for obs in getattr(trace_full, "observations", []) or []:
500+
obs_rmd = getattr(obs, "requester_metadata", None)
501+
if isinstance(obs_rmd, dict) and any(
502+
(isinstance(v, str) and contains_val in v) for v in obs_rmd.values()
503+
):
504+
found_match = True
505+
break
506+
obs_md = getattr(obs, "metadata", None)
507+
if isinstance(obs_md, dict):
508+
nested = obs_md.get("requester_metadata")
509+
if isinstance(nested, dict) and any(
510+
(isinstance(v, str) and contains_val in v) for v in nested.values()
511+
):
512+
found_match = True
513+
break
514+
except Exception:
515+
found_match = False
516+
if not found_match:
517+
continue
518+
387519
try:
388520
if converter:
389521
eval_row = converter(trace_full, include_tool_calls, span_name)
@@ -451,16 +583,22 @@ def upload_scores(self, rows: List[EvaluationRow], model_name: str, mean_score:
451583
"""
452584
try:
453585
for trace_id in set(
454-
row.input_metadata.session_data["langfuse_trace_id"]
586+
(row.input_metadata.session_data or {}).get("langfuse_trace_id")
455587
for row in rows
456-
if row.evaluation_result and row.input_metadata and row.input_metadata.session_data
588+
if row.input_metadata and row.input_metadata.session_data
457589
):
458590
if trace_id:
459-
self.client.create_score(
460-
trace_id=trace_id,
461-
name=model_name,
462-
value=mean_score,
463-
)
591+
try:
592+
self.client.api.score.create(
593+
trace_id=trace_id,
594+
name=model_name,
595+
value=mean_score,
596+
)
597+
except Exception:
598+
# Fallback to legacy client if available in some environments
599+
create_score = getattr(self.client, "create_score", None)
600+
if callable(create_score):
601+
create_score(trace_id=trace_id, name=model_name, value=mean_score)
464602
except Exception as e:
465603
logger.warning("Failed to push scores to Langfuse: %s", e)
466604

eval_protocol/adapters/langsmith.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ class LangSmithAdapter(BaseAdapter):
3535
- outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict]
3636
"""
3737

38-
def __init__(self, client: Optional[Client] = None) -> None:
38+
def __init__(self, client: Optional[Any] = None) -> None:
3939
if not LANGSMITH_AVAILABLE:
4040
raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'")
41-
self.client = client or Client()
41+
# Client is provided by langsmith package; typing is relaxed to Any to avoid
42+
# static analysis issues when stubs aren't available.
43+
self.client = client or Client() # type: ignore[reportCallIssue]
4244

4345
def get_evaluation_rows(
4446
self,

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
6+
from .remote_rollout_processor import RemoteRolloutProcessor
67
from .evaluation_test import evaluation_test
78
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
89
from .rollout_processor import RolloutProcessor
@@ -31,6 +32,7 @@
3132
"MCPGymRolloutProcessor",
3233
"RolloutProcessor",
3334
"SingleTurnRolloutProcessor",
35+
"RemoteRolloutProcessor",
3436
"NoOpRolloutProcessor",
3537
"default_dataset_adapter",
3638
"RolloutProcessorConfig",

0 commit comments

Comments
 (0)