Skip to content

Commit 1d07878

Browse files
authored
add finish reason (#421)
* add finish reason * remove * add test * fix * add * fix
1 parent 2d0e9e1 commit 1d07878

File tree

5 files changed

+96
-5
lines changed

5 files changed

+96
-5
lines changed

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import logging
99
import requests
1010
from datetime import datetime
11-
from typing import Any, Dict, List, Optional, Protocol
11+
import ast
12+
import json
1213
import os
14+
from typing import Any, Dict, List, Optional, Protocol
1315

1416
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1517
from .base import BaseAdapter
@@ -44,6 +46,43 @@ def __call__(
4446
...
4547

4648

49+
def extract_openai_response(observations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
50+
"""Attempt to extract and parse attributes from raw_gen_ai_request observation. This only works when stored in OTEL format.
51+
52+
Args:
53+
observations: List of observation dictionaries from the trace
54+
55+
Returns:
56+
Dict with all attributes parsed. Or None if not found.
57+
"""
58+
for obs in observations:
59+
if obs.get("name") == "raw_gen_ai_request" and obs.get("type") == "SPAN":
60+
metadata = obs.get("metadata") or {}
61+
attributes = metadata.get("attributes") or {}
62+
63+
result: Dict[str, Any] = {}
64+
65+
for key, value in attributes.items():
66+
# Try to parse stringified objects (could be Python repr or JSON)
67+
if isinstance(value, str) and value.startswith(("[", "{")):
68+
try:
69+
result[key] = ast.literal_eval(value)
70+
except Exception as e:
71+
logger.debug("Failed to parse %s with ast.literal_eval: %s", key, e)
72+
try:
73+
result[key] = json.loads(value)
74+
except Exception as e:
75+
logger.debug("Failed to parse %s with json.loads: %s", key, e)
76+
result[key] = value
77+
else:
78+
result[key] = value
79+
80+
if result:
81+
return result
82+
83+
return None
84+
85+
4786
def convert_trace_dict_to_evaluation_row(
4887
trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None
4988
) -> Optional[EvaluationRow]:
@@ -96,6 +135,14 @@ def convert_trace_dict_to_evaluation_row(
96135
):
97136
break # Break early if we've found all the metadata we need
98137

138+
observations = trace.get("observations") or []
139+
# We can only extract when stored in OTEL format.
140+
openai_response = extract_openai_response(observations)
141+
if openai_response:
142+
choices = openai_response.get("llm.openai.choices")
143+
if choices and len(choices) > 0:
144+
execution_metadata.finish_reason = choices[0].get("finish_reason")
145+
99146
return EvaluationRow(
100147
messages=messages,
101148
tools=tools,
@@ -160,7 +207,7 @@ def extract_messages_from_trace_dict(
160207
# Fallback: use the last GENERATION observation which typically contains full chat history
161208
if not messages:
162209
try:
163-
all_observations = trace.get("observations", [])
210+
all_observations = trace.get("observations") or []
164211
gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"]
165212
if gens:
166213
gens.sort(key=lambda x: x.get("start_time", ""))
@@ -186,7 +233,7 @@ def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) ->
186233
The final generation dictionary, or None if not found
187234
"""
188235
# Get all observations from the trace
189-
all_observations = trace.get("observations", [])
236+
all_observations = trace.get("observations") or []
190237

191238
# Find a span with the given name that has generation children
192239
parent_span = None

eval_protocol/proxy/proxy_core/langfuse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]:
5050
"input": getattr(obs, "input", None),
5151
"output": getattr(obs, "output", None),
5252
"parent_observation_id": getattr(obs, "parent_observation_id", None),
53+
"metadata": getattr(obs, "metadata", None),
5354
}
5455
for obs in getattr(trace_full, "observations", [])
5556
]

eval_protocol/reward_function.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .models import EvaluateResult, MetricResult
1313
from .typed_interface import reward_function
1414

15-
logging.basicConfig(level=logging.INFO)
1615
logger = logging.getLogger(__name__)
1716

1817
T = TypeVar("T", bound=Callable[..., EvaluateResult])

tests/remote_server/remote_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
app = FastAPI()
1515

16+
# Configure logging for the remote server (required for INFO-level logs to be emitted)
17+
logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")
18+
1619
# Attach Fireworks tracing handler to root logger
1720
fireworks_handler = FireworksTracingHttpHandler()
1821
logging.getLogger().addHandler(fireworks_handler)

tests/remote_server/test_remote_fireworks.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
22

3+
import logging
34
import subprocess
45
import socket
56
import time
@@ -19,10 +20,23 @@
1920
ROLLOUT_IDS = set()
2021

2122

23+
class StatusLogCaptureHandler(logging.Handler):
24+
"""Custom handler to capture status log messages."""
25+
26+
def __init__(self):
27+
super().__init__()
28+
self.status_100_messages: List[str] = []
29+
30+
def emit(self, record):
31+
msg = record.getMessage() # Use getMessage(), not .message attribute
32+
if "Found Fireworks log" in msg and "with status code 100" in msg:
33+
self.status_100_messages.append(msg)
34+
35+
2236
@pytest.fixture(autouse=True)
2337
def check_rollout_coverage(monkeypatch):
2438
"""
25-
Ensure we attempted to fetch remote traces for each rollout.
39+
Ensure we attempted to fetch remote traces for each rollout and received status logs.
2640
2741
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
2842
and tracks rollout_ids passed through its DataLoaderConfig.
@@ -37,9 +51,32 @@ def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
3751
return original_loader(config)
3852

3953
monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
54+
55+
# Add custom handler to capture status logs
56+
status_handler = StatusLogCaptureHandler()
57+
status_handler.setLevel(logging.INFO)
58+
rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor")
59+
rrp_logger.addHandler(status_handler)
60+
# Ensure the logger level allows INFO messages through
61+
original_level = rrp_logger.level
62+
rrp_logger.setLevel(logging.INFO)
63+
4064
yield
65+
66+
# Cleanup handler and restore level
67+
rrp_logger.removeHandler(status_handler)
68+
rrp_logger.setLevel(original_level)
69+
70+
# After test completes, verify we saw status logs for all 3 rollouts
4171
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
4272

73+
# Check that we received "Found Fireworks log ... with status code 100" for each rollout
74+
assert len(status_handler.status_100_messages) == 3, (
75+
f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. "
76+
f"This means the status logs from the remote server were not received. "
77+
f"Messages captured: {status_handler.status_100_messages}"
78+
)
79+
4380

4481
def find_available_port() -> int:
4582
"""Find an available port on localhost"""
@@ -141,4 +178,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
141178
assert "data_loader_type" in row.input_metadata.dataset_info
142179
assert "data_loader_num_rows" in row.input_metadata.dataset_info
143180

181+
assert row.execution_metadata.finish_reason == "stop", (
182+
f"Expected finish_reason='stop', got {row.execution_metadata.finish_reason}"
183+
)
184+
144185
return row

0 commit comments

Comments
 (0)