Skip to content

Commit 55c0aea

Browse files
committed
persist flow
1 parent d951083 commit 55c0aea

File tree

2 files changed

+252
-18
lines changed

2 files changed

+252
-18
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 210 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
import asyncio
2+
import configparser
23
import copy
4+
import functools
35
import inspect
46
import json
57
import math
68
import os
79
import pathlib
810
import re
11+
import requests
912
import statistics
1013
import time
1114
from dataclasses import replace
1215
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1316
from collections import defaultdict
17+
from pathlib import Path
1418
import hashlib
1519
import ast
1620
from mcp.types import Completion
1721
import pytest
1822

23+
1924
from eval_protocol.dataset_logger import default_logger
2025
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
2126
from eval_protocol.human_id import generate_id, num_combinations
@@ -65,6 +70,41 @@
6570

6671
from ..common_utils import load_jsonl
6772

73+
from pytest import StashKey
74+
75+
76+
EXPERIMENT_LINKS_STASH_KEY = StashKey[list]()
77+
78+
79+
def _store_experiment_link(experiment_id: str, job_link: str, status: str):
80+
"""Store experiment link in pytest session stash."""
81+
try:
82+
import sys
83+
84+
# Walk up the call stack to find the pytest session
85+
session = None
86+
frame = sys._getframe()
87+
while frame:
88+
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"):
89+
session = frame.f_locals["session"]
90+
break
91+
frame = frame.f_back
92+
93+
if session is not None:
94+
global EXPERIMENT_LINKS_STASH_KEY
95+
96+
if EXPERIMENT_LINKS_STASH_KEY not in session.stash:
97+
session.stash[EXPERIMENT_LINKS_STASH_KEY] = []
98+
99+
session.stash[EXPERIMENT_LINKS_STASH_KEY].append(
100+
{"experiment_id": experiment_id, "job_link": job_link, "status": status}
101+
)
102+
else:
103+
pass
104+
105+
except Exception as e:
106+
pass
107+
68108

69109
def postprocess(
70110
all_results: List[List[EvaluationRow]],
@@ -213,22 +253,176 @@ def postprocess(
213253
# Do not fail evaluation if summary writing fails
214254
pass
215255

216-
# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
217-
# try:
218-
# if active_logger is not None:
219-
# rows = active_logger.read()
220-
# # Write to a .jsonl file alongside the summary file
221-
# jsonl_path = "logs.jsonl"
222-
# import json
223-
224-
# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
225-
# for row in rows:
226-
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
227-
# f_jsonl.write("\n")
228-
# except Exception as e:
229-
# # Do not fail evaluation if log writing fails
230-
# print(e)
231-
# pass
256+
try:
257+
# Default is to save and upload experiment JSONL files, unless explicitly disabled
258+
should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1"
259+
260+
if should_save_and_upload:
261+
current_run_rows = [item for sublist in all_results for item in sublist]
262+
if current_run_rows:
263+
experiments: Dict[str, List[EvaluationRow]] = defaultdict(list)
264+
for row in current_run_rows:
265+
if row.execution_metadata and row.execution_metadata.experiment_id:
266+
experiments[row.execution_metadata.experiment_id].append(row)
267+
268+
exp_dir = pathlib.Path("experiment_results")
269+
exp_dir.mkdir(parents=True, exist_ok=True)
270+
271+
# Create one JSONL file per experiment_id
272+
for experiment_id, exp_rows in experiments.items():
273+
if not experiment_id or not exp_rows:
274+
continue
275+
276+
# Generate dataset name (sanitize for Fireworks API compatibility)
277+
# API requires: lowercase a-z, 0-9, and hyphen (-) only
278+
safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower()
279+
safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower()
280+
dataset_name = f"{safe_test_func_name}-{safe_experiment_id}"
281+
282+
exp_file = exp_dir / f"{experiment_id}.jsonl"
283+
with open(exp_file, "w", encoding="utf-8") as f:
284+
for row in exp_rows:
285+
row_data = row.model_dump(exclude_none=True, mode="json")
286+
287+
if row.evaluation_result:
288+
row_data["evals"] = {"score": row.evaluation_result.score}
289+
290+
row_data["eval_details"] = {
291+
"score": row.evaluation_result.score,
292+
"is_score_valid": row.evaluation_result.is_score_valid,
293+
"reason": row.evaluation_result.reason or "",
294+
"metrics": {
295+
name: metric.model_dump() if metric else {}
296+
for name, metric in (row.evaluation_result.metrics or {}).items()
297+
},
298+
}
299+
else:
300+
# Default values if no evaluation result
301+
row_data["evals"] = {"score": 0}
302+
row_data["eval_details"] = {
303+
"score": 0,
304+
"is_score_valid": False,
305+
"reason": "No evaluation result",
306+
"metrics": {},
307+
}
308+
309+
json.dump(row_data, f, ensure_ascii=False)
310+
f.write("\n")
311+
312+
def get_auth_value(key):
313+
"""Get auth value from config file or environment."""
314+
try:
315+
config_path = Path.home() / ".fireworks" / "auth.ini"
316+
if config_path.exists():
317+
config = configparser.ConfigParser()
318+
config.read(config_path)
319+
for section in ["DEFAULT", "auth"]:
320+
if config.has_section(section) and config.has_option(section, key):
321+
return config.get(section, key)
322+
except Exception:
323+
pass
324+
return os.getenv(key)
325+
326+
fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
327+
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")
328+
329+
if fireworks_api_key and fireworks_account_id:
330+
headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"}
331+
332+
# Make dataset first
333+
dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets"
334+
335+
dataset_payload = {
336+
"dataset": {
337+
"displayName": dataset_name,
338+
"evalProtocol": {},
339+
"format": "FORMAT_UNSPECIFIED",
340+
},
341+
"datasetId": dataset_name,
342+
}
343+
344+
dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers)
345+
346+
# Skip if dataset creation failed
347+
if dataset_response.status_code not in [200, 201]:
348+
_store_experiment_link(
349+
experiment_id,
350+
f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}",
351+
"failure",
352+
)
353+
continue
354+
355+
dataset_data = dataset_response.json()
356+
dataset_id = dataset_data.get("datasetId", dataset_name)
357+
358+
# Upload the JSONL file content
359+
upload_url = (
360+
f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
361+
)
362+
upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"}
363+
364+
with open(exp_file, "rb") as f:
365+
files = {"file": f}
366+
upload_response = requests.post(upload_url, files=files, headers=upload_headers)
367+
368+
# Skip if upload failed
369+
if upload_response.status_code not in [200, 201]:
370+
_store_experiment_link(
371+
experiment_id,
372+
f"File upload failed: {upload_response.status_code} {upload_response.text}",
373+
"failure",
374+
)
375+
continue
376+
377+
# Create evaluation job (optional - don't skip experiment if this fails)
378+
eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs"
379+
# Truncate job ID to fit 63 character limit
380+
job_id_base = f"{dataset_name}-job"
381+
if len(job_id_base) > 63:
382+
# Keep the "-job" suffix and truncate the dataset_name part
383+
max_dataset_name_len = 63 - 4 # 4 = len("-job")
384+
truncated_dataset_name = dataset_name[:max_dataset_name_len]
385+
job_id_base = f"{truncated_dataset_name}-job"
386+
387+
eval_job_payload = {
388+
"evaluationJobId": job_id_base,
389+
"evaluationJob": {
390+
"evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy",
391+
"inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy",
392+
"outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}",
393+
},
394+
}
395+
396+
eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers)
397+
398+
if eval_response.status_code in [200, 201]:
399+
eval_job_data = eval_response.json()
400+
job_id = eval_job_data.get("evaluationJobId", job_id_base)
401+
402+
_store_experiment_link(
403+
experiment_id,
404+
f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}",
405+
"success",
406+
)
407+
else:
408+
_store_experiment_link(
409+
experiment_id,
410+
f"Job creation failed: {eval_response.status_code} {eval_response.text}",
411+
"failure",
412+
)
413+
414+
else:
415+
# Store failure for missing credentials for all experiments
416+
for experiment_id, exp_rows in experiments.items():
417+
if experiment_id and exp_rows:
418+
_store_experiment_link(
419+
experiment_id, "No Fireworks API key or account ID found", "failure"
420+
)
421+
422+
except Exception as e:
423+
# Do not fail evaluation if experiment JSONL writing fails
424+
print(f"Warning: Failed to persist results: {e}")
425+
pass
232426

233427
# Check threshold after logging
234428
if threshold is not None and not passed:
@@ -812,7 +1006,6 @@ def create_dual_mode_wrapper() -> Callable:
8121006
Returns:
8131007
A callable that can handle both pytest test execution and direct function calls
8141008
"""
815-
import asyncio
8161009

8171010
# Check if the test function is async
8181011
is_async = asyncio.iscoroutinefunction(test_func)
@@ -859,7 +1052,6 @@ async def dual_mode_wrapper(*args, **kwargs):
8591052
}
8601053

8611054
# Copy all attributes from the pytest wrapper to our dual mode wrapper
862-
import functools
8631055

8641056
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)
8651057

eval_protocol/pytest/plugin.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import logging
1616
import os
1717
from typing import Optional
18+
import sys
19+
from pytest import StashKey
1820

1921

2022
def pytest_addoption(parser) -> None:
@@ -87,6 +89,15 @@ def pytest_addoption(parser) -> None:
8789
"Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts."
8890
),
8991
)
92+
group.addoption(
93+
"--ep-no-upload",
94+
action="store_true",
95+
default=False,
96+
help=(
97+
"Disable saving and uploading of detailed experiment JSON files to Fireworks. "
98+
"Default: false (experiment JSONs are saved and uploaded by default)."
99+
),
100+
)
90101

91102

92103
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -161,6 +172,9 @@ def pytest_configure(config) -> None:
161172
if fail_on_max_retry is not None:
162173
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
163174

175+
if config.getoption("--ep-no-upload"):
176+
os.environ["EP_NO_UPLOAD"] = "1"
177+
164178
# Allow ad-hoc overrides of input params via CLI flags
165179
try:
166180
import json as _json
@@ -198,3 +212,31 @@ def pytest_configure(config) -> None:
198212
except Exception:
199213
# best effort, do not crash pytest session
200214
pass
215+
216+
217+
def pytest_sessionfinish(session, exitstatus):
218+
"""Print all collected Fireworks experiment links from pytest stash."""
219+
try:
220+
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY
221+
222+
# Get links from pytest stash using shared key
223+
links = []
224+
225+
if EXPERIMENT_LINKS_STASH_KEY in session.stash:
226+
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]
227+
228+
if links:
229+
print("\n" + "=" * 80, file=sys.__stderr__)
230+
print("🔥 FIREWORKS EXPERIMENT LINKS", file=sys.__stderr__)
231+
print("=" * 80, file=sys.__stderr__)
232+
233+
for link in links:
234+
if link["status"] == "success":
235+
print(f"🔗 Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
236+
else:
237+
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
238+
239+
print("=" * 80, file=sys.__stderr__)
240+
sys.__stderr__.flush()
241+
except Exception as e:
242+
pass

0 commit comments

Comments
 (0)