Skip to content

Commit 8e78dba

Browse files
committed
fix logging sink
1 parent 2b64a3f commit 8e78dba

File tree

7 files changed

+263
-18
lines changed

7 files changed

+263
-18
lines changed

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,30 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
273273
if not tags:
274274
raise ValueError("At least one tag is required to fetch logs")
275275

276-
url = f"{self.base_url}/logs"
277276
headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"}
278277
params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"}
279278

280-
try:
281-
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
282-
response.raise_for_status()
283-
data = response.json() or {}
284-
except requests.exceptions.RequestException as e:
285-
logger.error("Failed to fetch logs from Fireworks /logs: %s", str(e))
279+
# Try /logs first, fall back to /v1/logs if not found
280+
urls_to_try = [f"{self.base_url}/logs", f"{self.base_url}/v1/logs"]
281+
data: Dict[str, Any] = {}
282+
last_error: Optional[str] = None
283+
for url in urls_to_try:
284+
try:
285+
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
286+
if response.status_code == 404:
287+
# Try next variant
288+
last_error = f"404 for {url}"
289+
continue
290+
response.raise_for_status()
291+
data = response.json() or {}
292+
break
293+
except requests.exceptions.RequestException as e:
294+
last_error = str(e)
295+
continue
296+
else:
297+
# All attempts failed
298+
if last_error:
299+
logger.error("Failed to fetch logs from Fireworks (tried %s): %s", urls_to_try, last_error)
286300
return []
287301

288302
entries: List[Dict[str, Any]] = data.get("entries", []) or []

eval_protocol/cli_commands/logs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ def logs_command(args):
3232

3333
# Setup backend configs
3434
elasticsearch_config = None
35-
fireworks_base_url = os.environ.get("FW_TRACING_GATEWAY_BASE_URL") or "https://tracing.fireworks.ai"
35+
# Prefer explicit FW_TRACING_GATEWAY_BASE_URL, then GATEWAY_URL from env (remote validation),
36+
# finally default to public tracing.fireworks.ai
37+
fireworks_base_url = (
38+
os.environ.get("FW_TRACING_GATEWAY_BASE_URL")
39+
or os.environ.get("GATEWAY_URL")
40+
or "https://tracing.fireworks.ai"
41+
)
3642
try:
3743
if not use_fireworks:
3844
if getattr(args, "use_env_elasticsearch_config", False):

eval_protocol/log_utils/fireworks_tracing_http_handler.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,42 @@ def emit(self, record: logging.LogRecord) -> None:
3434
if not rollout_id:
3535
return
3636
payload = self._build_payload(record, rollout_id)
37-
url = f"{self.gateway_base_url.rstrip('/')}/logs"
37+
base = self.gateway_base_url.rstrip("/")
38+
url = f"{base}/logs"
39+
# Optional debug prints to aid local diagnostics
40+
if os.environ.get("EP_DEBUG") == "true":
41+
try:
42+
tags_val = payload.get("tags")
43+
tags_len = len(tags_val) if isinstance(tags_val, list) else 0
44+
msg_val = payload.get("message")
45+
msg_preview = msg_val[:80] if isinstance(msg_val, str) else msg_val
46+
print(f"[FW_LOG] POST {url} rollout_id={rollout_id} tags={tags_len} msg={msg_preview}")
47+
except Exception:
48+
pass
3849
with self._lock:
39-
self._session.post(url, json=payload, timeout=5)
50+
resp = self._session.post(url, json=payload, timeout=5)
51+
if os.environ.get("EP_DEBUG") == "true":
52+
try:
53+
print(f"[FW_LOG] resp={resp.status_code}")
54+
except Exception:
55+
pass
56+
# Fallback to /v1/logs if /logs is not found
57+
if resp is not None and getattr(resp, "status_code", None) == 404:
58+
alt = f"{base}/v1/logs"
59+
if os.environ.get("EP_DEBUG") == "true":
60+
try:
61+
tags_val = payload.get("tags")
62+
tags_len = len(tags_val) if isinstance(tags_val, list) else 0
63+
print(f"[FW_LOG] RETRY POST {alt} rollout_id={rollout_id} tags={tags_len}")
64+
except Exception:
65+
pass
66+
with self._lock:
67+
resp2 = self._session.post(alt, json=payload, timeout=5)
68+
if os.environ.get("EP_DEBUG") == "true":
69+
try:
70+
print(f"[FW_LOG] retry resp={resp2.status_code}")
71+
except Exception:
72+
pass
4073
except Exception:
4174
# Avoid raising exceptions from logging
4275
self.handleError(record)

eval_protocol/log_utils/init.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def init_external_logging_from_env() -> None:
4141

4242
# Fireworks tracing: prefer if FIREWORKS_API_KEY is present; default base URL if not provided
4343
fw_key = _get_env("FIREWORKS_API_KEY")
44-
fw_url = _get_env("FW_TRACING_GATEWAY_BASE_URL") or "https://tracing.fireworks.ai"
44+
# Allow remote validation gateway to act as tracing base when provided
45+
fw_url = _get_env("FW_TRACING_GATEWAY_BASE_URL") or _get_env("GATEWAY_URL") or "https://tracing.fireworks.ai"
4546
if fw_key and "FireworksTracingHttpHandler" not in existing_handler_types:
4647
fw_handler = FireworksTracingHttpHandler(gateway_base_url=fw_url)
4748
fw_handler.setLevel(logging.INFO)

eval_protocol/proxy/proxy_core/redis_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from typing import Set
6+
from typing import Set, cast
77
import redis
88

99
logger = logging.getLogger(__name__)
@@ -40,7 +40,16 @@ def get_insertion_ids(redis_client: redis.Redis, rollout_id: str) -> Set[str]:
4040
Set of insertion_id strings, empty set if none found or on error
4141
"""
4242
try:
43-
insertion_ids = redis_client.smembers(rollout_id)
43+
raw = redis_client.smembers(rollout_id)
44+
# Typing in redis stubs may be Awaitable[Set[Any]] | Set[Any]; at runtime this is a Set[bytes]
45+
raw_ids = cast(Set[object], raw)
46+
# Normalize to set[str]
47+
insertion_ids: Set[str] = set()
48+
for b in raw_ids:
49+
try:
50+
insertion_ids.add(b.decode("utf-8") if isinstance(b, (bytes, bytearray)) else cast(str, b))
51+
except Exception:
52+
continue
4453
logger.debug(f"Found {len(insertion_ids)} expected insertion_ids for rollout {rollout_id}")
4554
return insertion_ids
4655
except Exception as e:

scripts/validate_remote.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
import sys
3+
import time
4+
import requests
5+
6+
7+
def require_env(var_name: str) -> str:
8+
value = os.getenv(var_name)
9+
if not value:
10+
print(f"Missing required env var: {var_name}", file=sys.stderr)
11+
sys.exit(1)
12+
return value
13+
14+
15+
def require_logs_endpoints(base_url: str) -> None:
16+
try:
17+
r = requests.get(f"{base_url}/openapi.json", timeout=30)
18+
if not r.ok:
19+
print("OpenAPI schema unavailable", file=sys.stderr)
20+
sys.exit(1)
21+
paths = r.json().get("paths", {})
22+
ok = any(p.startswith("/logs") or p.startswith("/v1/logs") for p in paths.keys())
23+
if not ok:
24+
print("/logs endpoints not present on deployment", file=sys.stderr)
25+
sys.exit(1)
26+
except Exception as e:
27+
print(f"Failed to check OpenAPI: {e}", file=sys.stderr)
28+
sys.exit(1)
29+
30+
31+
def post_chat_completion(base_url: str, api_key: str, rollout_id: str) -> None:
32+
headers = {"Authorization": f"Bearer {api_key}"}
33+
now = int(time.time())
34+
url = (
35+
f"{base_url}/rollout_id/{rollout_id}/"
36+
f"invocation_id/inv{now}/"
37+
f"experiment_id/remote-validate/"
38+
f"run_id/run-1/"
39+
f"row_id/row-1/"
40+
f"chat/completions"
41+
)
42+
body = {
43+
"model": "fireworks_ai/accounts/fireworks/models/llama-v3p1-8b-instruct",
44+
"messages": [{"role": "user", "content": "Say 'ok' if you can read this."}],
45+
"temperature": 0.1,
46+
}
47+
r = requests.post(url, headers=headers, json=body, timeout=60)
48+
if r.status_code != 200:
49+
print(f"Chat completion failed: {r.status_code} {r.text[:500]}", file=sys.stderr)
50+
sys.exit(1)
51+
print("chat: ok")
52+
53+
54+
def wait_for_traces(base_url: str, api_key: str, rollout_id: str, max_attempts: int = 8) -> None:
55+
headers = {"Authorization": f"Bearer {api_key}"}
56+
params = {
57+
"tags": [f"rollout_id:{rollout_id}"],
58+
"limit": 10,
59+
"hours_back": 6,
60+
}
61+
url = f"{base_url}/traces"
62+
for attempt in range(1, max_attempts + 1):
63+
r = requests.get(url, headers=headers, params=params, timeout=30)
64+
if r.status_code == 200:
65+
data = r.json()
66+
total = int(data.get("total_traces") or 0)
67+
print(f"traces: ok total_traces={total}")
68+
if total > 0:
69+
return
70+
elif r.status_code != 404 and r.status_code != 401:
71+
print(f"Traces fetch failed: {r.status_code} {r.text[:500]}", file=sys.stderr)
72+
sys.exit(1)
73+
sleep_s = min(2 ** (attempt - 1), 10)
74+
time.sleep(sleep_s)
75+
print("Traces not available after retries (indexing delay?)", file=sys.stderr)
76+
sys.exit(1)
77+
78+
79+
def validate_logs_endpoints(base_url: str, rollout_id: str) -> None:
80+
require_logs_endpoints(base_url)
81+
82+
# Ingest a structured log
83+
payload = {
84+
"program": "eval_protocol",
85+
"status": "completed",
86+
"message": "Remote validation run finished",
87+
"tags": [f"rollout_id:{rollout_id}", "experiment_id:remote", "run_id:test"],
88+
"metadata": {"dataset": "AIME"},
89+
"extras": {"num_examples": 3},
90+
}
91+
r = requests.post(f"{base_url}/logs", json=payload, timeout=30)
92+
if r.status_code != 200:
93+
print(f"logs ingest failed: {r.status_code} {r.text[:500]}", file=sys.stderr)
94+
sys.exit(1)
95+
print("logs ingest: ok")
96+
97+
# Retrieve logs (retry for indexing)
98+
params = {
99+
"tags": [f"rollout_id:{rollout_id}"],
100+
"program": "eval_protocol",
101+
"hours_back": 1,
102+
"limit": 10,
103+
}
104+
total = 0
105+
for attempt in range(1, 12):
106+
rr = requests.get(f"{base_url}/logs", params=params, timeout=30)
107+
if rr.status_code == 200:
108+
data = rr.json()
109+
total = int(data.get("total_entries") or 0)
110+
if total > 0:
111+
print(f"logs fetch: ok total_entries={total}")
112+
break
113+
sleep_s = min(2 ** (attempt - 1), 10)
114+
time.sleep(sleep_s)
115+
if total == 0:
116+
print("logs fetch: no entries found within retry window", file=sys.stderr)
117+
sys.exit(1)
118+
119+
120+
def main():
121+
base_url = require_env("GATEWAY_URL")
122+
api_key = require_env("FIREWORKS_API_KEY")
123+
rollout_id = f"r{int(time.time())}"
124+
125+
print(f"Gateway: {base_url}")
126+
print(f"Rollout: rollout_id:{rollout_id}")
127+
128+
post_chat_completion(base_url, api_key, rollout_id)
129+
wait_for_traces(base_url, api_key, rollout_id)
130+
validate_logs_endpoints(base_url, rollout_id)
131+
132+
print("remote validation: SUCCESS")
133+
134+
135+
if __name__ == "__main__":
136+
main()

scripts/verify_logging_locally.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,30 @@ def _now_rollout_id() -> str:
1515
return f"verify-{int(time.time())}"
1616

1717

18+
def _detect_gateway_base_url() -> str:
19+
# Prefer explicit FW_TRACING_GATEWAY_BASE_URL, else GATEWAY_URL, else public default
20+
return os.getenv("FW_TRACING_GATEWAY_BASE_URL") or os.getenv("GATEWAY_URL") or "https://tracing.fireworks.ai"
21+
22+
23+
def _detect_logs_endpoint(base_url: str) -> str:
24+
# Inspect OpenAPI and choose the correct logs endpoint
25+
try:
26+
import requests
27+
28+
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=5)
29+
if r.ok:
30+
paths = (r.json() or {}).get("paths", {})
31+
if any(p.startswith("/v1/logs") for p in paths.keys()):
32+
return "/v1/logs"
33+
if any(p.startswith("/logs") for p in paths.keys()):
34+
return "/logs"
35+
except Exception:
36+
pass
37+
return "/logs"
38+
39+
1840
def verify_fireworks(rollout_id: str) -> int:
19-
base_url = os.getenv("FW_TRACING_GATEWAY_BASE_URL") or "https://tracing.fireworks.ai"
41+
base_url = _detect_gateway_base_url()
2042
api_key = os.getenv("FIREWORKS_API_KEY")
2143
if not api_key:
2244
print("FIREWORKS_API_KEY not set; cannot verify Fireworks")
@@ -26,6 +48,20 @@ def verify_fireworks(rollout_id: str) -> int:
2648
root = logging.getLogger()
2749
root.setLevel(logging.INFO)
2850
init_external_logging_from_env()
51+
# Detect and use the correct logs endpoint
52+
logs_ep = _detect_logs_endpoint(base_url)
53+
# Print handler info for diagnostics
54+
handlers = [type(h).__name__ for h in root.handlers]
55+
print(
56+
json.dumps(
57+
{
58+
"gateway_url": base_url,
59+
"logs_endpoint": logs_ep,
60+
"root_handlers": handlers,
61+
}
62+
)
63+
)
64+
2965
logger = logging.getLogger("ep.verify.fireworks")
3066
for i in range(2):
3167
logger.info(
@@ -47,12 +83,22 @@ def verify_fireworks(rollout_id: str) -> int:
4783
"limit": 50,
4884
"hours_back": 6,
4985
}
50-
url = f"{base_url.rstrip('/')}/logs"
86+
candidate_eps = [logs_ep, "/v1/logs" if logs_ep != "/v1/logs" else "/logs"]
5187
for _ in range(20):
5288
try:
53-
r = requests.get(url, headers=headers, params=params, timeout=15)
54-
r.raise_for_status()
55-
data: Dict[str, Any] = r.json() or {}
89+
data: Dict[str, Any] = {}
90+
last_err: str | None = None
91+
for ep in candidate_eps:
92+
url = f"{base_url.rstrip('/')}{ep}"
93+
r = requests.get(url, headers=headers, params=params, timeout=15)
94+
if r.status_code == 404:
95+
last_err = f"404 for {ep}"
96+
continue
97+
r.raise_for_status()
98+
data = r.json() or {}
99+
break
100+
else:
101+
raise Exception(last_err or "all endpoints failed")
56102
entries: List[Dict[str, Any]] = data.get("entries", []) or []
57103
matched = [e for e in entries if any(t == f"rollout_id:{rollout_id}" for t in e.get("tags", []))]
58104
if matched:

0 commit comments

Comments
 (0)