Skip to content

Commit 97072d4

Browse files
author
Dylan Huang
committed
part 3
1 parent b17cf90 commit 97072d4

File tree

4 files changed

+255
-246
lines changed

4 files changed

+255
-246
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import asyncio
2-
import configparser
32
import functools
43
import inspect
54
import json
65
import math
76
import os
87
import pathlib
9-
import requests
108
import statistics
119
import time
1210
from collections import defaultdict
13-
from pathlib import Path
1411
from typing import Any, Callable
1512

1613
import pytest
@@ -29,6 +26,7 @@
2926
Message,
3027
Status,
3128
)
29+
from eval_protocol.pytest.handle_persist_flow import handle_persist_flow
3230
from eval_protocol.pytest.parameterize import pytest_parametrize
3331
from eval_protocol.pytest.validate_signature import validate_signature
3432
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
@@ -67,42 +65,6 @@
6765

6866
from ..common_utils import load_jsonl
6967

70-
from pytest import StashKey
71-
from typing_extensions import Literal
72-
73-
74-
EXPERIMENT_LINKS_STASH_KEY = StashKey[list]()
75-
76-
77-
def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]):
78-
"""Store experiment link in pytest session stash."""
79-
try:
80-
import sys
81-
82-
# Walk up the call stack to find the pytest session
83-
session = None
84-
frame = sys._getframe()
85-
while frame:
86-
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"):
87-
session = frame.f_locals["session"]
88-
break
89-
frame = frame.f_back
90-
91-
if session is not None:
92-
global EXPERIMENT_LINKS_STASH_KEY
93-
94-
if EXPERIMENT_LINKS_STASH_KEY not in session.stash:
95-
session.stash[EXPERIMENT_LINKS_STASH_KEY] = []
96-
97-
session.stash[EXPERIMENT_LINKS_STASH_KEY].append(
98-
{"experiment_id": experiment_id, "job_link": job_link, "status": status}
99-
)
100-
else:
101-
pass
102-
103-
except Exception as e:
104-
pass
105-
10668

10769
def postprocess(
10870
all_results: list[list[EvaluationRow]],
@@ -254,193 +216,7 @@ def postprocess(
254216
# Do not fail evaluation if summary writing fails
255217
pass
256218

257-
try:
258-
# Default is to save and upload experiment JSONL files, unless explicitly disabled
259-
should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1"
260-
261-
if should_save_and_upload:
262-
current_run_rows = [item for sublist in all_results for item in sublist]
263-
if current_run_rows:
264-
experiments: Dict[str, List[EvaluationRow]] = defaultdict(list)
265-
for row in current_run_rows:
266-
if row.execution_metadata and row.execution_metadata.experiment_id:
267-
experiments[row.execution_metadata.experiment_id].append(row)
268-
269-
exp_dir = pathlib.Path("experiment_results")
270-
exp_dir.mkdir(parents=True, exist_ok=True)
271-
272-
# Create one JSONL file per experiment_id
273-
for experiment_id, exp_rows in experiments.items():
274-
if not experiment_id or not exp_rows:
275-
continue
276-
277-
# Generate dataset name (sanitize for Fireworks API compatibility)
278-
# API requires: lowercase a-z, 0-9, and hyphen (-) only
279-
safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower()
280-
safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower()
281-
dataset_name = f"{safe_test_func_name}-{safe_experiment_id}"
282-
283-
if len(dataset_name) > 63:
284-
dataset_name = dataset_name[:63]
285-
286-
exp_file = exp_dir / f"{experiment_id}.jsonl"
287-
with open(exp_file, "w", encoding="utf-8") as f:
288-
for row in exp_rows:
289-
row_data = row.model_dump(exclude_none=True, mode="json")
290-
291-
if row.evaluation_result:
292-
row_data["evals"] = {"score": row.evaluation_result.score}
293-
294-
row_data["eval_details"] = {
295-
"score": row.evaluation_result.score,
296-
"is_score_valid": row.evaluation_result.is_score_valid,
297-
"reason": row.evaluation_result.reason or "",
298-
"metrics": {
299-
name: metric.model_dump() if metric else {}
300-
for name, metric in (row.evaluation_result.metrics or {}).items()
301-
},
302-
}
303-
else:
304-
# Default values if no evaluation result
305-
row_data["evals"] = {"score": 0}
306-
row_data["eval_details"] = {
307-
"score": 0,
308-
"is_score_valid": True,
309-
"reason": "No evaluation result",
310-
"metrics": {},
311-
}
312-
313-
json.dump(row_data, f, ensure_ascii=False)
314-
f.write("\n")
315-
316-
def get_auth_value(key):
317-
"""Get auth value from config file or environment."""
318-
try:
319-
config_path = Path.home() / ".fireworks" / "auth.ini"
320-
if config_path.exists():
321-
config = configparser.ConfigParser()
322-
config.read(config_path)
323-
for section in ["DEFAULT", "auth"]:
324-
if config.has_section(section) and config.has_option(section, key):
325-
return config.get(section, key)
326-
except Exception:
327-
pass
328-
return os.getenv(key)
329-
330-
fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
331-
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")
332-
333-
if not fireworks_api_key and not fireworks_account_id:
334-
_store_experiment_link(
335-
experiment_id,
336-
"No Fireworks API key AND account ID found",
337-
"failure",
338-
)
339-
continue
340-
elif not fireworks_api_key:
341-
_store_experiment_link(
342-
experiment_id,
343-
"No Fireworks API key found",
344-
"failure",
345-
)
346-
continue
347-
elif not fireworks_account_id:
348-
_store_experiment_link(
349-
experiment_id,
350-
"No Fireworks account ID found",
351-
"failure",
352-
)
353-
continue
354-
355-
headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"}
356-
357-
# Make dataset first
358-
dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets"
359-
360-
dataset_payload = {
361-
"dataset": {
362-
"displayName": dataset_name,
363-
"evalProtocol": {},
364-
"format": "FORMAT_UNSPECIFIED",
365-
"exampleCount": f"{len(exp_rows)}",
366-
},
367-
"datasetId": dataset_name,
368-
}
369-
370-
dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers)
371-
372-
# Skip if dataset creation failed
373-
if dataset_response.status_code not in [200, 201]:
374-
_store_experiment_link(
375-
experiment_id,
376-
f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}",
377-
"failure",
378-
)
379-
continue
380-
381-
dataset_data = dataset_response.json()
382-
dataset_id = dataset_data.get("datasetId", dataset_name)
383-
384-
# Upload the JSONL file content
385-
upload_url = (
386-
f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
387-
)
388-
upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"}
389-
390-
with open(exp_file, "rb") as f:
391-
files = {"file": f}
392-
upload_response = requests.post(upload_url, files=files, headers=upload_headers)
393-
394-
# Skip if upload failed
395-
if upload_response.status_code not in [200, 201]:
396-
_store_experiment_link(
397-
experiment_id,
398-
f"File upload failed: {upload_response.status_code} {upload_response.text}",
399-
"failure",
400-
)
401-
continue
402-
403-
# Create evaluation job (optional - don't skip experiment if this fails)
404-
eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs"
405-
# Truncate job ID to fit 63 character limit
406-
job_id_base = f"{dataset_name}-job"
407-
if len(job_id_base) > 63:
408-
# Keep the "-job" suffix and truncate the dataset_name part
409-
max_dataset_name_len = 63 - 4 # 4 = len("-job")
410-
truncated_dataset_name = dataset_name[:max_dataset_name_len]
411-
job_id_base = f"{truncated_dataset_name}-job"
412-
413-
eval_job_payload = {
414-
"evaluationJobId": job_id_base,
415-
"evaluationJob": {
416-
"evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy",
417-
"inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy",
418-
"outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}",
419-
},
420-
}
421-
422-
eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers)
423-
424-
if eval_response.status_code in [200, 201]:
425-
eval_job_data = eval_response.json()
426-
job_id = eval_job_data.get("evaluationJobId", job_id_base)
427-
428-
_store_experiment_link(
429-
experiment_id,
430-
f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}",
431-
"success",
432-
)
433-
else:
434-
_store_experiment_link(
435-
experiment_id,
436-
f"Job creation failed: {eval_response.status_code} {eval_response.text}",
437-
"failure",
438-
)
439-
440-
except Exception as e:
441-
# Do not fail evaluation if experiment JSONL writing fails
442-
print(f"Warning: Failed to persist results: {e}")
443-
pass
219+
handle_persist_flow(all_results, test_func_name)
444220

445221
# Check threshold after logging
446222
if threshold is not None and not passed:
@@ -566,26 +342,15 @@ def decorator(
566342
validate_signature(sig, mode, completion_params)
567343

568344
# Calculate all possible combinations of parameters
569-
if mode == "groupwise":
570-
combinations = generate_parameter_combinations(
571-
input_dataset,
572-
completion_params,
573-
input_messages,
574-
input_rows,
575-
evaluation_test_kwargs,
576-
max_dataset_rows,
577-
combine_datasets,
578-
)
579-
else:
580-
combinations = generate_parameter_combinations(
581-
input_dataset,
582-
completion_params,
583-
input_messages,
584-
input_rows,
585-
evaluation_test_kwargs,
586-
max_dataset_rows,
587-
combine_datasets,
588-
)
345+
combinations = generate_parameter_combinations(
346+
input_dataset,
347+
completion_params,
348+
input_messages,
349+
input_rows,
350+
evaluation_test_kwargs,
351+
max_dataset_rows,
352+
combine_datasets,
353+
)
589354
if len(combinations) == 0:
590355
raise ValueError(
591356
"No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."

0 commit comments

Comments
 (0)