diff --git a/README.md b/README.md index 97952da..9952b83 100644 --- a/README.md +++ b/README.md @@ -42,31 +42,27 @@ The package doesn't have the dataset, it is stored on our [HuggingFace page](htt ## Latest News πŸ“£ -* [2025/12] Evaluation class converted to function see [new `evaluate(...)` function](./llmsql/evaluation/evaluate.py#evaluate) +* [2026/03] Added support for API inference, for now only for OpenAI-compatable APIs, see [`inference_api()` function](./llmsql/inference/inference_api.py#inference_api) -* New page version added to [`https://llmsql.github.io/llmsql-benchmark/`](https://llmsql.github.io/llmsql-benchmark/) +* [2026/03] The page now contains first version of [leaderboard](https://llmsql.github.io/llmsql-benchmark/#:~:text=%F0%9F%93%8A%20Leaderboard%20%E2%80%94%20Execution%20Accuracy%20%28EX)! -* Vllm inference method now supports chat templates, see [`inference_vllm(...)`](./llmsql/inference/inference_vllm.py#inference_vllm). -* Transformers inference now supports custom chat tempalates with `chat_template` argument, see [`inference_transformers(...)`](./llmsql/inference/inference_transformers.py#inference_transformers) +* [2026/02] The new LLMSQL 2.0 version is out now! See the [dataset](https://huggingface.co/datasets/llmsql-bench/llmsql-2.0). The support is already added with the `version` parameter to each `inference` function. -* More stable and deterministic inference with [`inference_vllm(...)`](./llmsql/inference/inference_vllm.py#inference_vllm) function added by setting [some envars](./llmsql/inference/inference_vllm.py) +* [2025/12] Evaluation class converted to function see [new `evaluate(...)` function](./llmsql/evaluation/evaluate.py#evaluate) -* `padding_side` argument added to [`inference_transformers(...)`](./llmsql/inference/inference_transformers.py#inference_transformers) function with default `left` option. ## Usage Recommendations -Modern LLMs are already strong at **producing SQL queries without finetuning**. +Modern LLMs are already strong at producing SQL queries without finetuning. We therefore recommend that most users: 1. **Run inference** directly on the full benchmark: - model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", - output_file="path_to_your_outputs.jsonl", - - Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)` + - Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`. The api inference is also supported, see [`inference_api()`](./llmsql/inference/inference_api.py#inference_api) - Evaluate results against the benchmark with the [`llmsql.evaluate`](./llmsql/evaluation/evaluator.py) function. 2. **Optional finetuning**: - - For research or domain adaptation, we provide finetuning version for HF models. Use [Finetune Ready](https://huggingface.co/datasets/llmsql-bench/llmsql-benchmark-finetune-ready) dataset from HuggingFace. + - For research or domain adaptation, we provide finetuning version for HF models. Use [Finetune Ready](https://huggingface.co/collections/llmsql-bench/fine-tune-ready-versions-of-the-llmsql-benchmark) datasets from HuggingFace. > [!Tip] > You can find additional manuals in the README files of each folder([Inferece Readme](./llmsql/inference/README.md), [Evaluation Readme](./llmsql/evaluation/README.md)) @@ -80,7 +76,7 @@ We therefore recommend that most users: ``` llmsql/ -β”œβ”€β”€ evaluation/ # Scripts for downloading DB + evaluating predictions +β”œβ”€β”€ evaluation/ # Scripts for evaluation └── inference/ # Generate SQL queries with your LLM ``` @@ -159,10 +155,12 @@ print(report) ``` +For more examples check the [examples folder](./examples/) + ## Prompt Template -The prompt defines explicit constraints on the generated output. -The model is instructed to output only a valid SQL `SELECT` query, to use a fixed table name (`"Table"`) **(which will be replaced with the actual table name during evaluation)**, to quote all table and column names, and to restrict generation to the specified SQL functions, condition operators, and keywords. +The prompt defines explicit constraints on the generated output. +The model is instructed to output only a valid SQL `SELECT` query, to use a fixed table name (`"Table"`) **(which will be replaced with the actual table name during evaluation)**, to quote all table and column names, and to restrict generation to the specified SQL functions, condition operators, and keywords. The full prompt specification is provided in the prompt template. Below is an example of the **5-shot prompt template** used during inference. @@ -224,13 +222,6 @@ Implementations of 0-shot, 1-shot, and 5-shot prompt templates are available her πŸ‘‰ [link-to-file](./llmsql/prompts/prompts.py) - -## Suggested Workflow - -* **Primary**: Run inference on all questions with vllm or transformers β†’ Evaluate with `evaluate()`. -* **Secondary (optional)**: Fine-tune on `train/val` β†’ Test on `test_questions.jsonl`. You can find the datasets here [HF Finetune Ready](https://huggingface.co/datasets/llmsql-bench/llmsql-benchmark-finetune-ready). - - ## Contributing Check out our [open issues](https://github.com/LLMSQL/llmsql-benchmark/issues), fork this repo and feel free to submit pull requests! diff --git a/docs/_templates/index.html b/docs/_templates/index.html index a93b256..925447d 100644 --- a/docs/_templates/index.html +++ b/docs/_templates/index.html @@ -113,7 +113,7 @@

1️⃣ Installation

2️⃣ Inference from CLI

vLLM Backend (Recommended)

-
llmsql inference --method vllm \
+
llmsql inference vllm \
 --model-name Qwen/Qwen2.5-1.5B-Instruct \
 --output-file outputs/preds.jsonl \
 --batch-size 8 \
@@ -121,7 +121,7 @@ 

2️⃣ Inference from CLI

--temperature 0.0

Transformers Backend

-
llmsql inference --method transformers \
+
llmsql inference transformers \
 --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
 --output-file outputs/preds.jsonl \
 --batch-size 8 \
@@ -163,7 +163,7 @@ 

πŸ“„ Citation

@inproceedings{llmsql_bench,
   title={LLMSQL: Upgrading WikiSQL for the LLM Era of Text-to-SQL},
   author={Pihulski, Dzmitry and Charchut, Karol and Novogrodskaia, Viktoria and Koco{'n}, Jan},
-  booktitle={2025 IEEE ICувцDMW},
+  booktitle={2025 IEEE International Conference on Data Mining Workshops (ICDMW)},
   year={2025},
   organization={IEEE}
 }
diff --git a/docs/docs/inference.rst b/docs/docs/inference.rst
index 5bcf0c6..1a33533 100644
--- a/docs/docs/inference.rst
+++ b/docs/docs/inference.rst
@@ -14,6 +14,12 @@ Inference API Reference
 
 ---
 
+.. automodule:: llmsql.inference.inference_api
+   :members:
+   :undoc-members:
+
+---
+
 .. raw:: html
 
    
diff --git a/docs/docs/usage.rst b/docs/docs/usage.rst index 806a4e8..b8966fe 100644 --- a/docs/docs/usage.rst +++ b/docs/docs/usage.rst @@ -77,6 +77,41 @@ Using vllm backend. print(report) +Using OpenAI-compateble API. + +.. code-block:: python + + from llmsql import inference_api + from dotenv import load_dotenv + import os + load_dotenv() + + # Run inference (will take some time) + results = inference_api( + model_name="gpt-5-mini", + base_url="https://api.openai.com/v1/", + api_key=os.environ["OPENAI_API_KEY"], + api_kwargs={ + "response_format": { + "type": "text" + }, + "verbosity": "medium", + "reasoning_effort": "medium", + "store": False + }, + requests_per_minute=100, + output_file="test_output_api.jsonl", + limit=50, + num_fewshots = 5, + seed=42, + version="2.0" + ) + + # Evaluate the results + evaluator = LLMSQLEvaluator() + report = evaluator.evaluate(outputs_path="outputs/preds_transformers.jsonl") + print(report) + --- .. raw:: html diff --git a/examples/inference_api.ipynb b/examples/inference_api.ipynb new file mode 100644 index 0000000..a049257 --- /dev/null +++ b/examples/inference_api.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5409b21a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from llmsql import inference_api\n", + "from dotenv import load_dotenv\n", + "import os\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "581e9c25", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-03-04 08:10:34,504 [INFO] llmsql-bench: Removing existing path: llmsql_workdir/questions.jsonl\n", + "2026-03-04 08:10:34,506 [INFO] llmsql-bench: Downloading questions.jsonl from Hugging Face Hub...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a71443d8f32840838ba484eadf26d9d0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "questions.jsonl: 0%| | 0.00/18.3M [00:00 None: # ========================= @@ -101,7 +107,6 @@ def add_common_generation_args( ) self._parser_transformers.add_argument("--chat-template") - # Generation add_common_generation_args(self._parser_transformers, 0.0, False) self._parser_transformers.add_argument("--top-p", type=float, default=1.0) self._parser_transformers.add_argument("--top-k", type=int, default=50) @@ -113,8 +118,6 @@ def add_common_generation_args( add_common_benchmark_args(self._parser_transformers) - self._parser_transformers.set_defaults(func=self._execute_transformers) - # ========================= # vLLM # ========================= @@ -155,7 +158,42 @@ def add_common_generation_args( add_common_benchmark_args(self._parser_vllm) - self._parser_vllm.set_defaults(func=self._execute_vllm) + # ========================= + # OpenAI-compatible API + # ========================= + self._parser_api.add_argument( + "--model-name", + required=True, + help="Target model name expected by the API", + ) + self._parser_api.add_argument( + "--base-url", + required=True, + help="API base URL, e.g. https://api.openai.com/v1", + ) + self._parser_api.add_argument( + "--endpoint", + default="chat/completions", + help="Completion endpoint path relative to --base-url", + ) + self._parser_api.add_argument("--api-key") + self._parser_api.add_argument("--timeout", type=float, default=120.0) + self._parser_api.add_argument( + "--requests-per-minute", + type=float, + help="Rate limit for API requests", + ) + self._parser_api.add_argument( + "--api-kwargs", + type=json.loads, + help="JSON string merged into API request payload", + ) + self._parser_api.add_argument( + "--request-headers", + type=json.loads, + help="JSON string merged into HTTP request headers", + ) + add_common_benchmark_args(self._parser_api) @staticmethod def _execute_transformers(args: argparse.Namespace) -> None: @@ -213,3 +251,26 @@ def _execute_vllm(args: argparse.Namespace) -> None: batch_size=args.batch_size, seed=args.seed, ) + + @staticmethod + def _execute_api(args: argparse.Namespace) -> None: + from llmsql import inference_api + + inference_api( + model_name=args.model_name, + base_url=args.base_url, + endpoint=args.endpoint, + api_key=args.api_key, + timeout=args.timeout, + requests_per_minute=args.requests_per_minute, + api_kwargs=args.api_kwargs, + request_headers=args.request_headers, + version=args.version, + output_file=args.output_file, + questions_path=args.questions_path, + tables_path=args.tables_path, + workdir_path=args.workdir_path, + limit=args.limit, + num_fewshots=args.num_fewshots, + seed=args.seed, + ) diff --git a/llmsql/_cli/llmsql_cli.py b/llmsql/_cli/llmsql_cli.py index 156caee..b12e509 100644 --- a/llmsql/_cli/llmsql_cli.py +++ b/llmsql/_cli/llmsql_cli.py @@ -38,6 +38,12 @@ def __init__(self) -> None: --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \ --llm-kwargs '{"max_model_len": 4096}' + # 5️⃣ OpenAI-compatible API backend + llmsql inference api \ + --model-name gpt-5-mini \ + --base-url https://api.openai.com/v1 \ + --requests-per-minute 30 + Visit https://github.com/LLMSQL/llmsql-benchmark for more """), formatter_class=argparse.RawDescriptionHelpFormatter, diff --git a/llmsql/inference/README.md b/llmsql/inference/README.md index b93339c..096d1b5 100644 --- a/llmsql/inference/README.md +++ b/llmsql/inference/README.md @@ -2,8 +2,9 @@ LLMSQL provides two inference backends for **Text-to-SQL generation** with large language models: -* 🧠 **Transformers** β€” runs inference using the standard Hugging Face `transformers` pipeline. -* ⚑ **vLLM** β€” runs inference using the high-performance [vLLM](https://github.com/vllm-project/vllm) backend. +* **Transformers** β€” runs inference using the standard Hugging Face `transformers` pipeline. +* **vLLM** β€” runs inference using the high-performance [vLLM](https://github.com/vllm-project/vllm) backend. +* **API** β€” runs inference against an OpenAI-compatible Chat Completions API with configurable base URL and rate limiting. Both backends load benchmark questions and table schemas, build prompts (with few-shot examples), and generate SQL queries in parallel batches. @@ -27,7 +28,7 @@ pip install llmsql[vllm] ## Quick Start -### βœ… Option 1 β€” Using the **Transformers** backend +### Option 1 β€” Using the **Transformers** backend ```python from llmsql import inference_transformers @@ -52,7 +53,7 @@ results = inference_transformers( --- -### ⚑ Option 2 β€” Using the **vLLM** backend +### Option 2 β€” Using the **vLLM** backend ```python from llmsql import inference_vllm @@ -65,6 +66,36 @@ results = inference_vllm( ) ``` +### Option 3 β€” Using an OpenAI-compatible API backend + +```python +from llmsql import inference_api +from dotenv import load_dotenv +import os +load_dotenv() + +results = inference_api( + model_name="gpt-5-mini", + base_url="https://api.openai.com/v1/", + api_key=os.environ["OPENAI_API_KEY"], + api_kwargs={ + "response_format": { + "type": "text" + }, + "verbosity": "medium", + "reasoning_effort": "medium", + "store": False + }, + requests_per_minute=100, + output_file="test_output_api.jsonl", + limit=50, + num_fewshots = 5, + seed=42, + version="2.0" +) +``` + + --- ## Command-Line Interface (CLI) @@ -72,7 +103,7 @@ results = inference_vllm( You can also run inference directly from the command line: ```bash -llmsql inference --method vllm \ +llmsql inference vllm \ --model-name Qwen/Qwen2.5-1.5B-Instruct \ --output-file outputs/preds.jsonl \ --batch-size 8 \ @@ -83,7 +114,7 @@ llmsql inference --method vllm \ Or use the Transformers backend: ```bash -llmsql inference --method transformers \ +llmsql inference transformers \ --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ --output-file outputs/preds.jsonl \ --batch-size 8 \ diff --git a/llmsql/inference/inference_api.py b/llmsql/inference/inference_api.py new file mode 100644 index 0000000..c01f0f9 --- /dev/null +++ b/llmsql/inference/inference_api.py @@ -0,0 +1,283 @@ +""" +LLMSQL OpenAI-Compatible API Inference Function +=============================================== + +This module provides ``inference_api()`` for text-to-SQL generation against an +OpenAI-compatible Chat Completions API. +""" + +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import time +from typing import Any, Literal + +import aiohttp +from dotenv import load_dotenv +import nest_asyncio +from tqdm.asyncio import tqdm + +from llmsql.config.config import ( + DEFAULT_LLMSQL_VERSION, + DEFAULT_WORKDIR_PATH, + get_repo_id, +) +from llmsql.loggers.logging_config import log +from llmsql.utils.inference_utils import _maybe_download, _setup_seed +from llmsql.utils.utils import ( + choose_prompt_builder, + load_jsonl, + overwrite_jsonl, + save_jsonl_lines, +) + +load_dotenv() + + +class _AsyncRateLimiter: + """ + Token-bucket style async rate limiter. + + Releases one token every (60 / requests_per_minute) seconds, + so requests are spaced from their *start* time β€” not from when + the previous one finished. This allows concurrent in-flight + requests while still honouring the RPM cap. + """ + + def __init__(self, requests_per_minute: float | None) -> None: + if requests_per_minute is not None and requests_per_minute <= 0: + raise ValueError("requests_per_minute must be > 0 when provided.") + self._interval: float | None = ( + 60.0 / requests_per_minute if requests_per_minute is not None else None + ) + self._next_allowed: float = 0.0 + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + """Wait until a request slot is available, then claim it.""" + if self._interval is None: + return + + async with self._lock: + now = time.monotonic() + wait = self._next_allowed - now + if wait > 0: + await asyncio.sleep(wait) + # Claim the next slot *before* releasing the lock so the + # following coroutine waits for exactly one more interval. + self._next_allowed = time.monotonic() + self._interval + + +async def _post_chat_completion_async( + *, + session: aiohttp.ClientSession, + base_url: str, + endpoint: str, + payload: dict[str, Any], + timeout: float, +) -> dict[str, Any]: + base = base_url.rstrip("/") + ep = endpoint.lstrip("/") + url = f"{base}/{ep}" + + async with session.post( + url, json=payload, timeout=aiohttp.ClientTimeout(total=timeout) + ) as resp: + resp.raise_for_status() + parsed: dict[str, Any] = await resp.json() + + if "choices" not in parsed: + raise ValueError("API response does not contain `choices`.") + return parsed + + +async def _inference_api_async( + model_name: str, + *, + base_url: str, + endpoint: str, + headers: dict[str, str], + timeout: float, + requests_per_minute: float | None, + api_kwargs: dict[str, Any], + questions: list[dict[str, Any]], + tables: dict[str, Any], + prompt_builder: Any, + output_file: str, +) -> list[dict[str, str]]: + limiter = _AsyncRateLimiter(requests_per_minute) + all_results: list[dict[str, str]] = [] + # Lock to serialise file writes while allowing concurrent HTTP calls. + write_lock = asyncio.Lock() + + async with aiohttp.ClientSession(headers=headers) as session: + + async def process_question(q: dict[str, Any]) -> dict[str, str]: + tbl = tables[q["table_id"]] + example_row = tbl["rows"][0] if tbl["rows"] else [] + prompt = prompt_builder( + q["question"], tbl["header"], tbl["types"], example_row + ) + + payload = { + "model": model_name, + "messages": [ + {"role": "user", "content": [{"type": "text", "text": prompt}]} + ], + **api_kwargs, + } + + # Acquire a rate-limit slot *before* firing the request so that + # the HTTP round-trip time doesn't count against the interval. + await limiter.acquire() + + response = await _post_chat_completion_async( + session=session, + base_url=base_url, + endpoint=endpoint, + payload=payload, + timeout=timeout, + ) + completion = response["choices"][0]["message"]["content"] + + result = { + "question_id": q.get("question_id", q.get("id", "")), + "completion": completion, + } + + async with write_lock: + save_jsonl_lines(output_file, [result]) + + return result + + tasks = [process_question(q) for q in questions] + for coro in tqdm( + asyncio.as_completed(tasks), + total=len(tasks), + desc="Generating", + ): + result = await coro + all_results.append(result) + + return all_results + + +def inference_api( + model_name: str, + *, + base_url: str, + endpoint: str = "chat/completions", + api_key: str | None = None, + timeout: float = 120.0, + requests_per_minute: float | None = None, + api_kwargs: dict[str, Any] | None = None, + request_headers: dict[str, str] | None = None, + version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, + output_file: str = "llm_sql_predictions.jsonl", + questions_path: str | None = None, + tables_path: str | None = None, + workdir_path: str = DEFAULT_WORKDIR_PATH, + limit: int | float | None = None, + num_fewshots: int = 5, + seed: int = 42, +) -> list[dict[str, str]]: + """Run SQL generation using an OpenAI-compatible Chat Completions API. + + Requests are dispatched concurrently so that HTTP round-trip time does + not count against the rate-limit interval β€” achieving a true + `requests_per_minute` throughput rather than + ``requests_per_minute / (1 + latency_in_minutes)``. + + Args: + model_name: The model name of the api. + + base_url: e.g. "https://api.openai.com/v1/" + endpoint: e.g. "chat/completions" + + # Benchmark: + version: LLMSQL version + output_file: Path to write outputs (will be overwritten). + questions_path: Path to questions.jsonl (auto-downloads if missing). + tables_path: Path to tables.jsonl (auto-downloads if missing). + workdir_path: Directory to store downloaded data. + num_fewshots: Number of few-shot examples (0, 1, or 5). + batch_size: Number of questions per generation batch. + seed: Random seed for reproducibility. + limit: Limit the number of questions to evaluate. If an integer, evaluates + the first N samples. If a float between 0.0 and 1.0, evaluates the + first X*100% of samples. If None, evaluates all samples (default). + + Returns: + List of dicts containing `question_id` and generated `completion`. + """ + _setup_seed(seed=seed) + api_kwargs = api_kwargs or {} + request_headers = request_headers or {} + + workdir = Path(workdir_path) + workdir.mkdir(parents=True, exist_ok=True) + + repo_id = get_repo_id(version) + questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) + tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) + + questions = load_jsonl(questions_path) + tables_list = load_jsonl(tables_path) + tables = {t["table_id"]: t for t in tables_list} + + if limit is not None: + if isinstance(limit, float): + if not (0.0 < limit <= 1.0): + raise ValueError( + f"When a float, `limit` must be between 0.0 and 1.0, got {limit}." + ) + limit = max(1, int(len(questions) * limit)) + if not isinstance(limit, int) or limit < 1: + raise ValueError( + f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}." + ) + questions = questions[:limit] + + key = api_key or os.environ.get("OPENAI_API_KEY") + headers: dict[str, str] = { + "Content-Type": "application/json", + **request_headers, + } + if key: + headers["Authorization"] = f"Bearer {key}" + + prompt_builder = choose_prompt_builder(num_fewshots) + + overwrite_jsonl(output_file) + + coro = _inference_api_async( + model_name, + base_url=base_url, + endpoint=endpoint, + headers=headers, + timeout=timeout, + requests_per_minute=requests_per_minute, + api_kwargs=api_kwargs, + questions=questions, + tables=tables, + prompt_builder=prompt_builder, + output_file=output_file, + ) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + # Inside a Jupyter notebook (or any other environment that already + # owns an event loop) β€” patch the loop so nested runs are allowed. + nest_asyncio.apply(loop) + all_results = loop.run_until_complete(coro) + else: + all_results = asyncio.run(coro) + + log.info(f"Generation completed. {len(all_results)} results saved to {output_file}") + return all_results diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 4f2d543..8f57657 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -84,6 +84,45 @@ async def test_vllm_backend_called(monkeypatch): assert call_kwargs["tensor_parallel_size"] == 2 +@pytest.mark.asyncio +async def test_api_backend_called(monkeypatch): + """ + Ensure API backend is correctly invoked. + """ + mock_inference = AsyncMock(return_value=[]) + + monkeypatch.setattr( + "llmsql.inference_api", + mock_inference, + ) + + test_args = [ + "llmsql", + "inference", + "api", + "--model-name", + "gpt-4o-mini", + "--base-url", + "https://api.openai.com/v1", + "--requests-per-minute", + "30", + ] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + args = cli.parse_args() + + cli.execute(args) + + mock_inference.assert_called_once() + + call_kwargs = mock_inference.call_args.kwargs + assert call_kwargs["model_name"] == "gpt-4o-mini" + assert call_kwargs["base_url"] == "https://api.openai.com/v1" + assert call_kwargs["requests_per_minute"] == 30.0 + + @pytest.mark.asyncio async def test_missing_backend_errors(monkeypatch): """ diff --git a/tests/inference/test_inference_api.py b/tests/inference/test_inference_api.py new file mode 100644 index 0000000..21c5547 --- /dev/null +++ b/tests/inference/test_inference_api.py @@ -0,0 +1,356 @@ +"""Tests for the async inference_api implementation.""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import llmsql.inference.inference_api as api_mod +from llmsql.inference.inference_api import _AsyncRateLimiter, inference_api + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_jsonl(path: Path, rows: list[dict]) -> None: + path.write_text("\n".join(json.dumps(r) for r in rows) + "\n") + + +def _make_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]: + """Return (questions_path, tables_path, out_path) pre-populated with minimal data.""" + questions = [ + {"question_id": "q1", "question": "What is 1+1?", "table_id": "t1"}, + {"question_id": "q2", "question": "What is 2+2?", "table_id": "t1"}, + ] + tables = [{"table_id": "t1", "header": ["col"], "types": ["text"], "rows": [["x"]]}] + qpath = tmp_path / "questions.jsonl" + tpath = tmp_path / "tables.jsonl" + outpath = tmp_path / "out.jsonl" + _write_jsonl(qpath, questions) + _write_jsonl(tpath, tables) + return qpath, tpath, outpath + + +def _fake_http_response(content: str = "SELECT 1") -> MagicMock: + """Build a mock that looks like a successful aiohttp response.""" + mock_resp = AsyncMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock( + return_value={"choices": [{"message": {"content": content}}]} + ) + # Support async context-manager usage: `async with session.post(...) as resp` + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + return mock_ctx + + +# --------------------------------------------------------------------------- +# _AsyncRateLimiter unit tests +# --------------------------------------------------------------------------- + + +class TestAsyncRateLimiter: + def test_raises_on_non_positive_rpm(self): + with pytest.raises(ValueError, match="requests_per_minute must be > 0"): + _AsyncRateLimiter(0) + with pytest.raises(ValueError, match="requests_per_minute must be > 0"): + _AsyncRateLimiter(-10) + + def test_no_limit_returns_immediately(self): + """acquire() with rpm=None should not sleep at all.""" + limiter = _AsyncRateLimiter(None) + + async def _run(): + t0 = time.monotonic() + await limiter.acquire() + await limiter.acquire() + await limiter.acquire() + return time.monotonic() - t0 + + elapsed = asyncio.run(_run()) + assert elapsed < 0.05 # well under any sleep threshold + + def test_slots_are_spaced_correctly(self): + """With 60 RPM the limiter should space 3 calls ~1 s apart (2 s total).""" + limiter = _AsyncRateLimiter(60) # 1 request per second + + timestamps: list[float] = [] + + async def _run(): + for _ in range(3): + await limiter.acquire() + timestamps.append(time.monotonic()) + + asyncio.run(_run()) + gaps = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)] + for gap in gaps: + assert gap == pytest.approx(1.0, abs=0.15), f"gap {gap:.3f}s not ~1 s" + + def test_concurrent_slots_are_serialised(self): + """ + Multiple concurrent coroutines must each get a distinct slot β€” + _next_allowed must advance monotonically even under concurrency. + """ + limiter = _AsyncRateLimiter(600) # 0.1 s interval β€” fast enough for tests + TASKS = 5 + slots: list[float] = [] + + async def _worker(): + await limiter.acquire() + slots.append(time.monotonic()) + + async def _run(): + await asyncio.gather(*[_worker() for _ in range(TASKS)]) + + asyncio.run(_run()) + slots.sort() + gaps = [slots[i + 1] - slots[i] for i in range(len(slots) - 1)] + for gap in gaps: + # Each gap should be β‰₯ interval minus a small scheduling tolerance. + assert gap >= 0.08, f"slots overlap too closely: gap={gap:.4f}s" + + +# --------------------------------------------------------------------------- +# _post_chat_completion_async unit test +# --------------------------------------------------------------------------- + + +class TestPostChatCompletionAsync: + def test_returns_parsed_response(self): + mock_ctx = _fake_http_response("SELECT 42") + + async def _run(): + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_ctx) + return await api_mod._post_chat_completion_async( + session=mock_session, + base_url="http://fake/v1", + endpoint="chat/completions", + payload={"model": "x", "messages": []}, + timeout=5.0, + ) + + result = asyncio.run(_run()) + assert result["choices"][0]["message"]["content"] == "SELECT 42" + + def test_raises_on_missing_choices(self): + mock_resp = AsyncMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock(return_value={"error": "oops"}) + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + async def _run(): + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_ctx) + return await api_mod._post_chat_completion_async( + session=mock_session, + base_url="http://fake/v1", + endpoint="chat/completions", + payload={}, + timeout=5.0, + ) + + with pytest.raises(ValueError, match="does not contain `choices`"): + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# inference_api integration-style tests (HTTP mocked) +# --------------------------------------------------------------------------- + + +class TestInferenceApi: + """Patch aiohttp.ClientSession so no real HTTP calls are made.""" + + def _patch_session(self, monkeypatch, content: str = "SELECT 1"): + """Replace aiohttp.ClientSession with a fully async-compatible mock.""" + mock_ctx = _fake_http_response(content) + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_ctx) + + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + + monkeypatch.setattr( + api_mod.aiohttp, "ClientSession", MagicMock(return_value=session_ctx) + ) + + def test_returns_results_for_all_questions(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + results = inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + ) + + assert len(results) == 2 + assert all("question_id" in r and "completion" in r for r in results) + + def test_output_file_written_correctly(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch, content="SELECT 99") + qpath, tpath, outpath = _make_fixtures(tmp_path) + + inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + ) + + lines = outpath.read_text().strip().splitlines() + assert len(lines) == 2 + for line in lines: + row = json.loads(line) + assert row["completion"] == "SELECT 99" + assert "question_id" in row + + def test_limit_integer(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + results = inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + limit=1, + ) + + assert len(results) == 1 + + def test_limit_float(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + results = inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + limit=0.5, # 50% of 2 questions β†’ 1 + ) + + assert len(results) == 1 + + def test_invalid_limit_raises(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + with pytest.raises(ValueError): + inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + limit=1.5, # float out of (0, 1] + ) + + def test_negative_limit_raises(self, monkeypatch, tmp_path): + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + with pytest.raises(ValueError): + inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + limit=-10, # float out of (0, 1] + ) + + def test_api_key_set_in_header(self, monkeypatch, tmp_path): + """Authorization header must be forwarded to the session.""" + captured_headers: dict = {} + + def fake_client_session(headers=None, **_): + captured_headers.update(headers or {}) + mock_ctx = _fake_http_response() + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_ctx) + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + return session_ctx + + monkeypatch.setattr(api_mod.aiohttp, "ClientSession", fake_client_session) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + api_key="sk-test-key", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + ) + + assert captured_headers.get("Authorization") == "Bearer sk-test-key" + + def test_no_rate_limit_completes(self, monkeypatch, tmp_path): + """rpm=None should not raise and should still return all results.""" + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + results = inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + requests_per_minute=None, + ) + + assert len(results) == 2 + + def test_notebook_compat_with_running_loop(self, monkeypatch, tmp_path): + """ + Simulate a Jupyter environment: inference_api is called from inside a + running event loop. nest_asyncio should be applied and results returned. + """ + self._patch_session(monkeypatch) + qpath, tpath, outpath = _make_fixtures(tmp_path) + + apply_called = {"n": 0} + real_apply = api_mod.nest_asyncio.apply + + def spy_apply(loop=None): + apply_called["n"] += 1 + real_apply(loop) + + monkeypatch.setattr(api_mod.nest_asyncio, "apply", spy_apply) + + async def _run_inside_loop(): + return inference_api( + model_name="dummy", + base_url="http://localhost:9999/v1", + output_file=str(outpath), + questions_path=str(qpath), + tables_path=str(tpath), + ) + + import nest_asyncio + + nest_asyncio.apply() # allow the outer asyncio.run below to nest + results = asyncio.run(_run_inside_loop()) + + assert apply_called["n"] >= 1, "nest_asyncio.apply() should have been called" + assert len(results) == 2