Skip to content

Commit 80794bd

Browse files
committed
working for my own chinook trace, changing adapter now
1 parent da4023d commit 80794bd

File tree

5 files changed

+188
-159
lines changed

5 files changed

+188
-159
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
"""
66

77
import logging
8-
from datetime import datetime
8+
from datetime import datetime, timedelta
99
from typing import Any, Dict, Iterator, List, Optional, cast
1010

1111
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1212

1313
logger = logging.getLogger(__name__)
1414

1515
try:
16-
from langfuse import Langfuse # pyright: ignore[reportPrivateImportUsage]
16+
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
17+
from langfuse.api.resources.trace.types.traces import Traces
18+
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails
1719

1820
LANGFUSE_AVAILABLE = True
1921
except ImportError:
@@ -45,35 +47,20 @@ class LangfuseAdapter:
4547
... ))
4648
"""
4749

48-
def __init__(
49-
self,
50-
public_key: str,
51-
secret_key: str,
52-
host: str = "https://cloud.langfuse.com",
53-
project_id: Optional[str] = None,
54-
):
55-
"""Initialize the Langfuse adapter.
56-
57-
Args:
58-
public_key: Langfuse public key
59-
secret_key: Langfuse secret key
60-
host: Langfuse host URL (default: https://cloud.langfuse.com)
61-
project_id: Optional project ID to filter traces
62-
"""
50+
def __init__(self):
51+
"""Initialize the Langfuse adapter."""
6352
if not LANGFUSE_AVAILABLE:
6453
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")
6554

66-
self.client = cast(Any, Langfuse)(public_key=public_key, secret_key=secret_key, host=host)
67-
self.project_id = project_id
55+
self.client = get_client()
6856

6957
def get_evaluation_rows(
7058
self,
7159
limit: int = 100,
7260
tags: Optional[List[str]] = None,
7361
user_id: Optional[str] = None,
7462
session_id: Optional[str] = None,
75-
from_timestamp: Optional[datetime] = None,
76-
to_timestamp: Optional[datetime] = None,
63+
hours_back: Optional[int] = None,
7764
include_tool_calls: bool = True,
7865
) -> List[EvaluationRow]:
7966
"""Pull traces from Langfuse and convert to EvaluationRow format.
@@ -83,16 +70,23 @@ def get_evaluation_rows(
8370
tags: Filter by specific tags
8471
user_id: Filter by user ID
8572
session_id: Filter by session ID
86-
from_timestamp: Filter traces after this timestamp
87-
to_timestamp: Filter traces before this timestamp
73+
hours_back: Filter traces from this many hours ago
8874
include_tool_calls: Whether to include tool calling traces
8975
9076
Yields:
9177
EvaluationRow: Converted evaluation rows
9278
"""
9379
# Get traces from Langfuse using new API
80+
81+
if hours_back:
82+
to_timestamp = datetime.now()
83+
from_timestamp = to_timestamp - timedelta(hours=hours_back)
84+
else:
85+
to_timestamp = None
86+
from_timestamp = None
87+
9488
eval_rows = []
95-
traces = self.client.api.trace.list(
89+
traces: Traces = self.client.api.trace.list(
9690
limit=limit,
9791
tags=tags,
9892
user_id=user_id,
@@ -128,7 +122,7 @@ def get_evaluation_rows_by_ids(
128122
eval_rows = []
129123
for trace_id in trace_ids:
130124
try:
131-
trace = self.client.api.trace.get(trace_id)
125+
trace: TraceWithFullDetails = self.client.api.trace.get(trace_id)
132126
eval_row = self._convert_trace_to_evaluation_row(trace, include_tool_calls)
133127
if eval_row:
134128
eval_rows.append(eval_row)
@@ -147,10 +141,10 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
147141
Returns:
148142
EvaluationRow or None if conversion fails
149143
"""
150-
# TODO: move this logic into an adapter in llm_judge.py. langfuse.py should just return traces
151144
try:
152145
# Get observations (generations, spans) from the trace
153146
observations_response = self.client.api.observations.get_many(trace_id=trace.id, limit=100)
147+
# print(observations_response)
154148
observations = (
155149
observations_response.data if hasattr(observations_response, "data") else list(observations_response)
156150
)
@@ -406,7 +400,6 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
406400
"trace_id": trace.id,
407401
"trace_name": getattr(trace, "name", None),
408402
"trace_tags": getattr(trace, "tags", []),
409-
"langfuse_project_id": self.project_id,
410403
}
411404

412405
# Add trace metadata if available
@@ -418,9 +411,6 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
418411
"session_id": getattr(trace, "session_id", None),
419412
"user_id": getattr(trace, "user_id", None),
420413
"timestamp": getattr(trace, "timestamp", None),
421-
"langfuse_trace_url": (
422-
f"{self.client.host}/project/{self.project_id}/traces/{trace.id}" if self.project_id else None
423-
),
424414
}
425415

426416
return InputMetadata(
@@ -497,26 +487,7 @@ def _extract_tools(self, observations: List[Any], trace: Any = None) -> Optional
497487
return tools if tools else None
498488

499489

500-
def create_langfuse_adapter(
501-
public_key: str,
502-
secret_key: str,
503-
host: str = "https://cloud.langfuse.com",
504-
project_id: Optional[str] = None,
505-
) -> LangfuseAdapter:
506-
"""Factory function to create a Langfuse adapter.
490+
def create_langfuse_adapter() -> LangfuseAdapter:
491+
"""Factory function to create a Langfuse adapter."""
507492

508-
Args:
509-
public_key: Langfuse public key
510-
secret_key: Langfuse secret key
511-
host: Langfuse host URL
512-
project_id: Optional project ID
513-
514-
Returns:
515-
LangfuseAdapter instance
516-
"""
517-
return LangfuseAdapter(
518-
public_key=public_key,
519-
secret_key=secret_key,
520-
host=host,
521-
project_id=project_id,
522-
)
493+
return LangfuseAdapter()

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
parse_ep_num_runs,
5959
parse_ep_passed_threshold,
6060
rollout_processor_with_retry,
61-
split_multi_turn_rows,
6261
)
6362

6463
from ..common_utils import load_jsonl
@@ -85,7 +84,7 @@ def evaluation_test(
8584
steps: int = 30,
8685
mode: EvaluationTestMode = "pointwise",
8786
combine_datasets: bool = True,
88-
split_multi_turn: bool = False,
87+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None,
8988
logger: DatasetLogger | None = None,
9089
exception_handler_config: ExceptionHandlerConfig | None = None,
9190
) -> Callable[[TestFunction], TestFunction]:
@@ -152,9 +151,9 @@ def evaluation_test(
152151
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
153152
"groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo).
154153
"all" applies test function to the whole dataset.
155-
split_multi_turn: If True, splits multi-turn conversations into individual evaluation rows
156-
for each assistant response. Each row will contain the conversation context up to that point
157-
and the assistant's response as ground truth. Useful for Arena-Hard-Auto style evaluations.
154+
preprocess_fn: Optional preprocessing function that takes a list of EvaluationRow objects
155+
and returns a modified list. Useful for transformations like splitting multi-turn conversations,
156+
filtering data, or other preprocessing steps before rollout execution.
158157
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
159158
exception_handler_config: Configuration for exception handling and backoff retry logic.
160159
If not provided, a default configuration will be used with common retryable exceptions.
@@ -249,8 +248,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
249248
else:
250249
raise ValueError("No input dataset, input messages, or input rows provided")
251250

252-
if split_multi_turn:
253-
data = split_multi_turn_rows(data)
251+
if preprocess_fn:
252+
data = preprocess_fn(data)
254253

255254
for row in data:
256255
# generate a stable row_id for each row

eval_protocol/pytest/utils.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -352,42 +352,3 @@ def add_cost_metrics(row: EvaluationRow) -> None:
352352
output_cost=output_cost,
353353
total_cost=total_cost,
354354
)
355-
356-
357-
def split_multi_turn_rows(data: list[EvaluationRow]) -> list[EvaluationRow]:
358-
"""
359-
Split multi-turn conversation rows into individual evaluation rows for each assistant message.
360-
361-
Args:
362-
data: List of EvaluationRow objects
363-
364-
Returns:
365-
List of expanded EvaluationRow objects, one for each assistant message
366-
"""
367-
expanded_rows = []
368-
369-
for row in data:
370-
messages = row.messages
371-
tools = row.tools
372-
input_metadata = row.input_metadata
373-
374-
assistant_positions = []
375-
for i, message in enumerate(messages):
376-
if message.role == "assistant":
377-
assistant_positions.append(i)
378-
379-
# Create separate evaluation rows on each assistant message (where the comparison model will respond)
380-
for assistant_pos in assistant_positions:
381-
messages_before_assistant = messages[:assistant_pos]
382-
ground_truth_message = messages[assistant_pos].content
383-
384-
expanded_rows.append(
385-
EvaluationRow(
386-
messages=messages_before_assistant,
387-
tools=tools,
388-
input_metadata=input_metadata,
389-
ground_truth=ground_truth_message,
390-
)
391-
)
392-
393-
return expanded_rows

0 commit comments

Comments
 (0)