Skip to content

Commit 3aa96ae

Browse files
xzrderekDylan Huang
andauthored
Fireworks Tracing (#252)
* Fireworks Tracing * update path * various changes * add dataloaderconfig * use get * address comments --------- Co-authored-by: Dylan Huang <dhuang@fireworks.ai>
1 parent cf55313 commit 3aa96ae

File tree

8 files changed

+609
-35
lines changed

8 files changed

+609
-35
lines changed

eval_protocol/adapters/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Available adapters:
77
- BaseAdapter: Abstract base class for all adapters
88
- LangfuseAdapter: Pull data from Langfuse deployments
9+
- FireworksTracingAdapter: Pull data from Langfuse via Fireworks tracing proxy
910
- HuggingFaceAdapter: Load datasets from HuggingFace Hub
1011
- BigQueryAdapter: Query data from Google BigQuery
1112
- TRL integration (legacy)
@@ -24,6 +25,10 @@
2425
except ImportError:
2526
pass
2627

28+
from .fireworks_tracing import FireworksTracingAdapter
29+
30+
__all__.extend(["FireworksTracingAdapter"])
31+
2732
try:
2833
from .huggingface import (
2934
HuggingFaceAdapter,
Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
"""Fireworks Tracing adapter for Eval Protocol.
2+
3+
This adapter uses the Fireworks tracing proxy at tracing.fireworks.ai
4+
to pull data from Langfuse deployments with simplified retry logic handling.
5+
"""
6+
7+
from __future__ import annotations
8+
import logging
9+
import requests
10+
from datetime import datetime
11+
from typing import Any, Dict, List, Optional, Protocol
12+
13+
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
14+
from .base import BaseAdapter
15+
from .utils import extract_messages_from_data
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class TraceDictConverter(Protocol):
21+
"""Protocol for custom trace dictionary-to-EvaluationRow converter functions.
22+
23+
A converter function should take a trace dictionary along with processing
24+
options and return an EvaluationRow or None to skip the trace.
25+
"""
26+
27+
def __call__(
28+
self,
29+
trace: Dict[str, Any],
30+
include_tool_calls: bool,
31+
span_name: Optional[str],
32+
) -> Optional[EvaluationRow]:
33+
"""Convert a trace dictionary to an EvaluationRow.
34+
35+
Args:
36+
trace: The trace dictionary to convert
37+
include_tool_calls: Whether to include tool calling information
38+
span_name: Optional span name to extract messages from
39+
40+
Returns:
41+
EvaluationRow or None if the trace should be skipped
42+
"""
43+
...
44+
45+
46+
def convert_trace_dict_to_evaluation_row(
47+
trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None
48+
) -> Optional[EvaluationRow]:
49+
"""Convert a trace dictionary (from proxy API) to EvaluationRow format.
50+
51+
Args:
52+
trace: Trace dictionary from Fireworks proxy API
53+
include_tool_calls: Whether to include tool calling information
54+
span_name: If provided, extract messages from generations within this named span
55+
56+
Returns:
57+
EvaluationRow or None if conversion fails
58+
"""
59+
try:
60+
# Extract messages from trace input and output
61+
messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name)
62+
63+
# Extract tools if available
64+
tools = None
65+
if include_tool_calls and isinstance(trace.get("input"), dict) and "tools" in trace["input"]:
66+
tools = trace["input"]["tools"]
67+
68+
if not messages:
69+
return None
70+
71+
execution_metadata = ExecutionMetadata()
72+
row_id = None
73+
74+
# Extract metadata from tags
75+
tags = trace.get("tags", [])
76+
if tags:
77+
for tag in tags:
78+
if tag.startswith("invocation_id:"):
79+
execution_metadata.invocation_id = tag.split(":", 1)[1]
80+
elif tag.startswith("experiment_id:"):
81+
execution_metadata.experiment_id = tag.split(":", 1)[1]
82+
elif tag.startswith("rollout_id:"):
83+
execution_metadata.rollout_id = tag.split(":", 1)[1]
84+
elif tag.startswith("run_id:"):
85+
execution_metadata.run_id = tag.split(":", 1)[1]
86+
elif tag.startswith("row_id:"):
87+
row_id = tag.split(":", 1)[1]
88+
89+
if (
90+
execution_metadata.invocation_id
91+
and execution_metadata.experiment_id
92+
and execution_metadata.rollout_id
93+
and execution_metadata.run_id
94+
and row_id
95+
):
96+
break # Break early if we've found all the metadata we need
97+
98+
return EvaluationRow(
99+
messages=messages,
100+
tools=tools,
101+
input_metadata=InputMetadata(
102+
row_id=row_id,
103+
session_data={
104+
"langfuse_trace_id": trace.get("id"), # Store the trace ID here
105+
},
106+
),
107+
execution_metadata=execution_metadata,
108+
)
109+
110+
except (AttributeError, ValueError, KeyError) as e:
111+
logger.error("Error converting trace %s: %s", trace.get("id"), e)
112+
return None
113+
114+
115+
def extract_messages_from_trace_dict(
116+
trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None
117+
) -> List[Message]:
118+
"""Extract messages from trace dictionary.
119+
120+
Args:
121+
trace: Trace dictionary from proxy API
122+
include_tool_calls: Whether to include tool calling information
123+
span_name: If provided, extract messages from generations within this named span
124+
125+
Returns:
126+
List of Message objects
127+
"""
128+
messages = []
129+
130+
if span_name: # Look for a generation tied to a span name
131+
try:
132+
# Find the final generation in the named span
133+
gen = get_final_generation_in_span_dict(trace, span_name)
134+
if not gen:
135+
return messages
136+
137+
# Extract messages from generation input and output
138+
if gen.get("input"):
139+
messages.extend(extract_messages_from_data(gen["input"], include_tool_calls))
140+
if gen.get("output"):
141+
messages.extend(extract_messages_from_data(gen["output"], include_tool_calls))
142+
143+
return messages
144+
145+
except Exception as e:
146+
logger.error("Failed to extract messages from span '%s' in trace %s: %s", span_name, trace.get("id"), e)
147+
return messages
148+
149+
else:
150+
try:
151+
# Extract messages from trace input and output
152+
if trace.get("input"):
153+
messages.extend(extract_messages_from_data(trace["input"], include_tool_calls))
154+
if trace.get("output"):
155+
messages.extend(extract_messages_from_data(trace["output"], include_tool_calls))
156+
except (AttributeError, ValueError, KeyError) as e:
157+
logger.warning("Error processing trace %s: %s", trace.get("id"), e)
158+
159+
# Fallback: use the last GENERATION observation which typically contains full chat history
160+
if not messages:
161+
try:
162+
all_observations = trace.get("observations", [])
163+
gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"]
164+
if gens:
165+
gens.sort(key=lambda x: x.get("start_time", ""))
166+
last_gen = gens[-1]
167+
if last_gen.get("input"):
168+
messages.extend(extract_messages_from_data(last_gen["input"], include_tool_calls))
169+
if last_gen.get("output"):
170+
messages.extend(extract_messages_from_data(last_gen["output"], include_tool_calls))
171+
except Exception as e:
172+
logger.warning("Failed to extract from last generation for trace %s: %s", trace.get("id"), e)
173+
174+
return messages
175+
176+
177+
def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> Optional[Dict[str, Any]]:
178+
"""Get the final generation within a named span from trace dictionary.
179+
180+
Args:
181+
trace: Trace dictionary
182+
span_name: Name of the span to search for
183+
184+
Returns:
185+
The final generation dictionary, or None if not found
186+
"""
187+
# Get all observations from the trace
188+
all_observations = trace.get("observations", [])
189+
190+
# Find a span with the given name that has generation children
191+
parent_span = None
192+
for obs in all_observations:
193+
if obs.get("name") == span_name and obs.get("type") == "SPAN":
194+
# Check if this span has generation children
195+
has_generations = any(
196+
child.get("type") == "GENERATION" and child.get("parent_observation_id") == obs.get("id")
197+
for child in all_observations
198+
)
199+
if has_generations:
200+
parent_span = obs
201+
break
202+
203+
if not parent_span:
204+
logger.warning("No span named '%s' found in trace %s", span_name, trace.get("id"))
205+
return None
206+
207+
# Find all generations within this span
208+
generations = []
209+
for obs in all_observations:
210+
if obs.get("type") == "GENERATION" and obs.get("parent_observation_id") == parent_span.get("id"):
211+
generations.append(obs)
212+
213+
if not generations:
214+
logger.warning("No generations found in span '%s' in trace %s", span_name, trace.get("id"))
215+
return None
216+
217+
# Sort generations by start time for chronological order
218+
generations.sort(key=lambda x: x.get("start_time", ""))
219+
220+
# Return the final generation (contains full message history)
221+
return generations[-1]
222+
223+
224+
class FireworksTracingAdapter(BaseAdapter):
225+
"""Adapter to pull data from Langfuse via Fireworks tracing proxy.
226+
227+
This adapter uses the Fireworks tracing proxy API which handles retry logic
228+
and rate limiting internally, simplifying the client-side implementation.
229+
230+
Examples:
231+
Basic usage (default project):
232+
>>> adapter = FireworksTracingAdapter()
233+
>>> rows = list(adapter.get_evaluation_rows(tags=["rollout_id:xyz"], limit=10))
234+
235+
With explicit project ID:
236+
>>> adapter = FireworksTracingAdapter(
237+
... project_id="your_project_id",
238+
... base_url="https://tracing.fireworks.ai"
239+
... )
240+
>>> rows = list(adapter.get_evaluation_rows(tags=["production"], limit=10))
241+
242+
Filter by specific criteria:
243+
>>> rows = list(adapter.get_evaluation_rows(
244+
... tags=["production"],
245+
... limit=50,
246+
... hours_back=24
247+
... ))
248+
"""
249+
250+
def __init__(
251+
self,
252+
project_id: Optional[str] = None,
253+
base_url: str = "https://tracing.fireworks.ai",
254+
timeout: int = 300,
255+
):
256+
"""Initialize the Fireworks Tracing adapter.
257+
258+
Args:
259+
project_id: Optional project ID. If not provided, uses the default project configured on the server.
260+
base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai)
261+
timeout: Request timeout in seconds (default: 300)
262+
"""
263+
self.project_id = project_id
264+
self.base_url = base_url.rstrip("/")
265+
self.timeout = timeout
266+
267+
def get_evaluation_rows(
268+
self,
269+
tags: List[str],
270+
limit: int = 100,
271+
sample_size: Optional[int] = None,
272+
user_id: Optional[str] = None,
273+
session_id: Optional[str] = None,
274+
name: Optional[str] = None,
275+
environment: Optional[str] = None,
276+
version: Optional[str] = None,
277+
release: Optional[str] = None,
278+
fields: Optional[str] = None,
279+
hours_back: Optional[int] = None,
280+
from_timestamp: Optional[datetime] = None,
281+
to_timestamp: Optional[datetime] = None,
282+
include_tool_calls: bool = True,
283+
sleep_between_gets: float = 2.5,
284+
max_retries: int = 3,
285+
span_name: Optional[str] = None,
286+
converter: Optional[TraceDictConverter] = None,
287+
) -> List[EvaluationRow]:
288+
"""Pull traces from Langfuse via proxy and convert to EvaluationRow format.
289+
290+
Args:
291+
tags: REQUIRED - Filter by specific tags (prevents fetching all traces).
292+
Must provide at least one tag (e.g., ['rollout_id:xyz'], ['production'])
293+
limit: Max number of trace summaries to collect via pagination
294+
sample_size: Optional number of traces to randomly sample (if None, process all)
295+
user_id: Filter by user ID
296+
session_id: Filter by session ID
297+
name: Filter by trace name
298+
environment: Filter by environment (e.g., production, staging, development)
299+
version: Filter by trace version
300+
release: Filter by trace release
301+
fields: Comma-separated list of fields to include
302+
hours_back: Filter traces from this many hours ago
303+
from_timestamp: Explicit start time (ISO format)
304+
to_timestamp: Explicit end time (ISO format)
305+
include_tool_calls: Whether to include tool calling traces
306+
sleep_between_gets: Sleep time between trace.get() calls (handled by proxy)
307+
max_retries: Maximum retries for rate limit errors (handled by proxy)
308+
span_name: If provided, extract messages from generations within this named span
309+
converter: Optional custom converter implementing TraceDictConverter protocol.
310+
If provided, this will be used instead of the default conversion logic.
311+
312+
Returns:
313+
List[EvaluationRow]: Converted evaluation rows
314+
315+
Raises:
316+
ValueError: If tags list is empty
317+
"""
318+
# Validate that tags are provided (security requirement)
319+
if not tags or len(tags) == 0:
320+
raise ValueError("At least one tag is required to fetch traces (security: prevents fetching all traces)")
321+
322+
eval_rows = []
323+
324+
# Build query parameters for GET request
325+
params = {
326+
"limit": limit,
327+
"sample_size": sample_size,
328+
"tags": tags,
329+
"user_id": user_id,
330+
"session_id": session_id,
331+
"name": name,
332+
"environment": environment,
333+
"version": version,
334+
"release": release,
335+
"fields": fields,
336+
"hours_back": hours_back,
337+
"from_timestamp": from_timestamp.isoformat() if from_timestamp else None,
338+
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
339+
"sleep_between_gets": sleep_between_gets,
340+
"max_retries": max_retries,
341+
}
342+
343+
# Remove None values
344+
params = {k: v for k, v in params.items() if v is not None}
345+
346+
# Make request to proxy
347+
if self.project_id:
348+
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
349+
else:
350+
url = f"{self.base_url}/v1/traces"
351+
352+
try:
353+
response = requests.get(url, params=params, timeout=self.timeout)
354+
response.raise_for_status()
355+
result = response.json()
356+
except requests.exceptions.RequestException as e:
357+
logger.error("Failed to fetch traces from proxy: %s", e)
358+
return eval_rows
359+
360+
# Extract traces from response
361+
traces = result.get("traces", [])
362+
363+
# Convert each trace to EvaluationRow
364+
for trace in traces:
365+
try:
366+
if converter:
367+
eval_row = converter(trace, include_tool_calls, span_name)
368+
else:
369+
eval_row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls, span_name)
370+
if eval_row:
371+
eval_rows.append(eval_row)
372+
except (AttributeError, ValueError, KeyError) as e:
373+
logger.warning("Failed to convert trace %s: %s", trace.get("id"), e)
374+
continue
375+
376+
logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows))
377+
return eval_rows

0 commit comments

Comments
 (0)