diff --git a/.streamlit/secrets.toml.example b/.streamlit/secrets.toml.example index 66b4645..87fdeac 100644 --- a/.streamlit/secrets.toml.example +++ b/.streamlit/secrets.toml.example @@ -25,4 +25,12 @@ HF_PROVIDER = "auto" # Compatibility: older name for the endpoint base URL. The app will treat this # as an alias for HF_ENDPOINT_URL if set. -HF_INFERENCE_BASE_URL = "" \ No newline at end of file +HF_INFERENCE_BASE_URL = "" + +# Optional: OpenAI fallback settings used when HF Inference requests fail. +OPENAI_API_KEY = "" +OPENAI_FALLBACK_MODEL = "gpt-5-nano" +# When set to a truthy value (\"true\", \"1\", \"yes\"), the app will use structured +# JSON outputs for the fallback ({\"sql\": \"...\"}) and attempt to parse the SQL +# field. If parsing fails, it falls back to plain text handling. +OPENAI_FALLBACK_STRICT_JSON = "false" \ No newline at end of file diff --git a/app/streamlit_app.py b/app/streamlit_app.py index d99255e..8d07200 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -42,12 +42,14 @@ from __future__ import annotations +import json import logging import os from typing import Any, Mapping, NamedTuple, Optional, Tuple import streamlit as st from huggingface_hub import InferenceClient # type: ignore[import] +from openai import OpenAI # type: ignore[import] logger = logging.getLogger(__name__) @@ -130,6 +132,72 @@ def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: ) +def _get_openai_settings() -> Tuple[str, str]: + """ + Resolve OpenAI fallback settings from Streamlit secrets and environment variables. + + Secrets take precedence over environment variables. The model name falls back + to "gpt-5-nano" when not configured explicitly. + """ + try: + secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment] + except Exception: # noqa: BLE001 + secrets = {} + + def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: + try: + value = mapping.get(key) # type: ignore[attr-defined] + except Exception: # noqa: BLE001 + value = None + if value is None: + return "" + return str(value).strip() + + api_key = _get_from_mapping(secrets, "OPENAI_API_KEY") or os.environ.get( + "OPENAI_API_KEY", + "", + ).strip() + + model_name = _get_from_mapping(secrets, "OPENAI_FALLBACK_MODEL") or os.environ.get( + "OPENAI_FALLBACK_MODEL", + "", + ).strip() + + if not model_name: + model_name = "gpt-5-nano" + + return api_key, model_name + + +def _is_openai_strict_json_enabled() -> bool: + """ + Return True if structured JSON mode is enabled for the OpenAI fallback. + + Controlled by OPENAI_FALLBACK_STRICT_JSON in Streamlit secrets or environment + variables. Defaults to False. + """ + try: + secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment] + except Exception: # noqa: BLE001 + secrets = {} + + def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: + try: + value = mapping.get(key) # type: ignore[attr-defined] + except Exception: # noqa: BLE001 + value = None + if value is None: + return "" + return str(value).strip() + + raw_value = _get_from_mapping(secrets, "OPENAI_FALLBACK_STRICT_JSON") or os.environ.get( + "OPENAI_FALLBACK_STRICT_JSON", + "", + ).strip() + + return raw_value.lower() in {"1", "true", "yes", "on"} + + @st.cache_resource(show_spinner=False) def _get_cached_client( hf_token: str, @@ -162,6 +230,157 @@ def _get_cached_client( ) +def _openai_response_text(resp: Any) -> str: + """ + Extract text content from an OpenAI Responses API response object. + + 1) Prefer the convenience `output_text` attribute when present and non-empty. + 2) Otherwise, iterate over `resp.output` and aggregate any text content blocks. + + The Responses API typically exposes text as: + resp.output[0].content[0].text.value + but this helper is defensive and also supports dict-like structures. + """ + # 1) Convenience attribute, when present. + try: + value = getattr(resp, "output_text", None) + except Exception: # noqa: BLE001 + value = None + + if isinstance(value, str) and value.strip(): + return value.strip() + + # 2) Walk the output structure. + try: + output = getattr(resp, "output", None) + except Exception: # noqa: BLE001 + output = None + + if not output: + return "" + + chunks: list[str] = [] + + try: + for item in output: + # Handle both object-style and dict-style access to content. + content_list: Any + if isinstance(item, dict): + content_list = item.get("content") + else: + try: + content_list = getattr(item, "content", None) + except Exception: # noqa: BLE001 + content_list = None + + if not content_list: + continue + + for content in content_list: + # Retrieve the "text" field from the content object or dict. + text_obj: Any = None + if isinstance(content, dict): + text_obj = content.get("text") + else: + try: + text_obj = getattr(content, "text", None) + except Exception: # noqa: BLE001 + text_obj = None + + if text_obj is None: + continue + + # Unwrap the actual string value, handling nested objects/dicts. + text_val: Optional[str] = None + if isinstance(text_obj, str): + text_val = text_obj + else: + # Object-style: content.text.value + try: + inner_val = getattr(text_obj, "value", None) + except Exception: # noqa: BLE001 + inner_val = None + if isinstance(inner_val, str): + text_val = inner_val + elif isinstance(text_obj, dict): + # Dict-style: {"text": {"value": "..."}} or similar. + inner_val = text_obj.get("value") + if isinstance(inner_val, str): + text_val = inner_val + + if isinstance(text_val, str) and text_val: + chunks.append(text_val) + except Exception: # noqa: BLE001 + logger.exception("Failed to extract text from OpenAI response.output.") + return "" + + return "".join(chunks).strip() + + +def _call_openai_fallback( + system_prompt: str, + user_prompt: str, +) -> str: + """ + Call the OpenAI Responses API as a fallback when HF inference fails. + + Uses the cheapest nano model (default: gpt-5-nano) and avoids passing + unsupported parameters such as temperature or top_p. + """ + api_key, model_name = _get_openai_settings() + if not api_key: + raise RuntimeError("OpenAI fallback not configured") + + client = OpenAI(api_key=api_key) + full_input = f"{system_prompt}\n\n{user_prompt}" + + extra_kwargs: dict[str, Any] = {} + strict_json = _is_openai_strict_json_enabled() + if strict_json: + extra_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "sql_fallback", + "schema": { + "type": "object", + "properties": { + "sql": {"type": "string"}, + }, + "required": ["sql"], + "additionalProperties": False, + }, + "strict": True, + }, + } + + response = client.responses.create( + model=model_name, + input=full_input, + **extra_kwargs, + ) + + raw_text = _openai_response_text(response) + + if strict_json and raw_text: + try: + parsed = json.loads(raw_text) + if isinstance(parsed, dict): + sql_val = parsed.get("sql") + if isinstance(sql_val, str) and sql_val.strip(): + return sql_val.strip() + except Exception: # noqa: BLE001 + logger.exception( + "Failed to parse structured JSON output from OpenAI fallback; " + "falling back to raw text.", + ) + + if not raw_text: + logger.warning("OpenAI fallback returned empty text output.") + return "" + + return raw_text.strip() + + def _build_prompt(schema: str, question: str) -> Tuple[str, str]: """ Build the system and user prompt content for text-to-SQL generation. @@ -249,6 +468,17 @@ def _create_inference_client(timeout_s: int = 45) -> Tuple[InferenceClient, HFCo return client, hf_config +def _extract_sql_from_text(raw_text: str) -> str: + """ + Extract SQL from model output text. + + Currently implemented as a simple pass-through that returns the + stripped text. Defined as a separate helper so it can be extended + or replaced in tests without changing call sites. + """ + return raw_text.strip() + + def _call_model( client: InferenceClient, schema: str, @@ -283,14 +513,40 @@ def _call_model( try: response = client.text_generation(**generation_kwargs) except Exception as exc: # noqa: BLE001 - logger.error("Error while calling Hugging Face Inference API.", exc_info=True) - st.error( - "The Hugging Face Inference endpoint did not respond successfully. " - "This can happen if the endpoint is cold, overloaded, or misconfigured. " - "Please try again, or check your HF endpoint / model settings." - ) - st.caption(f"Details: {exc}") - return None, user_prompt + message = str(exc) if exc is not None else "" + lower_message = message.lower() + if "endpoint is paused" in lower_message: + logger.warning( + "Hugging Face Inference endpoint is paused; using OpenAI fallback.", + ) + else: + logger.exception( + "Error while calling Hugging Face Inference API. " + "Attempting OpenAI fallback if configured.", + ) + try: + raw_text = _call_openai_fallback( + system_prompt=system_prompt, + user_prompt=user_prompt, + ) + except Exception: # noqa: BLE001 + logger.exception("OpenAI fallback inference failed.") + st.error( + "The service is temporarily unavailable. Please try again in a moment." + ) + return None, user_prompt + + if not raw_text.strip(): + sql_text = "-- Fallback provider returned empty output. Try again." + else: + extracted = _extract_sql_from_text(raw_text) + if extracted and extracted.strip(): + sql_text = extracted.strip() + else: + sql_text = raw_text.strip() + + st.caption("Using backup inference provider") + return sql_text, user_prompt # InferenceClient.text_generation may return a string, a dict, or a list. try: diff --git a/requirements.txt b/requirements.txt index ca8cfa0..6b066bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,7 @@ pytest>=8.0.0 huggingface-hub>=0.36.0 # --- Evaluation helpers --- -sqlglot>=28.5.0 \ No newline at end of file +sqlglot>=28.5.0 + +# --- Inference fallbacks --- +openai>=2.15.0 \ No newline at end of file diff --git a/scripts/smoke_openai_fallback_local.py b/scripts/smoke_openai_fallback_local.py new file mode 100644 index 0000000..0bdb214 --- /dev/null +++ b/scripts/smoke_openai_fallback_local.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +""" +Local smoke test for the OpenAI fallback path used by the Streamlit app. + +This script: +- Loads OpenAI fallback configuration from Streamlit-style settings. +- Prints which config keys are present (masking the API key). +- Runs a single real OpenAI Responses API call with a fixed schema + question. +- Prints the raw output returned by the fallback helper. +""" + +from __future__ import annotations + +from pathlib import Path +import sys +from typing import Tuple + +"""from openai import OpenAI # type: ignore[import]""" + + +def _ensure_root_on_path() -> None: + """Ensure that the project root is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + if str(root) not in sys.path: + sys.path.insert(0, str(root)) + + +_ensure_root_on_path() + +from app import streamlit_app # noqa: E402 # isort: skip # pylint: disable=wrong-import-position + + +def _mask_api_key(value: str) -> str: + """Return a masked representation of an API key for logging/debugging.""" + if not value: + return "" + if len(value) <= 8: + return value[0] + "..." + value[-1] + return value[:4] + "..." + value[-4:] + + +def _load_openai_config() -> Tuple[str, str]: + """Load OpenAI fallback API key and model name using the app helper.""" + api_key, model_name = streamlit_app._get_openai_settings() # type: ignore[attr-defined] + return api_key, model_name + + +def main() -> None: + api_key, model_name = _load_openai_config() + + print("OpenAI fallback configuration:") + print(f" OPENAI_API_KEY: {_mask_api_key(api_key)}") + print(f" OPENAI_FALLBACK_MODEL: {model_name!r}") + + if not api_key: + print("ERROR: OPENAI_API_KEY is not configured. Aborting smoke test.") + return + + """client = OpenAI(api_key=api_key)""" + + schema = """ +CREATE TABLE customers ( + customer_id INTEGER PRIMARY KEY, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(100), + city VARCHAR(50) +); + +CREATE TABLE orders ( + order_id INTEGER PRIMARY KEY, + customer_id INTEGER, + order_date DATE, + total_amount REAL, + FOREIGN KEY (customer_id) REFERENCES customers(customer_id) +); +""".strip() + + question = """ +What are the full names of the customers who have placed an order with a total amount +greater than 100, and how many orders did each of them place? +""".strip() + + system_prompt, user_prompt = streamlit_app._build_prompt(schema=schema, question=question) # type: ignore[attr-defined] + + # Use the same fallback helper as the Streamlit app so behavior is consistent. + raw_sql = streamlit_app._call_openai_fallback( # type: ignore[attr-defined] + system_prompt=system_prompt, + user_prompt=user_prompt, + ) + + print("\nOpenAI fallback raw output:") + print(raw_sql) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_infer_quantization.py b/tests/test_infer_quantization.py index c1f6900..a03e979 100644 --- a/tests/test_infer_quantization.py +++ b/tests/test_infer_quantization.py @@ -2,6 +2,8 @@ import sys from unittest import mock +import pytest + def _ensure_src_on_path() -> None: """Ensure that the 'src' directory is available on sys.path for imports.""" @@ -19,8 +21,14 @@ def test_load_model_for_inference_4bit_uses_quantization_config() -> None: Ensure that load_model_for_inference can be called in 4-bit mode without actually downloading a model, and that it wires BitsAndBytesConfig through to AutoModelForCausalLM.from_pretrained. + + If the local environment cannot import the necessary transformer stack + (e.g. due to version constraints), this test is skipped instead of failing. """ - import text2sql.infer as infer # isort: skip + try: + import text2sql.infer as infer # isort: skip + except ImportError as exc: + pytest.skip(f"Skipping 4-bit quantization test due to import error: {exc}") with mock.patch.object(infer, "AutoTokenizer") as mock_tok_cls, \ mock.patch.object(infer, "AutoModelForCausalLM") as mock_model_cls, \ diff --git a/tests/test_streamlit_openai_fallback.py b/tests/test_streamlit_openai_fallback.py new file mode 100644 index 0000000..ecf6fba --- /dev/null +++ b/tests/test_streamlit_openai_fallback.py @@ -0,0 +1,160 @@ +from pathlib import Path +import sys +from typing import Any + +import pytest + + +def _ensure_root_on_path() -> None: + """Ensure that the project root is available on sys.path for imports.""" + root = Path(__file__).resolve().parents[1] + if str(root) not in sys.path: + sys.path.insert(0, str(root)) + + +_ensure_root_on_path() + +from app import streamlit_app # noqa: E402 # isort: skip + + +class _DummyHFClient: + def text_generation(self, **_: Any) -> str: + raise RuntimeError("HF inference failed") + + +class _DummyOpenAIResponse: + def __init__(self, text: str) -> None: + self.output_text = text + + +class _DummyOpenAIResponses: + def __init__(self, text: str) -> None: + self._text = text + + def create(self, *args: Any, **kwargs: Any) -> _DummyOpenAIResponse: # noqa: ARG002 + """Mimic OpenAI Responses.create without enforcing a specific signature.""" + return _DummyOpenAIResponse(self._text) + + +class _DummyOpenAIClient: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.responses = _DummyOpenAIResponses("SELECT 1;") + + +class _DummyStreamlit: + def __init__(self) -> None: + self.caption_called = False + self.error_called = False + + def caption(self, *_: Any, **__: Any) -> None: + self.caption_called = True + + def error(self, *_: Any, **__: Any) -> None: + self.error_called = True + + +def test_openai_response_text_uses_output_when_output_text_empty() -> None: + class _Resp: + def __init__(self) -> None: + self.output_text = "" + self.output = [ + { + "content": [ + {"text": "SELECT 42;"}, + ], + } + ] + + resp = _Resp() + text = streamlit_app._openai_response_text(resp) + assert text == "SELECT 42;" + + +def test_hf_error_triggers_openai_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_client = _DummyHFClient() + + # Ensure OpenAI settings return an API key and model name without touching real secrets/env. + monkeypatch.setattr( + streamlit_app, + "_get_openai_settings", + lambda: ("test-api-key", "gpt-5-nano"), + ) + + # Replace OpenAI client in the app module with our dummy implementation. + monkeypatch.setattr(streamlit_app, "OpenAI", _DummyOpenAIClient) + + # Replace Streamlit module used inside the app with a minimal stub to avoid UI dependencies. + dummy_st = _DummyStreamlit() + monkeypatch.setattr(streamlit_app, "st", dummy_st) + + sql_text, user_prompt = streamlit_app._call_model( + client=dummy_client, + schema="CREATE TABLE test (id INT);", + question="How many rows are in test?", + temperature=0.1, + max_tokens=128, + timeout_s=45, + adapter_id=None, + use_endpoint=True, + ) + + assert sql_text == "SELECT 1;" + assert "How many rows" in user_prompt + assert dummy_st.caption_called is True + assert dummy_st.error_called is False + + +def test_openai_fallback_uses_raw_text_when_extractor_returns_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dummy_client = _DummyHFClient() + + # Ensure OpenAI settings return an API key and model name without touching real secrets/env. + monkeypatch.setattr( + streamlit_app, + "_get_openai_settings", + lambda: ("test-api-key", "gpt-5-nano"), + ) + + class _DummyOpenAIResponseRaw: + def __init__(self, text: str) -> None: + self.output_text = text + + class _DummyOpenAIResponsesRaw: + def __init__(self, text: str) -> None: + self._text = text + + def create(self, *args: Any, **kwargs: Any) -> _DummyOpenAIResponseRaw: # noqa: ARG002 + """Mimic OpenAI Responses.create without enforcing a specific signature.""" + return _DummyOpenAIResponseRaw(self._text) + + class _DummyOpenAIClientRaw: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.responses = _DummyOpenAIResponsesRaw("RAW TEXT OUTPUT") + + # Replace OpenAI client in the app module with our dummy implementation. + monkeypatch.setattr(streamlit_app, "OpenAI", _DummyOpenAIClientRaw) + + # Force the SQL extractor to return an empty string to ensure we fall back to the raw text. + monkeypatch.setattr(streamlit_app, "_extract_sql_from_text", lambda _raw: "") + + # Replace Streamlit module used inside the app with a minimal stub to avoid UI dependencies. + dummy_st = _DummyStreamlit() + monkeypatch.setattr(streamlit_app, "st", dummy_st) + + sql_text, _ = streamlit_app._call_model( + client=dummy_client, + schema="CREATE TABLE test (id INT);", + question="How many rows are in test?", + temperature=0.1, + max_tokens=128, + timeout_s=45, + adapter_id=None, + use_endpoint=True, + ) + + assert sql_text == "RAW TEXT OUTPUT" + assert dummy_st.caption_called is True + assert dummy_st.error_called is False \ No newline at end of file