Skip to content

Commit 58dd9d2

Browse files
authored
Persisting Flow onto Fireworks (#122)
* persist flow * Update * add Literal type
1 parent 5d64cda commit 58dd9d2

File tree

2 files changed

+257
-18
lines changed

2 files changed

+257
-18
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 215 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
import asyncio
2+
import configparser
23
import copy
4+
import functools
35
import inspect
46
import json
57
import math
68
import os
79
import pathlib
810
import re
11+
import requests
912
import statistics
1013
import time
1114
from dataclasses import replace
1215
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1316
from collections import defaultdict
17+
from pathlib import Path
1418
import hashlib
1519
import ast
1620
from mcp.types import Completion
1721
import pytest
1822

23+
1924
from eval_protocol.dataset_logger import default_logger
2025
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
2126
from eval_protocol.human_id import generate_id, num_combinations
@@ -66,6 +71,42 @@
6671

6772
from ..common_utils import load_jsonl
6873

74+
from pytest import StashKey
75+
from typing_extensions import Literal
76+
77+
78+
EXPERIMENT_LINKS_STASH_KEY = StashKey[list]()
79+
80+
81+
def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]):
82+
"""Store experiment link in pytest session stash."""
83+
try:
84+
import sys
85+
86+
# Walk up the call stack to find the pytest session
87+
session = None
88+
frame = sys._getframe()
89+
while frame:
90+
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"):
91+
session = frame.f_locals["session"]
92+
break
93+
frame = frame.f_back
94+
95+
if session is not None:
96+
global EXPERIMENT_LINKS_STASH_KEY
97+
98+
if EXPERIMENT_LINKS_STASH_KEY not in session.stash:
99+
session.stash[EXPERIMENT_LINKS_STASH_KEY] = []
100+
101+
session.stash[EXPERIMENT_LINKS_STASH_KEY].append(
102+
{"experiment_id": experiment_id, "job_link": job_link, "status": status}
103+
)
104+
else:
105+
pass
106+
107+
except Exception as e:
108+
pass
109+
69110

70111
def postprocess(
71112
all_results: List[List[EvaluationRow]],
@@ -214,22 +255,180 @@ def postprocess(
214255
# Do not fail evaluation if summary writing fails
215256
pass
216257

217-
# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
218-
# try:
219-
# if active_logger is not None:
220-
# rows = active_logger.read()
221-
# # Write to a .jsonl file alongside the summary file
222-
# jsonl_path = "logs.jsonl"
223-
# import json
224-
225-
# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
226-
# for row in rows:
227-
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
228-
# f_jsonl.write("\n")
229-
# except Exception as e:
230-
# # Do not fail evaluation if log writing fails
231-
# print(e)
232-
# pass
258+
try:
259+
# Default is to save and upload experiment JSONL files, unless explicitly disabled
260+
should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1"
261+
262+
if should_save_and_upload:
263+
current_run_rows = [item for sublist in all_results for item in sublist]
264+
if current_run_rows:
265+
experiments: Dict[str, List[EvaluationRow]] = defaultdict(list)
266+
for row in current_run_rows:
267+
if row.execution_metadata and row.execution_metadata.experiment_id:
268+
experiments[row.execution_metadata.experiment_id].append(row)
269+
270+
exp_dir = pathlib.Path("experiment_results")
271+
exp_dir.mkdir(parents=True, exist_ok=True)
272+
273+
# Create one JSONL file per experiment_id
274+
for experiment_id, exp_rows in experiments.items():
275+
if not experiment_id or not exp_rows:
276+
continue
277+
278+
# Generate dataset name (sanitize for Fireworks API compatibility)
279+
# API requires: lowercase a-z, 0-9, and hyphen (-) only
280+
safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower()
281+
safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower()
282+
dataset_name = f"{safe_test_func_name}-{safe_experiment_id}"
283+
284+
if len(dataset_name) > 63:
285+
dataset_name = dataset_name[:63]
286+
287+
exp_file = exp_dir / f"{experiment_id}.jsonl"
288+
with open(exp_file, "w", encoding="utf-8") as f:
289+
for row in exp_rows:
290+
row_data = row.model_dump(exclude_none=True, mode="json")
291+
292+
if row.evaluation_result:
293+
row_data["evals"] = {"score": row.evaluation_result.score}
294+
295+
row_data["eval_details"] = {
296+
"score": row.evaluation_result.score,
297+
"is_score_valid": row.evaluation_result.is_score_valid,
298+
"reason": row.evaluation_result.reason or "",
299+
"metrics": {
300+
name: metric.model_dump() if metric else {}
301+
for name, metric in (row.evaluation_result.metrics or {}).items()
302+
},
303+
}
304+
else:
305+
# Default values if no evaluation result
306+
row_data["evals"] = {"score": 0}
307+
row_data["eval_details"] = {
308+
"score": 0,
309+
"is_score_valid": True,
310+
"reason": "No evaluation result",
311+
"metrics": {},
312+
}
313+
314+
json.dump(row_data, f, ensure_ascii=False)
315+
f.write("\n")
316+
317+
def get_auth_value(key):
318+
"""Get auth value from config file or environment."""
319+
try:
320+
config_path = Path.home() / ".fireworks" / "auth.ini"
321+
if config_path.exists():
322+
config = configparser.ConfigParser()
323+
config.read(config_path)
324+
for section in ["DEFAULT", "auth"]:
325+
if config.has_section(section) and config.has_option(section, key):
326+
return config.get(section, key)
327+
except Exception:
328+
pass
329+
return os.getenv(key)
330+
331+
fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
332+
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")
333+
334+
if fireworks_api_key and fireworks_account_id:
335+
headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"}
336+
337+
# Make dataset first
338+
dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets"
339+
340+
dataset_payload = {
341+
"dataset": {
342+
"displayName": dataset_name,
343+
"evalProtocol": {},
344+
"format": "FORMAT_UNSPECIFIED",
345+
"exampleCount": f"{len(exp_rows)}",
346+
},
347+
"datasetId": dataset_name,
348+
}
349+
350+
dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers)
351+
352+
# Skip if dataset creation failed
353+
if dataset_response.status_code not in [200, 201]:
354+
_store_experiment_link(
355+
experiment_id,
356+
f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}",
357+
"failure",
358+
)
359+
continue
360+
361+
dataset_data = dataset_response.json()
362+
dataset_id = dataset_data.get("datasetId", dataset_name)
363+
364+
# Upload the JSONL file content
365+
upload_url = (
366+
f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
367+
)
368+
upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"}
369+
370+
with open(exp_file, "rb") as f:
371+
files = {"file": f}
372+
upload_response = requests.post(upload_url, files=files, headers=upload_headers)
373+
374+
# Skip if upload failed
375+
if upload_response.status_code not in [200, 201]:
376+
_store_experiment_link(
377+
experiment_id,
378+
f"File upload failed: {upload_response.status_code} {upload_response.text}",
379+
"failure",
380+
)
381+
continue
382+
383+
# Create evaluation job (optional - don't skip experiment if this fails)
384+
eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs"
385+
# Truncate job ID to fit 63 character limit
386+
job_id_base = f"{dataset_name}-job"
387+
if len(job_id_base) > 63:
388+
# Keep the "-job" suffix and truncate the dataset_name part
389+
max_dataset_name_len = 63 - 4 # 4 = len("-job")
390+
truncated_dataset_name = dataset_name[:max_dataset_name_len]
391+
job_id_base = f"{truncated_dataset_name}-job"
392+
393+
eval_job_payload = {
394+
"evaluationJobId": job_id_base,
395+
"evaluationJob": {
396+
"evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy",
397+
"inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy",
398+
"outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}",
399+
},
400+
}
401+
402+
eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers)
403+
404+
if eval_response.status_code in [200, 201]:
405+
eval_job_data = eval_response.json()
406+
job_id = eval_job_data.get("evaluationJobId", job_id_base)
407+
408+
_store_experiment_link(
409+
experiment_id,
410+
f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}",
411+
"success",
412+
)
413+
else:
414+
_store_experiment_link(
415+
experiment_id,
416+
f"Job creation failed: {eval_response.status_code} {eval_response.text}",
417+
"failure",
418+
)
419+
420+
else:
421+
# Store failure for missing credentials for all experiments
422+
for experiment_id, exp_rows in experiments.items():
423+
if experiment_id and exp_rows:
424+
_store_experiment_link(
425+
experiment_id, "No Fireworks API key or account ID found", "failure"
426+
)
427+
428+
except Exception as e:
429+
# Do not fail evaluation if experiment JSONL writing fails
430+
print(f"Warning: Failed to persist results: {e}")
431+
pass
233432

234433
# Check threshold after logging
235434
if threshold is not None and not passed:
@@ -814,7 +1013,6 @@ def create_dual_mode_wrapper() -> Callable:
8141013
Returns:
8151014
A callable that can handle both pytest test execution and direct function calls
8161015
"""
817-
import asyncio
8181016

8191017
# Check if the test function is async
8201018
is_async = asyncio.iscoroutinefunction(test_func)
@@ -861,7 +1059,6 @@ async def dual_mode_wrapper(*args, **kwargs):
8611059
}
8621060

8631061
# Copy all attributes from the pytest wrapper to our dual mode wrapper
864-
import functools
8651062

8661063
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)
8671064

eval_protocol/pytest/plugin.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from typing import Optional
1818
import json
1919
import pathlib
20+
import sys
21+
from pytest import StashKey
2022

2123

2224
def pytest_addoption(parser) -> None:
@@ -104,6 +106,15 @@ def pytest_addoption(parser) -> None:
104106
"Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0."
105107
),
106108
)
109+
group.addoption(
110+
"--ep-no-upload",
111+
action="store_true",
112+
default=False,
113+
help=(
114+
"Disable saving and uploading of detailed experiment JSON files to Fireworks. "
115+
"Default: false (experiment JSONs are saved and uploaded by default)."
116+
),
117+
)
107118

108119

109120
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -229,6 +240,9 @@ def pytest_configure(config) -> None:
229240
if threshold_env is not None:
230241
os.environ["EP_PASSED_THRESHOLD"] = threshold_env
231242

243+
if config.getoption("--ep-no-upload"):
244+
os.environ["EP_NO_UPLOAD"] = "1"
245+
232246
# Allow ad-hoc overrides of input params via CLI flags
233247
try:
234248
merged: dict = {}
@@ -263,3 +277,31 @@ def pytest_configure(config) -> None:
263277
except Exception:
264278
# best effort, do not crash pytest session
265279
pass
280+
281+
282+
def pytest_sessionfinish(session, exitstatus):
283+
"""Print all collected Fireworks experiment links from pytest stash."""
284+
try:
285+
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY
286+
287+
# Get links from pytest stash using shared key
288+
links = []
289+
290+
if EXPERIMENT_LINKS_STASH_KEY in session.stash:
291+
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]
292+
293+
if links:
294+
print("\n" + "=" * 80, file=sys.__stderr__)
295+
print("🔥 FIREWORKS EXPERIMENT LINKS", file=sys.__stderr__)
296+
print("=" * 80, file=sys.__stderr__)
297+
298+
for link in links:
299+
if link["status"] == "success":
300+
print(f"🔗 Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
301+
else:
302+
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
303+
304+
print("=" * 80, file=sys.__stderr__)
305+
sys.__stderr__.flush()
306+
except Exception as e:
307+
pass

0 commit comments

Comments
 (0)