diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..db260f3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: CI + +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + +jobs: + lint-and-test: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install ruff + + - name: Syntax check (compileall) + run: python scripts/check_syntax.py + + - name: Ruff lint + run: ruff check . + + - name: Run tests + run: pytest -q \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8051a04..277a7e3 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ venv/ data/ outputs/ logs/ +reports/ # OS / editor files .DS_Store diff --git a/README.md b/README.md index ff5f7c7..3a7e480 100644 --- a/README.md +++ b/README.md @@ -151,34 +151,112 @@ QLoRA/LoRA configuration, and troubleshooting (OOM, sequence length, etc.). --- -## Evaluation (placeholder) +## Evaluation -> Evaluation scripts and methodology will be documented here later. +The repository includes a lightweight but robust evaluation pipeline for +text-to-SQL: -Planned content for this section: +- **Internal evaluation** on the processed `b-mc2/sql-create-context` val set. +- **Secondary external validation** on the Spider dev split. + +See [`docs/evaluation.md`](./docs/evaluation.md) for full details. Below are the +most common commands. + +### Internal Evaluation (b-mc2/sql-create-context val) + +Mock mode (no model required, exercises metrics/reporting): + +```bash +python scripts/evaluate_internal.py --mock \ + --val_path data/processed/val.jsonl \ + --out_dir reports/ +``` -- How to run evaluation on: - - WikiSQL test split. - - (Optional) Spider dev set. -- Metrics: - - Logical form accuracy (exact SQL match). - - Execution accuracy (matching query results). - - Latency benchmarks (p50/p95). -- How to generate evaluation reports under `docs/` or `outputs/`. +With a trained QLoRA adapter (GPU recommended): + +```bash +python scripts/evaluate_internal.py \ + --val_path data/processed/val.jsonl \ + --base_model mistralai/Mistral-7B-Instruct-v0.1 \ + --adapter_dir /path/to/outputs/adapters \ + --device auto \ + --max_examples 200 \ + --out_dir reports/ +``` + +Outputs: + +- `reports/eval_internal.json` – metrics, config, and sample predictions. +- `reports/eval_internal.md` – human-readable summary. + +### External Validation (Spider dev) + +Mock mode (offline fixtures only, no internet required): + +```bash +python scripts/evaluate_spider_external.py --mock \ + --out_dir reports/ +``` + +Full Spider dev evaluation with a trained model (requires internet + HF Datasets): + +```bash +python scripts/evaluate_spider_external.py \ + --base_model mistralai/Mistral-7B-Instruct-v0.1 \ + --adapter_dir /path/to/outputs/adapters \ + --device auto \ + --spider_source xlangai/spider \ + --schema_source richardr1126/spider-schema \ + --spider_split validation \ + --max_examples 200 \ + --out_dir reports/ +``` + +Outputs: + +- `reports/eval_spider.json` – metrics, config, and sample predictions. +- `reports/eval_spider.md` – human-readable summary, including notes on + differences from official Spider evaluation. --- -## External Validation (Spider dev) – planned +## External Validation (Spider dev) + +After training on `b-mc2/sql-create-context`, we run a secondary evaluation +harness on the **Spider** dev set (e.g., `xlangai/spider`) to measure +generalization to harder, multi-table, cross-domain text-to-SQL tasks. + +Spider evaluation uses a **lightweight EM-style** metric suite: + +- Exact Match and No-values Exact Match on normalized SQL. +- SQL parse success using `sqlglot`. +- Schema adherence checks against serialized schemas from + `richardr1126/spider-schema` (licensed under **CC BY-SA 4.0**). -After training on `b-mc2/sql-create-context`, we plan to add a secondary -evaluation harness on the **Spider** dev set (e.g., `xlangai/spider`) to -measure generalization to harder, multi-table, cross-domain text-to-SQL tasks. +Spider and its schema helper are used **only for evaluation**, not for +training. + +For details, see [`docs/external_validation.md`](./docs/external_validation.md). -For the high-level plan, see [`docs/external_validation.md`](./docs/external_validation.md). -_code --new- The Streamlit UI will be documented here when implemented. @@ -200,24 +278,55 @@ Current high-level layout: ```text . -├── app/ # Streamlit app (to be implemented) -├── docs/ # Documentation, design notes, evaluation reports -├── notebooks/ # Jupyter/Colab notebooks for experimentation -├── scripts/ # CLI scripts (e.g., dataset loading, training, eval) -│ └── smoke_load_dataset.py +├── app/ # Streamlit app (to be implemented) +├── docs/ # Documentation, design notes, evaluation reports +│ ├── dataset.md +│ ├── training.md +│ ├── evaluation.md +│ └── external_validation.md +├── notebooks/ # Jupyter/Colab notebooks for experimentation +├── scripts/ # CLI scripts (dataset, training, evaluation) +│ ├── build_dataset.py +│ ├── smoke_load_dataset.py +│ ├── train_qlora.py +│ ├── evaluate_internal.py +│ └── evaluate_spider_external.py ├── src/ -│ └── text2sql/ # Core Python package +│ └── text2sql/ # Core Python package │ ├── __init__.py -│ └── utils/ # Utility modules (to be implemented) -│ └── __init__.py +│ ├── data_prep.py +│ ├── infer.py +│ ├── training/ +│ │ ├── __init__.py +│ │ ├── config.py +│ │ └── formatting.py +│ └── eval/ +│ ├── __init__.py +│ ├── normalize.py +│ ├── schema.py +│ ├── metrics.py +│ └── spider.py ├── tests/ -│ └── test_repo_smoke.py # Basic smoke test (imports the package) -├── .env.example # Example environment file +│ ├── fixtures/ +│ │ ├── sql_create_context_sample.jsonl +│ │ ├── eval_internal_sample.jsonl +│ │ ├── spider_sample.jsonl +│ │ └── spider_schema_sample.jsonl +│ ├── test_repo_smoke.py +│ ├── test_build_dataset_offline.py +│ ├── test_data_prep.py +│ ├── test_prompt_formatting.py +│ ├── test_normalize_sql.py +│ ├── test_schema_adherence.py +│ ├── test_metrics_aggregate.py +│ └── test_prompt_building_spider.py +├── .env.example # Example environment file ├── .gitignore -├── context.md # Persistent project context & decisions +├── context.md # Persistent project context & decisions ├── LICENSE ├── README.md └── requirements.txt ``` -As the project progresses, this structure will be refined and additional modules, scripts, and documentation will be added. \ No newline at end of file +As the project progresses, this structure will be refined and additional modules, +scripts, and documentation will be added. \ No newline at end of file diff --git a/context.md b/context.md index 6d1e3bd..a309d00 100644 --- a/context.md +++ b/context.md @@ -119,6 +119,7 @@ This repo will contain: - **2026-01-10** – Decided to create our own deterministic validation split (default 8% of the data, seed=42) from the single `train` split shipped with `b-mc2/sql-create-context`, to enable reproducible model selection and early-stopping. - **2026-01-10** – Selected **`mistralai/Mistral-7B-Instruct-v0.1`** as the base model for fine-tuning, using **QLoRA (4-bit) + LoRA adapters** implemented via **Unsloth + bitsandbytes** for efficient training on a single GPU. - **2026-01-10** – Planned a **secondary external validation** step on **Spider dev** (e.g., `xlangai/spider`) after primary training on `b-mc2/sql-create-context`, to measure cross-domain, multi-table generalization. +- **2026-01-10** – Implemented a dedicated evaluation pipeline (internal + Spider dev) using normalized SQL metrics, schema adherence checks, and lightweight external validation based on `xlangai/spider` and `richardr1126/spider-schema` (Spider used only for evaluation, not training). --- @@ -132,4 +133,5 @@ This repo will contain: - Added basic pytest smoke test to verify that the `text2sql` package imports successfully. - **2026-01-10** – Updated dataset plan and smoke loader to use the parquet-backed **`b-mc2/sql-create-context`** dataset (compatible with `datasets>=4`) instead of the script-based `Salesforce/wikisql`, and documented this decision in the project context. - **2026-01-10** – Added a dataset preprocessing pipeline (`scripts/build_dataset.py`) that converts `b-mc2/sql-create-context` into Alpaca-style instruction-tuning JSONL files under `data/processed/` (train/val splits), along with reusable formatting utilities in `text2sql.data_prep`. -- **2026-01-10** – Added QLoRA training scaffolding: a detailed Colab-friendly notebook (`notebooks/finetune_mistral7b_qlora_text2sql.ipynb`), a reproducible training script (`scripts/train_qlora.py`), training utilities under `src/text2sql/training/`, and documentation for training (`docs/training.md`) plus planned external validation on Spider dev (`docs/external_validation.md`).LoRA training scaffolding: a detailed Colab-friendly notebook (`notebooks/finetune_mistral7b_qlora_text2sql.ipynb`), a reproducible training script (`scripts/train_qlora.py`), training utilities under `src`. \ No newline at end of file +- **2026-01-10** – Added QLoRA training scaffolding: a detailed Colab-friendly notebook (`notebooks/finetune_mistral7b_qlora_text2sql.ipynb`), a reproducible training script (`scripts/train_qlora.py`), training utilities under `src/text2sql/training/`, and documentation for training (`docs/training.md`) plus planned external validation on Spider dev (`docs/external_validation.md`). +- **2026-01-10** – Task 4: Added an evaluation pipeline with internal metrics on `b-mc2/sql-create-context` (Exact Match, No-values EM, SQL parse success, schema adherence) and a secondary external validation harness on Spider dev using `xlangai/spider` and `richardr1126/spider-schema`, along with reports under `reports/` and supporting documentation (`docs/evaluation.md`). \ No newline at end of file diff --git a/docs/evaluation.md b/docs/evaluation.md new file mode 100644 index 0000000..5e70dd5 --- /dev/null +++ b/docs/evaluation.md @@ -0,0 +1,330 @@ +# Evaluation for Analytics Copilot (Text-to-SQL) + +This document describes the evaluation pipeline for the Analytics Copilot +(Text-to-SQL) model, including: + +- **Internal evaluation** on the preprocessed `b-mc2/sql-create-context` val set. +- **Secondary external validation** on the Spider dev set using lightweight, + portfolio-friendly metrics. + +The goal is to provide reproducible, scriptable evaluation that can run both +in local development environments (including mock/offline modes) and in +GPU-backed Colab sessions with trained adapters. + +--- + +## 1. Internal Evaluation (b-mc2/sql-create-context val) + +### 1.1 Dataset + +Internal evaluation uses the Alpaca-style validation file produced by the +dataset builder: + +- `data/processed/val.jsonl` + +Each line is a JSON object with at least: + +- `instruction` +- `input` – formatted schema + question, e.g.: + + ```text + ### Schema: + + + ### Question: + + ``` + +- `output` – normalized gold SQL query. + +### 1.2 Metrics + +The internal evaluation script computes: + +- **Exact Match (EM)** – comparison on *normalized* SQL: + - Strips leading/trailing whitespace. + - Removes trailing semicolons. + - Collapses runs of whitespace into a single space. + - Implemented via `text2sql.eval.normalize.normalize_sql`. + +- **No-values Exact Match** + - Builds on `normalize_sql` and additionally replaces: + - Single-quoted string literals with a placeholder (`'__STR__'`). + - Numeric literals (integers/decimals, optionally negative) with a + placeholder (`__NUM__`). + - Useful to detect structural matches even when literal values differ. + +- **SQL parse success rate** + - Fraction of predictions that can be parsed by `sqlglot.parse_one`. + - Provides a lightweight proxy for syntactic validity of generated SQL. + +- **Schema adherence rate** + - Uses the `CREATE TABLE` context from each example and parses it with + `sqlglot` to recover: + - Known tables. + - Known columns per table. + - Parses the predicted SQL and extracts referenced table and column names. + - A prediction is schema-adherent if **all** referenced tables/columns + appear in the context. + - Implemented via: + - `text2sql.eval.schema.parse_create_table_context` + - `text2sql.eval.schema.referenced_identifiers` + - `text2sql.eval.schema.schema_adherence` + +All metrics are aggregated via: + +- `text2sql.eval.metrics.aggregate_metrics` + +which returns: + +- `n_examples` +- `exact_match` – `{count, rate}` +- `no_values_em` – `{count, rate}` +- `parse_success` – `{count, rate}` +- `schema_adherence` – `{count, rate}` + +### 1.3 How to Run Internal Evaluation + +#### 1.3.1 Mock Mode (no model required) + +Mock mode is designed for quick local checks and CI: + +```bash +python scripts/evaluate_internal.py --mock \ + --val_path data/processed/val.jsonl \ + --out_dir reports/ +``` + +Behavior: + +- Uses the gold SQL (`output`) as the prediction. +- Exercises normalization, parsing, schema adherence, and reporting code. +- Produces: + + - `reports/eval_internal.json` + - `reports/eval_internal.md` + +#### 1.3.2 Real Evaluation with Adapters (GPU recommended) + +After fine-tuning with QLoRA (see `docs/training.md`), you can evaluate the +model using the trained adapters: + +```bash +python scripts/evaluate_internal.py \ + --val_path data/processed/val.jsonl \ + --base_model mistralai/Mistral-7B-Instruct-v0.1 \ + --adapter_dir /path/to/outputs/adapters \ + --device auto \ + --max_examples 200 \ + --temperature 0.0 \ + --top_p 0.9 \ + --max_new_tokens 256 \ + --out_dir reports/ +``` + +Notes: + +- `--device auto` prefers GPU when available and falls back to CPU otherwise + (with a warning). +- By default, when running on CUDA the inference loader will try to load the + base model in **4-bit (bitsandbytes)** for faster and more memory-efficient + evaluation. You can explicitly control this with: + - `--load_in_4bit` / `--no_load_in_4bit` + - `--dtype` (default `auto`, which maps to `float16` on CUDA and `float32` on CPU) +- `--max_examples` allows you to subsample the validation set for quick runs. +- If you have a **merged model directory**, you can pass it as `--base_model` + and omit `--adapter_dir`. + +--- + +## 2. External Validation on Spider Dev + +### 2.1 Datasets and Licensing + +External validation uses two Hugging Face datasets: + +1. **Spider examples** + + - Dataset: `xlangai/spider` + - Split: `validation` (configured via `--spider_split`) + - Provides: + - `db_id` + - `question` + - `query` (gold SQL) + +2. **Spider schema helper** + + - Dataset: `richardr1126/spider-schema` + - Provides: + - `db_id` + - `create_table_context` – a serialized schema context with `CREATE TABLE` + information for all tables in the database. + +> **License:** `xlangai/spider` is derived from the original Spider benchmark, +> and `richardr1126/spider-schema` is licensed under **CC BY-SA 4.0**. In this +> project, Spider is used **only for evaluation**, **not** for training. + +### 2.2 Prompt Construction + +For each Spider example: + +1. Look up `db_id` in the schema helper dataset to retrieve + `create_table_context`. +2. Build the schema + question input using the **same** format as internal + evaluation: + + ```text + ### Schema: + + + ### Question: + + ``` + +3. Use the same instruction text as training: + + > "Write a SQL query that answers the user's question using ONLY the tables + > and columns provided in the schema." + +4. Wrap instruction + input into a full prompt using the training formatter: + + - Implemented in `text2sql.eval.spider.build_spider_prompt`, which internally + reuses: + - `text2sql.data_prep.INSTRUCTION_TEXT` + - `text2sql.data_prep.build_input_text` + - `text2sql.training.formatting.build_prompt` + +### 2.3 Metrics + +Spider evaluation uses the **same metric suite** as internal evaluation: + +- **Exact Match (normalized SQL)** +- **No-values Exact Match** +- **SQL parse success rate** +- **Schema adherence rate** + +This provides a **lightweight generalization check** on Spider dev, but it is +**not a full reproduction** of official Spider evaluation. In particular: + +- Official Spider metrics include detailed component matching (SELECT, WHERE, + GROUP BY, etc.). +- Execution-based evaluation is often used to measure semantic equivalence via + query results. + +Here we focus on structural/logical-form approximations that are easy to run +without database execution, suitable for a portfolio-style baseline. + +### 2.4 How to Run Spider Evaluation + +#### 2.4.1 Mock Mode (offline, fixtures only) + +Mock mode uses small offline fixtures under `tests/fixtures/` and **does not +require internet**: + +```bash +python scripts/evaluate_spider_external.py --mock \ + --out_dir reports/ +``` + +Behavior: + +- Loads: + - `tests/fixtures/spider_sample.jsonl` + - `tests/fixtures/spider_schema_sample.jsonl` +- Uses gold SQL as predictions. +- Produces: + + - `reports/eval_spider.json` + - `reports/eval_spider.md` + +This is ideal for local smoke tests of the Spider pipeline. + +#### 2.4.2 Real Evaluation with Adapters (GPU recommended) + +With network access and a trained model, you can run full Spider dev evaluation: + +```bash +python scripts/evaluate_spider_external.py \ + --base_model mistralai/Mistral-7B-Instruct-v0.1 \ + --adapter_dir /path/to/outputs/adapters \ + --device auto \ + --spider_source xlangai/spider \ + --schema_source richardr1126/spider-schema \ + --spider_split validation \ + --max_examples 200 \ + --temperature 0.0 \ + --top_p 0.9 \ + --max_new_tokens 256 \ + --out_dir reports/ +``` + +Notes: + +- By default, when running on CUDA the inference loader will try to load the + base model in **4-bit (bitsandbytes)** for faster and more memory-efficient + evaluation. You can explicitly control this with: + - `--load_in_4bit` / `--no_load_in_4bit` + - `--dtype` (default `auto`, which maps to `float16` on CUDA and `float32` on CPU) +- `--max_examples` allows a lighter-weight subset run (e.g., 50–200 examples). +- When `--mock` is not set, the script downloads datasets via + `datasets.load_dataset`, so internet access is required. + +--- + +## 3. Inference Wrapper + +Both evaluation scripts rely on a shared inference helper: + +- `src/text2sql/infer.py` + +Key functions: + +- `load_model_for_inference(base_model, adapter_dir=None, device='auto', load_in_4bit=None, bnb_compute_dtype='float16', dtype='auto')` + - Loads a base HF model or local directory. + - Optionally applies LoRA adapters from `adapter_dir`. + - Resolves device via: + - `"auto"` → GPU if available, otherwise CPU (with a warning). + - `"cuda"` / `"cpu"` for explicit control. + - When running on CUDA and `load_in_4bit` is not explicitly set, the loader + defaults to 4-bit (NF4) quantization using bitsandbytes. This significantly + reduces memory usage and speeds up evaluation on Colab-style GPUs. + +- `generate_sql(prompt, max_new_tokens, temperature, top_p) -> str` + - Uses the loaded model/tokenizer to generate text. + - Evaluation scripts post-process the raw text via + `text2sql.training.formatting.ensure_sql_only` before metric computation. + +This separation keeps the evaluation scripts thin and allows reuse of the +inference pipeline in other tools (e.g., a Streamlit demo or interactive +notebooks). + +--- + +## 4. Local Testing Strategy (No Internet Required) + +To keep the test suite lightweight and offline-friendly: + +- Fixtures under `tests/fixtures/` provide small synthetic datasets: + - `eval_internal_sample.jsonl` – mini val-style examples. + - `spider_sample.jsonl` and `spider_schema_sample.jsonl` – Spider-like + examples and schemas. +- Unit tests cover: + - SQL normalization (`test_normalize_sql.py`). + - Schema parsing and adherence (`test_schema_adherence.py`). + - Metric aggregation (`test_metrics_aggregate.py`). + - Spider prompt construction (`test_prompt_building_spider.py`). + +CI or local developers can run: + +```bash +pytest -q +``` + +without requiring internet access or GPU hardware. For full model-based +evaluation, see the commands in sections 1.3.2 and 2.4.2 above. + +If you see TensorFlow CUDA warnings in Colab logs (e.g. about missing +`libcudart`), they can generally be ignored for this project. The evaluation +scripts also set `TF_CPP_MIN_LOG_LEVEL=3` to suppress most TensorFlow log +noise; you can optionally uninstall TensorFlow entirely if you are not using +it elsewhere in your notebook. \ No newline at end of file diff --git a/docs/external_validation.md b/docs/external_validation.md index 046014d..da33297 100644 --- a/docs/external_validation.md +++ b/docs/external_validation.md @@ -1,11 +1,10 @@ -# External Validation on Spider Dev (Planned) +# External Validation on Spider Dev -This document outlines the planned **secondary external validation** workflow +This document describes the **secondary external validation** workflow for the Analytics Copilot (Text-to-SQL) model using the **Spider** dataset. -The actual implementation will be tackled in a later task (Task 4). This file -serves as scaffolding so that the evaluation plan is embedded in the project -narrative from the beginning. +The implementation of this evaluation pipeline is completed as part of +**Task 4** and wired into the repository via dedicated scripts and tests. --- @@ -35,15 +34,20 @@ beyond the training distribution**. --- -## 2. Planned Dataset Source +## 2. Dataset Sources and Licensing -We plan to use a Hugging Face-hosted Spider variant, such as: +We use Hugging Face-hosted Spider variants: -- `xlangai/spider` (dev split) +- Examples: `xlangai/spider` (dev/validation split) +- Schema helper: `richardr1126/spider-schema` This choice keeps the evaluation flow consistent with the rest of the project, which already relies on Hugging Face Datasets for loading and caching. +> **License:** The `richardr1126/spider-schema` dataset is distributed under +> **CC BY-SA 4.0**. In this project, Spider and its schema helper are used +> **only for evaluation**, not for training. + --- ## 3. High-Level Evaluation Plan @@ -51,67 +55,76 @@ which already relies on Hugging Face Datasets for loading and caching. The high-level steps for external validation on Spider dev are: 1. **Load Spider dev** - - Use `datasets.load_dataset("xlangai/spider", split="validation")` or the - equivalent dev split. - - Inspect fields: questions, database ids, schemas, and SQL queries. + - Use `datasets.load_dataset("xlangai/spider", split="validation")` (or a + compatible dev split). + - Keep only rows with `db_id`, `question`, and `query` populated. 2. **Schema Serialization** - - Build a schema-serialization strategy that converts Spider’s multi-table - schemas into a textual context suitable for the model. - - Likely format (TBD): + - Load the schema helper dataset `richardr1126/spider-schema` and build a + mapping `{db_id -> create_table_context}` using + `text2sql.eval.spider.build_schema_map`. + - For each Spider example, retrieve `create_table_context` by `db_id` and + treat it as a textual schema context. + +3. **Prompt Construction** + - For each example, construct the input section: + ```text ### Schema: - CREATE TABLE table1 (...); - CREATE TABLE table2 (...); - ... + ### Question: ``` - - This should align with the prompt style used during training on - `b-mc2/sql-create-context`. -3. **Prompt Construction** - - Reuse or extend the same formatting utilities used for training: - - `build_prompt(...)`-style function for instruction + input. - - Ensure that the model receives a consistent prompt structure across - training and evaluation. + - Use the same instruction text as internal training/evaluation: + + > Write a SQL query that answers the user's question using ONLY the tables + > and columns provided in the schema. + + - Wrap instruction + input into a full prompt using + `text2sql.eval.spider.build_spider_prompt`, which internally reuses the + training formatter. 4. **Model Inference** - - Load the fine-tuned Mistral-7B-Instruct model with the QLoRA adapters. - - Run the model on Spider dev questions with the appropriate schema - context. - - Decode generated SQL queries. + - Load the fine-tuned Mistral-7B-Instruct model with QLoRA adapters (or a + merged model) via `text2sql.infer.load_model_for_inference`. + - Generate SQL for each prompt using `text2sql.infer.generate_sql`. + - Post-process generated text into clean SQL with + `text2sql.training.formatting.ensure_sql_only`. 5. **Metrics** - - We plan to report (at minimum): - - **Logical form accuracy** (exact match between generated and gold SQL). - - **Execution accuracy** (whether executing the generated SQL matches - the gold answer). - - Where possible, reuse or adapt existing Spider evaluation scripts to - ensure comparability with prior work. + - Compute lightweight logical-form metrics using + `text2sql.eval.metrics.aggregate_metrics`: + - **Exact Match (normalized SQL)**. + - **No-values Exact Match** (string and numeric literals replaced). + - **SQL parse success rate** using `sqlglot`. + - **Schema adherence** (references confined to the serialized schema). + - These are intentionally lightweight and do **not** attempt to reproduce + the full official Spider evaluation protocol. 6. **Reporting** - Summarize: - Overall metrics. - - Per-database performance. - - Qualitative examples (successes and failures). - - Integrate key results into `docs/` and the main `README`. + - Representative examples (successes and failures). + - Write machine-readable JSON and human-readable Markdown reports under + `reports/` (see `docs/evaluation.md` for details). --- -## 4. Implementation Notes (for Task 4) +## 4. Implementation Notes (Task 4) -When we implement this evaluation pipeline (Task 4), we expect to: +Task 4 implemented this evaluation pipeline with the following components: -- Add a dedicated evaluation script under `scripts/` - (e.g., `scripts/eval_spider.py`). -- Add utility functions under `src/text2sql/` for: - - Schema serialization specific to Spider. - - Prompt construction for multi-table schemas. - - Metric computation (exact match, execution accuracy). -- Add tests that: - - Use small Spider-like fixtures to validate serialization and metrics. +- A dedicated evaluation script under `scripts/`: + - `scripts/evaluate_spider_external.py` +- Utility functions under `src/text2sql/eval/`: + - `spider.py` for schema mapping and prompt construction. + - `normalize.py`, `schema.py`, and `metrics.py` shared between internal and + external evaluation. +- Tests that: + - Use small Spider-like fixtures in `tests/fixtures/` to validate prompt + construction and metrics. - Do not require access to the full Spider dataset or a database engine. --- diff --git a/docs/training.md b/docs/training.md index cc71d31..15f3363 100644 --- a/docs/training.md +++ b/docs/training.md @@ -303,12 +303,15 @@ If you encounter CUDA OOM errors: --- -## 6. External Validation (Spider dev) – Planned +## 6. External Validation (Spider dev) -After primary training on `b-mc2/sql-create-context`, we plan to perform +After primary training on `b-mc2/sql-create-context`, we perform **secondary external validation** on the **Spider dev** split (e.g., `xlangai/spider`), which is significantly more challenging (multi-table, cross-domain text-to-SQL). -This will be implemented as a dedicated evaluation pipeline (Task 4). For the -high-level plan, see [`docs/external_validation.md`](./external_validation.md). \ No newline at end of file +This is implemented via the Spider evaluation pipeline (Task 4). For details on +datasets, metrics, and how to run the scripts, see: + +- [`docs/external_validation.md`](./external_validation.md) +- [`docs/evaluation.md`](./evaluation.md) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e8061f2..ca8cfa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,7 @@ bitsandbytes # --- Tooling / testing --- pytest>=8.0.0 -huggingface-hub>=0.36.0 \ No newline at end of file +huggingface-hub>=0.36.0 + +# --- Evaluation helpers --- +sqlglot>=28.5.0 \ No newline at end of file diff --git a/scripts/build_dataset.py b/scripts/build_dataset.py index 2840603..3abfb08 100644 --- a/scripts/build_dataset.py +++ b/scripts/build_dataset.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional import sys -from datasets import Dataset, DatasetDict, load_dataset +from typing import TYPE_CHECKING # Ensure the src/ directory is on sys.path so that `text2sql` can be imported # when this script is run directly via `python scripts/build_dataset.py`. @@ -14,6 +14,9 @@ if str(SRC_DIR) not in sys.path: sys.path.insert(0, str(SRC_DIR)) +if TYPE_CHECKING: # pragma: no cover - import only for type checking + from datasets import Dataset, DatasetDict # noqa: F401 + from text2sql.data_prep import format_record @@ -90,9 +93,9 @@ def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: def load_raw_dataset( input_jsonl: Optional[Path], max_rows: Optional[int] -) -> Dataset: +): """ - Load the raw dataset, either from a local JSONL file or from Hugging Face. + Load the raw dataset, either from a local JSONL file or from Hugging F_codeging Face. Parameters ---------- @@ -106,7 +109,8 @@ def load_raw_dataset( Returns ------- datasets.Dataset - A datasets.Dataset object with at least the keys: question, context, answer. + A datasets.Dataset-like object with at least the keys: question, context, + answer. Raises ------ @@ -127,8 +131,15 @@ def load_raw_dataset( raise RuntimeError( f"No records found in local JSONL file: {input_jsonl}" ) + + # Import datasets lazily so that running with --input_jsonl does not + # require the 'datasets' package to be installed at module import time. + from datasets import Dataset # type: ignore[import] + ds = Dataset.from_list(records) else: + from datasets import DatasetDict, load_dataset # type: ignore[import] + logger.info("Loading dataset '%s' from Hugging Face Datasets...", DATASET_NAME) ds_dict: DatasetDict = load_dataset(DATASET_NAME) if "train" not in ds_dict: @@ -163,8 +174,8 @@ def load_raw_dataset( def split_dataset( - ds: Dataset, val_ratio: float, seed: int -) -> Dict[str, Dataset]: + ds, val_ratio: float, seed: int +) -> Dict[str, Any]: """ Split the dataset into train and validation sets deterministically. @@ -180,7 +191,9 @@ def split_dataset( Returns ------- dict - A dictionary with keys 'train' and 'val', each a datasets.Dataset. + A dictionary with keys 'train' and 'val', each a datasets.Dataset-like + obj_codeecnewt list[Path]: + files: list[Path] = [] + for root in root_dirs: + if not root.is_dir(): + continue + for path in root.rglob("*.py"): + # Skip compiled files and cache dirs just in case. + if "__pycache__" in path.parts: + continue + files.append(path) + return files + + +def main(argv: list[str] | None = None) -> int: + """ + Compile all Python files under src/, scripts/, and app/ to check syntax. + + This uses the Python stdlib `compileall` module and exits non-zero if any + file fails to compile. + """ + project_root = Path(__file__).resolve().parents[1] + candidates = [ + project_root / "src", + project_root / "scripts", + project_root / "app", + ] + + py_files = _discover_python_files(candidates) + if not py_files: + print("No Python files found under src/, scripts/, or app/.") + return 0 + + failures: list[Path] = [] + for path in py_files: + ok = compileall.compile_file( + str(path), + ddir=str(path.parent), + quiet=1, + ) + if not ok: + failures.append(path) + + if failures: + print("Syntax check FAILED for the following files:") + for path in failures: + print(f" - {path}") + print(f"Total: {len(failures)} file(s) with syntax errors.") + return 1 + + print(f"Syntax OK: compiled {len(py_files)} files.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/scripts/evaluate_internal.py b/scripts/evaluate_internal.py new file mode 100644 index 0000000..dd2db78 --- /dev/null +++ b/scripts/evaluate_internal.py @@ -0,0 +1,464 @@ +import argparse +import json +import logging +import os +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +import sys + +# Reduce TensorFlow/CUDA log noise if TensorFlow is installed. +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") + +# Ensure the src/ directory is on sys.path so that `text2sql` can be imported +ROOT = Path(__file__).resolve().parents[1] +SRC_DIR = ROOT / "src" +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + +from text2sql.eval.metrics import aggregate_metrics # noqa: E402 # isort: skip +from text2sql.eval.normalize import normalize_sql # noqa: E402 # isort: skip +from text2sql.training.formatting import ( # noqa: E402 # isort: skip + build_prompt, + ensure_sql_only, +) + + +logger = logging.getLogger(__name__) + + +@dataclass +class EvalConfig: + val_path: str + base_model: str + adapter_dir: Optional[str] + device: str + max_examples: int + out_dir: str + temperature: float + top_p: float + max_new_tokens: int + mock: bool + load_in_4bit: Optional[bool] = None + dtype: str = "auto" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def configure_logging() -> None: + """Configure basic logging for the evaluation script.""" + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", + ) + + +def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: + """Parse command-line arguments for internal evaluation.""" + parser = argparse.ArgumentParser( + description="Evaluate a text-to-SQL model on the internal validation set." + ) + + parser.add_argument( + "--val_path", + type=str, + default="data/processed/val.jsonl", + help="Path to the Alpaca-style validation JSONL file.", + ) + parser.add_argument( + "--base_model", + type=str, + default="mistralai/Mistral-7B-Instruct-v0.1", + help="Base model name or path for inference.", + ) + parser.add_argument( + "--adapter_dir", + type=str, + default=None, + help=( + "Path to LoRA adapter directory. If omitted, the script will run " + "with the base model only (or a merged model if base_model points " + "to a local directory)." + ), + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device to use: 'auto', 'cuda', or 'cpu'.", + ) + parser.add_argument( + "--max_examples", + type=int, + default=200, + help="Maximum number of validation examples to evaluate.", + ) + parser.add_argument( + "--out_dir", + type=str, + default="reports", + help="Output directory for reports (JSON and Markdown).", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature for generation (0.0 for greedy).", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top-p sampling parameter for generation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=256, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--load_in_4bit", + action="store_true", + default=None, + help=( + "If set, force loading the base model in 4-bit (bitsandbytes) for " + "faster and more memory-efficient inference. By default this is " + "enabled automatically when running on CUDA and disabled on CPU." + ), + ) + parser.add_argument( + "--no_load_in_4bit", + action="store_false", + dest="load_in_4bit", + help="Disable 4-bit loading even when running on CUDA.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float16", "bfloat16", "float32"], + help=( + "Model dtype for base weights. 'auto' selects float16 on CUDA and " + "float32 on CPU." + ), + ) + parser.add_argument( + "--mock", + action="store_true", + help=( + "Run in mock mode: skip model loading and use gold SQL as " + "predictions to validate the metric pipeline." + ), + ) + + return parser.parse_args(argv) + + +def _load_alpaca_jsonl(path: Path) -> List[Dict[str, Any]]: + """Load an Alpaca-style JSONL file into a list of dicts.""" + if not path.is_file(): + raise FileNotFoundError(f"Validation JSONL file not found: {path}") + + records: List[Dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + records.append(json.loads(line)) + + if not records: + raise RuntimeError(f"No records found in validation file: {path}") + + return records + + +def _extract_schema_and_question(input_text: str) -> Tuple[str, str]: + """ + Extract schema context and question from the formatted input text. + + The expected structure is: + + ### Schema: + + + ### Question: + + """ + schema = "" + question = "" + + if not input_text: + return schema, question + + schema_marker = "### Schema:" + question_marker = "### Question:" + + text = input_text + + if schema_marker in text: + after_schema = text.split(schema_marker, 1)[1] + else: + after_schema = text + + if question_marker in after_schema: + schema_part, question_part = after_schema.split(question_marker, 1) + schema = schema_part.strip() + question = question_part.strip() + else: + schema = after_schema.strip() + question = "" + + return schema, question + + +def _write_json_report( + out_path: Path, + config: EvalConfig, + metrics: Dict[str, Any], + examples: List[Dict[str, Any]], +) -> None: + """Write a JSON report with metrics, config, and example predictions.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "config": config.to_dict(), + "metrics": metrics, + "examples": examples, + "generated_at_utc": datetime.now(timezone.utc).isoformat(), + } + with out_path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + + +def _write_markdown_report( + out_path: Path, + config: EvalConfig, + metrics: Dict[str, Any], + examples: List[Dict[str, Any]], +) -> None: + """Write a human-readable Markdown evaluation report.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + + n_examples = metrics.get("n_examples", 0) + em = metrics.get("exact_match", {}) + nvem = metrics.get("no_values_em", {}) + parse = metrics.get("parse_success", {}) + schema = metrics.get("schema_adherence", {}) + + def _fmt_rate(entry: Dict[str, Any]) -> str: + count = entry.get("count", 0) + rate = entry.get("rate", 0.0) + return f"{count}/{n_examples} ({rate:.3f})" + + lines: List[str] = [] + + lines.append("# Internal Evaluation – b-mc2/sql-create-context val\n") + lines.append("## Configuration\n") + lines.append(f"- **val_path:** `{config.val_path}`") + lines.append(f"- **base_model:** `{config.base_model}`") + lines.append(f"- **adapter_dir:** `{config.adapter_dir or 'None (base/merged model only)'}`") + lines.append(f"- **device:** `{config.device}`") + lines.append(f"- **max_examples:** `{config.max_examples}`") + lines.append(f"- **temperature:** `{config.temperature}`") + lines.append(f"- **top_p:** `{config.top_p}`") + lines.append(f"- **max_new_tokens:** `{config.max_new_tokens}`") + lines.append(f"- **mock:** `{config.mock}`") + lines.append(f"- **load_in_4bit:** `{config.load_in_4bit}`") + lines.append(f"- **dtype:** `{config.dtype}`") + lines.append(f"- **n_evaluated_examples:** `{n_examples}`\n") + + lines.append("## Metrics\n") + lines.append(f"- **Exact Match (normalized SQL):** {_fmt_rate(em)}") + lines.append(f"- **No-values Exact Match:** {_fmt_rate(nvem)}") + lines.append(f"- **SQL parse success rate:** {_fmt_rate(parse)}") + if schema: + lines.append(f"- **Schema adherence rate:** {_fmt_rate(schema)}") + lines.append("") + + lines.append("## Notes\n") + if config.mock: + lines.append( + "- This run used `--mock`, so predictions are set equal to gold SQL to " + "validate the evaluation pipeline. Metrics should be near 1.0 except " + "for parser robustness." + ) + else: + lines.append( + "- Exact Match and No-values EM are computed on normalized SQL strings " + "(whitespace collapsed, trailing semicolons removed)." + ) + lines.append( + "- No-values EM further replaces string and numeric literals with " + "placeholders, focusing on query structure rather than concrete values." + ) + lines.append("") + + lines.append("## Example Predictions\n") + if not examples: + lines.append("_No examples available._") + else: + for idx, ex in enumerate(examples, start=1): + lines.append(f"### Example {idx}\n") + lines.append(f"- **id:** `{ex.get('id', '')}`") + lines.append(f"- **Question:** {ex.get('question', '').strip()}") + schema_snippet = ex.get("schema_snippet", "") + lines.append(f"- **Schema snippet:** `{schema_snippet}`") + lines.append(f"- **Gold SQL:** `{ex.get('gold_sql', '').strip()}`") + lines.append(f"- **Predicted SQL:** `{ex.get('pred_sql', '').strip()}`") + lines.append("") + + with out_path.open("w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +def run_eval(args: argparse.Namespace) -> int: + """Execute the internal evaluation pipeline.""" + val_path = Path(args.val_path) + records = _load_alpaca_jsonl(val_path) + + if args.max_examples is not None and args.max_examples > 0: + records = records[: args.max_examples] + + logger.info( + "Loaded %d validation records from %s (max_examples=%d).", + len(records), + val_path, + args.max_examples, + ) + + eval_config = EvalConfig( + val_path=str(val_path), + base_model=args.base_model, + adapter_dir=args.adapter_dir, + device=args.device, + max_examples=args.max_examples, + out_dir=args.out_dir, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + mock=args.mock, + load_in_4bit=args.load_in_4bit, + dtype=args.dtype, + ) + + gold_sqls: List[str] = [] + pred_sqls: List[str] = [] + contexts: List[str] = [] + questions: List[str] = [] + example_ids: List[str] = [] + + for rec in records: + input_text = rec.get("input", "") + output_sql = rec.get("output", "") + rec_id = rec.get("id", "") + + schema, question = _extract_schema_and_question(input_text) + contexts.append(schema) + gold_sqls.append(output_sql) + questions.append(question) + example_ids.append(str(rec_id)) + + if args.mock: + logger.info("Running in --mock mode; using gold SQL as predictions.") + pred_sqls = list(gold_sqls) + else: + # Import inference helpers lazily so that --mock mode does not require + # heavy runtime dependencies like torch to be installed. + from text2sql.infer import ( # type: ignore[import] + generate_sql, + load_model_for_inference, + ) + + logger.info( + "Loading model for inference with base_model=%s, adapter_dir=%s, " + "device=%s, load_in_4bit=%s, dtype=%s", + args.base_model, + args.adapter_dir, + args.device, + args.load_in_4bit, + args.dtype, + ) + load_model_for_inference( + base_model=args.base_model, + adapter_dir=args.adapter_dir, + device=args.device, + load_in_4bit=args.load_in_4bit, + bnb_compute_dtype="float16", + dtype=args.dtype, + ) + + for idx, rec in enumerate(records): + instruction = rec.get("instruction", "") + input_text = rec.get("input", "") + prompt = build_prompt(instruction=instruction, input=input_text) + + logger.info("Generating prediction for example %d/%d", idx + 1, len(records)) + raw_output = generate_sql( + prompt=prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + ) + cleaned_sql = ensure_sql_only(raw_output) + pred_sqls.append(cleaned_sql) + + metrics = aggregate_metrics( + predictions=pred_sqls, + golds=gold_sqls, + contexts=contexts, + compute_schema_adherence=True, + ) + logger.info("Evaluation metrics: %s", metrics) + + # Build up to 10 example entries for the reports. + example_entries: List[Dict[str, Any]] = [] + num_examples_to_show = min(10, len(records)) + for i in range(num_examples_to_show): + schema = contexts[i] + schema_snippet = schema.replace("\n", " ") + if len(schema_snippet) > 200: + schema_snippet = schema_snippet[:197] + "..." + + example_entries.append( + { + "id": example_ids[i], + "question": questions[i], + "schema_snippet": schema_snippet, + "gold_sql": normalize_sql(gold_sqls[i]), + "pred_sql": normalize_sql(pred_sqls[i]), + } + ) + + out_dir = Path(args.out_dir) + json_path = out_dir / "eval_internal.json" + md_path = out_dir / "eval_internal.md" + + _write_json_report(json_path, eval_config, metrics, example_entries) + _write_markdown_report(md_path, eval_config, metrics, example_entries) + + logger.info("Internal evaluation reports written to %s and %s", json_path, md_path) + return 0 + + +def main(argv: Optional[List[str]] = None) -> int: + configure_logging() + args = parse_args(argv) + + try: + return run_eval(args) + except FileNotFoundError as exc: + logger.error("File not found: %s", exc) + return 1 + except (RuntimeError, ValueError) as exc: + logger.error("Evaluation failed: %s", exc) + return 1 + except Exception: # noqa: BLE001 + logger.error("Unexpected error during evaluation.", exc_info=True) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/scripts/evaluate_spider_external.py b/scripts/evaluate_spider_external.py new file mode 100644 index 0000000..2ce6f28 --- /dev/null +++ b/scripts/evaluate_spider_external.py @@ -0,0 +1,596 @@ +import argparse +import json +import logging +import os +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +import sys + +# Reduce TensorFlow/CUDA log noise if TensorFlow is installed. +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") + +# Ensure the src/ directory is on sys.path so that `text2sql` can be imported +ROOT = Path(__file__).resolve().parents[1] +SRC_DIR = ROOT / "src" +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + +from text2sql.eval.metrics import aggregate_metrics # noqa: E402 # isort: skip +from text2sql.eval.normalize import normalize_sql # noqa: E402 # isort: skip +from text2sql.eval.schema import parse_create_table_context # noqa: E402 # isort: skip +from text2sql.eval.spider import ( # noqa: E402 # isort: skip + build_schema_map, + build_spider_prompt, + load_spider_schema_map, + spider_schema_to_pseudo_ddl, +) +from text2sql.training.formatting import ensure_sql_only # noqa: E402 # isort: skip + + +logger = logging.getLogger(__name__) + + +@dataclass +class SpiderEvalConfig: + base_model: str + adapter_dir: Optional[str] + device: str + spider_split: str + spider_source: str + schema_source: str + max_examples: int + out_dir: str + temperature: float + top_p: float + max_new_tokens: int + mock: bool + load_in_4bit: Optional[bool] = None + dtype: str = "auto" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def configure_logging() -> None: + """Configure basic logging for the evaluation script.""" + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", + ) + + +def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: + """Parse command-line arguments for Spider external evaluation.""" + parser = argparse.ArgumentParser( + description="Evaluate a text-to-SQL model on the Spider dev set (external validation)." + ) + + parser.add_argument( + "--base_model", + type=str, + default="mistralai/Mistral-7B-Instruct-v0.1", + help="Base model name or path for inference.", + ) + parser.add_argument( + "--adapter_dir", + type=str, + default=None, + help=( + "Path to LoRA adapter directory. If omitted, the script will run " + "with the base model only (or a merged model if base_model points " + "to a local directory)." + ), + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device to use: 'auto', 'cuda', or 'cpu'.", + ) + parser.add_argument( + "--spider_split", + type=str, + default="validation", + help="Spider split to evaluate on (e.g., 'validation').", + ) + parser.add_argument( + "--max_examples", + type=int, + default=200, + help="Maximum number of Spider examples to evaluate.", + ) + parser.add_argument( + "--out_dir", + type=str, + default="reports", + help="Output directory for reports (JSON and Markdown).", + ) + parser.add_argument( + "--schema_source", + type=str, + default="richardr1126/spider-schema", + help="Hugging Face dataset id for Spider schemas.", + ) + parser.add_argument( + "--spider_source", + type=str, + default="xlangai/spider", + help="Hugging Face dataset id for Spider examples.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature for generation (0.0 for greedy).", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top-p sampling parameter for generation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=256, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--load_in_4bit", + action="store_true", + default=None, + help=( + "If set, force loading the base model in 4-bit (bitsandbytes) for " + "faster and more memory-efficient inference. By default this is " + "enabled automatically when running on CUDA and disabled on CPU." + ), + ) + parser.add_argument( + "--no_load_in_4bit", + action="store_false", + dest="load_in_4bit", + help="Disable 4-bit loading even when running on CUDA.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float16", "bfloat16", "float32"], + help=( + "Model dtype for base weights. 'auto' selects float16 on CUDA and " + "float32 on CPU." + ), + ) + parser.add_argument( + "--mock", + action="store_true", + help=( + "Run in mock mode: load small local Spider fixtures from tests/fixtures " + "and use gold SQL as predictions to validate prompt building and metrics." + ), + ) + + return parser.parse_args(argv) + + +def _load_jsonl(path: Path) -> List[Dict[str, Any]]: + """Load a JSONL file into a list of dicts.""" + if not path.is_file(): + raise FileNotFoundError(f"JSONL file not found: {path}") + rows: List[Dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if not rows: + raise RuntimeError(f"No records found in JSONL file: {path}") + return rows + + +def _build_examples_with_schema( + spider_rows: List[Dict[str, Any]], + schema_map: Dict[str, str], +) -> List[Dict[str, Any]]: + """ + Join Spider examples with their schemas and log intersection diagnostics. + """ + spider_db_ids = {str(row["db_id"]) for row in spider_rows if "db_id" in row} + schema_db_ids = set(schema_map.keys()) + intersection = spider_db_ids & schema_db_ids + + logger.info("Total Spider examples: %d", len(spider_rows)) + logger.info("Unique Spider db_ids: %d", len(spider_db_ids)) + logger.info("Total schema db_ids: %d", len(schema_db_ids)) + logger.info("Intersection size (db_ids in both): %d", len(intersection)) + + if not intersection: + spider_samples = sorted(spider_db_ids)[:10] + schema_samples = sorted(schema_db_ids)[:10] + logger.error("No matching db_ids between Spider split and schema source.") + logger.error("Sample Spider db_ids: %s", spider_samples) + logger.error("Sample schema db_ids: %s", schema_samples) + raise RuntimeError( + "No matching db_ids between Spider split and schema source" + ) + + examples: List[Dict[str, Any]] = [] + skipped_due_to_missing_schema = 0 + + for row in spider_rows: + db_id = row.get("db_id") + question = row.get("question") + query = row.get("query") + if db_id is None or question is None or query is None: + continue + + db_id_str = str(db_id) + schema_text = schema_map.get(db_id_str) + if not schema_text: + skipped_due_to_missing_schema += 1 + continue + + examples.append( + { + "db_id": db_id_str, + "question": str(question), + "query": str(query), + "schema_text": str(schema_text), + } + ) + + logger.info( + "After joining with schema: evaluated_count=%d, " + "skipped_due_to_missing_schema=%d", + len(examples), + skipped_due_to_missing_schema, + ) + return examples + + +def _load_spider_from_fixtures() -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + """ + Load Spider examples and schema mapping from local test fixtures. + + This is used when running in --mock mode so that tests and local runs do + not require internet access or full Spider downloads. + """ + fixtures_dir = ROOT / "tests" / "fixtures" + spider_path = fixtures_dir / "spider_sample.jsonl" + schema_path = fixtures_dir / "spider_schema_sample.jsonl" + + spider_rows = _load_jsonl(spider_path) + schema_records = _load_jsonl(schema_path) + schema_map = load_spider_schema_map(schema_records) + + examples = _build_examples_with_schema(spider_rows, schema_map) + return examples, schema_map + + +def _write_json_report( + out_path: Path, + config: SpiderEvalConfig, + metrics: Dict[str, Any], + examples: List[Dict[str, Any]], +) -> None: + """Write a JSON report with metrics, config, and example predictions.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "config": config.to_dict(), + "metrics": metrics, + "examples": examples, + "generated_at_utc": datetime.now(timezone.utc).isoformat(), + } + with out_path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + + +def _write_markdown_report( + out_path: Path, + config: SpiderEvalConfig, + metrics: Dict[str, Any], + examples: List[Dict[str, Any]], +) -> None: + """Write a human-readable Markdown evaluation report for Spider.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + + n_examples = metrics.get("n_examples", 0) + em = metrics.get("exact_match", {}) + nvem = metrics.get("no_values_em", {}) + parse = metrics.get("parse_success", {}) + schema = metrics.get("schema_adherence", {}) + + def _fmt_rate(entry: Dict[str, Any]) -> str: + count = entry.get("count", 0) + rate = entry.get("rate", 0.0) + return f"{count}/{n_examples} ({rate:.3f})" + + lines: List[str] = [] + + lines.append("# External Evaluation – Spider dev (lightweight)\n") + lines.append("## Configuration\n") + lines.append(f"- **spider_source:** `{config.spider_source}`") + lines.append(f"- **schema_source:** `{config.schema_source}`") + lines.append(f"- **spider_split:** `{config.spider_split}`") + lines.append(f"- **base_model:** `{config.base_model}`") + lines.append(f"- **adapter_dir:** `{config.adapter_dir or 'None (base/merged model only)'}`") + lines.append(f"- **device:** `{config.device}`") + lines.append(f"- **max_examples:** `{config.max_examples}`") + lines.append(f"- **temperature:** `{config.temperature}`") + lines.append(f"- **top_p:** `{config.top_p}`") + lines.append(f"- **max_new_tokens:** `{config.max_new_tokens}`") + lines.append(f"- **mock:** `{config.mock}`") + lines.append(f"- **n_evaluated_examples:** `{n_examples}`\n") + + lines.append("## Metrics\n") + lines.append(f"- **Exact Match (normalized SQL):** {_fmt_rate(em)}") + lines.append(f"- **No-values Exact Match:** {_fmt_rate(nvem)}") + lines.append(f"- **SQL parse success rate:** {_fmt_rate(parse)}") + if schema: + lines.append(f"- **Schema adherence rate:** {_fmt_rate(schema)}") + lines.append("") + + lines.append("## Notes\n") + lines.append( + "- This is a lightweight Spider external validation intended as a " + "portfolio-style baseline, not a full reproduction of official Spider " + "evaluation." + ) + lines.append( + "- Official Spider metrics include component matching and execution-based " + "evaluation. Here we report simple logical-form approximations: Exact " + "Match, No-values EM, and parse success." + ) + lines.append( + "- Schema adherence is computed by checking that predicted queries only " + "reference tables and columns present in the serialized schema context." + ) + if config.mock: + lines.append( + "- This run used local fixtures and `--mock`, so predictions are set " + "equal to gold SQL to validate prompt construction and metric logic." + ) + lines.append("") + + lines.append("## Example Predictions\n") + if not examples: + lines.append("_No examples available._") + else: + for idx, ex in enumerate(examples, start=1): + lines.append(f"### Example {idx}\n") + lines.append(f"- **db_id:** `{ex.get('db_id', '')}`") + lines.append(f"- **Question:** {ex.get('question', '').strip()}") + schema_snippet = ex.get("schema_snippet", "") + lines.append(f"- **Schema snippet:** `{schema_snippet}`") + lines.append(f"- **Gold SQL:** `{ex.get('gold_sql', '').strip()}`") + lines.append(f"- **Predicted SQL:** `{ex.get('pred_sql', '').strip()}`") + lines.append("") + + with out_path.open("w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +def _load_spider_from_hub( + spider_source: str, + spider_split: str, + schema_source: str, + max_examples: int, +) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + """ + Load Spider examples and schema mapping from Hugging Face datasets. + + Parameters + ---------- + spider_source : str + HF dataset id for Spider examples. + spider_split : str + Split name for Spider examples (e.g., 'validation'). + schema_source : str + HF dataset id for Spider schema helper. + max_examples : int + Maximum number of examples to keep. + + Returns + ------- + (examples, schema_map) + examples: list of records with at least db_id, question, query and + attached schema text. + schema_map: mapping {db_id -> raw schema text}. + """ + from datasets import load_dataset # Imported lazily to keep tests lightweight. + + logger.info( + "Loading Spider dataset '%s' (split=%s) and schema '%s' from Hugging Face.", + spider_source, + spider_split, + schema_source, + ) + + spider_ds = load_dataset(spider_source, split=spider_split) + schema_ds = load_dataset(schema_source, split="train") + + schema_map = load_spider_schema_map(schema_ds) + + # Materialise Spider rows so we can compute intersections and slice. + spider_rows: List[Dict[str, Any]] = [dict(row) for row in spider_ds] # type: ignore[arg-type] + + examples = _build_examples_with_schema(spider_rows, schema_map) + + if max_examples is not None and max_examples > 0: + examples = examples[:max_examples] + + logger.info( + "Loaded %d Spider examples with matching schema entries.", len(examples) + ) + return examples, schema_map + + +def run_eval(args: argparse.Namespace) -> int: + """Execute the Spider external evaluation pipeline.""" + if args.mock: + logger.info("Running Spider evaluation in --mock mode using local fixtures.") + examples, schema_map = _load_spider_from_fixtures() + else: + logger.info( + "Running Spider evaluation with spider_source=%s, schema_source=%s, split=%s", + args.spider_source, + args.schema_source, + args.spider_split, + ) + examples, schema_map = _load_spider_from_hub( + spider_source=args.spider_source, + spider_split=args.spider_split, + schema_source=args.schema_source, + max_examples=args.max_examples, + ) + + if not examples: + raise RuntimeError("No Spider examples available for evaluation.") + + if args.max_examples is not None and args.max_examples > 0: + examples = examples[: args.max_examples] + + eval_config = SpiderEvalConfig( + base_model=args.base_model, + adapter_dir=args.adapter_dir, + device=args.device, + spider_split=args.spider_split, + spider_source=args.spider_source, + schema_source=args.schema_source, + max_examples=args.max_examples, + out_dir=args.out_dir, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + mock=args.mock, + ) + + gold_sqls: List[str] = [] + pred_sqls: List[str] = [] + contexts: List[str] = [] + questions: List[str] = [] + db_ids: List[str] = [] + + for ex in examples: + db_id = ex["db_id"] + question = ex["question"] + gold_query = ex["query"] + raw_schema_text = ex.get("schema_text", "") + schema_context = spider_schema_to_pseudo_ddl(raw_schema_text) + + db_ids.append(db_id) + questions.append(question) + gold_sqls.append(gold_query) + contexts.append(schema_context) + + if not db_ids: + raise RuntimeError("After filtering, no Spider examples had matching schemas.") + + if args.mock: + logger.info("Using gold SQL as predictions in --mock mode.") + pred_sqls = list(gold_sqls) + else: + # Import inference helpers lazily so that --mock mode does not require + # heavy runtime dependencies like torch to be installed. + from text2sql.infer import ( # type: ignore[import] + generate_sql, + load_model_for_inference, + ) + + logger.info( + "Loading model for inference with base_model=%s, adapter_dir=%s, " + "device=%s, load_in_4bit=%s, dtype=%s", + args.base_model, + args.adapter_dir, + args.device, + args.load_in_4bit, + args.dtype, + ) + load_model_for_inference( + base_model=args.base_model, + adapter_dir=args.adapter_dir, + device=args.device, + load_in_4bit=args.load_in_4bit, + bnb_compute_dtype="float16", + dtype=args.dtype, + ) + + for idx, (db_id, question, schema_context) in enumerate( + zip(db_ids, questions, contexts) + ): + prompt = build_spider_prompt(schema_context=schema_context, question=question) + logger.info( + "Generating prediction for Spider example %d/%d (db_id=%s)", + idx + 1, + len(db_ids), + db_id, + ) + raw_output = generate_sql( + prompt=prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + ) + cleaned_sql = ensure_sql_only(raw_output) + pred_sqls.append(cleaned_sql) + + metrics = aggregate_metrics( + predictions=pred_sqls, + golds=gold_sqls, + contexts=contexts, + compute_schema_adherence=True, + ) + logger.info("Spider evaluation metrics: %s", metrics) + + # Build up to 10 example entries for the reports. + example_entries: List[Dict[str, Any]] = [] + num_examples_to_show = min(10, len(db_ids)) + for i in range(num_examples_to_show): + schema_context = contexts[i] + schema_snippet = schema_context.replace("\n", " ") + if len(schema_snippet) > 200: + schema_snippet = schema_snippet[:197] + "..." + + example_entries.append( + { + "db_id": db_ids[i], + "question": questions[i], + "schema_snippet": schema_snippet, + "gold_sql": normalize_sql(gold_sqls[i]), + "pred_sql": normalize_sql(pred_sqls[i]), + } + ) + + out_dir = Path(args.out_dir) + json_path = out_dir / "eval_spider.json" + md_path = out_dir / "eval_spider.md" + + _write_json_report(json_path, eval_config, metrics, example_entries) + _write_markdown_report(md_path, eval_config, metrics, example_entries) + + logger.info("Spider external evaluation reports written to %s and %s", json_path, md_path) + return 0 + + +def main(argv: Optional[List[str]] = None) -> int: + configure_logging() + args = parse_args(argv) + + try: + return run_eval(args) + except FileNotFoundError as exc: + logger.error("File not found: %s", exc) + return 1 + except (RuntimeError, ValueError) as exc: + logger.error("Spider evaluation failed: %s", exc) + return 1 + except Exception: # noqa: BLE001 + logger.error("Unexpected error during Spider evaluation.", exc_info=True) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/src/text2sql/eval/__init__.py b/src/text2sql/eval/__init__.py new file mode 100644 index 0000000..6ae4205 --- /dev/null +++ b/src/text2sql/eval/__init__.py @@ -0,0 +1,23 @@ +""" +Evaluation utilities for the Analytics Copilot (Text-to-SQL) project. + +This subpackage contains: +- SQL normalization helpers (normalize_sql, normalize_sql_no_values). +- Schema parsing and adherence checks for CREATE TABLE context. +- Aggregate metrics for text-to-SQL evaluation. +- Spider-specific prompt construction helpers. +""" + +from .normalize import normalize_sql, normalize_sql_no_values +from .schema import parse_create_table_context, referenced_identifiers, schema_adherence +from .metrics import exact_match, aggregate_metrics + +__all__ = [ + "normalize_sql", + "normalize_sql_no_values", + "parse_create_table_context", + "referenced_identifiers", + "schema_adherence", + "exact_match", + "aggregate_metrics", +] \ No newline at end of file diff --git a/src/text2sql/eval/metrics.py b/src/text2sql/eval/metrics.py new file mode 100644 index 0000000..86fa4e5 --- /dev/null +++ b/src/text2sql/eval/metrics.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import Dict, Iterable, List, Mapping, Optional, Sequence + +import sqlglot + +from .normalize import normalize_sql, normalize_sql_no_values +from .schema import schema_adherence as _schema_adherence + + +def exact_match(pred: str, gold: str) -> bool: + """ + Exact match on normalized SQL strings. + + Both prediction and gold SQL are normalized with `normalize_sql` before + comparison to make the metric robust to trivial formatting differences + (whitespace, trailing semicolons). + """ + return normalize_sql(pred) == normalize_sql(gold) + + +def _parse_success(sql: str) -> bool: + """ + Return True if sqlglot can parse the given SQL string. + + This is used as a lightweight proxy for syntactic validity of the query. + """ + if not sql or not sql.strip(): + return False + try: + sqlglot.parse_one(sql) + return True + except Exception: # noqa: BLE001 + return False + + +def aggregate_metrics( + predictions: Sequence[str], + golds: Sequence[str], + *, + contexts: Optional[Sequence[str]] = None, + compute_schema_adherence: bool = False, +) -> Dict[str, object]: + """ + Aggregate core text-to-SQL evaluation metrics over a set of examples. + + Parameters + ---------- + predictions : Sequence[str] + Model-predicted SQL strings. + golds : Sequence[str] + Gold/reference SQL strings. + contexts : Optional[Sequence[str]], optional + Schema context strings (e.g., CREATE TABLE statements). Required if + `compute_schema_adherence` is True. + compute_schema_adherence : bool, optional + Whether to compute the schema adherence rate using the provided + contexts, by default False. + + Returns + ------- + dict + Dictionary containing counts and rates for: + - n_examples + - exact_match + - no_values_em + - parse_success + - schema_adherence (if requested) + """ + if len(predictions) != len(golds): + raise ValueError( + f"Expected predictions and golds to have the same length, " + f"got {len(predictions)} and {len(golds)}." + ) + + if compute_schema_adherence: + if contexts is None: + raise ValueError( + "contexts must be provided when compute_schema_adherence is True." + ) + if len(contexts) != len(predictions): + raise ValueError( + f"Expected contexts to have the same length as predictions, " + f"got {len(contexts)} and {len(predictions)}." + ) + + n_examples = len(predictions) + if n_examples == 0: + return { + "n_examples": 0, + "exact_match": {"count": 0, "rate": 0.0}, + "no_values_em": {"count": 0, "rate": 0.0}, + "parse_success": {"count": 0, "rate": 0.0}, + "schema_adherence": {"count": 0, "rate": 0.0} + if compute_schema_adherence + else None, + } + + em_count = 0 + no_values_em_count = 0 + parse_success_count = 0 + schema_adherence_count = 0 + + for idx, (pred, gold) in enumerate(zip(predictions, golds)): + if exact_match(pred, gold): + em_count += 1 + + pred_no_vals = normalize_sql_no_values(pred) + gold_no_vals = normalize_sql_no_values(gold) + if pred_no_vals == gold_no_vals: + no_values_em_count += 1 + + if _parse_success(pred): + parse_success_count += 1 + + if compute_schema_adherence and contexts is not None: + ctx = contexts[idx] + if _schema_adherence(pred, ctx): + schema_adherence_count += 1 + + def _rate(count: int) -> float: + return float(count) / float(n_examples) if n_examples else 0.0 + + result: Dict[str, object] = { + "n_examples": n_examples, + "exact_match": {"count": em_count, "rate": _rate(em_count)}, + "no_values_em": { + "count": no_values_em_count, + "rate": _rate(no_values_em_count), + }, + "parse_success": { + "count": parse_success_count, + "rate": _rate(parse_success_count), + }, + } + + if compute_schema_adherence: + result["schema_adherence"] = { + "count": schema_adherence_count, + "rate": _rate(schema_adherence_count), + } + + return result \ No newline at end of file diff --git a/src/text2sql/eval/normalize.py b/src/text2sql/eval/normalize.py new file mode 100644 index 0000000..0054cf4 --- /dev/null +++ b/src/text2sql/eval/normalize.py @@ -0,0 +1,81 @@ +import re +from typing import Final + +# Placeholders used when stripping literal values from SQL. +_NUM_PLACEHOLDER: Final[str] = "__NUM__" +_STR_PLACEHOLDER: Final[str] = "__STR__" + + +def normalize_sql(sql: str) -> str: + """ + Normalize a SQL string for string-based comparison. + + The normalization is intentionally conservative and focuses on: + - Stripping leading/trailing whitespace. + - Removing trailing semicolons. + - Collapsing runs of whitespace (spaces, tabs, newlines) into a single space. + + Parameters + ---------- + sql : str + Raw SQL string. + + Returns + ------- + str + Normalized SQL string. + """ + if sql is None: + return "" + + text = sql.strip() + if not text: + return "" + + # Remove one or more trailing semicolons plus any following whitespace. + text = re.sub(r";\s*$", "", text) + + # Collapse all whitespace (spaces, tabs, newlines) into a single space. + text = re.sub(r"\s+", " ", text) + + return text.strip() + + +def normalize_sql_no_values(sql: str) -> str: + """ + Normalize a SQL string while stripping literal values. + + This function builds on `normalize_sql` and additionally: + - Replaces single-quoted string literals with a placeholder. + - Replaces numeric literals (integers and decimals, optionally negative) + with a placeholder. + + The goal is to support "no-values" exact match evaluations that ignore + concrete literal values but still compare query structure. + + Parameters + ---------- + sql : str + Raw SQL string. + + Returns + ------- + str + Normalized SQL string with literal values replaced by placeholders. + """ + text = normalize_sql(sql) + if not text: + return "" + + # Replace single-quoted string literals, e.g., 'Alice' -> '__STR__'. + # This is a best-effort approach and does not handle all SQL dialect edge cases. + text = re.sub(r"'([^']*)'", f"'{_STR_PLACEHOLDER}'", text) + + # Replace numeric literals, e.g., 42, -3.14 -> '__NUM__'. + # We use word boundaries to avoid touching identifiers like col1 or t2_name. + text = re.sub(r"\b-?\d+(?:\.\d+)?\b", _NUM_PLACEHOLDER, text) + + # Collapse whitespace again in case substitutions introduced irregular spacing. + text = re.sub(r"\s+", " ", text) + + return text.strip() \ No newline at end of file diff --git a/src/text2sql/eval/schema.py b/src/text2sql/eval/schema.py new file mode 100644 index 0000000..39eebc6 --- /dev/null +++ b/src/text2sql/eval/schema.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Dict, Iterable, Mapping, MutableMapping, Set + +import sqlglot +from sqlglot import expressions as exp + + +def parse_create_table_context( + context: str, +) -> Dict[str, object]: + """ + Parse a CREATE TABLE schema context into tables and columns. + + Parameters + ---------- + context : str + One or more CREATE TABLE statements, typically the schema context from + b-mc2/sql-create-context or the Spider schema helper dataset. + + Returns + ------- + dict + A dictionary with: + - "tables": set[str] of table names (lowercased). + - "columns_by_table": dict[str, set[str]] mapping table -> set of + column names (lowercased). + """ + tables: Set[str] = set() + columns_by_table: MutableMapping[str, Set[str]] = defaultdict(set) + + if not context or not context.strip(): + return {"tables": tables, "columns_by_table": dict(columns_by_table)} + + try: + statements: Iterable[exp.Expression] = sqlglot.parse(context) + except Exception: # noqa: BLE001 + # Best-effort: if parsing fails entirely, return empty schema so that + # adherence checks can degrade gracefully. + return {"tables": tables, "columns_by_table": dict(columns_by_table)} + + for stmt in statements: + if not isinstance(stmt, exp.Create): + continue + + # For CREATE TABLE, sqlglot represents the schema as stmt.this (Schema), + # whose .this is the Table expression and .expressions are ColumnDef nodes. + schema_expr = stmt.this + if not isinstance(schema_expr, exp.Schema): + continue + + table_expr = schema_expr.this + if not isinstance(table_expr, exp.Table): + continue + + table_name = table_expr.name + if not table_name: + continue + + table_key = table_name.lower() + tables.add(table_key) + + for col_def in schema_expr.expressions: + if isinstance(col_def, exp.ColumnDef) and col_def.this is not None: + col_ident = col_def.this + col_name = getattr(col_ident, "this", None) + if col_name: + columns_by_table[table_key].add(str(col_name).lower()) + + return {"tables": tables, "columns_by_table": dict(columns_by_table)} + + +def referenced_identifiers(sql: str) -> Dict[str, Set[str]]: + """ + Extract referenced table and column identifiers from a SQL query. + + Parameters + ---------- + sql : str + SQL query string. + + Returns + ------- + dict + Dictionary with: + - "tables": set[str] of referenced table names (lowercased). + - "columns": set[str] of referenced column names (lowercased). + """ + tables: Set[str] = set() + columns: Set[str] = set() + + if not sql or not sql.strip(): + return {"tables": tables, "columns": columns} + + try: + expr = sqlglot.parse_one(sql) + except Exception: # noqa: BLE001 + # If parsing fails, we conservatively return empty sets; callers can + # combine this with parse_success metrics separately. + return {"tables": tables, "columns": columns} + + for table_expr in expr.find_all(exp.Table): + name = table_expr.name + if name: + tables.add(name.lower()) + + for col_expr in expr.find_all(exp.Column): + # sqlglot Column expressions expose a `.name` property with the column + # name, independent of table/alias qualification. + col_name = getattr(col_expr, "name", None) + if col_name: + columns.add(str(col_name).lower()) + + return {"tables": tables, "columns": columns} + + +def schema_adherence(sql: str, context: str) -> bool: + """ + Check whether a SQL query only references tables/columns in a given schema. + + The check is intentionally conservative and operates purely at the level + of identifier names: + - A table reference is considered valid if its (lowercased) name appears + among the tables parsed from the CREATE TABLE context. + - A column reference is considered valid if its (lowercased) name appears + in the union of all column sets from the parsed context. + + If the schema cannot be parsed and the query references identifiers, this + function returns False. If the query references no identifiers, adherence + is considered True even if the schema is empty. + + Parameters + ---------- + sql : str + SQL query string. + context : str + CREATE TABLE context string. + + Returns + ------- + bool + True if all referenced tables/columns are present in the schema; + False otherwise. + """ + schema = parse_create_table_context(context) + schema_tables: Set[str] = schema["tables"] # type: ignore[assignment] + columns_by_table: Mapping[str, Set[str]] = schema[ + "columns_by_table" + ] # type: ignore[assignment] + + refs = referenced_identifiers(sql) + ref_tables = refs["tables"] + ref_columns = refs["columns"] + + if not schema_tables and not any(columns_by_table.values()): + # No schema information available. + if ref_tables or ref_columns: + return False + return True + + # Tables must be a subset of known schema tables. + for table in ref_tables: + if table not in schema_tables: + return False + + # Columns must appear in at least one table's column set. + all_columns: Set[str] = set() + for cols in columns_by_table.values(): + all_columns.update(cols) + + for col in ref_columns: + if col not in all_columns: + return False + + return True \ No newline at end of file diff --git a/src/text2sql/eval/spider.py b/src/text2sql/eval/spider.py new file mode 100644 index 0000000..ae7da14 --- /dev/null +++ b/src/text2sql/eval/spider.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import logging +import re +from typing import Any, Dict, Iterable, Mapping, Sequence + +from text2sql.data_prep import INSTRUCTION_TEXT, build_input_text +from text2sql.training.formatting import build_prompt + +logger = logging.getLogger(__name__) + + +def build_spider_prompt(schema_context: str, question: str) -> str: + """ + Build a text-to-SQL prompt for a Spider example. + + This reuses the same instruction text and input formatting as the + internal b-mc2/sql-create-context pipeline: + + - Instruction: INSTRUCTION_TEXT from text2sql.data_prep. + - Input: built via build_input_text(schema_context, question). + - Prompt: wrapped using training.formatting.build_prompt. + """ + input_text = build_input_text(context=schema_context, question=question) + return build_prompt(INSTRUCTION_TEXT, input_text) + + +def build_schema_map( + records: Sequence[Mapping[str, object]], + *, + db_id_field: str = "db_id", + schema_field: str = "create_table_context", +) -> Dict[str, str]: + """ + Build a mapping {db_id -> schema_context} from Spider-schema records. + + This helper is kept for backwards compatibility in tests or scripts that + already provide a concrete `schema_field` (e.g. a pre-built + `create_table_context`). For the real Spider schema helper dataset, use + ``load_spider_schema_map`` instead, which can automatically detect the + schema text column. + """ + schema_map: Dict[str, str] = {} + for row in records: + db_id = row.get(db_id_field) + schema_context = row.get(schema_field) + if db_id is None or schema_context is None: + continue + db_id_str = str(db_id) + schema_map[db_id_str] = str(schema_context) + return schema_map + + +_SCHEMA_FIELD_CANDIDATES: tuple[str, ...] = ( + # Exact field name used by richardr1126/spider-schema. + "Schema (values (type))", + # Common alternatives seen in other helper datasets. + "schema", + "schema_text", + "ddl", + "create_table", +) + + +def _infer_column_names(schema_ds: Any) -> list[str]: + """ + Best-effort extraction of column names from a dataset-like object. + + Supports: + - Hugging Face datasets with ``column_names``. + - Simple sequences of dicts (e.g. JSONL fixture rows). + """ + if hasattr(schema_ds, "column_names"): + return list(schema_ds.column_names) # type: ignore[no-any-return] + + if isinstance(schema_ds, (list, tuple)) and schema_ds: + first = schema_ds[0] + if isinstance(first, Mapping): + return list(first.keys()) + + # Fallback: materialise a small iterator if needed. + try: + iterator = iter(schema_ds) + except TypeError: + return [] + rows = list(iterator) + if not rows: + return [] + first = rows[0] + if isinstance(first, Mapping): + return list(first.keys()) + return [] + + +def load_spider_schema_map(schema_ds: Iterable[Mapping[str, Any]]) -> Dict[str, str]: + """ + Load a mapping {db_id -> schema_text} from a Spider schema dataset. + + The richardr1126/spider-schema helper uses a non-standard column name + for the serialized schema: + + \"Schema (values (type))\" + + This loader inspects available columns and chooses the first match from + a fallback list: + + - \"Schema (values (type))\" + - \"schema\" + - \"schema_text\" + - \"ddl\" + - \"create_table\" + + If none of these are present, a ValueError is raised with the available + columns listed. + """ + column_names = _infer_column_names(schema_ds) + logger.info("Spider schema dataset columns: %s", column_names) + + schema_field: str | None = None + for candidate in _SCHEMA_FIELD_CANDIDATES: + if candidate in column_names: + schema_field = candidate + break + + if schema_field is None: + raise ValueError( + "Could not find a schema text field in Spider schema dataset. " + f"Available columns: {column_names}. " + f"Tried: {_SCHEMA_FIELD_CANDIDATES}." + ) + + schema_map: Dict[str, str] = {} + for row in schema_ds: + db_id = row.get("db_id") + if db_id is None: + continue + schema_text = row.get(schema_field) + if not schema_text: + continue + db_id_str = str(db_id) + schema_map[db_id_str] = str(schema_text) + + sample_db_ids = sorted(schema_map.keys())[:5] + logger.info( + "Loaded %d Spider schema entries using field '%s'. Sample db_ids: %s", + len(schema_map), + schema_field, + sample_db_ids, + ) + return schema_map + + +def spider_schema_to_pseudo_ddl(schema_text: str) -> str: + """ + Convert a compact Spider schema description into pseudo-DDL. + + The richardr1126/spider-schema dataset serializes each database schema as + a compact text description, for example:: + + \"department : Department_ID (number) , Name (text) ... + course : Course_ID (number) , Title (text) ...\" + + This function turns that into a more SQL-looking form that is aligned with + the internal training format, eg:: + + CREATE TABLE department (Department_ID NUMBER, Name TEXT, ...); + CREATE TABLE course (Course_ID NUMBER, Title TEXT, ...); + + The exact SQL dialect is not important for evaluation; it just needs to be + consistent and readable. + """ + if not schema_text: + return "" + + lines = [line.strip() for line in schema_text.splitlines() if line.strip()] + statements: list[str] = [] + + for line in lines: + # Expected basic pattern: \"table_name : col1 (type) , col2 (type) , ...\" + if ":" not in line: + continue + + table_part, cols_part = line.split(":", 1) + table_name = table_part.strip() + if not table_name: + continue + + column_specs: list[str] = [] + for raw_col in cols_part.split(","): + col = raw_col.strip() + if not col: + continue + + # Try to match \"column_name (type)\". + match = re.match(r"([^()]+)\(([^)]+)\)", col) + if not match: + # Fallback: keep the raw column name if parsing fails. + column_name = col.strip() + if column_name: + column_specs.append(column_name) + continue + + column_name = match.group(1).strip() + column_type = match.group(2).strip() + + type_norm = column_type.upper() + if type_norm in {"NUMBER", "INT", "INTEGER"}: + type_norm = "NUMBER" + elif type_norm in {"TEXT", "STRING", "VARCHAR"}: + type_norm = "TEXT" + elif type_norm in {"REAL", "FLOAT", "DOUBLE"}: + type_norm = "REAL" + + if column_name: + column_specs.append(f"{column_name} {type_norm}") + + if column_specs: + stmt = f"CREATE TABLE {table_name} ({', '.join(column_specs)});" + statements.append(stmt) + + return " ".join(statements) \ No newline at end of file diff --git a/src/text2sql/infer.py b/src/text2sql/infer.py new file mode 100644 index 0000000..a2d2655 --- /dev/null +++ b/src/text2sql/infer.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Optional, Tuple + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase + +logger = logging.getLogger(__name__) + +_MODEL: Optional[PreTrainedModel] = None +_TOKENIZER: Optional[PreTrainedTokenizerBase] = None +_DEVICE: Optional[str] = None + + +def _resolve_torch_dtype(dtype: str, device: str) -> torch.dtype: + """ + Resolve a string dtype name to a torch.dtype, with an 'auto' option. + + - 'auto' → float16 on CUDA, float32 on CPU. + - Accepts common aliases like 'fp16'/'bf16'/'fp32'. + """ + dtype_normalized = (dtype or "auto").lower() + if dtype_normalized == "auto": + return torch.float16 if device == "cuda" else torch.float32 + + mapping = { + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, + } + try: + return mapping[dtype_normalized] + except KeyError as exc: + raise ValueError( + f"Unsupported dtype '{dtype}'. Expected one of: auto, float16, bfloat16, float32." + ) from exc + + +def _select_device(device: str) -> str: + """ + Resolve the requested device string to an actual device ("cpu" or "cuda"). + + If CUDA is requested but not available, this falls back to CPU and logs a + warning since generation will be significantly slower. + """ + device = (device or "auto").lower() + if device not in {"auto", "cpu", "cuda"}: + raise ValueError(f"Unsupported device '{device}'. Expected one of: auto, cpu, cuda.") + + if device == "cpu": + return "cpu" + + if device in {"auto", "cuda"}: + if torch.cuda.is_available(): + return "cuda" + logger.warning( + "CUDA was requested or auto-selected, but no GPU is available. " + "Falling back to CPU; generation will be slow." + ) + return "cpu" + + # Fallback; should not be reached. + return "cpu" + + +def load_model_for_inference( + base_model: str, + adapter_dir: Optional[str] = None, + device: str = "auto", + load_in_4bit: Optional[bool] = None, + bnb_compute_dtype: str = "float16", + dtype: str = "auto", +) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """ + Load a base model and optional LoRA adapters for text-to-SQL inference. + + This function supports two main modes: + 1) Base HF model + LoRA adapters (adapter_dir points to the trained QLoRA + adapters from the training pipeline). + 2) A locally merged model directory passed as `base_model` with + `adapter_dir` left as None. + + Parameters + ---------- + base_model : str + Hugging Face model id or local path to the base (or merged) model. + adapter_dir : Optional[str], optional + Path to a directory containing LoRA adapters (PEFT), by default None. + device : str, optional + "auto", "cuda", or "cpu". If "auto", prefer CUDA when available, + otherwise fall back to CPU. + load_in_4bit : Optional[bool], optional + If True, attempt to load the base model using 4-bit quantization + (bitsandbytes). If None, 4-bit is enabled automatically when running + on CUDA and disabled otherwise. + bnb_compute_dtype : str, optional + Compute dtype for 4-bit quantization (e.g. "float16", "bfloat16"). + Defaults to "float16". + dtype : str, optional + Torch dtype for the base model weights ("auto", "float16", + "bfloat16", or "float32"). "auto" resolves to float16 on CUDA and + float32 on CPU. + + Returns + ------- + (model, tokenizer) + Loaded model and tokenizer ready for inference. + """ + global _MODEL, _TOKENIZER, _DEVICE + + resolved_device = _select_device(device) + + # Decide whether to enable 4-bit quantization. + use_4bit = load_in_4bit + if use_4bit is None: + use_4bit = resolved_device == "cuda" + + logger.info( + "Loading model for inference: base_model=%s, adapter_dir=%s, device=%s, " + "load_in_4bit=%s, dtype=%s, bnb_compute_dtype=%s", + base_model, + adapter_dir, + resolved_device, + use_4bit, + dtype, + bnb_compute_dtype, + ) + + tokenizer = AutoTokenizer.from_pretrained(adapter_dir or base_model) + if tokenizer.pad_token_id is None: + # Many causal LM tokenizers do not have an explicit pad token; use EOS. + tokenizer.pad_token_id = tokenizer.eos_token_id + + torch_dtype = _resolve_torch_dtype(dtype, resolved_device) + + if use_4bit and resolved_device != "cuda": + logger.warning( + "4-bit quantization was requested (load_in_4bit=True) but device is '%s'. " + "Disabling 4-bit and using dtype %s instead.", + resolved_device, + torch_dtype, + ) + use_4bit = False + + if use_4bit: + # Lazy import to avoid introducing a hard dependency for non-4bit users. + try: + from transformers import BitsAndBytesConfig + except ImportError as exc: # pragma: no cover - environment-specific + raise ImportError( + "4-bit quantization requested, but transformers.BitsAndBytesConfig " + "is not available. Ensure you have a recent transformers and " + "bitsandbytes installed, or disable 4-bit loading." + ) from exc + + compute_dtype = _resolve_torch_dtype(bnb_compute_dtype, resolved_device) + + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=compute_dtype, + ) + + logger.info( + "Using 4-bit NF4 quantization for base model with torch_dtype=%s and " + "compute_dtype=%s.", + torch_dtype, + compute_dtype, + ) + + model = AutoModelForCausalLM.from_pretrained( + base_model, + quantization_config=quant_config, + device_map="auto", + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + ) + else: + # Standard full-precision / mixed-precision loading. + model = AutoModelForCausalLM.from_pretrained( + base_model, + torch_dtype=torch_dtype, + ) + model.to(resolved_device) + + if adapter_dir: + adapter_path = Path(adapter_dir) + if not adapter_path.is_dir(): + raise FileNotFoundError( + f"Adapter directory not found: {adapter_dir}. " + "Ensure you pass the correct path to the trained LoRA adapters." + ) + logger.info("Loading LoRA adapters from %s", adapter_dir) + model = PeftModel.from_pretrained(model, adapter_dir) + + model.eval() + + _MODEL = model + _TOKENIZER = tokenizer + _DEVICE = resolved_device + + return model, tokenizer + + +def generate_sql( + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.0, + top_p: float = 0.9, +) -> str: + """ + Generate a SQL string from a prompt using the loaded model. + + `load_model_for_inference` must be called once before this function is + used; otherwise a RuntimeError will be raised. + """ + if _MODEL is None or _TOKENIZER is None or _DEVICE is None: + raise RuntimeError( + "Model is not loaded. Call load_model_for_inference(...) before generate_sql(...)." + ) + + model = _MODEL + tokenizer = _TOKENIZER + + do_sample = temperature > 0.0 + + inputs = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + ) + input_ids = inputs["input_ids"].to(_DEVICE) + attention_mask = inputs.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(_DEVICE) + + with torch.no_grad(): + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=max(temperature, 1e-5) if do_sample else 1.0, + top_p=top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + generated_ids = outputs[0][input_ids.shape[1] :] + text = tokenizer.decode(generated_ids, skip_special_tokens=True) + return text.strip() \ No newline at end of file diff --git a/tests/fixtures/eval_internal_sample.jsonl b/tests/fixtures/eval_internal_sample.jsonl new file mode 100644 index 0000000..3355860 --- /dev/null +++ b/tests/fixtures/eval_internal_sample.jsonl @@ -0,0 +1,3 @@ +{"id": "eval-internal-000001", "instruction": "Write a SQL query that answers the user's question using ONLY the tables and columns provided in the schema.", "input": "### Schema:\nCREATE TABLE head (id INTEGER PRIMARY KEY, age INTEGER, department TEXT);\n\n### Question:\nHow many heads of the departments are older than 56?", "output": "SELECT COUNT(*) FROM head WHERE age > 56", "source": "b-mc2/sql-create-context", "meta": {"split": "val", "row": 0}} +{"id": "eval-internal-000002", "instruction": "Write a SQL query that answers the user's question using ONLY the tables and columns provided in the schema.", "input": "### Schema:\nCREATE TABLE singer (id INTEGER PRIMARY KEY, name TEXT, country TEXT);\n\n### Question:\nList the names of all singers from the USA.", "output": "SELECT name FROM singer WHERE country = 'USA'", "source": "b-mc2/sql-create-context", "meta": {"split": "val", "row": 1}} +{"id": "eval-internal-000003", "instruction": "Write a SQL query that answers the user's question using ONLY the tables and columns provided in the schema.", "input": "### Schema:\nCREATE TABLE concert (id INTEGER PRIMARY KEY, city TEXT, year INTEGER);\n\n### Question:\nShow each concert city and the number of concerts held there.", "output": "SELECT city, COUNT(*) FROM concert GROUP BY city", "source": "b-mc2/sql-create-context", "meta": {"split": "val", "row": 2}} \ No newline at end of file diff --git a/tests/fixtures/spider_sample.jsonl b/tests/fixtures/spider_sample.jsonl new file mode 100644 index 0000000..6b946b4 --- /dev/null +++ b/tests/fixtures/spider_sample.jsonl @@ -0,0 +1,5 @@ +{"db_id": "concert_singer", "question": "Show each city and the number of concerts held there.", "query": "SELECT city, COUNT(*) FROM concert GROUP BY city"} +{"db_id": "concert_singer", "question": "List the names of singers who have concerts in New York.", "query": "SELECT DISTINCT s.name FROM singer AS s JOIN concert AS c ON s.singer_id = c.singer_id WHERE c.city = 'New York'"} +{"db_id": "academic", "question": "How many professors are older than 50?", "query": "SELECT COUNT(*) FROM professor WHERE age > 50"} +{"db_id": "academic", "question": "List university names and the number of professors at each.", "query": "SELECT u.name, COUNT(*) FROM university AS u JOIN professor AS p ON u.university_id = p.university_id GROUP BY u.name"} +{"db_id": "college_1", "question": "Show all distinct cities where colleges are located.", "query": "SELECT DISTINCT city FROM college"} \ No newline at end of file diff --git a/tests/fixtures/spider_schema_sample.jsonl b/tests/fixtures/spider_schema_sample.jsonl new file mode 100644 index 0000000..0887323 --- /dev/null +++ b/tests/fixtures/spider_schema_sample.jsonl @@ -0,0 +1,3 @@ +{"db_id": "concert_singer", "Schema (values (type))": "singer : singer_id (number) , name (text) , country (text)\nconcert : concert_id (number) , singer_id (number) , city (text) , year (number)", "Primary Keys": "singer.singer_id, concert.concert_id", "Foreign Keys": "concert.singer_id = singer.singer_id"} +{"db_id": "academic", "Schema (values (type))": "professor : prof_id (number) , name (text) , age (number) , university_id (number)\nuniversity : university_id (number) , name (text) , city (text)", "Primary Keys": "professor.prof_id, university.university_id", "Foreign Keys": "professor.university_id = university.university_id"} +{"db_id": "college_1", "Schema (values (type))": "college : college_id (number) , name (text) , city (text)", "Primary Keys": "college.college_id", "Foreign Keys": ""} \ No newline at end of file diff --git a/tests/test_eval_cli_args.py b/tests/test_eval_cli_args.py new file mode 100644 index 0000000..d796851 --- /dev/null +++ b/tests/test_eval_cli_args.py @@ -0,0 +1,45 @@ +from pathlib import Path +import sys + +import pytest + + +def _ensure_root_on_path() -> None: + """Ensure that the project root is available on sys.path for script imports.""" + root = Path(__file__).resolve().parents[1] + if str(root) not in sys.path: + sys.path.insert(0, str(root)) + + +_ensure_root_on_path() + +from scripts import evaluate_internal # noqa: E402 # isort: skip +from scripts import evaluate_spider_external # noqa: E402 # isort: skip + + +def test_evaluate_internal_parses_4bit_flags() -> None: + args = evaluate_internal.parse_args( + [ + "--mock", + "--load_in_4bit", + "--dtype", + "float16", + ] + ) + assert args.mock is True + assert args.load_in_4bit is True + assert args.dtype == "float16" + + +def test_evaluate_spider_parses_4bit_flags() -> None: + args = evaluate_spider_external.parse_args( + [ + "--mock", + "--load_in_4bit", + "--dtype", + "float16", + ] + ) + assert args.mock is True + assert args.load_in_4bit is True + assert args.dtype == "float16" \ No newline at end of file diff --git a/tests/test_infer_quantization.py b/tests/test_infer_quantization.py new file mode 100644 index 0000000..9b806bf --- /dev/null +++ b/tests/test_infer_quantization.py @@ -0,0 +1,65 @@ +from pathlib import Path +import sys +from unittest import mock + +import pytest + + +def _ensure_src_on_path() -> None: + """Ensure that the 'src' directory is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_ensure_src_on_path() + + +def test_load_model_for_inference_4bit_uses_quantization_config() -> None: + """ + Ensure that load_model_for_inference can be called in 4-bit mode without + actually downloading a model, and that it wires BitsAndBytesConfig through + to AutoModelForCausalLM.from_pretrained. + """ + import text2sql.infer as infer # noqa: WPS433 # isort: skip + + with mock.patch.object(infer, "AutoTokenizer") as mock_tok_cls, \ + mock.patch.object(infer, "AutoModelForCausalLM") as mock_model_cls, \ + mock.patch.object(infer, "PeftModel") as mock_peft_cls, \ + mock.patch.object(infer.torch.cuda, "is_available", return_value=True): + + # Mock tokenizer to avoid any real downloads. + tok = mock.Mock() + tok.eos_token_id = 0 + tok.pad_token_id = None + mock_tok_cls.from_pretrained.return_value = tok + + # Mock model loading. + model_instance = mock.Mock() + mock_model_cls.from_pretrained.return_value = model_instance + + # Make PEFT a no-op that simply returns the base model. + mock_peft_cls.from_pretrained.side_effect = lambda base_model, adapter_dir: base_model + + model, tokenizer = infer.load_model_for_inference( + base_model="dummy-model", + adapter_dir=None, + device="cuda", + load_in_4bit=True, + bnb_compute_dtype="float16", + dtype="float16", + ) + + # Ensure we returned the mocked objects. + assert model is model_instance + assert tokenizer is tok + + # Check that quantization configuration and device_map were passed. + mock_model_cls.from_pretrained.assert_called_once() + _, kwargs = mock_model_cls.from_pretrained.call_args + assert "quantization_config" in kwargs + quant_config = kwargs["quantization_config"] + assert getattr(quant_config, "load_in_4bit", False) is True + assert kwargs.get("device_map") == "auto" + assert kwargs.get("low_cpu_mem_usage") is True \ No newline at end of file diff --git a/tests/test_metrics_aggregate.py b/tests/test_metrics_aggregate.py new file mode 100644 index 0000000..57da5f2 --- /dev/null +++ b/tests/test_metrics_aggregate.py @@ -0,0 +1,66 @@ +from pathlib import Path +import sys + +import pytest + + +def _ensure_src_on_path() -> None: + """Ensure that the 'src' directory is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_ensure_src_on_path() + +from text2sql.eval.metrics import aggregate_metrics # noqa: E402 # isort: skip + + +SCHEMA_CONTEXT = """ +CREATE TABLE head ( + id INTEGER PRIMARY KEY, + age INTEGER, + department TEXT +); +""" + + +def test_aggregate_metrics_basic_counts_and_rates() -> None: + preds = [ + "SELECT department FROM head WHERE age > 56", + "SELECT age FROM head WHERE age > 30", + ] + golds = [ + "SELECT department FROM head WHERE age > 56", + "SELECT age FROM head WHERE age > 40", + ] + contexts = [SCHEMA_CONTEXT, SCHEMA_CONTEXT] + + metrics = aggregate_metrics( + predictions=preds, + golds=golds, + contexts=contexts, + compute_schema_adherence=True, + ) + + assert metrics["n_examples"] == 2 + + # First example is exact match; second differs by literal. + em = metrics["exact_match"] + assert em["count"] == 1 + assert em["rate"] == pytest.approx(0.5) + + # No-values EM should treat both as matches. + nvem = metrics["no_values_em"] + assert nvem["count"] == 2 + assert nvem["rate"] == pytest.approx(1.0) + + # All predictions should parse and adhere to the schema. + parse = metrics["parse_success"] + assert parse["count"] == 2 + assert parse["rate"] == pytest.approx(1.0) + + schema = metrics["schema_adherence"] + assert schema["count"] == 2 + assert schema["rate"] == pytest.approx(1.0) \ No newline at end of file diff --git a/tests/test_normalize_sql.py b/tests/test_normalize_sql.py new file mode 100644 index 0000000..079e3ec --- /dev/null +++ b/tests/test_normalize_sql.py @@ -0,0 +1,43 @@ +from pathlib import Path +import sys + +import pytest + + +def _ensure_src_on_path() -> None: + """Ensure that the 'src' directory is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_ensure_src_on_path() + +from text2sql.eval.normalize import ( # noqa: E402 # isort: skip + normalize_sql, + normalize_sql_no_values, +) + + +def test_normalize_sql_strips_whitespace_and_trailing_semicolons() -> None: + raw = " \n\tSELECT * FROM head ; \n" + normalized = normalize_sql(raw) + + assert normalized == "SELECT * FROM head" + assert normalized.endswith(";") is False + assert " " not in normalized + + +def test_normalize_sql_no_values_replaces_literals() -> None: + raw = "SELECT * FROM head WHERE name = 'Alice' AND age >= 30;" + normalized = normalize_sql_no_values(raw) + + # Should normalize structural SQL. + assert normalized.startswith("SELECT * FROM head WHERE") + # String literal should be replaced. + assert "Alice" not in normalized + assert "__STR__" in normalized + # Numeric literal should be replaced. + assert "30" not in normalized + assert "__NUM__" in normalized \ No newline at end of file diff --git a/tests/test_prompt_building_spider.py b/tests/test_prompt_building_spider.py new file mode 100644 index 0000000..9007766 --- /dev/null +++ b/tests/test_prompt_building_spider.py @@ -0,0 +1,105 @@ +from pathlib import Path +import json +import sys + +import pytest + + +def _ensure_src_on_path() -> None: + """Ensure that the 'src' directory is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_ensure_src_on_path() + +from text2sql.eval.spider import ( # noqa: E402 # isort: skip + build_schema_map, + build_spider_prompt, + load_spider_schema_map, + spider_schema_to_pseudo_ddl, +) + + +def _load_jsonl(path: Path) -> list[dict]: + rows: list[dict] = [] + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + return rows + + +@pytest.fixture() +def spider_fixtures() -> tuple[list[dict], dict[str, str]]: + root = Path(__file__).resolve().parents[1] + fixtures_dir = root / "tests" / "fixtures" + spider_examples_path = fixtures_dir / "spider_sample.jsonl" + spider_schema_path = fixtures_dir / "spider_schema_sample.jsonl" + + examples = _load_jsonl(spider_examples_path) + schema_records = _load_jsonl(spider_schema_path) + schema_map = load_spider_schema_map(schema_records) + + return examples, schema_map + + +def test_build_spider_prompt_uses_schema_and_question( + spider_fixtures: tuple[list[dict], dict[str, str]] +) -> None: + examples, schema_map = spider_fixtures + example = examples[0] + + db_id = example["db_id"] + question = example["question"] + raw_schema_text = schema_map[db_id] + schema_context = spider_schema_to_pseudo_ddl(raw_schema_text) + + prompt = build_spider_prompt(schema_context=schema_context, question=question) + + # Prompt should include markers from training-style formatting. + assert "Instruction" in prompt + assert "Input" in prompt + assert "Response" in prompt + + # Ensure schema and question are present. + assert "### Schema:" in prompt + assert "CREATE TABLE" in prompt + assert "### Question:" in prompt + assert question in prompt + + +def test_load_spider_schema_map_handles_real_column_name( + spider_fixtures: tuple[list[dict], dict[str, str]] +) -> None: + _, schema_map = spider_fixtures + # Ensure we have mappings for known db_ids from the fixtures. + for db_id in ("concert_singer", "academic", "college_1"): + assert db_id in schema_map + assert isinstance(schema_map[db_id], str) + assert schema_map[db_id] + + +def test_intersection_nonzero_with_fixtures( + spider_fixtures: tuple[list[dict], dict[str, str]] +) -> None: + examples, schema_map = spider_fixtures + spider_db_ids = {ex["db_id"] for ex in examples} + schema_db_ids = set(schema_map.keys()) + intersection = spider_db_ids & schema_db_ids + assert intersection, "Expected non-empty intersection between Spider and schema db_ids" + + +def test_spider_schema_to_pseudo_ddl_nonempty( + spider_fixtures: tuple[list[dict], dict[str, str]] +) -> None: + _, schema_map = spider_fixtures + # Take one schema text and ensure we can build a non-empty pseudo-DDL string. + raw_schema_text = next(iter(schema_map.values())) + pseudo_ddl = spider_schema_to_pseudo_ddl(raw_schema_text) + assert pseudo_ddl + assert "CREATE TABLE" in pseudo_ddl \ No newline at end of file diff --git a/tests/test_schema_adherence.py b/tests/test_schema_adherence.py new file mode 100644 index 0000000..fab57d4 --- /dev/null +++ b/tests/test_schema_adherence.py @@ -0,0 +1,73 @@ +from pathlib import Path +import sys + +import pytest + + +def _ensure_src_on_path() -> None: + """Ensure that the 'src' directory is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_ensure_src_on_path() + +from text2sql.eval.schema import ( # noqa: E402 # isort: skip + parse_create_table_context, + referenced_identifiers, + schema_adherence, +) + + +SCHEMA_CONTEXT = """ +CREATE TABLE head ( + id INTEGER PRIMARY KEY, + age INTEGER, + department TEXT +); + +CREATE TABLE manager ( + id INTEGER PRIMARY KEY, + name TEXT +); +""" + + +def test_parse_create_table_context_extracts_tables_and_columns() -> None: + parsed = parse_create_table_context(SCHEMA_CONTEXT) + tables = parsed["tables"] + columns_by_table = parsed["columns_by_table"] + + assert "head" in tables + assert "manager" in tables + + assert "age" in columns_by_table["head"] + assert "department" in columns_by_table["head"] + assert "name" in columns_by_table["manager"] + + +def test_referenced_identifiers_finds_tables_and_columns() -> None: + sql = "SELECT department, COUNT(*) FROM head WHERE age > 56" + refs = referenced_identifiers(sql) + + assert "head" in refs["tables"] + # We only track column names (lowercased, without table prefixes). + assert "department" in refs["columns"] + assert "age" in refs["columns"] + + +def test_schema_adherence_true_for_known_tables_and_columns() -> None: + sql = "SELECT department FROM head WHERE age > 56" + assert schema_adherence(sql, SCHEMA_CONTEXT) is True + + +def test_schema_adherence_false_for_unknown_table() -> None: + sql = "SELECT department FROM unknown_table WHERE age > 56" + assert schema_adherence(sql, SCHEMA_CONTEXT) is False + + +def test_schema_adherence_false_for_unknown_column() -> None: + sql = "SELECT salary FROM head WHERE age > 56" + assert schema_adherence(sql, SCHEMA_CONTEXT) is False \ No newline at end of file