diff --git a/.gitignore b/.gitignore index 277a7e3..8445289 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ venv/ # Environment / secrets .env *.env +.streamlit/secrets.toml # Data and outputs data/ diff --git a/.streamlit/secrets.toml.example b/.streamlit/secrets.toml.example new file mode 100644 index 0000000..66b4645 --- /dev/null +++ b/.streamlit/secrets.toml.example @@ -0,0 +1,28 @@ +# Example Streamlit secrets for the Analytics Copilot Text-to-SQL demo. +# +# IMPORTANT: +# - Do NOT commit real tokens or endpoint URLs to version control. +# - Create a local `.streamlit/secrets.toml` (ignored by git) and fill in +# environment-specific values there, or configure secrets directly in +# Streamlit Community Cloud. +# +# The app also supports reading these values from environment variables; when +# both are present, Streamlit secrets take precedence. + +HF_TOKEN = "hf_your_access_token_here" + +# Preferred: dedicated Inference Endpoint / TGI deployment with adapters. +# This should point to your HF Inference Endpoint URL, and HF_ADAPTER_ID should +# match the adapter identifier configured in the endpoint's LORA_ADAPTERS. +HF_ENDPOINT_URL = "https://youendpoint.endpoints.huggingface.cloud" +HF_ADAPTER_ID = "your-adapter-id" + +# Fallback: provider/router-based model id (no adapters). +# Use this only with models that are directly supported by HF Inference +# providers (e.g. merged text-to-SQL models), not pure adapter repos. +HF_MODEL_ID = "your-model-id" +HF_PROVIDER = "auto" + +# Compatibility: older name for the endpoint base URL. The app will treat this +# as an alias for HF_ENDPOINT_URL if set. +HF_INFERENCE_BASE_URL = "" \ No newline at end of file diff --git a/README.md b/README.md index 3a7e480..b6946d9 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,19 @@ python scripts/evaluate_internal.py \ --out_dir reports/ ``` +Quick smoke test (small subset, CPU-friendly): + +```bash +python scripts/evaluate_internal.py --smoke \ + --val_path data/processed/val.jsonl \ + --out_dir reports/ +``` + +- On GPU machines, `--smoke` runs a tiny subset through the full pipeline. +- On CPU-only environments, `--smoke` automatically falls back to `--mock` so + that no heavy model loading is attempted while still exercising the metrics + and reporting stack. + Outputs: - `reports/eval_internal.json` – metrics, config, and sample predictions. @@ -257,18 +270,137 @@ pytest -q These commands are also wired into the CI workflow (`.github/workflows/ci.yml`). -## Demo (placeholder) +--- + +## Publishing to Hugging Face Hub + +After training, you can publish the QLoRA adapter artifacts to the Hugging Face Hub +for reuse and remote inference. + +1. Authenticate with Hugging Face: + + ```bash + huggingface-cli login --token YOUR_HF_TOKEN + ``` + + or set an environment variable: + + ```bash + export HF_TOKEN=YOUR_HF_TOKEN + ``` + +2. Run the publish script, pointing it at your adapter directory and desired repo id: + + ```bash + python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters \ + --include_metrics reports/eval_internal.json + ``` + +The script will: + +- Validate that a Hugging Face token is available. +- Create the model repo if it does not exist (public by default, or `--private`). +- Ensure a README.md model card is written into the adapter directory (with metrics + if provided). +- Upload the entire adapter folder to the Hub using `huggingface_hub.HfApi`. + +You can re-run the script safely; it will perform another commit with the specified +`--commit_message`. + +--- + + + +## Demo – Streamlit UI (Remote Inference) + +The repository includes a lightweight Streamlit UI that talks to a **remote** +model via Hugging Face Inference (no local GPU required). The app lives at +`app/streamlit_app.py` and intentionally does **not** import `torch` or +`transformers`. + +### Remote Inference Note + +- The Streamlit app is **UI-only**; it never loads model weights locally. +- All text-to-SQL generation is performed remotely using + `huggingface_hub.InferenceClient`. +- For small models you may be able to use Hugging Face serverless Inference, + but large models like Mistral-7B often require **Inference Endpoints** or a + dedicated provider. +- If serverless calls fail or time out, consider deploying a dedicated + Inference Endpoint or self-hosted TGI/serving stack and pointing the app at + its URL via `HF_ENDPOINT_URL` / `HF_INFERENCE_BASE_URL`. + +> **Adapter repos and the HF router:** If you point the app at a pure LoRA +> adapter repository (e.g. `BrejBala/analytics-copilot-mistral7b-text2sql-adapter`) +> using `HF_MODEL_ID` without an `HF_ENDPOINT_URL`, the request goes through +> the Hugging Face **router** and most providers will respond with +> `model_not_supported`. For adapter-based inference, use a dedicated +> Inference Endpoint and configure `HF_ENDPOINT_URL` + `HF_ADAPTER_ID` instead +> of trying to call the adapter repo directly via the router. + +### Running the app locally + +1. Configure Streamlit secrets by creating `.streamlit/secrets.toml` from the + example: + + ```bash + cp .streamlit/secrets.toml.example .streamlit/secrets.toml + ``` + + Then edit `.streamlit/secrets.toml` (not tracked by git) and fill in either: + + **Preferred: dedicated endpoint + adapter** + + ```toml + HF_TOKEN = "hf_your_access_token_here" + + # Dedicated Inference Endpoint / TGI URL + HF_ENDPOINT_URL = "https://o0mmkmv1itfrikie.us-east4.gcp.endpoints.huggingface.cloud" + + # Adapter identifier configured in your endpoint's LORA_ADAPTERS + HF_ADAPTER_ID = "BrejBala/analytics-copilot-mistral7b-text2sql-adapter" + ``` + + **Fallback: provider/router-based merged model (no adapters)** + + ```toml + HF_TOKEN = "hf_your_access_token_here" + HF_MODEL_ID = "your-username/your-merged-text2sql-model" + HF_PROVIDER = "auto" # optional provider hint + ``` + + `HF_INFERENCE_BASE_URL` is also supported as an alias for `HF_ENDPOINT_URL`. + The app will always prefer secrets over environment variables when both are + set. + +2. Start the Streamlit app: + + ```bash + streamlit run app/streamlit_app.py + ``` + +3. In the UI: + + - Paste your database schema (DDL) into the **Schema** text area. + - Enter a natural-language question. + - Click **Generate SQL** to call the remote model. + - View the generated SQL in a code block (with a copy button). + - Optionally open the **Show prompt** expander to inspect the exact prompt + sent to the model (useful for debugging and prompt engineering). -> The Streamlit UI will be documented here when implemented. +### Deploying on Streamlit Community Cloud -Planned content for this section: +When deploying to Streamlit Cloud: -- How to start the Streamlit app (under `app/`). -- Sample configuration for connecting to a demo database or local SQLite DB. -- Usage examples: - - Asking natural-language questions. - - Viewing generated SQL and query results. - - Editing and re-running SQL. +- Add `HF_TOKEN`, `HF_ENDPOINT_URL`, and `HF_ADAPTER_ID` (or `HF_MODEL_ID` / + `HF_PROVIDER` for the router fallback) to the app's **Secrets** in the + Streamlit Cloud UI. +- The app will automatically construct an `InferenceClient` from those values + and use the dedicated endpoint when `HF_ENDPOINT_URL` is set. +- No GPU is required on the Streamlit side; all heavy lifting is done by the + remote Hugging Face Inference backend. --- @@ -278,19 +410,23 @@ Current high-level layout: ```text . -├── app/ # Streamlit app (to be implemented) +├── app/ # Streamlit UI (remote inference via HF InferenceClient) +│ └── streamlit_app.py ├── docs/ # Documentation, design notes, evaluation reports │ ├── dataset.md │ ├── training.md │ ├── evaluation.md │ └── external_validation.md ├── notebooks/ # Jupyter/Colab notebooks for experimentation -├── scripts/ # CLI scripts (dataset, training, evaluation) +├── scripts/ # CLI scripts (dataset, training, evaluation, utilities) │ ├── build_dataset.py +│ ├── check_syntax.py │ ├── smoke_load_dataset.py +│ ├── smoke_infer_endpoint.py │ ├── train_qlora.py │ ├── evaluate_internal.py -│ └── evaluate_spider_external.py +│ ├── evaluate_spider_external.py +│ └── publish_to_hub.py ├── src/ │ └── text2sql/ # Core Python package │ ├── __init__.py @@ -315,11 +451,12 @@ Current high-level layout: │ ├── test_repo_smoke.py │ ├── test_build_dataset_offline.py │ ├── test_data_prep.py +│ ├── test_eval_cli_args.py +│ ├── test_infer_quantization.py │ ├── test_prompt_formatting.py │ ├── test_normalize_sql.py │ ├── test_schema_adherence.py -│ ├── test_metrics_aggregate.py -│ └── test_prompt_building_spider.py +│ └── test_metrics_aggregate.py ├── .env.example # Example environment file ├── .gitignore ├── context.md # Persistent project context & decisions diff --git a/app/streamlit_app.py b/app/streamlit_app.py new file mode 100644 index 0000000..d99255e --- /dev/null +++ b/app/streamlit_app.py @@ -0,0 +1,412 @@ +""" +Streamlit UI for Analytics Copilot (Text-to-SQL). + +This app is intentionally **UI-only**: + +- It does NOT load any local models or GPU resources. +- All inference is performed remotely via Hugging Face Inference + (`huggingface_hub.InferenceClient`), making it suitable for + Streamlit Community Cloud deployments. + +Configuration can be provided via Streamlit secrets (`.streamlit/secrets.toml`) +or environment variables. Secrets take precedence over environment variables. + +Required: + HF_TOKEN = "hf_xxx" # Hugging Face access token + +Preferred (dedicated endpoint + adapters): + HF_ENDPOINT_URL = "https://..." # Dedicated Inference Endpoint / TGI URL + HF_ADAPTER_ID = "adapter-id" # Adapter identifier configured in TGI + +Fallback (provider/router-based; no adapters): + HF_MODEL_ID = "username/model" # Provider-supported merged model + HF_PROVIDER = "auto" # Provider hint for InferenceClient(model=...) + +Compatibility: + HF_INFERENCE_BASE_URL is also supported as an alias for HF_ENDPOINT_URL. + +Priority: +1. If HF_ENDPOINT_URL (or HF_INFERENCE_BASE_URL) is non-empty, we call: + + InferenceClient(base_url=HF_ENDPOINT_URL, api_key=HF_TOKEN) + + and send `adapter_id=HF_ADAPTER_ID` in `text_generation` requests. + +2. Otherwise, we fall back to provider-based routing via the HF router: + + InferenceClient(model=HF_MODEL_ID, api_key=HF_TOKEN, provider=HF_PROVIDER) + + Note that pure adapter repositories are **not** supported by the router; use + a dedicated endpoint for adapter-based inference. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Mapping, NamedTuple, Optional, Tuple + +import streamlit as st +from huggingface_hub import InferenceClient # type: ignore[import] + + +logger = logging.getLogger(__name__) + + +def _configure_logging() -> None: + """Configure lightweight logging for the Streamlit app.""" + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", + ) + + +class HFConfig(NamedTuple): + """Resolved Hugging Face configuration for the Streamlit app.""" + + hf_token: str + endpoint_url: str + model_id: str + provider: str + adapter_id: Optional[str] + + +def _resolve_hf_config( + secrets: Mapping[str, Any], + environ: Mapping[str, str], +) -> HFConfig: + """ + Resolve Hugging Face configuration from secrets and environment variables. + + Precedence: + - Secrets take priority over environment variables. + - HF_ENDPOINT_URL and HF_INFERENCE_BASE_URL are treated as aliases. + """ + + def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: + try: + value = mapping.get(key) # type: ignore[attr-defined] + except Exception: # noqa: BLE001 + value = None + if value is None: + return "" + return str(value).strip() + + hf_token = _get_from_mapping(secrets, "HF_TOKEN") or environ.get( + "HF_TOKEN", + "", + ).strip() + + endpoint_url = ( + _get_from_mapping(secrets, "HF_ENDPOINT_URL") + or _get_from_mapping(secrets, "HF_INFERENCE_BASE_URL") + or environ.get("HF_ENDPOINT_URL", "").strip() + or environ.get("HF_INFERENCE_BASE_URL", "").strip() + ) + + model_id = _get_from_mapping(secrets, "HF_MODEL_ID") or environ.get( + "HF_MODEL_ID", + "", + ).strip() + + provider = ( + _get_from_mapping(secrets, "HF_PROVIDER") + or environ.get("HF_PROVIDER", "").strip() + or "auto" + ) + + adapter_id_raw = _get_from_mapping(secrets, "HF_ADAPTER_ID") or environ.get( + "HF_ADAPTER_ID", + "", + ).strip() + adapter_id = adapter_id_raw or None + + return HFConfig( + hf_token=hf_token, + endpoint_url=endpoint_url, + model_id=model_id, + provider=provider, + adapter_id=adapter_id, + ) + + +@st.cache_resource(show_spinner=False) +def _get_cached_client( + hf_token: str, + base_url: str, + model_id: str, + provider: str, + timeout_s: int, +) -> InferenceClient: + """ + Construct and cache a Hugging Face InferenceClient instance. + + The cache key is derived from the provided parameters, so changing any of + them in Streamlit secrets or environment variables will cause a new client + to be created. + """ + if base_url: + logger.info("Creating InferenceClient with base_url=%s", base_url) + return InferenceClient(base_url=base_url, api_key=hf_token, timeout=timeout_s) + + logger.info( + "Creating InferenceClient with model=%s and provider=%s", + model_id, + provider, + ) + return InferenceClient( + model=model_id, + api_key=hf_token, + provider=provider, + timeout=timeout_s, + ) + + +def _build_prompt(schema: str, question: str) -> Tuple[str, str]: + """ + Build the system and user prompt content for text-to-SQL generation. + + The prompt format mirrors the training/evaluation pipeline: + ### Schema: + + + ### Question: + + """ + system_prompt = ( + "You are a careful text-to-SQL assistant. " + "Given a database schema and a question, you respond with a single SQL " + "query that answers the question. " + "Return ONLY the SQL query, without explanation or commentary." + ) + + user_prompt = f"""### Schema: +{schema.strip()} + +### Question: +{question.strip()} + +Return only the SQL query.""" + return system_prompt, user_prompt + + +def _create_inference_client(timeout_s: int = 45) -> Tuple[InferenceClient, HFConfig]: + """ + Create (or retrieve) an InferenceClient based on Streamlit secrets/env. + + Raises Streamlit errors if required configuration is missing so the user + sees actionable feedback in the UI. + """ + secrets = st.secrets + hf_config = _resolve_hf_config(secrets=secrets, environ=os.environ) + + if not hf_config.hf_token: + st.error( + "Missing `HF_TOKEN` in Streamlit secrets or environment. " + "Set it in `.streamlit/secrets.toml` or as the `HF_TOKEN` " + "environment variable." + ) + st.stop() + + if not hf_config.endpoint_url and not hf_config.model_id: + st.error( + "Neither `HF_ENDPOINT_URL`/`HF_INFERENCE_BASE_URL` nor `HF_MODEL_ID` " + "is configured. Set at least one via Streamlit secrets or " + "environment variables." + ) + st.stop() + + if hf_config.endpoint_url and hf_config.adapter_id is None: + st.error( + "HF_ENDPOINT_URL is set but `HF_ADAPTER_ID` is missing. " + "For adapter-based inference with a dedicated endpoint, set " + "`HF_ADAPTER_ID` to the adapter identifier configured in your " + "Text Generation Inference (TGI) endpoint (the value used as " + "`adapter_id` in requests)." + ) + st.stop() + + if hf_config.endpoint_url: + logger.info( + "Using dedicated Hugging Face Inference Endpoint at %s", + hf_config.endpoint_url, + ) + else: + logger.info( + "Using provider-based HF Inference routing with model=%s, provider=%s", + hf_config.model_id, + hf_config.provider, + ) + + client = _get_cached_client( + hf_token=hf_config.hf_token, + base_url=hf_config.endpoint_url, + model_id=hf_config.model_id, + provider=hf_config.provider, + timeout_s=timeout_s, + ) + + return client, hf_config + + +def _call_model( + client: InferenceClient, + schema: str, + question: str, + temperature: float, + max_tokens: int, + timeout_s: int = 45, + adapter_id: Optional[str] = None, + use_endpoint: bool = False, +) -> Tuple[Optional[str], str]: + """ + Call the remote model via text_generation and return the generated SQL. + + Returns a tuple of (sql_text_or_none, user_prompt), where sql_text_or_none + is None if the call failed. + + When `use_endpoint` is True and `adapter_id` is provided, the request will + include `adapter_id` to select the appropriate LoRA adapter on a TGI + Inference Endpoint. + """ + system_prompt, user_prompt = _build_prompt(schema=schema, question=question) + full_prompt = f"{system_prompt}\n\n{user_prompt}" + + generation_kwargs: dict[str, Any] = { + "prompt": full_prompt, + "max_new_tokens": max_tokens, + "temperature": temperature, + } + if use_endpoint and adapter_id: + generation_kwargs["adapter_id"] = adapter_id + + try: + response = client.text_generation(**generation_kwargs) + except Exception as exc: # noqa: BLE001 + logger.error("Error while calling Hugging Face Inference API.", exc_info=True) + st.error( + "The Hugging Face Inference endpoint did not respond successfully. " + "This can happen if the endpoint is cold, overloaded, or misconfigured. " + "Please try again, or check your HF endpoint / model settings." + ) + st.caption(f"Details: {exc}") + return None, user_prompt + + # InferenceClient.text_generation may return a string, a dict, or a list. + try: + if isinstance(response, str): + text = response + elif isinstance(response, dict) and "generated_text" in response: + text = str(response["generated_text"]) + elif ( + isinstance(response, list) + and response + and isinstance(response[0], dict) + and "generated_text" in response[0] + ): + text = str(response[0]["generated_text"]) + else: + # Fallback: best-effort string representation. + text = str(response) + except Exception as exc: # noqa: BLE001 + logger.error("Unexpected response format from InferenceClient.", exc_info=True) + st.error( + "Received an unexpected response format from the Hugging Face " + "Inference API. Please check your endpoint and model configuration." + ) + st.caption(f"Details: {exc}") + return None, user_prompt + + sql_text = (text or "").strip() + return sql_text, user_prompt + + +def main() -> None: + """Render the Streamlit UI.""" + _configure_logging() + st.set_page_config( + page_title="Analytics Copilot – Text-to-SQL", + layout="centered", + ) + + st.title("Analytics Copilot – Text-to-SQL") + st.markdown( + "This demo converts natural-language questions into SQL using a " + "remote model hosted on **Hugging Face Inference**. " + "The model is not loaded inside the Streamlit app; all heavy lifting " + "happens on a remote endpoint or serverless provider." + ) + + st.markdown("### Inputs") + + schema = st.text_area( + "Database schema (DDL)", + height=220, + placeholder="CREATE TABLE orders (...);\nCREATE TABLE customers (...);", + ) + + question = st.text_area( + "Question (natural language)", + height=140, + placeholder="What is the total order amount per customer for the last 7 days?", + ) + + with st.expander("Advanced generation settings", expanded=False): + col1, col2 = st.columns(2) + with col1: + temperature = st.slider( + "Temperature", + min_value=0.001, + max_value=1.0, + value=0.01, + step=0.05, + help="Sampling temperature for the model (0.0 = greedy).", + ) + with col2: + max_new_tokens = st.slider( + "Max new tokens", + min_value=32, + max_value=512, + value=256, + step=16, + help="Maximum number of tokens to generate for the SQL query.", + ) + + generate_clicked = st.button("Generate SQL", type="primary") + + if generate_clicked: + if not schema.strip(): + st.warning("Please provide a database schema (DDL) before generating SQL.") + return + if not question.strip(): + st.warning("Please provide a natural-language question.") + return + + with st.spinner("Calling Hugging Face Inference API..."): + client, hf_config = _create_inference_client(timeout_s=45) + sql_text, user_prompt = _call_model( + client=client, + schema=schema, + question=question, + temperature=temperature, + max_tokens=max_new_tokens, + timeout_s=45, + adapter_id=hf_config.adapter_id, + use_endpoint=bool(hf_config.endpoint_url), + ) + + if sql_text is not None: + st.subheader("Generated SQL") + st.code(sql_text, language="sql") + + with st.expander("Show prompt", expanded=False): + st.code(user_prompt, language="markdown") + else: + st.warning( + "No SQL was generated due to an error when calling the remote " + "inference endpoint. Please review the error message above." + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/context.md b/context.md index a309d00..cee1f95 100644 --- a/context.md +++ b/context.md @@ -134,4 +134,14 @@ This repo will contain: - **2026-01-10** – Updated dataset plan and smoke loader to use the parquet-backed **`b-mc2/sql-create-context`** dataset (compatible with `datasets>=4`) instead of the script-based `Salesforce/wikisql`, and documented this decision in the project context. - **2026-01-10** – Added a dataset preprocessing pipeline (`scripts/build_dataset.py`) that converts `b-mc2/sql-create-context` into Alpaca-style instruction-tuning JSONL files under `data/processed/` (train/val splits), along with reusable formatting utilities in `text2sql.data_prep`. - **2026-01-10** – Added QLoRA training scaffolding: a detailed Colab-friendly notebook (`notebooks/finetune_mistral7b_qlora_text2sql.ipynb`), a reproducible training script (`scripts/train_qlora.py`), training utilities under `src/text2sql/training/`, and documentation for training (`docs/training.md`) plus planned external validation on Spider dev (`docs/external_validation.md`). -- **2026-01-10** – Task 4: Added an evaluation pipeline with internal metrics on `b-mc2/sql-create-context` (Exact Match, No-values EM, SQL parse success, schema adherence) and a secondary external validation harness on Spider dev using `xlangai/spider` and `richardr1126/spider-schema`, along with reports under `reports/` and supporting documentation (`docs/evaluation.md`). \ No newline at end of file +- **2026-01-10** – Task 4 (evaluation pipeline): Added an evaluation pipeline with internal metrics on `b-mc2/sql-create-context` (Exact Match, No-values EM, SQL parse success, schema adherence) and a secondary external validation harness on Spider dev using `xlangai/spider` and `richardr1126/spider-schema`, along with reports under `reports/` and supporting documentation (`docs/evaluation.md`). +- **2026-01-10** – Task 4 extension (4-bit eval + HF Hub publish + Streamlit UI): + - Extended `scripts/evaluate_internal.py` to surface 4-bit quantization controls (`--load_in_4bit`, `--bnb_4bit_quant_type`, `--bnb_4bit_compute_dtype`, `--bnb_4bit_use_double_quant`) and added a CPU-friendly `--smoke` mode that evaluates only a handful of examples and automatically falls back to `--mock` when no GPU is available. + - Typical smoke command: `python scripts/evaluate_internal.py --smoke --val_path data/processed/val.jsonl --out_dir reports/`. + - Updated the shared inference helper `src/text2sql/infer.py` to accept the new 4-bit configuration knobs and to log quantization settings clearly while still falling back gracefully on CPU. + - Added `scripts/publish_to_hub.py` to publish trained QLoRA adapter artifacts (under `outputs/adapters`) to the Hugging Face Hub, including automatic creation of a README-style model card with optional embedded metrics from `reports/eval_internal.json` or `reports/eval_spider.json`. + - Typical publish command: `python scripts/publish_to_hub.py --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora --adapter_dir outputs/adapters --include_metrics reports/eval_internal.json`. + - Implemented a Streamlit demo app at `app/streamlit_app.py` that calls **remote** inference only (via `huggingface_hub.InferenceClient`), suitable for Streamlit Community Cloud. The app expects secrets (`HF_TOKEN`, `HF_MODEL_ID`, optional `HF_INFERENCE_BASE_URL`, `HF_PROVIDER`) in `.streamlit/secrets.toml` and can be run locally with: + - `cp .streamlit/secrets.toml.example .streamlit/secrets.toml` (then edit values) + - `streamlit run app/streamlit_app.py`. + - Documented these additions and workflows in `README.md` (HF Hub publishing section, Streamlit remote-inference notes) and ensured syntax checking (`scripts/check_syntax.py`) and the test suite (`pytest -q`) cover the new code paths. \ No newline at end of file diff --git a/docs/evaluation.md b/docs/evaluation.md index 5e70dd5..83a4317 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -130,8 +130,13 @@ Notes: base model in **4-bit (bitsandbytes)** for faster and more memory-efficient evaluation. You can explicitly control this with: - `--load_in_4bit` / `--no_load_in_4bit` + - `--bnb_4bit_quant_type`, `--bnb_4bit_compute_dtype`, and + `--bnb_4bit_use_double_quant` for advanced 4-bit configuration. - `--dtype` (default `auto`, which maps to `float16` on CUDA and `float32` on CPU) - `--max_examples` allows you to subsample the validation set for quick runs. +- `--smoke` evaluates only a small handful of validation examples; on CPU-only + environments it automatically falls back to `--mock` to avoid loading the + large model while still exercising the metrics/reporting pipeline. - If you have a **merged model directory**, you can pass it as `--base_model` and omit `--adapter_dir`. @@ -279,15 +284,19 @@ Both evaluation scripts rely on a shared inference helper: Key functions: -- `load_model_for_inference(base_model, adapter_dir=None, device='auto', load_in_4bit=None, bnb_compute_dtype='float16', dtype='auto')` +- `load_model_for_inference(base_model, adapter_dir=None, device='auto', load_in_4bit=None, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_compute_dtype='float16', dtype='auto')` - Loads a base HF model or local directory. - Optionally applies LoRA adapters from `adapter_dir`. - Resolves device via: - `"auto"` → GPU if available, otherwise CPU (with a warning). - `"cuda"` / `"cpu"` for explicit control. - When running on CUDA and `load_in_4bit` is not explicitly set, the loader - defaults to 4-bit (NF4) quantization using bitsandbytes. This significantly - reduces memory usage and speeds up evaluation on Colab-style GPUs. + defaults to 4-bit quantization using bitsandbytes (NF4 + double-quant by + default). This significantly reduces memory usage and speeds up evaluation + on Colab-style GPUs. + - The `bnb_4bit_*` arguments allow you to tune quantization behavior when + needed (e.g. quantization type, compute dtype, and whether double quant is + used). - `generate_sql(prompt, max_new_tokens, temperature, top_p) -> str` - Uses the loaded model/tokenizer to generate text. diff --git a/docs/huggingface_deploy.md b/docs/huggingface_deploy.md new file mode 100644 index 0000000..3adc084 --- /dev/null +++ b/docs/huggingface_deploy.md @@ -0,0 +1,291 @@ +# Deploying the Text-to-SQL Adapter on Hugging Face + +This guide walks through publishing a trained Text-to-SQL LoRA/QLoRA adapter +to the Hugging Face Hub and deploying it via **Inference Endpoints** using +**Multi-LoRA** in Text Generation Inference (TGI). + +The recommended pattern is: + +- Host the **base model once** in an Inference Endpoint. +- Attach one or more **LoRA adapters** via the `LORA_ADAPTERS` environment + variable. +- Select the adapter at request time using `adapter_id`. + +--- + +## 1. Create a Hugging Face token + +1. Go to . +2. Create a new token with at least: + - **Write** access to model repos for publishing. + - **Read** access for inference. + +You can authenticate in two ways: + +```bash +huggingface-cli login +``` + +or set the environment variable: + +```bash +export HF_TOKEN="hf_your_token_here" +``` + +The `scripts/publish_to_hub.py` script uses this token both to validate +authentication and to push adapter artifacts. + +--- + +## 2. Create a model repository (optional) + +You can either: + +- Create a repo in the UI at , or +- Let `publish_to_hub.py` create it for you automatically. + +If you prefer CLI: + +```bash +huggingface-cli repo create your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --type model +``` + +`publish_to_hub.py` will also call `create_repo(..., exist_ok=True)` so it is +safe to rerun. + +--- + +## 3. Publish the adapter with `scripts/publish_to_hub.py` + +Assuming you have run training and have an adapter saved under: + +- `outputs/adapters/` + - `adapter_config.json` + - `adapter_model.safetensors` (or `adapter_model.bin`) + +you can publish it using: + +```bash +python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters +``` + +The script will: + +1. Validate **adapter contents**: + - Require `adapter_config.json`. + - Require `adapter_model.safetensors` **or** `adapter_model.bin`. +2. Validate **Hugging Face authentication** via `HfApi().whoami()`. +3. Create or reuse the Hub repo (`repo_type="model"`). +4. Generate a **minimal README.md** (if none exists) that includes: + - Base model name (from `adapter_config.json.base_model_name_or_path`). + - Task description: Text-to-SQL (schema + question → SQL). + - Evaluation commands (internal + Spider external). + - Deployment notes for Inference Endpoints + Multi-LoRA. +5. Upload the entire adapter directory to the model repo. + +### 3.1 Skipping or tightening README generation + +You can opt out of auto README generation: + +```bash +python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters \ + --skip_readme +``` + +- If `--skip_readme` is set: + - The script **does not** create or modify README.md. + - Any existing README.md in `adapter_dir` is left as-is. + +You can also enforce strict behavior when README generation fails: + +```bash +python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters \ + --strict_readme +``` + +- If `--strict_readme` is set: + - Any error while generating README.md will cause the script to **exit non-zero**. +- If it is **not** set (default): + - README errors are logged and the adapter upload still proceeds. + +### 3.2 Including evaluation metrics in the README + +If you have run evaluation scripts and saved metrics to a JSON report +(e.g. `reports/eval_internal.json` or `reports/eval_spider.json`), you can +embed a summary into the README: + +```bash +python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters \ + --include_metrics reports/eval_internal.json +``` + +The script expects either: + +- A raw metrics dict, or +- An object with a top-level `metrics` key (as produced by the evaluation + scripts). + +--- + +## 4. Create an Inference Endpoint with the base model + +Next, create an Inference Endpoint for the **base model** only, for example: + +- Base model: `mistralai/Mistral-7B-Instruct-v0.1` +- Task: Text Generation +- Accelerator: a GPU instance (e.g. A10G, A100) suitable for Mistral-7B +- Implementation: **Text Generation Inference (TGI)** + +You can create the endpoint via the HF UI: + +1. Go to . +2. Click **Create Endpoint**. +3. Choose the base model (e.g. `mistralai/Mistral-7B-Instruct-v0.1`). +4. Configure hardware and autoscaling. +5. In **Advanced configuration** / environment variables, you will set + `LORA_ADAPTERS` as described in the next section. + +--- + +## 5. Attach the adapter via `LORA_ADAPTERS` (Multi-LoRA) + +TGI supports loading multiple LoRA adapters using the `LORA_ADAPTERS` +environment variable, which contains a JSON array describing each adapter. + +For a single Text-to-SQL adapter, set: + +```bash +LORA_ADAPTERS='[ + {"id": "text2sql-qlora", "source": "your-username/analytics-copilot-text2sql-mistral7b-qlora"} +]' +``` + +- `id`: Logical adapter identifier you will reference as `adapter_id` in + requests (e.g. `"text2sql-qlora"`). +- `source`: Hugging Face Hub model repo containing the adapter (the same + `repo_id` you passed to `publish_to_hub.py`). + +You can set this in the endpoint’s **Environment variables** section in +the UI, or via the Inference Endpoints API. + +After saving the configuration and deploying the endpoint, TGI will: + +- Load the base Mistral-7B model. +- Load the LoRA adapter weights from the specified repo. +- Register the adapter under `id="text2sql-qlora"`. + +--- + +## 6. Sending requests with `adapter_id` + +To use the Text-to-SQL adapter at inference time, include `adapter_id` in +your request parameters. + +### 6.1 Raw HTTP request example + +Assuming your endpoint URL is: + +```text +https://your-endpoint-1234.us-east-1.aws.endpoints.huggingface.cloud +``` + +you can send a POST request: + +```bash +curl -X POST \ + -H "Authorization: Bearer $HF_TOKEN" \ + -H "Content-Type: application/json" \ + https://your-endpoint-1234.us-east-1.aws.endpoints.huggingface.cloud \ + -d '{ + "inputs": "### Schema:\n\n\n### Question:\n", + "parameters": { + "adapter_id": "text2sql-qlora", + "max_new_tokens": 256, + "temperature": 0.0 + } + }' +``` + +Key points: + +- `inputs` should contain the **schema + question** prompt in the same + format used for training/evaluation. +- `parameters.adapter_id` selects the **LoRA adapter** to apply. +- Other generation parameters (`max_new_tokens`, `temperature`, etc.) + can be tuned as usual. + +### 6.2 Using `huggingface_hub.InferenceClient` + +You can also call the endpoint from Python: + +```python +from huggingface_hub import InferenceClient + +ENDPOINT_URL = "https://your-endpoint-1234.us-east-1.aws.endpoints.huggingface.cloud" + +client = InferenceClient( + base_url=ENDPOINT_URL, + api_key="hf_your_token_here", +) + +schema = """CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER, + amount NUMERIC, + created_at TIMESTAMP +);""" + +question = "Total order amount per customer for the last 7 days." + +prompt = f"""### Schema: +{schema} + +### Question: +{question} + +Return only the SQL query.""" + +response = client.post( + json={ + "inputs": prompt, + "parameters": { + "adapter_id": "text2sql-qlora", + "max_new_tokens": 256, + "temperature": 0.0, + }, + } +) + +print(response) +``` + +Depending on how your endpoint is configured, you may also be able to use +`client.text_generation` or `client.chat_completion` with provider-specific +options for `adapter_id`. The raw `post` call shown above works with the +standard TGI JSON API. + +--- + +## 7. Summary + +- Use `scripts/publish_to_hub.py` to: + - Validate adapter files (`adapter_config.json`, adapter weights). + - Optionally generate a minimal README.md. + - Push the adapter to a Hub model repo. +- Deploy the **base model** once via an Inference Endpoint (TGI). +- Attach the adapter by setting `LORA_ADAPTERS` with the adapter repo id. +- At inference time, select the adapter using `adapter_id` in your request, + and pass the schema + question prompt in the same format as training. + +This pattern keeps deployment efficient (one base model, multiple adapters) +and makes it easy to switch between different specialized adapters without +reprovisioning new endpoints. \ No newline at end of file diff --git a/scripts/build_dataset.py b/scripts/build_dataset.py index 3abfb08..67b2459 100644 --- a/scripts/build_dataset.py +++ b/scripts/build_dataset.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: # pragma: no cover - import only for type checking from datasets import Dataset, DatasetDict # noqa: F401 -from text2sql.data_prep import format_record +from text2sql.data_prep import format_record # noqa: E402 # isort: skip logger = logging.getLogger(__name__) @@ -386,7 +386,7 @@ def main(argv: Optional[list[str]] = None) -> int: except (RuntimeError, ValueError) as exc: logger.error("Dataset build failed: %s", exc) return 1 - except Exception as exc: # noqa: BLE001 + except Exception: # noqa: BLE001 logger.error("Unexpected error during dataset build.", exc_info=True) return 1 diff --git a/scripts/check_syntax.py b/scripts/check_syntax.py index a6b9ea0..f7485d2 100644 --- a/scripts/check_syntax.py +++ b/scripts/check_syntax.py @@ -2,7 +2,6 @@ import compileall from pathlib import Path -import sys def _discover_python_files(root_dirs: list[Path]) -> list[Path]: diff --git a/scripts/evaluate_internal.py b/scripts/evaluate_internal.py index dd2db78..45c1847 100644 --- a/scripts/evaluate_internal.py +++ b/scripts/evaluate_internal.py @@ -30,6 +30,13 @@ @dataclass class EvalConfig: + """ + Configuration snapshot for an internal evaluation run. + + This is primarily used for logging and for embedding run metadata into the + JSON / Markdown reports so that results remain reproducible. + """ + val_path: str base_model: str adapter_dir: Optional[str] @@ -42,6 +49,10 @@ class EvalConfig: mock: bool load_in_4bit: Optional[bool] = None dtype: str = "auto" + bnb_4bit_quant_type: str = "nf4" + bnb_4bit_compute_dtype: str = "float16" + bnb_4bit_use_double_quant: bool = True + smoke: bool = False def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -122,11 +133,12 @@ def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: parser.add_argument( "--load_in_4bit", action="store_true", - default=None, + dest="load_in_4bit", + default=True, help=( - "If set, force loading the base model in 4-bit (bitsandbytes) for " - "faster and more memory-efficient inference. By default this is " - "enabled automatically when running on CUDA and disabled on CPU." + "Enable loading the base model in 4-bit (bitsandbytes) for faster " + "and more memory-efficient inference. By default this is enabled " + "automatically when running on CUDA and disabled on CPU." ), ) parser.add_argument( @@ -145,6 +157,43 @@ def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: "float32 on CPU." ), ) + parser.add_argument( + "--bnb_4bit_quant_type", + type=str, + default="nf4", + help=( + "Quantization type for 4-bit weights (e.g. 'nf4', 'fp4'). " + "This is passed to transformers.BitsAndBytesConfig when " + "--load_in_4bit is enabled." + ), + ) + parser.add_argument( + "--bnb_4bit_compute_dtype", + type=str, + default="auto", + choices=["auto", "float16", "bfloat16"], + help=( + "Compute dtype for 4-bit quantization. " + "'auto' chooses bfloat16 on CUDA GPUs with bf16 support, " + "otherwise float16." + ), + ) + parser.add_argument( + "--bnb_4bit_use_double_quant", + action="store_true", + dest="bnb_4bit_use_double_quant", + default=True, + help=( + "Whether to use nested (double) quantization when loading in 4-bit. " + "Enabled by default for better memory efficiency." + ), + ) + parser.add_argument( + "--no_bnb_4bit_use_double_quant", + action="store_false", + dest="bnb_4bit_use_double_quant", + help="Disable nested (double) quantization when using 4-bit weights.", + ) parser.add_argument( "--mock", action="store_true", @@ -153,6 +202,15 @@ def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: "predictions to validate the metric pipeline." ), ) + parser.add_argument( + "--smoke", + action="store_true", + help=( + "Run a lightweight smoke test over a small subset of validation " + "examples. On CPU-only environments this will automatically " + "fall back to --mock to skip GPU-only model loading." + ), + ) return parser.parse_args(argv) @@ -266,8 +324,12 @@ def _fmt_rate(entry: Dict[str, Any]) -> str: lines.append(f"- **top_p:** `{config.top_p}`") lines.append(f"- **max_new_tokens:** `{config.max_new_tokens}`") lines.append(f"- **mock:** `{config.mock}`") + lines.append(f"- **smoke:** `{config.smoke}`") lines.append(f"- **load_in_4bit:** `{config.load_in_4bit}`") lines.append(f"- **dtype:** `{config.dtype}`") + lines.append(f"- **bnb_4bit_quant_type:** `{config.bnb_4bit_quant_type}`") + lines.append(f"- **bnb_4bit_compute_dtype:** `{config.bnb_4bit_compute_dtype}`") + lines.append(f"- **bnb_4bit_use_double_quant:** `{config.bnb_4bit_use_double_quant}`") lines.append(f"- **n_evaluated_examples:** `{n_examples}`\n") lines.append("## Metrics\n") @@ -294,6 +356,11 @@ def _fmt_rate(entry: Dict[str, Any]) -> str: "- No-values EM further replaces string and numeric literals with " "placeholders, focusing on query structure rather than concrete values." ) + if config.smoke: + lines.append( + "- `--smoke` mode was enabled, so only a small subset of validation " + "examples was evaluated for a quick health check." + ) lines.append("") lines.append("## Example Predictions\n") @@ -319,14 +386,83 @@ def run_eval(args: argparse.Namespace) -> int: val_path = Path(args.val_path) records = _load_alpaca_jsonl(val_path) + original_total = len(records) + if args.max_examples is not None and args.max_examples > 0: records = records[: args.max_examples] + if args.smoke: + # In smoke mode we intentionally evaluate only a tiny subset of the + # validation set to keep the run fast and cheap. + smoke_limit = min(5, len(records)) if records else 0 + records = records[:smoke_limit] + logger.info( + "Running in --smoke mode: evaluating %d example(s) out of %d loaded " + "(max_examples=%d).", + len(records), + original_total, + args.max_examples, + ) + + # On CPU-only environments, skip GPU-only model loading by switching + # to --mock semantics automatically. + if not args.mock: + try: + import torch # type: ignore[import] + except Exception: # pragma: no cover - import-time issues + logger.warning( + "PyTorch is not available; --smoke will run in mock mode " + "without loading a model." + ) + args.mock = True + else: + if not torch.cuda.is_available(): + logger.info( + "No CUDA device detected; --smoke will run in mock mode " + "and skip model loading." + ) + args.mock = True + logger.info( - "Loaded %d validation records from %s (max_examples=%d).", + "Loaded %d validation records from %s (max_examples=%d, smoke=%s).", len(records), val_path, args.max_examples, + args.smoke, + ) + + # Resolve the compute dtype for 4-bit quantization. We accept an 'auto' + # option that prefers bfloat16 on GPUs with bf16 support and otherwise + # falls back to float16. + bnb_compute_dtype = args.bnb_4bit_compute_dtype + if bnb_compute_dtype == "auto": + resolved = "float16" + try: + import torch # type: ignore[import] + except Exception: # pragma: no cover - import-time issues + logger.warning( + "Could not import torch while resolving --bnb_4bit_compute_dtype; " + "falling back to float16." + ) + else: + has_cuda = torch.cuda.is_available() + if has_cuda and (args.device or "auto").lower() != "cpu": + is_bf16_supported = getattr(torch.cuda, "is_bf16_supported", None) + try: + if callable(is_bf16_supported) and torch.cuda.is_bf16_supported(): + resolved = "bfloat16" + except Exception: # pragma: no cover - defensive + resolved = "float16" + bnb_compute_dtype = resolved + + logger.info( + "4-bit loading configuration: load_in_4bit=%s, " + "bnb_4bit_quant_type=%s, bnb_4bit_compute_dtype=%s, " + "bnb_4bit_use_double_quant=%s", + args.load_in_4bit, + args.bnb_4bit_quant_type, + bnb_compute_dtype, + args.bnb_4bit_use_double_quant, ) eval_config = EvalConfig( @@ -342,6 +478,10 @@ def run_eval(args: argparse.Namespace) -> int: mock=args.mock, load_in_4bit=args.load_in_4bit, dtype=args.dtype, + bnb_4bit_quant_type=args.bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_compute_dtype, + bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant, + smoke=args.smoke, ) gold_sqls: List[str] = [] @@ -374,19 +514,25 @@ def run_eval(args: argparse.Namespace) -> int: logger.info( "Loading model for inference with base_model=%s, adapter_dir=%s, " - "device=%s, load_in_4bit=%s, dtype=%s", + "device=%s, load_in_4bit=%s, dtype=%s, bnb_4bit_quant_type=%s, " + "bnb_4bit_compute_dtype=%s, bnb_4bit_use_double_quant=%s", args.base_model, args.adapter_dir, args.device, args.load_in_4bit, args.dtype, + args.bnb_4bit_quant_type, + bnb_compute_dtype, + args.bnb_4bit_use_double_quant, ) load_model_for_inference( base_model=args.base_model, adapter_dir=args.adapter_dir, device=args.device, load_in_4bit=args.load_in_4bit, - bnb_compute_dtype="float16", + bnb_4bit_quant_type=args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant, + bnb_compute_dtype=bnb_compute_dtype, dtype=args.dtype, ) diff --git a/scripts/evaluate_spider_external.py b/scripts/evaluate_spider_external.py index 2ce6f28..9d4b73a 100644 --- a/scripts/evaluate_spider_external.py +++ b/scripts/evaluate_spider_external.py @@ -19,9 +19,7 @@ from text2sql.eval.metrics import aggregate_metrics # noqa: E402 # isort: skip from text2sql.eval.normalize import normalize_sql # noqa: E402 # isort: skip -from text2sql.eval.schema import parse_create_table_context # noqa: E402 # isort: skip from text2sql.eval.spider import ( # noqa: E402 # isort: skip - build_schema_map, build_spider_prompt, load_spider_schema_map, spider_schema_to_pseudo_ddl, diff --git a/scripts/publish_to_hub.py b/scripts/publish_to_hub.py new file mode 100644 index 0000000..95bfb56 --- /dev/null +++ b/scripts/publish_to_hub.py @@ -0,0 +1,570 @@ +""" +Utility script to publish QLoRA/LoRA adapter artifacts to Hugging Face Hub. + +This script is designed to be idempotent and friendly to both local +development environments and CI: + +- Validates that a Hugging Face token is available. +- Ensures the target repo exists (creates it if needed). +- Ensures a README.md model card is present under the adapter directory. +- Uploads the entire adapter folder (LoRA/QLoRA artifacts) to the Hub. + +Typical usage: + + python scripts/publish_to_hub.py \ + --repo_id your-username/analytics-copilot-text2sql-mistral7b-qlora \ + --adapter_dir outputs/adapters + +The script expects that you have already authenticated with Hugging Face, e.g.: + + huggingface-cli login + +or by setting an HF_TOKEN environment variable. +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +from huggingface_hub import HfApi # type: ignore[import] +from huggingface_hub.utils import HfHubHTTPError # type: ignore[import] + + +logger = logging.getLogger(__name__) + + +def configure_logging() -> None: + """Configure basic logging for the publish script.""" + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", + ) + + +def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: + """Parse command-line arguments for the publish-to-hub workflow.""" + parser = argparse.ArgumentParser( + description=( + "Publish QLoRA/LoRA adapter artifacts from a local directory to a " + "Hugging Face Hub model repository." + ) + ) + + parser.add_argument( + "--repo_id", + type=str, + required=True, + help=( + "Target Hugging Face Hub repository id " + "(e.g. 'username/analytics-copilot-text2sql-mistral7b-qlora')." + ), + ) + parser.add_argument( + "--adapter_dir", + type=str, + default="outputs/adapters", + help=( + "Path to the local directory containing QLoRA/LoRA adapter artifacts. " + "This directory will be uploaded as the root of the HF model repo." + ), + ) + parser.add_argument( + "--private", + action="store_true", + help="Create the target repository as private instead of public.", + ) + parser.add_argument( + "--commit_message", + type=str, + default="Add QLoRA adapter artifacts", + help="Commit message to use for the upload.", + ) + parser.add_argument( + "--include_metrics", + type=str, + default=None, + help=( + "Optional path to a JSON metrics file (e.g. reports/eval_internal.json " + "or reports/eval_spider.json). If provided, a summary will be " + "embedded into the generated README.md model card." + ), + ) + parser.add_argument( + "--skip_readme", + action="store_true", + help=( + "Skip auto-generating README.md. Existing README.md (if any) is left " + "untouched and only adapter files are uploaded." + ), + ) + parser.add_argument( + "--strict_readme", + action="store_true", + help=( + "If set, fail the publish if README generation fails. " + "By default, README generation errors are logged and the upload " + "continues." + ), + ) + + return parser.parse_args(argv) + + +def _require_hf_token(api: HfApi) -> None: + """ + Ensure that a Hugging Face token is available and valid. + + We call `whoami` as a lightweight check that authentication is correctly + configured. If this fails, we raise a RuntimeError with a clear message so + that the caller can surface it and exit with a non-zero status. + """ + try: + api.whoami() + except HfHubHTTPError as exc: + msg = ( + "Hugging Face authentication failed. " + "Please run `huggingface-cli login` or set the HF_TOKEN environment " + "variable before running scripts/publish_to_hub.py." + ) + logger.error(msg) + raise RuntimeError(msg) from exc + except Exception as exc: # noqa: BLE001 + msg = ( + "Unexpected error while checking Hugging Face authentication. " + "Ensure you have network connectivity and a valid HF token." + ) + logger.error(msg, exc_info=True) + raise RuntimeError(msg) from exc + + +def _validate_adapter_dir(adapter_dir: Path) -> Dict[str, Any]: + """ + Validate that the adapter directory contains the expected files. + + Required: + - adapter_config.json + - adapter_model.safetensors OR adapter_model.bin + + Returns the parsed adapter_config.json payload on success. + """ + config_path = adapter_dir / "adapter_config.json" + if not config_path.is_file(): + raise RuntimeError( + f"Adapter directory '{adapter_dir}' is missing 'adapter_config.json'. " + "This file is required to describe the adapter configuration." + ) + + safetensors_path = adapter_dir / "adapter_model.safetensors" + bin_path = adapter_dir / "adapter_model.bin" + + if not safetensors_path.is_file() and not bin_path.is_file(): + raise RuntimeError( + f"Adapter directory '{adapter_dir}' is missing adapter weights. " + "Expected either 'adapter_model.safetensors' or 'adapter_model.bin'." + ) + + try: + with config_path.open("r", encoding="utf-8") as f: + config = json.load(f) + except Exception as exc: # noqa: BLE001 + raise RuntimeError( + f"Failed to parse adapter config JSON at '{config_path}': {exc}" + ) from exc + + if not isinstance(config, dict): + raise RuntimeError( + f"Adapter config at '{config_path}' must be a JSON object." + ) + + return config + + +def _load_metrics(metrics_path: Path) -> Optional[Dict[str, Any]]: + """ + Load metrics from a JSON file, handling both raw metric dicts and + report-style payloads that wrap metrics under a 'metrics' key. + """ + if not metrics_path.is_file(): + logger.warning( + "Metrics file '%s' does not exist; skipping metrics section.", + metrics_path, + ) + return None + + try: + with metrics_path.open("r", encoding="utf-8") as f: + payload = json.load(f) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to parse metrics JSON file '%s'. Error: %s. " + "Continuing without embedding metrics.", + metrics_path, + exc, + ) + return None + + if isinstance(payload, dict) and "metrics" in payload and isinstance( + payload["metrics"], + dict, + ): + return payload["metrics"] + + if isinstance(payload, dict): + return payload + + logger.warning( + "Metrics file '%s' did not contain a JSON object; skipping metrics section.", + metrics_path, + ) + return None + + +def _format_metrics_section(metrics: Dict[str, Any]) -> str: + """ + Render a Markdown metrics section from a metrics dict. + + The expected schema for internal / Spider evaluation reports is: + { + "n_examples": ..., + "exact_match": {"count": ..., "rate": ...}, + "no_values_em": {...}, + "parse_success": {...}, + "schema_adherence": {...}, + } + """ + if not metrics: + return "_No metrics available._\n" + + lines: list[str] = [] + + n_examples = metrics.get("n_examples") + if n_examples is not None: + lines.append(f"- **n_examples:** {n_examples}") + + def _fmt(entry_name: str, label: str) -> None: + entry = metrics.get(entry_name) + if not isinstance(entry, dict): + return + count = entry.get("count") + rate = entry.get("rate") + if count is None or rate is None: + return + lines.append(f"- **{label}:** {count} examples ({rate:.3f})") + + _fmt("exact_match", "Exact Match (normalized SQL)") + _fmt("no_values_em", "No-values Exact Match") + _fmt("parse_success", "SQL parse success rate") + _fmt("schema_adherence", "Schema adherence rate") + + if not lines: + return ( + "_Metrics JSON did not match the expected schema; " + "see raw file for details._\n" + ) + + return "\n".join(lines) + "\n" + + +def _ensure_readme( + adapter_dir: Path, + repo_id: str, + base_model: str, + metrics_path: Optional[Path], +) -> Path: + """ + Ensure that a README.md model card exists in `adapter_dir`. + + If README.md already exists, it is left unchanged. Otherwise, a minimal, + non-LLM README is created that documents the base model, task, evaluation + scripts, and a recommended Inference Endpoint + Multi-LoRA deployment + pattern. + """ + readme_path = adapter_dir / "README.md" + if readme_path.is_file(): + logger.info("README.md already exists at %s; leaving it unchanged.", readme_path) + return readme_path + + metrics: Optional[Dict[str, Any]] = None + if metrics_path is not None: + metrics = _load_metrics(metrics_path) + + metrics_section = ( + _format_metrics_section(metrics) if metrics else "_No metrics provided._\n" + ) + + lines: list[str] = [] + + # Header and basic info. + lines.append("# Analytics Copilot Text-to-SQL – QLoRA Adapter") + lines.append("") + lines.append( + "This repository contains a LoRA/QLoRA adapter for a Text-to-SQL model. " + "It is intended to be applied on top of a base language model hosted on " + "Hugging Face Hub or in an Inference Endpoint." + ) + lines.append("") + lines.append(f"- **Base model:** `{base_model}`") + lines.append("- **Task:** Text-to-SQL (schema + question → SQL)") + lines.append(f"- **Adapter repo id:** `{repo_id}`") + lines.append("") + lines.append( + "> Note: This Hub repo typically contains adapters only (no full model " + "merge). To run the model locally you must load the base model and " + "apply these adapters." + ) + lines.append("") + lines.append("---") + lines.append("") + + # Evaluation instructions. + lines.append("## Evaluation") + lines.append("") + lines.append( + "You can evaluate the adapter using the internal and external " + "evaluation scripts in this project." + ) + lines.append("") + lines.append("### Internal evaluation (sql-create-context)") + lines.append("") + lines.append("Example command (assuming the repo layout from this project):") + lines.append("") + lines.append("```bash") + lines.append("python scripts/evaluate_internal.py \\") + lines.append(" --val_path data/processed/val.jsonl \\") + lines.append(f" --base_model {base_model} \\") + lines.append(" --adapter_dir /path/to/local/adapters \\") + lines.append(" --device auto \\") + lines.append(" --max_examples 200 \\") + lines.append(" --out_dir reports/") + lines.append("```") + lines.append("") + lines.append("### External evaluation (Spider dev)") + lines.append("") + lines.append("Example command:") + lines.append("") + lines.append("```bash") + lines.append("python scripts/evaluate_spider_external.py \\") + lines.append(f" --base_model {base_model} \\") + lines.append(" --adapter_dir /path/to/local/adapters \\") + lines.append(" --device auto \\") + lines.append(" --spider_source xlangai/spider \\") + lines.append(" --schema_source richardr1126/spider-schema \\") + lines.append(" --spider_split validation \\") + lines.append(" --max_examples 200 \\") + lines.append(" --out_dir reports/") + lines.append("```") + lines.append("") + lines.append( + "See `docs/evaluation.md` in the project for a detailed description of " + "the metrics and evaluation setup." + ) + lines.append("") + lines.append("---") + lines.append("") + + # Deployment instructions (Inference Endpoint + Multi-LoRA). + lines.append("## Deployment with Hugging Face Inference Endpoints (Multi-LoRA)") + lines.append("") + lines.append( + "A recommended way to serve this adapter is to deploy the base model once " + "using a Text Generation Inference (TGI) Inference Endpoint, and attach " + "this adapter via the `LORA_ADAPTERS` environment variable." + ) + lines.append("") + lines.append("High-level steps:") + lines.append("") + lines.append("1. Create a new Inference Endpoint based on the base model, e.g.:") + lines.append(f" - Base model: `{base_model}`") + lines.append(" - Hardware: choose a GPU instance suitable for Mistral-7B.") + lines.append("2. In the endpoint configuration, set the environment variable:") + lines.append("") + lines.append(" ```bash") + lines.append( + " LORA_ADAPTERS='[" + '{"id": "text2sql-qlora", "source": "' + repo_id + '"}' + "]'" + ) + lines.append(" ```") + lines.append("") + lines.append( + " This tells TGI to load the adapter from this Hub repo under the " + "logical adapter id `text2sql-qlora`." + ) + lines.append("3. Deploy the endpoint.") + lines.append("") + lines.append( + "At inference time, you can select the adapter by passing an `adapter_id` " + "parameter in your request. For example, using the raw HTTP API:" + ) + lines.append("") + lines.append("```json") + lines.append("{") + lines.append(' "inputs": "### Schema:\\n\\n\\n### Question:\\n",') + lines.append(' "parameters": {') + lines.append(' "adapter_id": "text2sql-qlora",') + lines.append(' "max_new_tokens": 256,') + lines.append(' "temperature": 0.0') + lines.append(" }") + lines.append("}") + lines.append("```") + lines.append("") + lines.append( + "If you are using `huggingface_hub.InferenceClient`, you can pass " + "`adapter_id` via the `extra_headers` or specific provider parameters " + "depending on the TGI configuration." + ) + lines.append("") + lines.append("---") + lines.append("") + + # Metrics section. + lines.append("## Metrics") + lines.append("") + lines.append( + "If you have run internal or Spider evaluations and passed " + "`--include_metrics` to `scripts/publish_to_hub.py`, a summary of those " + "metrics is included below." + ) + lines.append("") + if metrics: + lines.extend(metrics_section.rstrip("\n").splitlines()) + else: + lines.append("_No metrics provided._") + + content = "\n".join(lines) + "\n" + + readme_path.write_text(content, encoding="utf-8") + logger.info("Model card written to %s", readme_path) + return readme_path + + +def main(argv: Optional[list[str]] = None) -> int: + """Entry point for publishing adapter artifacts to Hugging Face Hub.""" + configure_logging() + args = parse_args(argv) + + adapter_dir = Path(args.adapter_dir) + if not adapter_dir.is_dir(): + logger.error( + "Adapter directory '%s' does not exist or is not a directory. " + "Ensure you have run training and that the path is correct.", + adapter_dir, + ) + return 1 + + # Validate adapter contents before hitting the Hub APIs. + try: + adapter_config = _validate_adapter_dir(adapter_dir) + except (RuntimeError, ValueError) as exc: + logger.error("Adapter validation failed: %s", exc) + return 1 + + base_model = adapter_config.get("base_model_name_or_path", "") + + api = HfApi() + + try: + _require_hf_token(api) + except RuntimeError: + # Error already logged with details. + return 1 + + repo_id: str = args.repo_id + + logger.info( + "Ensuring Hugging Face Hub repo '%s' exists (private=%s).", + repo_id, + args.private, + ) + try: + api.create_repo( + repo_id=repo_id, + private=args.private, + exist_ok=True, + repo_type="model", + ) + except HfHubHTTPError: + logger.error( + "Failed to create or access repository '%s' on Hugging Face Hub.", + repo_id, + exc_info=True, + ) + return 1 + except Exception as exc: # noqa: BLE001 + logger.error( + "Unexpected error while creating or accessing repository '%s': %s", + repo_id, + exc, + exc_info=True, + ) + return 1 + + metrics_path: Optional[Path] = None + if args.include_metrics is not None: + metrics_path = Path(args.include_metrics) + + if args.skip_readme: + logger.info( + "Skipping README.md generation because --skip_readme was provided." + ) + else: + try: + _ensure_readme( + adapter_dir=adapter_dir, + repo_id=repo_id, + base_model=base_model, + metrics_path=metrics_path, + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Failed to generate README.md for adapter repo '%s': %s", + repo_id, + exc, + exc_info=True, + ) + if args.strict_readme: + logger.error( + "Aborting publish because --strict_readme was set and " + "README generation failed." + ) + return 1 + logger.info( + "Continuing without auto-generated README because " + "--strict_readme was not set." + ) + + logger.info( + "Uploading contents of '%s' to Hugging Face Hub repo '%s' " + "with commit message: %s", + adapter_dir, + repo_id, + args.commit_message, + ) + + try: + api.upload_folder( + folder_path=str(adapter_dir), + repo_id=repo_id, + repo_type="model", + commit_message=args.commit_message, + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Failed to upload adapter directory '%s' to '%s': %s", + adapter_dir, + repo_id, + exc, + exc_info=True, + ) + return 1 + + logger.info("Successfully uploaded adapter artifacts to '%s'.", repo_id) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/scripts/smoke_infer_endpoint.py b/scripts/smoke_infer_endpoint.py new file mode 100644 index 0000000..3829a6e --- /dev/null +++ b/scripts/smoke_infer_endpoint.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Dict + +from huggingface_hub import InferenceClient # type: ignore[import] + + +logger = logging.getLogger(__name__) + + +def configure_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", + ) + + +def build_prompt(schema: str, question: str) -> str: + """ + Build a simple text-to-SQL prompt consistent with the main project. + + The schema and question are formatted as: + + ### Schema: + + + ### Question: + + """ + system_prompt = ( + "You are a careful text-to-SQL assistant. " + "Given a database schema and a question, you respond with a single SQL " + "query that answers the question. " + "Return ONLY the SQL query, without explanation or commentary." + ) + + user_prompt = f"""### Schema: +{schema.strip()} + +### Question: +{question.strip()} + +Return only the SQL query.""" + + return f"{system_prompt}\n\n{user_prompt}" + + +def _extract_generated_text(response: Any) -> str: + """ + Best-effort extraction of generated text from InferenceClient.text_generation. + + Handles common shapes: + - str + - {"generated_text": "..."} + - [{"generated_text": "..."}] + """ + if isinstance(response, str): + return response + + if isinstance(response, dict) and "generated_text" in response: + return str(response["generated_text"]) + + if ( + isinstance(response, list) + and response + and isinstance(response[0], dict) + and "generated_text" in response[0] + ): + return str(response[0]["generated_text"]) + + return str(response) + + +def main() -> int: + """ + Smoke test for a dedicated HF Inference Endpoint with LoRA adapters. + + Reads configuration from environment variables: + + HF_TOKEN – required + HF_ENDPOINT_URL – required + HF_ADAPTER_ID – required + + Sends a single text_generation request with adapter_id and prints the raw + response to stdout. + """ + configure_logging() + + hf_token = os.getenv("HF_TOKEN", "").strip() + endpoint_url = os.getenv("HF_ENDPOINT_URL", "").strip() or os.getenv( + "HF_INFERENCE_BASE_URL", + "", + ).strip() + adapter_id = os.getenv("HF_ADAPTER_ID", "").strip() + + if not hf_token: + logger.error("HF_TOKEN is not set in the environment.") + return 1 + if not endpoint_url: + logger.error("HF_ENDPOINT_URL (or HF_INFERENCE_BASE_URL) is not set.") + return 1 + if not adapter_id: + logger.error( + "HF_ADAPTER_ID is not set. This smoke test is intended for " + "adapter-based endpoints and always sends adapter_id." + ) + return 1 + + logger.info("Using HF endpoint: %s", endpoint_url) + logger.info("Using adapter_id: %s", adapter_id) + + client = InferenceClient(base_url=endpoint_url, api_key=hf_token, timeout=60) + + # Tiny toy schema + question for a quick smoke test. + schema = """CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER, + amount NUMERIC, + created_at TIMESTAMP +);""" + + question = "Total order amount per customer for the last 7 days." + + prompt = build_prompt(schema=schema, question=question) + + generation_kwargs: Dict[str, Any] = { + "prompt": prompt, + "max_new_tokens": 128, + "temperature": 0.0, + "adapter_id": adapter_id, + } + + try: + logger.info("Calling text_generation on the HF endpoint...") + response = client.text_generation(**generation_kwargs) + except Exception as exc: # noqa: BLE001 + logger.error("Error while calling the HF endpoint.", exc_info=True) + print(f"ERROR: {exc}") + return 1 + + text = _extract_generated_text(response) + print("=== Raw text_generation response ===") + print(text) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/scripts/train_qlora.py b/scripts/train_qlora.py index 6d78404..21a38ea 100644 --- a/scripts/train_qlora.py +++ b/scripts/train_qlora.py @@ -15,8 +15,8 @@ if str(SRC_DIR) not in sys.path: sys.path.insert(0, str(SRC_DIR)) -from text2sql.training.config import TrainingConfig -from text2sql.training.formatting import build_prompt, ensure_sql_only +from text2sql.training.config import TrainingConfig # noqa: E402 # isort: skip +from text2sql.training.formatting import build_prompt, ensure_sql_only # noqa: E402 # isort: skip logger = logging.getLogger(__name__) @@ -466,7 +466,7 @@ def main(argv: Optional[List[str]] = None) -> int: except (RuntimeError, ValueError) as exc: logger.error("Training run failed: %s", exc) return 1 - except Exception as exc: # noqa: BLE001 + except Exception: # noqa: BLE001 logger.error("Unexpected error during training run.", exc_info=True) return 1 diff --git a/src/text2sql/eval/metrics.py b/src/text2sql/eval/metrics.py index 86fa4e5..f804c55 100644 --- a/src/text2sql/eval/metrics.py +++ b/src/text2sql/eval/metrics.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Iterable, List, Mapping, Optional, Sequence +from typing import Dict, Optional, Sequence import sqlglot diff --git a/src/text2sql/infer.py b/src/text2sql/infer.py index a2d2655..e691c51 100644 --- a/src/text2sql/infer.py +++ b/src/text2sql/infer.py @@ -74,6 +74,8 @@ def load_model_for_inference( adapter_dir: Optional[str] = None, device: str = "auto", load_in_4bit: Optional[bool] = None, + bnb_4bit_quant_type: str = "nf4", + bnb_4bit_use_double_quant: bool = True, bnb_compute_dtype: str = "float16", dtype: str = "auto", ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: @@ -99,6 +101,10 @@ def load_model_for_inference( If True, attempt to load the base model using 4-bit quantization (bitsandbytes). If None, 4-bit is enabled automatically when running on CUDA and disabled otherwise. + bnb_4bit_quant_type : str, optional + Quantization type for 4-bit weights (e.g. "nf4", "fp4"). Defaults to "nf4". + bnb_4bit_use_double_quant : bool, optional + Whether to use nested (double) quantization for 4-bit weights. Defaults to True. bnb_compute_dtype : str, optional Compute dtype for 4-bit quantization (e.g. "float16", "bfloat16"). Defaults to "float16". @@ -123,13 +129,16 @@ def load_model_for_inference( logger.info( "Loading model for inference: base_model=%s, adapter_dir=%s, device=%s, " - "load_in_4bit=%s, dtype=%s, bnb_compute_dtype=%s", + "load_in_4bit=%s, dtype=%s, bnb_compute_dtype=%s, bnb_4bit_quant_type=%s, " + "bnb_4bit_use_double_quant=%s", base_model, adapter_dir, resolved_device, use_4bit, dtype, bnb_compute_dtype, + bnb_4bit_quant_type, + bnb_4bit_use_double_quant, ) tokenizer = AutoTokenizer.from_pretrained(adapter_dir or base_model) @@ -163,16 +172,18 @@ def load_model_for_inference( quant_config = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, bnb_4bit_compute_dtype=compute_dtype, ) logger.info( - "Using 4-bit NF4 quantization for base model with torch_dtype=%s and " - "compute_dtype=%s.", + "Using 4-bit quantization for base model with torch_dtype=%s, " + "compute_dtype=%s, quant_type=%s, double_quant=%s.", torch_dtype, compute_dtype, + bnb_4bit_quant_type, + bnb_4bit_use_double_quant, ) model = AutoModelForCausalLM.from_pretrained( diff --git a/tests/test_eval_cli_args.py b/tests/test_eval_cli_args.py index d796851..b683ef5 100644 --- a/tests/test_eval_cli_args.py +++ b/tests/test_eval_cli_args.py @@ -1,8 +1,6 @@ from pathlib import Path import sys -import pytest - def _ensure_root_on_path() -> None: """Ensure that the project root is available on sys.path for script imports.""" @@ -17,18 +15,28 @@ def _ensure_root_on_path() -> None: from scripts import evaluate_spider_external # noqa: E402 # isort: skip -def test_evaluate_internal_parses_4bit_flags() -> None: +def test_evaluate_internal_parses_4bit_and_smoke_flags() -> None: args = evaluate_internal.parse_args( [ "--mock", "--load_in_4bit", "--dtype", "float16", + "--bnb_4bit_quant_type", + "nf4", + "--bnb_4bit_compute_dtype", + "bfloat16", + "--no_bnb_4bit_use_double_quant", + "--smoke", ] ) assert args.mock is True assert args.load_in_4bit is True assert args.dtype == "float16" + assert args.bnb_4bit_quant_type == "nf4" + assert args.bnb_4bit_compute_dtype == "bfloat16" + assert args.bnb_4bit_use_double_quant is False + assert args.smoke is True def test_evaluate_spider_parses_4bit_flags() -> None: diff --git a/tests/test_infer_quantization.py b/tests/test_infer_quantization.py index 9b806bf..c1f6900 100644 --- a/tests/test_infer_quantization.py +++ b/tests/test_infer_quantization.py @@ -2,8 +2,6 @@ import sys from unittest import mock -import pytest - def _ensure_src_on_path() -> None: """Ensure that the 'src' directory is available on sys.path for imports.""" @@ -22,7 +20,7 @@ def test_load_model_for_inference_4bit_uses_quantization_config() -> None: actually downloading a model, and that it wires BitsAndBytesConfig through to AutoModelForCausalLM.from_pretrained. """ - import text2sql.infer as infer # noqa: WPS433 # isort: skip + import text2sql.infer as infer # isort: skip with mock.patch.object(infer, "AutoTokenizer") as mock_tok_cls, \ mock.patch.object(infer, "AutoModelForCausalLM") as mock_model_cls, \ diff --git a/tests/test_normalize_sql.py b/tests/test_normalize_sql.py index 079e3ec..06ec069 100644 --- a/tests/test_normalize_sql.py +++ b/tests/test_normalize_sql.py @@ -1,8 +1,6 @@ from pathlib import Path import sys -import pytest - def _ensure_src_on_path() -> None: """Ensure that the 'src' directory is available on sys.path for imports.""" diff --git a/tests/test_prompt_building_spider.py b/tests/test_prompt_building_spider.py index 9007766..66f4bfa 100644 --- a/tests/test_prompt_building_spider.py +++ b/tests/test_prompt_building_spider.py @@ -16,7 +16,6 @@ def _ensure_src_on_path() -> None: _ensure_src_on_path() from text2sql.eval.spider import ( # noqa: E402 # isort: skip - build_schema_map, build_spider_prompt, load_spider_schema_map, spider_schema_to_pseudo_ddl, diff --git a/tests/test_publish_to_hub.py b/tests/test_publish_to_hub.py new file mode 100644 index 0000000..36c1bca --- /dev/null +++ b/tests/test_publish_to_hub.py @@ -0,0 +1,121 @@ +from pathlib import Path +import json +import sys +from unittest import mock + +import pytest + + +def _ensure_root_on_path() -> None: + """Ensure that the project root is available on sys.path for script imports.""" + root = Path(__file__).resolve().parents[1] + if str(root) not in sys.path: + sys.path.insert(0, str(root)) + + +_ensure_root_on_path() + +from scripts import publish_to_hub # noqa: E402 # isort: skip + + +def _make_minimal_adapter_dir(tmp_path: Path) -> Path: + """Create a minimal fake adapter directory with required files.""" + adapter_dir = tmp_path / "adapters" + adapter_dir.mkdir(parents=True, exist_ok=True) + + config = {"base_model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1"} + (adapter_dir / "adapter_config.json").write_text( + json.dumps(config), + encoding="utf-8", + ) + + # Touch a fake weights file; content does not matter for validation. + (adapter_dir / "adapter_model.safetensors").write_bytes(b"") + + return adapter_dir + + +def test_publish_to_hub_skip_readme_does_not_create_readme(tmp_path: Path) -> None: + """--skip_readme should avoid README generation and still call upload_folder.""" + adapter_dir = _make_minimal_adapter_dir(tmp_path) + + with mock.patch.object(publish_to_hub, "HfApi") as mock_hfapi_cls: + api_instance = mock_hfapi_cls.return_value + api_instance.whoami.return_value = {"name": "tester"} + api_instance.create_repo.return_value = None + api_instance.upload_folder.return_value = None + + rc = publish_to_hub.main( + [ + "--repo_id", + "user/test-adapter", + "--adapter_dir", + str(adapter_dir), + "--skip_readme", + ] + ) + + assert rc == 0 + # README.md should not be auto-created when --skip_readme is used. + assert not (adapter_dir / "README.md").exists() + api_instance.upload_folder.assert_called_once() + + +def test_publish_to_hub_creates_readme_when_missing(tmp_path: Path) -> None: + """When README.md is missing, the script should create a minimal README.""" + adapter_dir = _make_minimal_adapter_dir(tmp_path) + readme_path = adapter_dir / "README.md" + assert not readme_path.exists() + + with mock.patch.object(publish_to_hub, "HfApi") as mock_hfapi_cls: + api_instance = mock_hfapi_cls.return_value + api_instance.whoami.return_value = {"name": "tester"} + api_instance.create_repo.return_value = None + api_instance.upload_folder.return_value = None + + rc = publish_to_hub.main( + [ + "--repo_id", + "user/test-adapter", + "--adapter_dir", + str(adapter_dir), + ] + ) + + assert rc == 0 + assert readme_path.is_file() + content = readme_path.read_text(encoding="utf-8") + # Basic sanity checks on README contents. + assert "Text-to-SQL" in content + assert "Deployment with Hugging Face Inference Endpoints" in content + assert "LORA_ADAPTERS" in content + + +def test_validate_adapter_dir_missing_config_raises(tmp_path: Path) -> None: + """Adapter validation should fail fast if adapter_config.json is missing.""" + adapter_dir = tmp_path / "adapters_missing_config" + adapter_dir.mkdir(parents=True, exist_ok=True) + # Only weights present. + (adapter_dir / "adapter_model.safetensors").write_bytes(b"") + + with pytest.raises(RuntimeError) as excinfo: + publish_to_hub._validate_adapter_dir(adapter_dir) # type: ignore[attr-defined] + + assert "adapter_config.json" in str(excinfo.value) + + +def test_validate_adapter_dir_missing_weights_raises(tmp_path: Path) -> None: + """Adapter validation should fail fast if no adapter weights are present.""" + adapter_dir = tmp_path / "adapters_missing_weights" + adapter_dir.mkdir(parents=True, exist_ok=True) + config = {"base_model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1"} + (adapter_dir / "adapter_config.json").write_text( + json.dumps(config), + encoding="utf-8", + ) + + with pytest.raises(RuntimeError) as excinfo: + publish_to_hub._validate_adapter_dir(adapter_dir) # type: ignore[attr-defined] + + msg = str(excinfo.value) + assert "adapter_model.safetensors" in msg or "adapter_model.bin" in msg \ No newline at end of file diff --git a/tests/test_schema_adherence.py b/tests/test_schema_adherence.py index fab57d4..4c9d4fc 100644 --- a/tests/test_schema_adherence.py +++ b/tests/test_schema_adherence.py @@ -1,8 +1,6 @@ from pathlib import Path import sys -import pytest - def _ensure_src_on_path() -> None: """Ensure that the 'src' directory is available on sys.path for imports.""" diff --git a/tests/test_streamlit_config.py b/tests/test_streamlit_config.py new file mode 100644 index 0000000..bdb6d62 --- /dev/null +++ b/tests/test_streamlit_config.py @@ -0,0 +1,78 @@ +from pathlib import Path +import sys + +from typing import Dict, Any + + +def _ensure_root_on_path() -> None: + """Ensure that the project root is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + if str(root) not in sys.path: + sys.path.insert(0, str(root)) + + +_ensure_root_on_path() + +from app import streamlit_app # noqa: E402 # isort: skip + + +def test_resolve_hf_config_prefers_secrets_over_env() -> None: + secrets: Dict[str, Any] = { + "HF_TOKEN": "secret-token", + "HF_ENDPOINT_URL": "https://endpoint-from-secrets", + "HF_ADAPTER_ID": "adapter-from-secrets", + "HF_MODEL_ID": "model-from-secrets", + "HF_PROVIDER": "hf-inference", + } + environ = { + "HF_TOKEN": "env-token", + "HF_ENDPOINT_URL": "https://endpoint-from-env", + "HF_ADAPTER_ID": "adapter-from-env", + "HF_MODEL_ID": "model-from-env", + "HF_PROVIDER": "env-provider", + } + + cfg = streamlit_app._resolve_hf_config(secrets=secrets, environ=environ) + + assert cfg.hf_token == "secret-token" + assert cfg.endpoint_url == "https://endpoint-from-secrets" + assert cfg.adapter_id == "adapter-from-secrets" + assert cfg.model_id == "model-from-secrets" + assert cfg.provider == "hf-inference" + + +def test_resolve_hf_config_falls_back_to_env_when_secrets_missing() -> None: + secrets: Dict[str, Any] = {} + environ = { + "HF_TOKEN": "env-token", + "HF_ENDPOINT_URL": "https://endpoint-from-env", + "HF_ADAPTER_ID": "adapter-from-env", + "HF_MODEL_ID": "model-from-env", + "HF_PROVIDER": "env-provider", + } + + cfg = streamlit_app._resolve_hf_config(secrets=secrets, environ=environ) + + assert cfg.hf_token == "env-token" + assert cfg.endpoint_url == "https://endpoint-from-env" + assert cfg.adapter_id == "adapter-from-env" + assert cfg.model_id == "model-from-env" + assert cfg.provider == "env-provider" + + +def test_resolve_hf_config_router_mode_without_endpoint() -> None: + secrets: Dict[str, Any] = { + "HF_TOKEN": "secret-token", + "HF_MODEL_ID": "model-from-secrets", + } + environ: Dict[str, str] = {} + + cfg = streamlit_app._resolve_hf_config(secrets=secrets, environ=environ) + + assert cfg.hf_token == "secret-token" + assert cfg.endpoint_url == "" + assert cfg.model_id == "model-from-secrets" + # Default provider should fall back to "auto" + assert cfg.provider == "auto" + # No adapter id when not configured + assert cfg.adapter_id is None \ No newline at end of file