diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index ff386873..e2f2a93d 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -16,9 +16,10 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: str): try: # Default is to save and upload experiment JSONL files, unless explicitly disabled - should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1" + custom_output_dir = os.getenv("EP_OUTPUT_DIR") + should_save = os.getenv("EP_NO_UPLOAD") != "1" or custom_output_dir is not None - if should_save_and_upload: + if should_save: current_run_rows = [item for sublist in all_results for item in sublist] if current_run_rows: experiments: dict[str, list[EvaluationRow]] = defaultdict(list) @@ -27,6 +28,8 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: experiments[row.execution_metadata.experiment_id].append(row) eval_protocol_dir = find_eval_protocol_dir() + if custom_output_dir: + eval_protocol_dir = custom_output_dir exp_dir = pathlib.Path(eval_protocol_dir) / "experiment_results" exp_dir.mkdir(parents=True, exist_ok=True) @@ -81,6 +84,10 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: json.dump(row_data, f, ensure_ascii=False) f.write("\n") + should_upload = os.getenv("EP_NO_UPLOAD") != "1" + if not should_upload: + continue + def get_auth_value(key: str) -> str | None: """Get auth value from config file or environment.""" try: diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 030c367e..d0c4af4d 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -133,6 +133,11 @@ def pytest_addoption(parser) -> None: default=None, help=("If set, use this base URL for remote rollout processing. Example: http://localhost:8000"), ) + group.addoption( + "--ep-output-dir", + default=None, + help=("If set, save evaluation results to this directory in jsonl format."), + ) def _normalize_max_rows(val: Optional[str]) -> Optional[str]: @@ -258,6 +263,10 @@ def pytest_configure(config) -> None: if threshold_env is not None: os.environ["EP_PASSED_THRESHOLD"] = threshold_env + if config.getoption("--ep-output-dir"): + # set this to save eval results to the target dir in jsonl format + os.environ["EP_OUTPUT_DIR"] = config.getoption("--ep-output-dir") + if config.getoption("--ep-no-upload"): os.environ["EP_NO_UPLOAD"] = "1"