Skip to content

Commit 289abc5

Browse files
xzrderekDylan Huang
andauthored
Redis Bug Fix (#254)
* Fireworks Tracing * update path * various changes * add dataloaderconfig * use get * validated using remote_server_multi_turn.py * Fireworks Tracing * various changes * validated using remote_server_multi_turn.py --------- Co-authored-by: Dylan Huang <dhuang@fireworks.ai>
1 parent cc8666e commit 289abc5

File tree

4 files changed

+162
-17
lines changed

4 files changed

+162
-17
lines changed

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88
import logging
99
import requests
10+
import time
1011
from datetime import datetime
1112
from typing import Any, Dict, List, Optional, Protocol
1213

@@ -280,8 +281,9 @@ def get_evaluation_rows(
280281
from_timestamp: Optional[datetime] = None,
281282
to_timestamp: Optional[datetime] = None,
282283
include_tool_calls: bool = True,
283-
sleep_between_gets: float = 2.5,
284-
max_retries: int = 3,
284+
backend_sleep_between_gets: float = 0.1,
285+
backend_max_retries: int = 3,
286+
proxy_max_retries: int = 3,
285287
span_name: Optional[str] = None,
286288
converter: Optional[TraceDictConverter] = None,
287289
) -> List[EvaluationRow]:
@@ -303,8 +305,9 @@ def get_evaluation_rows(
303305
from_timestamp: Explicit start time (ISO format)
304306
to_timestamp: Explicit end time (ISO format)
305307
include_tool_calls: Whether to include tool calling traces
306-
sleep_between_gets: Sleep time between trace.get() calls (handled by proxy)
307-
max_retries: Maximum retries for rate limit errors (handled by proxy)
308+
backend_sleep_between_gets: Sleep time between backend trace fetches (passed to proxy)
309+
backend_max_retries: Maximum retries for backend operations (passed to proxy)
310+
proxy_max_retries: Maximum retries when proxy returns 404 (client-side retries with exponential backoff)
308311
span_name: If provided, extract messages from generations within this named span
309312
converter: Optional custom converter implementing TraceDictConverter protocol.
310313
If provided, this will be used instead of the default conversion logic.
@@ -336,25 +339,60 @@ def get_evaluation_rows(
336339
"hours_back": hours_back,
337340
"from_timestamp": from_timestamp.isoformat() if from_timestamp else None,
338341
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
339-
"sleep_between_gets": sleep_between_gets,
340-
"max_retries": max_retries,
342+
"sleep_between_gets": backend_sleep_between_gets,
343+
"max_retries": backend_max_retries,
341344
}
342345

343346
# Remove None values
344347
params = {k: v for k, v in params.items() if v is not None}
345348

346-
# Make request to proxy
349+
# Make request to proxy with retry logic
347350
if self.project_id:
348351
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
349352
else:
350353
url = f"{self.base_url}/v1/traces"
351354

352-
try:
353-
response = requests.get(url, params=params, timeout=self.timeout)
354-
response.raise_for_status()
355-
result = response.json()
356-
except requests.exceptions.RequestException as e:
357-
logger.error("Failed to fetch traces from proxy: %s", e)
355+
# Retry loop for handling backend indexing delays (proxy returns 404)
356+
result = None
357+
for attempt in range(proxy_max_retries):
358+
try:
359+
response = requests.get(url, params=params, timeout=self.timeout)
360+
response.raise_for_status()
361+
result = response.json()
362+
break # Success, exit retry loop
363+
except requests.exceptions.HTTPError as e:
364+
error_msg = str(e)
365+
should_retry = False
366+
367+
# Try to extract detail message from response
368+
if e.response is not None:
369+
try:
370+
error_detail = e.response.json().get("detail", "")
371+
error_msg = error_detail or e.response.text
372+
373+
# Retry on 404 if it's due to incomplete/missing traces (backend still indexing)
374+
if e.response.status_code == 404 and (
375+
"Incomplete traces" in error_detail or "No traces found" in error_detail
376+
):
377+
should_retry = True
378+
except Exception:
379+
error_msg = e.response.text
380+
381+
if should_retry and attempt < proxy_max_retries - 1:
382+
sleep_time = 2 ** (attempt + 1)
383+
logger.warning(error_msg)
384+
time.sleep(sleep_time)
385+
else:
386+
# Final retry or non-retryable error
387+
logger.error("Failed to fetch traces from proxy: %s", error_msg)
388+
return eval_rows
389+
except requests.exceptions.RequestException as e:
390+
# Non-HTTP errors (network issues, timeouts, etc.)
391+
logger.error("Failed to fetch traces from proxy: %s", str(e))
392+
return eval_rows
393+
394+
if result is None:
395+
logger.error("Failed to fetch traces after %d retries", proxy_max_retries)
358396
return eval_rows
359397

360398
# Extract traces from response

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
5858
def fetch_traces() -> List[EvaluationRow]:
5959
base_url = config.model_base_url or "https://tracing.fireworks.ai"
6060
adapter = FireworksTracingAdapter(base_url=base_url)
61-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
61+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5)
6262

6363
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
6464

@@ -188,7 +188,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
188188
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
189189

190190
final_model_base_url = model_base_url
191-
if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"):
191+
if model_base_url and (
192+
model_base_url.startswith("https://tracing.fireworks.ai")
193+
or model_base_url.startswith("http://localhost")
194+
):
192195
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta)
193196

194197
init_payload: InitRequest = InitRequest(
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
import random
3+
import threading
4+
5+
import uvicorn
6+
from fastapi import FastAPI
7+
from openai import OpenAI
8+
import logging
9+
10+
from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter
11+
12+
13+
app = FastAPI()
14+
15+
# attach handler to root logger
16+
handler = ElasticsearchDirectHttpHandler()
17+
logging.getLogger().addHandler(handler)
18+
19+
20+
@app.post("/init")
21+
def init(req: InitRequest):
22+
if req.elastic_search_config:
23+
handler.configure(req.elastic_search_config)
24+
25+
# attach rollout_id filter to logger
26+
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
27+
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
28+
29+
# Kick off worker thread that does a multi-turn chat (6 turns total)
30+
def _worker():
31+
try:
32+
if not req.messages:
33+
raise ValueError("messages is required")
34+
35+
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
36+
37+
# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
38+
# Convert Message objects to dicts for OpenAI API
39+
conversation_history = [{"role": m.role, "content": m.content} for m in req.messages]
40+
41+
follow_up_questions = [
42+
"Tell me more about that.",
43+
"What else can you share about this topic?",
44+
]
45+
46+
# First completion (turns 1-2: initial user message + assistant response)
47+
logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}")
48+
completion = client.chat.completions.create(
49+
model=req.model,
50+
messages=conversation_history, # type: ignore
51+
)
52+
assistant_message = completion.choices[0].message
53+
assistant_content = assistant_message.content or ""
54+
conversation_history.append({"role": "assistant", "content": assistant_content})
55+
logger.info(f"Turn 2 response: {assistant_content[:100]}...")
56+
57+
# Second completion (turns 3-4: follow-up user message + assistant response)
58+
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
59+
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
60+
completion = client.chat.completions.create(
61+
model=req.model,
62+
messages=conversation_history, # type: ignore
63+
)
64+
assistant_message = completion.choices[0].message
65+
assistant_content = assistant_message.content or ""
66+
conversation_history.append({"role": "assistant", "content": assistant_content})
67+
logger.info(f"Turn 4 response: {assistant_content[:100]}...")
68+
69+
# Third completion (turns 5-6: another follow-up user message + assistant response)
70+
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
71+
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
72+
completion = client.chat.completions.create(
73+
model=req.model,
74+
messages=conversation_history, # type: ignore
75+
)
76+
assistant_message = completion.choices[0].message
77+
assistant_content = assistant_message.content or ""
78+
conversation_history.append({"role": "assistant", "content": assistant_content})
79+
logger.info(f"Turn 6 response: {assistant_content[:100]}...")
80+
81+
logger.info(f"Completed 6-turn conversation with {len(conversation_history)} messages total")
82+
83+
except Exception as e:
84+
# Best-effort; mark as done even on error to unblock polling
85+
print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
86+
pass
87+
finally:
88+
logger.info(
89+
f"Rollout {req.metadata.rollout_id} completed",
90+
extra={"status": Status.rollout_finished()},
91+
)
92+
93+
t = threading.Thread(target=_worker, daemon=True)
94+
t.start()
95+
96+
97+
def main():
98+
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
99+
port = int(os.getenv("REMOTE_SERVER_PORT", "3000"))
100+
uvicorn.run(app, host=host, port=port)
101+
102+
103+
if __name__ == "__main__":
104+
main()

tests/remote_server/test_remote_fireworks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
4343

4444
base_url = config.model_base_url or "https://tracing.fireworks.ai"
4545
adapter = FireworksTracingAdapter(base_url=base_url)
46-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
46+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5)
4747

4848

4949
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
@@ -65,7 +65,7 @@ def rows() -> List[EvaluationRow]:
6565
),
6666
rollout_processor=RemoteRolloutProcessor(
6767
remote_base_url="http://127.0.0.1:3000",
68-
timeout_seconds=30,
68+
timeout_seconds=180,
6969
output_data_loader=fireworks_output_data_loader,
7070
),
7171
)

0 commit comments

Comments
 (0)