|
1 | 1 | import asyncio |
2 | | -import configparser |
3 | 2 | import functools |
4 | 3 | import inspect |
5 | 4 | import json |
6 | 5 | import math |
7 | 6 | import os |
8 | 7 | import pathlib |
9 | | -import requests |
10 | 8 | import statistics |
11 | 9 | import time |
12 | 10 | from collections import defaultdict |
13 | | -from pathlib import Path |
14 | 11 | from typing import Any, Callable |
15 | 12 |
|
16 | 13 | import pytest |
|
29 | 26 | Message, |
30 | 27 | Status, |
31 | 28 | ) |
| 29 | +from eval_protocol.pytest.handle_persist_flow import handle_persist_flow |
32 | 30 | from eval_protocol.pytest.parameterize import pytest_parametrize |
33 | 31 | from eval_protocol.pytest.validate_signature import validate_signature |
34 | 32 | from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter |
|
67 | 65 |
|
68 | 66 | from ..common_utils import load_jsonl |
69 | 67 |
|
70 | | -from pytest import StashKey |
71 | | -from typing_extensions import Literal |
72 | | - |
73 | | - |
74 | | -EXPERIMENT_LINKS_STASH_KEY = StashKey[list]() |
75 | | - |
76 | | - |
77 | | -def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]): |
78 | | - """Store experiment link in pytest session stash.""" |
79 | | - try: |
80 | | - import sys |
81 | | - |
82 | | - # Walk up the call stack to find the pytest session |
83 | | - session = None |
84 | | - frame = sys._getframe() |
85 | | - while frame: |
86 | | - if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"): |
87 | | - session = frame.f_locals["session"] |
88 | | - break |
89 | | - frame = frame.f_back |
90 | | - |
91 | | - if session is not None: |
92 | | - global EXPERIMENT_LINKS_STASH_KEY |
93 | | - |
94 | | - if EXPERIMENT_LINKS_STASH_KEY not in session.stash: |
95 | | - session.stash[EXPERIMENT_LINKS_STASH_KEY] = [] |
96 | | - |
97 | | - session.stash[EXPERIMENT_LINKS_STASH_KEY].append( |
98 | | - {"experiment_id": experiment_id, "job_link": job_link, "status": status} |
99 | | - ) |
100 | | - else: |
101 | | - pass |
102 | | - |
103 | | - except Exception as e: |
104 | | - pass |
105 | | - |
106 | 68 |
|
107 | 69 | def postprocess( |
108 | 70 | all_results: list[list[EvaluationRow]], |
@@ -254,193 +216,7 @@ def postprocess( |
254 | 216 | # Do not fail evaluation if summary writing fails |
255 | 217 | pass |
256 | 218 |
|
257 | | - try: |
258 | | - # Default is to save and upload experiment JSONL files, unless explicitly disabled |
259 | | - should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1" |
260 | | - |
261 | | - if should_save_and_upload: |
262 | | - current_run_rows = [item for sublist in all_results for item in sublist] |
263 | | - if current_run_rows: |
264 | | - experiments: Dict[str, List[EvaluationRow]] = defaultdict(list) |
265 | | - for row in current_run_rows: |
266 | | - if row.execution_metadata and row.execution_metadata.experiment_id: |
267 | | - experiments[row.execution_metadata.experiment_id].append(row) |
268 | | - |
269 | | - exp_dir = pathlib.Path("experiment_results") |
270 | | - exp_dir.mkdir(parents=True, exist_ok=True) |
271 | | - |
272 | | - # Create one JSONL file per experiment_id |
273 | | - for experiment_id, exp_rows in experiments.items(): |
274 | | - if not experiment_id or not exp_rows: |
275 | | - continue |
276 | | - |
277 | | - # Generate dataset name (sanitize for Fireworks API compatibility) |
278 | | - # API requires: lowercase a-z, 0-9, and hyphen (-) only |
279 | | - safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower() |
280 | | - safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower() |
281 | | - dataset_name = f"{safe_test_func_name}-{safe_experiment_id}" |
282 | | - |
283 | | - if len(dataset_name) > 63: |
284 | | - dataset_name = dataset_name[:63] |
285 | | - |
286 | | - exp_file = exp_dir / f"{experiment_id}.jsonl" |
287 | | - with open(exp_file, "w", encoding="utf-8") as f: |
288 | | - for row in exp_rows: |
289 | | - row_data = row.model_dump(exclude_none=True, mode="json") |
290 | | - |
291 | | - if row.evaluation_result: |
292 | | - row_data["evals"] = {"score": row.evaluation_result.score} |
293 | | - |
294 | | - row_data["eval_details"] = { |
295 | | - "score": row.evaluation_result.score, |
296 | | - "is_score_valid": row.evaluation_result.is_score_valid, |
297 | | - "reason": row.evaluation_result.reason or "", |
298 | | - "metrics": { |
299 | | - name: metric.model_dump() if metric else {} |
300 | | - for name, metric in (row.evaluation_result.metrics or {}).items() |
301 | | - }, |
302 | | - } |
303 | | - else: |
304 | | - # Default values if no evaluation result |
305 | | - row_data["evals"] = {"score": 0} |
306 | | - row_data["eval_details"] = { |
307 | | - "score": 0, |
308 | | - "is_score_valid": True, |
309 | | - "reason": "No evaluation result", |
310 | | - "metrics": {}, |
311 | | - } |
312 | | - |
313 | | - json.dump(row_data, f, ensure_ascii=False) |
314 | | - f.write("\n") |
315 | | - |
316 | | - def get_auth_value(key): |
317 | | - """Get auth value from config file or environment.""" |
318 | | - try: |
319 | | - config_path = Path.home() / ".fireworks" / "auth.ini" |
320 | | - if config_path.exists(): |
321 | | - config = configparser.ConfigParser() |
322 | | - config.read(config_path) |
323 | | - for section in ["DEFAULT", "auth"]: |
324 | | - if config.has_section(section) and config.has_option(section, key): |
325 | | - return config.get(section, key) |
326 | | - except Exception: |
327 | | - pass |
328 | | - return os.getenv(key) |
329 | | - |
330 | | - fireworks_api_key = get_auth_value("FIREWORKS_API_KEY") |
331 | | - fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID") |
332 | | - |
333 | | - if not fireworks_api_key and not fireworks_account_id: |
334 | | - _store_experiment_link( |
335 | | - experiment_id, |
336 | | - "No Fireworks API key AND account ID found", |
337 | | - "failure", |
338 | | - ) |
339 | | - continue |
340 | | - elif not fireworks_api_key: |
341 | | - _store_experiment_link( |
342 | | - experiment_id, |
343 | | - "No Fireworks API key found", |
344 | | - "failure", |
345 | | - ) |
346 | | - continue |
347 | | - elif not fireworks_account_id: |
348 | | - _store_experiment_link( |
349 | | - experiment_id, |
350 | | - "No Fireworks account ID found", |
351 | | - "failure", |
352 | | - ) |
353 | | - continue |
354 | | - |
355 | | - headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"} |
356 | | - |
357 | | - # Make dataset first |
358 | | - dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets" |
359 | | - |
360 | | - dataset_payload = { |
361 | | - "dataset": { |
362 | | - "displayName": dataset_name, |
363 | | - "evalProtocol": {}, |
364 | | - "format": "FORMAT_UNSPECIFIED", |
365 | | - "exampleCount": f"{len(exp_rows)}", |
366 | | - }, |
367 | | - "datasetId": dataset_name, |
368 | | - } |
369 | | - |
370 | | - dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers) |
371 | | - |
372 | | - # Skip if dataset creation failed |
373 | | - if dataset_response.status_code not in [200, 201]: |
374 | | - _store_experiment_link( |
375 | | - experiment_id, |
376 | | - f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}", |
377 | | - "failure", |
378 | | - ) |
379 | | - continue |
380 | | - |
381 | | - dataset_data = dataset_response.json() |
382 | | - dataset_id = dataset_data.get("datasetId", dataset_name) |
383 | | - |
384 | | - # Upload the JSONL file content |
385 | | - upload_url = ( |
386 | | - f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" |
387 | | - ) |
388 | | - upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"} |
389 | | - |
390 | | - with open(exp_file, "rb") as f: |
391 | | - files = {"file": f} |
392 | | - upload_response = requests.post(upload_url, files=files, headers=upload_headers) |
393 | | - |
394 | | - # Skip if upload failed |
395 | | - if upload_response.status_code not in [200, 201]: |
396 | | - _store_experiment_link( |
397 | | - experiment_id, |
398 | | - f"File upload failed: {upload_response.status_code} {upload_response.text}", |
399 | | - "failure", |
400 | | - ) |
401 | | - continue |
402 | | - |
403 | | - # Create evaluation job (optional - don't skip experiment if this fails) |
404 | | - eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs" |
405 | | - # Truncate job ID to fit 63 character limit |
406 | | - job_id_base = f"{dataset_name}-job" |
407 | | - if len(job_id_base) > 63: |
408 | | - # Keep the "-job" suffix and truncate the dataset_name part |
409 | | - max_dataset_name_len = 63 - 4 # 4 = len("-job") |
410 | | - truncated_dataset_name = dataset_name[:max_dataset_name_len] |
411 | | - job_id_base = f"{truncated_dataset_name}-job" |
412 | | - |
413 | | - eval_job_payload = { |
414 | | - "evaluationJobId": job_id_base, |
415 | | - "evaluationJob": { |
416 | | - "evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy", |
417 | | - "inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy", |
418 | | - "outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}", |
419 | | - }, |
420 | | - } |
421 | | - |
422 | | - eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers) |
423 | | - |
424 | | - if eval_response.status_code in [200, 201]: |
425 | | - eval_job_data = eval_response.json() |
426 | | - job_id = eval_job_data.get("evaluationJobId", job_id_base) |
427 | | - |
428 | | - _store_experiment_link( |
429 | | - experiment_id, |
430 | | - f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}", |
431 | | - "success", |
432 | | - ) |
433 | | - else: |
434 | | - _store_experiment_link( |
435 | | - experiment_id, |
436 | | - f"Job creation failed: {eval_response.status_code} {eval_response.text}", |
437 | | - "failure", |
438 | | - ) |
439 | | - |
440 | | - except Exception as e: |
441 | | - # Do not fail evaluation if experiment JSONL writing fails |
442 | | - print(f"Warning: Failed to persist results: {e}") |
443 | | - pass |
| 219 | + handle_persist_flow(all_results, test_func_name) |
444 | 220 |
|
445 | 221 | # Check threshold after logging |
446 | 222 | if threshold is not None and not passed: |
@@ -566,26 +342,15 @@ def decorator( |
566 | 342 | validate_signature(sig, mode, completion_params) |
567 | 343 |
|
568 | 344 | # Calculate all possible combinations of parameters |
569 | | - if mode == "groupwise": |
570 | | - combinations = generate_parameter_combinations( |
571 | | - input_dataset, |
572 | | - completion_params, |
573 | | - input_messages, |
574 | | - input_rows, |
575 | | - evaluation_test_kwargs, |
576 | | - max_dataset_rows, |
577 | | - combine_datasets, |
578 | | - ) |
579 | | - else: |
580 | | - combinations = generate_parameter_combinations( |
581 | | - input_dataset, |
582 | | - completion_params, |
583 | | - input_messages, |
584 | | - input_rows, |
585 | | - evaluation_test_kwargs, |
586 | | - max_dataset_rows, |
587 | | - combine_datasets, |
588 | | - ) |
| 345 | + combinations = generate_parameter_combinations( |
| 346 | + input_dataset, |
| 347 | + completion_params, |
| 348 | + input_messages, |
| 349 | + input_rows, |
| 350 | + evaluation_test_kwargs, |
| 351 | + max_dataset_rows, |
| 352 | + combine_datasets, |
| 353 | + ) |
589 | 354 | if len(combinations) == 0: |
590 | 355 | raise ValueError( |
591 | 356 | "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows." |
|
0 commit comments