diff --git a/docs/docs/evaluation.rst b/docs/docs/evaluation.rst index 8b98b15..fa60e38 100644 --- a/docs/docs/evaluation.rst +++ b/docs/docs/evaluation.rst @@ -36,15 +36,13 @@ Evaluate from a list of Python dicts: report = evaluate(predictions) print(report) -Providing your own DB and questions (skip workdir): +Using a persistent cache directory for benchmark downloads: .. code-block:: python report = evaluate( "path_to_outputs.jsonl", - questions_path="bench/questions.jsonl", - db_path="bench/sqlite_tables.db", - workdir_path=None + workdir_path="./benchmark-cache", ) Function Arguments @@ -59,11 +57,7 @@ Function Arguments * - outputs - Path to JSONL file or a list of prediction dicts (required). * - workdir_path - - Directory for automatic benchmark downloads. Ignored if both questions_path and db_path are provided. Default: "llmsql_workdir". - * - questions_path - - Optional path to benchmark questions JSONL file. - * - db_path - - Optional path to SQLite DB with evaluation tables. + - Directory used to cache downloaded benchmark files. If omitted, a temporary directory is created automatically. * - save_report - Path to save detailed JSON report. Defaults to "evaluation_results_{uuid}.json". * - show_mismatches diff --git a/docs/docs/index.rst b/docs/docs/index.rst index b2760cd..fe3654c 100644 --- a/docs/docs/index.rst +++ b/docs/docs/index.rst @@ -36,8 +36,6 @@ Example: Running your first evaluation (with transformers backend) results = inference_transformers( model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", output_file="outputs/preds_transformers.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", num_fewshots=5, batch_size=8, max_new_tokens=256, diff --git a/docs/docs/usage.rst b/docs/docs/usage.rst index b8966fe..767fad6 100644 --- a/docs/docs/usage.rst +++ b/docs/docs/usage.rst @@ -27,8 +27,7 @@ Using transformers backend. results = inference_transformers( model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", output_file="outputs/preds_transformers.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", + workdir_path="./benchmark-cache", num_fewshots=5, batch_size=8, max_new_tokens=256, @@ -58,8 +57,7 @@ Using vllm backend. results = inference_vllm( model_name="Qwen/Qwen2.5-1.5B-Instruct", output_file="outputs/preds_vllm.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", + workdir_path="./benchmark-cache", num_fewshots=5, batch_size=8, max_new_tokens=256, diff --git a/llmsql/_cli/evaluate.py b/llmsql/_cli/evaluate.py index 942e4fb..bec8c0e 100644 --- a/llmsql/_cli/evaluate.py +++ b/llmsql/_cli/evaluate.py @@ -2,7 +2,7 @@ from typing import Any from llmsql._cli.subparsers import SubCommand -from llmsql.config.config import DEFAULT_LLMSQL_VERSION, DEFAULT_WORKDIR_PATH +from llmsql.config.config import DEFAULT_LLMSQL_VERSION from llmsql.evaluation.evaluate import evaluate @@ -10,22 +10,18 @@ class Evaluate(SubCommand): """Command for LLM evaluation""" def __init__( - self, - subparsers:argparse._SubParsersAction, - *args:Any, - **kwargs:Any - )->None: + self, subparsers: argparse._SubParsersAction, *args: Any, **kwargs: Any + ) -> None: self._parser = subparsers.add_parser( "evaluate", - help = "Evaluate predictions against the LLMSQL benchmark", + help="Evaluate predictions against the LLMSQL benchmark", formatter_class=argparse.RawDescriptionHelpFormatter, ) self._add_args() - self._parser.set_defaults(func = self._execute) + self._parser.set_defaults(func=self._execute) - - def _add_args(self)->None: + def _add_args(self) -> None: """Add evaluation-specific arguments to the parser.""" self._parser.add_argument( "--outputs", @@ -36,30 +32,15 @@ def _add_args(self)->None: self._parser.add_argument( "--version", - type = str, - default = DEFAULT_LLMSQL_VERSION, - choices=["1.0","2.0"], - help = f"LLMSQL benchmark version (default:{DEFAULT_LLMSQL_VERSION})" + type=str, + default=DEFAULT_LLMSQL_VERSION, + choices=["1.0", "2.0"], + help=f"LLMSQL benchmark version (default:{DEFAULT_LLMSQL_VERSION})", ) self._parser.add_argument( "--workdir-path", - default = DEFAULT_WORKDIR_PATH, - help = f"Directory for benchmark files (default: {DEFAULT_WORKDIR_PATH})", - ) - - self._parser.add_argument( - "--questions-path", - type = str, - default = None, - help = "Manual path to benchmark questions JSON file.", - ) - - self._parser.add_argument( - "--db-path", - type=str, - default = None, - help = "Path to SQLite benchmark database.", + help="Directory for benchmark downloads. If omitted, a temporary directory is used.", ) self._parser.add_argument( @@ -82,21 +63,18 @@ def _add_args(self)->None: default=None, help="Path to save evaluation report JSON.", ) - + @staticmethod def _execute(args: argparse.Namespace) -> None: """Execute the evaluate function with parsed arguments.""" try: evaluate( outputs=args.outputs, - version = args.version, + version=args.version, workdir_path=args.workdir_path, - questions_path=args.questions_path, - db_path=args.db_path, save_report=args.save_report, show_mismatches=args.show_mismatches, max_mismatches=args.max_mismatches, ) except Exception as e: print(f"Error during evaluation: {e}") - diff --git a/llmsql/_cli/inference.py b/llmsql/_cli/inference.py index 57aaef2..e4355f7 100644 --- a/llmsql/_cli/inference.py +++ b/llmsql/_cli/inference.py @@ -56,9 +56,11 @@ def _add_args(self) -> None: def add_common_benchmark_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--version", default="2.0", choices=["1.0", "2.0"]) parser.add_argument("--output-file", default="llm_sql_predictions.jsonl") - parser.add_argument("--questions-path") - parser.add_argument("--tables-path") - parser.add_argument("--workdir-path", default="./workdir") + parser.add_argument( + "--workdir-path", + default=None, + help="Directory for benchmark downloads. If omitted, a temporary directory is used.", + ) parser.add_argument("--num-fewshots", type=int, default=5) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--seed", type=int, default=42) @@ -217,8 +219,6 @@ def _execute_transformers(args: argparse.Namespace) -> None: generation_kwargs=args.generation_kwargs, version=args.version, output_file=args.output_file, - questions_path=args.questions_path, - tables_path=args.tables_path, workdir_path=args.workdir_path, num_fewshots=args.num_fewshots, batch_size=args.batch_size, @@ -243,8 +243,6 @@ def _execute_vllm(args: argparse.Namespace) -> None: sampling_kwargs=args.sampling_kwargs, version=args.version, output_file=args.output_file, - questions_path=args.questions_path, - tables_path=args.tables_path, workdir_path=args.workdir_path, limit=args.limit, num_fewshots=args.num_fewshots, @@ -267,8 +265,6 @@ def _execute_api(args: argparse.Namespace) -> None: request_headers=args.request_headers, version=args.version, output_file=args.output_file, - questions_path=args.questions_path, - tables_path=args.tables_path, workdir_path=args.workdir_path, limit=args.limit, num_fewshots=args.num_fewshots, diff --git a/llmsql/config/config.py b/llmsql/config/config.py index 7dfb693..5654481 100644 --- a/llmsql/config/config.py +++ b/llmsql/config/config.py @@ -6,7 +6,6 @@ } DEFAULT_LLMSQL_VERSION: Literal["1.0", "2.0"] = "2.0" -DEFAULT_WORKDIR_PATH = "llmsql_workdir" def get_repo_id(version: str = DEFAULT_LLMSQL_VERSION) -> str: diff --git a/llmsql/evaluation/README.md b/llmsql/evaluation/README.md index 4f727c6..419bc3e 100644 --- a/llmsql/evaluation/README.md +++ b/llmsql/evaluation/README.md @@ -43,9 +43,7 @@ print(report) evaluate( outputs, *, - workdir_path: str | None = "llmsql_workdir", - questions_path: str | None = None, - db_path: str | None = None, + workdir_path: str | None = None, save_report: str | None = None, show_mismatches: bool = True, max_mismatches: int = 5, @@ -55,9 +53,7 @@ evaluate( | Argument | Description | | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | | `outputs` | **Required**. Either a path to a JSONL file or a list of dicts with predictions. | -| `workdir_path` | Directory for automatic download of benchmark files (ignored if both `questions_path` and `db_path` are provided). Default: `"llmsql_workdir"`. | -| `questions_path` | Optional path to benchmark questions JSONL file. | -| `db_path` | Optional path to SQLite DB with evaluation tables. | +| `workdir_path` | Directory used to cache downloaded benchmark files. If omitted, a temporary directory is created automatically. | | `save_report` | Optional path to save detailed JSON report. Defaults to `evaluation_results_{uuid}.json`. | | `show_mismatches` | Print mismatches while evaluating. Default: `True`. | | `max_mismatches` | Maximum number of mismatches to print. Default: `5`. | diff --git a/llmsql/evaluation/evaluate.py b/llmsql/evaluation/evaluate.py index 9c68a8f..1fe79c1 100644 --- a/llmsql/evaluation/evaluate.py +++ b/llmsql/evaluation/evaluate.py @@ -9,21 +9,19 @@ """ from datetime import datetime, timezone -from pathlib import Path import uuid from rich.progress import track from llmsql.config.config import ( DEFAULT_LLMSQL_VERSION, - DEFAULT_WORKDIR_PATH, get_repo_id, ) from llmsql.utils.evaluation_utils import ( connect_sqlite, - download_benchmark_file, evaluate_sample, ) +from llmsql.utils.inference_utils import _maybe_download, resolve_workdir_path from llmsql.utils.rich_utils import log_mismatch, print_summary from llmsql.utils.utils import load_jsonl, load_jsonl_dict_by_key, save_json_report @@ -32,9 +30,7 @@ def evaluate( outputs: str | list[dict[int, str | int]], *, version: str = DEFAULT_LLMSQL_VERSION, - workdir_path: str | None = DEFAULT_WORKDIR_PATH, - questions_path: str | None = None, - db_path: str | None = None, + workdir_path: str | None = None, save_report: str | None = None, show_mismatches: bool = True, max_mismatches: int = 5, @@ -45,9 +41,8 @@ def evaluate( Args: version: LLMSQL version outputs: Either a JSONL file path or a list of dicts. - workdir_path: Directory for auto-downloads (ignored if all paths provided). - questions_path: Manual path to benchmark questions JSONL. - db_path: Manual path to SQLite benchmark DB. + workdir_path: Directory to store downloaded benchmark files. If omitted, a + temporary directory is created automatically. save_report: Optional manual save path. If None → auto-generated. show_mismatches: Print mismatches while evaluating. max_mismatches: Max mismatches to print. @@ -58,39 +53,12 @@ def evaluate( # Determine input type input_mode = "jsonl_path" if isinstance(outputs, str) else "dict_list" + workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) - # --- Resolve inputs if needed --- - workdir = Path(workdir_path) if workdir_path else None - if workdir_path is not None and (questions_path is None or db_path is None): - workdir.mkdir(parents=True, exist_ok=True) # type: ignore - - if questions_path is None: - if workdir is None: - raise ValueError( - "questions_path not provided, and workdir_path disabled. " - "Enable workdir or provide questions_path explicitly." - ) - local_q = workdir / "questions.jsonl" - questions_path = ( - str(local_q) - if local_q.is_file() - else download_benchmark_file(repo_id, "questions.jsonl", workdir) - ) - - if db_path is None: - if workdir is None: - raise ValueError( - "db_path not provided, and workdir_path disabled. " - "Enable workdir or provide db_path explicitly." - ) - local_db = workdir / "sqlite_tables.db" - db_path = ( - str(local_db) - if local_db.is_file() - else download_benchmark_file(repo_id, "sqlite_tables.db", workdir) - ) + questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) + db_path = _maybe_download(repo_id, "sqlite_tables.db", workdir) # --- Load benchmark questions --- questions = load_jsonl_dict_by_key(questions_path, key="question_id") diff --git a/llmsql/inference/README.md b/llmsql/inference/README.md index 096d1b5..1aba3f7 100644 --- a/llmsql/inference/README.md +++ b/llmsql/inference/README.md @@ -36,8 +36,7 @@ from llmsql import inference_transformers results = inference_transformers( model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", output_file="outputs/preds_transformers.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", + workdir_path="./benchmark-cache", num_fewshots=5, batch_size=8, max_new_tokens=256, @@ -168,9 +167,7 @@ Runs inference using the Hugging Face `transformers` backend. | Argument | Type | Default | Description | | ------------------------------- | ------- | ------------------------- | ------------------------------------------------ | | `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. | -| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). | -| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). | -| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. | +| `workdir_path` | `str \| None` | `None` | Directory used to cache downloaded benchmark files. If omitted, a temporary directory is created automatically. | | `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). | | `batch_size` | `int` | `8` | Batch size for inference. | | `seed` | `int` | `42` | Random seed for reproducibility. | @@ -210,9 +207,7 @@ Runs inference using the [vLLM](https://github.com/vllm-project/vllm) backend fo | Argument | Type | Default | Description | | ------------------------------- | -------------- | ----------------------------- | ------------------------------------------------ | | `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. | -| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). | -| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). | -| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. | +| `workdir_path` | `str \| None` | `None` | Directory used to cache downloaded benchmark files. If omitted, a temporary directory is created automatically. | | `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). | | `batch_size` | `int` | `8` | Number of prompts per batch. | | `seed` | `int` | `42` | Random seed for reproducibility. | diff --git a/llmsql/inference/inference_api.py b/llmsql/inference/inference_api.py index c01f0f9..86f8d7f 100644 --- a/llmsql/inference/inference_api.py +++ b/llmsql/inference/inference_api.py @@ -10,7 +10,6 @@ import asyncio import os -from pathlib import Path import time from typing import Any, Literal @@ -21,11 +20,14 @@ from llmsql.config.config import ( DEFAULT_LLMSQL_VERSION, - DEFAULT_WORKDIR_PATH, get_repo_id, ) from llmsql.loggers.logging_config import log -from llmsql.utils.inference_utils import _maybe_download, _setup_seed +from llmsql.utils.inference_utils import ( + _maybe_download, + _setup_seed, + resolve_workdir_path, +) from llmsql.utils.utils import ( choose_prompt_builder, load_jsonl, @@ -176,9 +178,7 @@ def inference_api( request_headers: dict[str, str] | None = None, version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", - questions_path: str | None = None, - tables_path: str | None = None, - workdir_path: str = DEFAULT_WORKDIR_PATH, + workdir_path: str | None = None, limit: int | float | None = None, num_fewshots: int = 5, seed: int = 42, @@ -199,9 +199,8 @@ def inference_api( # Benchmark: version: LLMSQL version output_file: Path to write outputs (will be overwritten). - questions_path: Path to questions.jsonl (auto-downloads if missing). - tables_path: Path to tables.jsonl (auto-downloads if missing). - workdir_path: Directory to store downloaded data. + workdir_path: Directory to store downloaded benchmark files. If omitted, a + temporary directory is created automatically. num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Number of questions per generation batch. seed: Random seed for reproducibility. @@ -216,12 +215,11 @@ def inference_api( api_kwargs = api_kwargs or {} request_headers = request_headers or {} - workdir = Path(workdir_path) - workdir.mkdir(parents=True, exist_ok=True) + workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) - questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) - tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) + questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) + tables_path = _maybe_download(repo_id, "tables.jsonl", workdir) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) diff --git a/llmsql/inference/inference_transformers.py b/llmsql/inference/inference_transformers.py index 3272d5b..0a0b553 100644 --- a/llmsql/inference/inference_transformers.py +++ b/llmsql/inference/inference_transformers.py @@ -16,8 +16,6 @@ model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", repo_id="llmsql-bench/llmsql-2.0", output_file="outputs/preds_transformers.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", num_fewshots=5, batch_size=8, max_new_tokens=256, @@ -39,7 +37,6 @@ """ -from pathlib import Path from typing import Any, Literal from dotenv import load_dotenv @@ -49,11 +46,14 @@ from llmsql.config.config import ( DEFAULT_LLMSQL_VERSION, - DEFAULT_WORKDIR_PATH, get_repo_id, ) from llmsql.loggers.logging_config import log -from llmsql.utils.inference_utils import _maybe_download, _setup_seed +from llmsql.utils.inference_utils import ( + _maybe_download, + _setup_seed, + resolve_workdir_path, +) from llmsql.utils.utils import ( choose_prompt_builder, load_jsonl, @@ -92,9 +92,7 @@ def inference_transformers( # --- Benchmark Parameters --- version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", - questions_path: str | None = None, - tables_path: str | None = None, - workdir_path: str = DEFAULT_WORKDIR_PATH, + workdir_path: str | None = None, num_fewshots: int = 5, batch_size: int = 8, limit: int | float | None = None, @@ -137,9 +135,8 @@ def inference_transformers( # Benchmark: version: LLMSQL version output_file: Output JSONL file path for completions. - questions_path: Path to benchmark questions JSONL. - tables_path: Path to benchmark tables JSONL. - workdir_path: Working directory path. + workdir_path: Directory to store downloaded benchmark files. If omitted, a + temporary directory is created automatically. num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Batch size for inference. seed: Random seed for reproducibility. @@ -153,9 +150,6 @@ def inference_transformers( # --- Setup --- _setup_seed(seed=seed) - workdir = Path(workdir_path) - workdir.mkdir(parents=True, exist_ok=True) - model_kwargs = model_kwargs or {} tokenizer_kwargs = tokenizer_kwargs or {} generation_kwargs = generation_kwargs or {} @@ -219,10 +213,11 @@ def inference_transformers( model.eval() # --- Load necessary files --- + workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) - questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) - tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) + questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) + tables_path = _maybe_download(repo_id, "tables.jsonl", workdir) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index f661c6b..56e77b1 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -15,8 +15,6 @@ results = inference_vllm( model_name="Qwen/Qwen2.5-1.5B-Instruct", version="2.0", - output_file="outputs/predictions.jsonl", - questions_path="data/questions.jsonl", tables_path="data/tables.jsonl", num_fewshots=5, batch_size=8, @@ -41,7 +39,6 @@ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" -from pathlib import Path from typing import Any, Literal from dotenv import load_dotenv @@ -50,11 +47,14 @@ from llmsql.config.config import ( DEFAULT_LLMSQL_VERSION, - DEFAULT_WORKDIR_PATH, get_repo_id, ) from llmsql.loggers.logging_config import log -from llmsql.utils.inference_utils import _maybe_download, _setup_seed +from llmsql.utils.inference_utils import ( + _maybe_download, + _setup_seed, + resolve_workdir_path, +) from llmsql.utils.utils import ( choose_prompt_builder, load_jsonl, @@ -82,9 +82,7 @@ def inference_vllm( # === Benchmark Parameters === version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", - questions_path: str | None = None, - tables_path: str | None = None, - workdir_path: str = DEFAULT_WORKDIR_PATH, + workdir_path: str | None = None, limit: int | float | None = None, num_fewshots: int = 5, batch_size: int = 8, @@ -116,9 +114,8 @@ def inference_vllm( # Benchmark: version: LLMSQL version output_file: Path to write outputs (will be overwritten). - questions_path: Path to questions.jsonl (auto-downloads if missing). - tables_path: Path to tables.jsonl (auto-downloads if missing). - workdir_path: Directory to store downloaded data. + workdir_path: Directory to store downloaded benchmark files. If omitted, a + temporary directory is created automatically. num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Number of questions per generation batch. seed: Random seed for reproducibility. @@ -135,16 +132,15 @@ def inference_vllm( _setup_seed(seed=seed) hf_token = hf_token or os.environ.get("HF_TOKEN") - workdir = Path(workdir_path) - workdir.mkdir(parents=True, exist_ok=True) # --- load input data --- log.info("Preparing questions and tables...") + workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) - questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) - tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) + questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) + tables_path = _maybe_download(repo_id, "tables.jsonl", workdir) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) diff --git a/llmsql/utils/inference_utils.py b/llmsql/utils/inference_utils.py index 892ccd9..72c5880 100644 --- a/llmsql/utils/inference_utils.py +++ b/llmsql/utils/inference_utils.py @@ -1,21 +1,40 @@ from pathlib import Path import random +import tempfile from huggingface_hub import hf_hub_download import numpy as np import torch -from llmsql.config.config import DEFAULT_WORKDIR_PATH from llmsql.loggers.logging_config import log +def resolve_workdir_path(workdir_path: str | Path | None) -> Path: + if workdir_path is None: + resolved = Path(tempfile.mkdtemp(prefix="llmsql-")) + log.info(f"Created temporary workdir: {resolved}") + return resolved + + resolved = Path(workdir_path) + if resolved.exists() and not resolved.is_dir(): + raise ValueError( + f"workdir_path must point to a directory, got file: {resolved}" + ) + + resolved.mkdir(parents=True, exist_ok=True) + return resolved + + # --- Load benchmark data --- -def _download_file(repo_id: str, filename: str) -> str: +def _download_file( + repo_id: str, filename: str, workdir_path: str | Path | None = None +) -> str: + local_dir = resolve_workdir_path(workdir_path) path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="dataset", - local_dir=DEFAULT_WORKDIR_PATH, + local_dir=local_dir, ) assert isinstance(path, str) return path @@ -29,24 +48,27 @@ def _setup_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def _maybe_download(repo_id: str, filename: str, local_path: str | None) -> str: - if local_path is not None: - return local_path - - target_path = Path(DEFAULT_WORKDIR_PATH) / filename +def _maybe_download( + repo_id: str, filename: str, workdir_path: str | Path | None +) -> str: + target_dir = resolve_workdir_path(workdir_path) + target_path = target_dir / filename if target_path.exists(): - log.info(f"Removing existing path: {target_path}") - if target_path.is_file() or target_path.is_symlink(): - target_path.unlink() + if not target_path.is_file(): + raise ValueError( + f"Expected downloaded benchmark file path to be a file: {target_path}" + ) + log.info(f"Using cached benchmark file: {target_path}") + return str(target_path) log.info(f"Downloading {filename} from Hugging Face Hub...") local_path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="dataset", - local_dir=DEFAULT_WORKDIR_PATH, + local_dir=str(target_dir), ) log.info(f"Downloaded {filename} to: {local_path}") - return local_path + return local_path # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index 1f60dca..9f4b0d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -110,12 +110,10 @@ def mock_utils(mocker, tmp_path): mocker.patch("llmsql.evaluation.evaluate.log_mismatch") mocker.patch("llmsql.evaluation.evaluate.print_summary") - # download files + # benchmark file resolver mocker.patch( - "llmsql.evaluation.evaluate.download_benchmark_file", - side_effect=lambda repo_id, filename, local_dir: str( - Path(local_dir) / filename - ), + "llmsql.evaluation.evaluate._maybe_download", + side_effect=lambda repo_id, filename, workdir: str(Path(workdir) / filename), ) # report writer diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index 8fce11a..12f82fa 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -1,10 +1,11 @@ import json +import shutil + import pytest from llmsql import evaluate from llmsql.config.config import get_available_versions - VALID_LLMSQL_VERSIONS = [None] + get_available_versions() INVALID_LLMSQL_VERSION = "1.1" @@ -27,11 +28,11 @@ async def test_evaluate_runs_with_valid_versions( ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") + # Fake outputs.jsonl outputs_path = temp_dir / "outputs.jsonl" - outputs_path.write_text( - json.dumps({"question_id": 1, "completion": "SELECT 1"}) - ) + outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"})) # Monkeypatch exactly like reference file monkeypatch.setattr( @@ -47,8 +48,7 @@ async def test_evaluate_runs_with_valid_versions( kwargs = { "outputs": str(outputs_path), - "questions_path": str(questions_path), - "db_path": dummy_db_file, + "workdir_path": str(temp_dir), "show_mismatches": False, } @@ -58,7 +58,6 @@ async def test_evaluate_runs_with_valid_versions( evaluate(**kwargs) - @pytest.mark.asyncio async def test_evaluate_raises_with_invalid_version( monkeypatch, temp_dir, dummy_db_file @@ -74,11 +73,10 @@ async def test_evaluate_raises_with_invalid_version( } ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") outputs_path = temp_dir / "outputs.jsonl" - outputs_path.write_text( - json.dumps({"question_id": 1, "completion": "SELECT 1"}) - ) + outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"})) monkeypatch.setattr( "llmsql.utils.evaluation_utils.evaluate_sample", @@ -91,11 +89,10 @@ async def test_evaluate_raises_with_invalid_version( monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) - with pytest.raises(Exception): + with pytest.raises(ValueError): evaluate( outputs=str(outputs_path), - questions_path=str(questions_path), - db_path=dummy_db_file, + workdir_path=str(temp_dir), show_mismatches=False, version=INVALID_LLMSQL_VERSION, - ) \ No newline at end of file + ) diff --git a/tests/evaluation/test_evaluator_stability.py b/tests/evaluation/test_evaluator_stability.py index b493431..09fb1a4 100644 --- a/tests/evaluation/test_evaluator_stability.py +++ b/tests/evaluation/test_evaluator_stability.py @@ -1,4 +1,6 @@ import json +from pathlib import Path +import shutil import pytest @@ -20,6 +22,8 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file): ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") + # Fake outputs.jsonl outputs_path = temp_dir / "outputs.jsonl" outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"})) @@ -34,8 +38,7 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file): report = evaluate( outputs=str(outputs_path), - questions_path=str(questions_path), - db_path=dummy_db_file, + workdir_path=str(temp_dir), show_mismatches=False, ) @@ -55,6 +58,7 @@ async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file): {"question_id": 1, "table_id": 1, "question": "Test", "sql": "SELECT 1"} ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") outputs_path = temp_dir / "outputs.jsonl" outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"})) @@ -71,8 +75,7 @@ async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file): evaluate( outputs=str(outputs_path), - questions_path=str(questions_path), - db_path=dummy_db_file, + workdir_path=str(temp_dir), save_report=str(report_path), show_mismatches=False, ) @@ -94,6 +97,7 @@ async def test_evaluate_with_jsonl_file(monkeypatch, temp_dir, dummy_db_file): {"question_id": 1, "table_id": 1, "question": "Sample", "sql": "SELECT 1"} ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") # Create fake outputs.jsonl outputs_path = temp_dir / "outputs.jsonl" @@ -115,8 +119,7 @@ async def test_evaluate_with_jsonl_file(monkeypatch, temp_dir, dummy_db_file): report = evaluate( outputs=str(outputs_path), - questions_path=str(questions_path), - db_path=dummy_db_file, + workdir_path=str(temp_dir), show_mismatches=False, ) @@ -135,6 +138,7 @@ async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file): {"question_id": 1, "table_id": 1, "question": "Sample", "sql": "SELECT 1"} ) ) + shutil.copy(dummy_db_file, temp_dir / "sqlite_tables.db") # Output as a list of dicts outputs_list = [{"question_id": 1, "completion": "SELECT 1"}] @@ -159,8 +163,7 @@ async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file): report = evaluate( outputs=outputs_list, - questions_path=str(questions_path), - db_path=dummy_db_file, + workdir_path=str(temp_dir), show_mismatches=False, ) @@ -192,32 +195,26 @@ def test_evaluate_with_jsonl_path(mock_utils, mocker): assert report["input_mode"] == "jsonl_path" -def test_missing_workdir_and_no_questions_path_raises(): - with pytest.raises(ValueError): - evaluate( - outputs=[{"question_id": 1, "completion": "x"}], - workdir_path=None, - questions_path=None, - ) - +def test_auto_temp_workdir_is_used_when_not_provided(mocker): + resolve = mocker.patch( + "llmsql.evaluation.evaluate.resolve_workdir_path", + return_value=Path("/tmp/llmsql-test"), + ) -def test_missing_workdir_and_no_db_path_raises(): - with pytest.raises(ValueError): - evaluate( - outputs=[{"question_id": 1, "completion": "x"}], - workdir_path=None, - db_path=None, - ) + evaluate([{"question_id": 1, "completion": "x"}], workdir_path=None) + resolve.assert_called_once_with(None) def test_download_occurs_if_files_missing(mock_utils, mocker): - dl = mocker.patch("llmsql.evaluation.evaluate.download_benchmark_file") + dl = mocker.patch("llmsql.evaluation.evaluate._maybe_download") + dl.side_effect = [ + str(mock_utils / "questions.jsonl"), + str(mock_utils / "sqlite_tables.db"), + ] evaluate( [{"question_id": 1, "completion": "SELECT 1"}], workdir_path=str(mock_utils), - questions_path=None, - db_path=None, ) assert dl.call_count == 2 # questions + sqlite diff --git a/tests/inference/test_inference_api.py b/tests/inference/test_inference_api.py index 21c5547..56af026 100644 --- a/tests/inference/test_inference_api.py +++ b/tests/inference/test_inference_api.py @@ -193,8 +193,7 @@ def test_returns_results_for_all_questions(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), ) assert len(results) == 2 @@ -208,8 +207,7 @@ def test_output_file_written_correctly(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), ) lines = outpath.read_text().strip().splitlines() @@ -227,8 +225,7 @@ def test_limit_integer(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=1, ) @@ -242,8 +239,7 @@ def test_limit_float(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=0.5, # 50% of 2 questions → 1 ) @@ -258,8 +254,7 @@ def test_invalid_limit_raises(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=1.5, # float out of (0, 1] ) @@ -272,8 +267,7 @@ def test_negative_limit_raises(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=-10, # float out of (0, 1] ) @@ -299,8 +293,7 @@ def fake_client_session(headers=None, **_): base_url="http://localhost:9999/v1", api_key="sk-test-key", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), ) assert captured_headers.get("Authorization") == "Bearer sk-test-key" @@ -314,8 +307,7 @@ def test_no_rate_limit_completes(self, monkeypatch, tmp_path): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), requests_per_minute=None, ) @@ -343,8 +335,7 @@ async def _run_inside_loop(): model_name="dummy", base_url="http://localhost:9999/v1", output_file=str(outpath), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), ) import nest_asyncio diff --git a/tests/inference/test_inference_stability.py b/tests/inference/test_inference_stability.py index be721ea..93b080b 100644 --- a/tests/inference/test_inference_stability.py +++ b/tests/inference/test_inference_stability.py @@ -58,8 +58,7 @@ async def test_inference_vllm_with_local_files(monkeypatch, tmp_path): results = mod.inference_vllm( model_name="dummy-model", output_file=str(out_file), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), num_fewshots=1, batch_size=1, max_new_tokens=5, @@ -122,8 +121,6 @@ def fake_download(repo_id, filename, path, **_): results = mod.inference_vllm( model_name="dummy-model", output_file=str(out_file), - questions_path=None, - tables_path=None, workdir_path=str(tmp_path), ) diff --git a/tests/inference/test_inference_transformers_different_llmsql_versions.py b/tests/inference/test_inference_transformers_different_llmsql_versions.py index cb1dec5..41ac978 100644 --- a/tests/inference/test_inference_transformers_different_llmsql_versions.py +++ b/tests/inference/test_inference_transformers_different_llmsql_versions.py @@ -9,7 +9,11 @@ questions = [ {"question_id": "q1", "table_id": "t1", "question": "Select name from students;"}, - {"question_id": "q2", "table_id": "t1", "question": "Count students older than 20;"}, + { + "question_id": "q2", + "table_id": "t1", + "question": "Count students older than 20;", + }, ] tables = [ { @@ -47,8 +51,7 @@ async def test_inference_stability_on_valid_version_flags(version_arg): "model_or_model_name_or_path": "sshleifer/tiny-gpt2", "tokenizer_or_name": "sshleifer/tiny-gpt2", "output_file": str(output_file), - "questions_path": str(questions_file), - "tables_path": str(tables_file), + "workdir_path": str(tmpdir_path), "batch_size": 1, "max_new_tokens": 8, "temperature": 0.0, @@ -56,7 +59,7 @@ async def test_inference_stability_on_valid_version_flags(version_arg): } if version_arg is not None: - kwargs["version"] = version_arg + kwargs["version"] = version_arg results = inference_transformers(**kwargs) @@ -66,7 +69,7 @@ async def test_inference_stability_on_valid_version_flags(version_arg): if version_arg is not None: for r in results: - assert "completion" in r + assert "completion" in r @pytest.mark.asyncio @@ -84,8 +87,7 @@ async def test_inference_stability_on_invalid_version_flag(): "model_or_model_name_or_path": "sshleifer/tiny-gpt2", "tokenizer_or_name": "sshleifer/tiny-gpt2", "output_file": str(out_file), - "questions_path": str(q_file), - "tables_path": str(t_file), + "workdir_path": str(tmpdir_path), "batch_size": 1, "max_new_tokens": 8, "temperature": 0.0, @@ -93,5 +95,5 @@ async def test_inference_stability_on_invalid_version_flag(): "version": INVALID_LLMSQL_VERSION, # invalid version } - with pytest.raises(Exception): + with pytest.raises(ValueError): inference_transformers(**kwargs) diff --git a/tests/inference/test_inference_transformers_stability.py b/tests/inference/test_inference_transformers_stability.py index 6a2b4f3..ba6500c 100644 --- a/tests/inference/test_inference_transformers_stability.py +++ b/tests/inference/test_inference_transformers_stability.py @@ -48,8 +48,7 @@ async def test_inference_stability(): model_or_model_name_or_path="sshleifer/tiny-gpt2", # tiny HF model for fast tests tokenizer_or_name="sshleifer/tiny-gpt2", output_file=str(output_file), - questions_path=str(questions_file), - tables_path=str(tables_file), + workdir_path=str(tmpdir_path), batch_size=1, max_new_tokens=8, temperature=0.0, diff --git a/tests/inference/test_inference_vllm_different_llmsql_versions.py b/tests/inference/test_inference_vllm_different_llmsql_versions.py index bc70af7..3e8954e 100644 --- a/tests/inference/test_inference_vllm_different_llmsql_versions.py +++ b/tests/inference/test_inference_vllm_different_llmsql_versions.py @@ -56,8 +56,7 @@ async def test_inference_vllm_valid_versions(monkeypatch, tmp_path, version_arg) kwargs = { "model_name": "dummy-model", "output_file": str(out_file), - "questions_path": str(q_file), - "tables_path": str(t_file), + "workdir_path": str(tmp_path), "num_fewshots": 1, "batch_size": 1, "max_new_tokens": 8, @@ -99,8 +98,7 @@ async def test_inference_vllm_invalid_version(monkeypatch, tmp_path): kwargs = { "model_name": "dummy-model", "output_file": str(out_file), - "questions_path": str(q_file), - "tables_path": str(t_file), + "workdir_path": str(tmp_path), "num_fewshots": 1, "batch_size": 1, "max_new_tokens": 8, diff --git a/tests/inference/test_limit_argument.py b/tests/inference/test_limit_argument.py index e2d2e09..5a68875 100644 --- a/tests/inference/test_limit_argument.py +++ b/tests/inference/test_limit_argument.py @@ -117,8 +117,7 @@ async def test_limit_integer_restricts_results(self, monkeypatch, tmp_path): results = vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=3, ) @@ -136,8 +135,7 @@ async def test_limit_float_restricts_results(self, monkeypatch, tmp_path): results = vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=0.4, ) @@ -155,8 +153,7 @@ async def test_limit_none_uses_all_samples(self, monkeypatch, tmp_path): results = vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=None, ) @@ -173,8 +170,7 @@ async def test_limit_float_1_uses_all_samples(self, monkeypatch, tmp_path): results = vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=1.0, ) @@ -192,8 +188,7 @@ async def test_limit_invalid_float_raises(self, monkeypatch, tmp_path): vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=1.5, ) @@ -209,8 +204,7 @@ async def test_limit_invalid_int_raises(self, monkeypatch, tmp_path): vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=0, ) @@ -225,8 +219,7 @@ async def test_limit_larger_than_dataset_uses_all(self, monkeypatch, tmp_path): results = vllm_mod.inference_vllm( model_name="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=9999, ) @@ -249,8 +242,7 @@ def test_limit_invalid_float_raises(self, monkeypatch, tmp_path): transformers_mod.inference_transformers( model_or_model_name_or_path="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=2.0, ) @@ -264,7 +256,6 @@ def test_limit_invalid_int_raises(self, monkeypatch, tmp_path): transformers_mod.inference_transformers( model_or_model_name_or_path="dummy", output_file=str(tmp_path / "out.jsonl"), - questions_path=str(qpath), - tables_path=str(tpath), + workdir_path=str(tmp_path), limit=-1, ) diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index 4ca241a..761e10c 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -6,11 +6,7 @@ import pytest import torch -from llmsql.config.config import ( - DEFAULT_LLMSQL_VERSION, - DEFAULT_WORKDIR_PATH, - get_repo_id, -) +from llmsql.config.config import DEFAULT_LLMSQL_VERSION, get_repo_id from llmsql.utils import inference_utils as mod @@ -22,12 +18,14 @@ async def test_download_file(monkeypatch, tmp_path): def fake_hf_hub_download(repo_id, filename, repo_type, local_dir): assert repo_id == get_repo_id(DEFAULT_LLMSQL_VERSION) assert repo_type == "dataset" - assert local_dir == DEFAULT_WORKDIR_PATH + assert Path(local_dir).is_dir() assert filename == "questions.jsonl" return expected_path monkeypatch.setattr(mod, "hf_hub_download", fake_hf_hub_download) - path = mod._download_file(get_repo_id(DEFAULT_LLMSQL_VERSION), "questions.jsonl") + path = mod._download_file( + get_repo_id(DEFAULT_LLMSQL_VERSION), "questions.jsonl", tmp_path + ) assert path == expected_path @@ -52,7 +50,6 @@ async def test_setup_seed(monkeypatch): @pytest.mark.asyncio async def test_maybe_download_calls_hf_hub(monkeypatch, tmp_path): """_maybe_download downloads file if missing.""" - monkeypatch.setattr(mod, "DEFAULT_WORKDIR_PATH", str(tmp_path)) filename = "questions.jsonl" called = {} @@ -66,7 +63,7 @@ def fake_hf_hub_download(**kwargs): monkeypatch.setattr(mod, "hf_hub_download", fake_hf_hub_download) path = mod._maybe_download( - get_repo_id(DEFAULT_LLMSQL_VERSION), filename, local_path=None + get_repo_id(DEFAULT_LLMSQL_VERSION), filename, workdir_path=tmp_path ) assert Path(path).exists() assert called["repo_id"] == get_repo_id(DEFAULT_LLMSQL_VERSION)