Skip to content
10 changes: 9 additions & 1 deletion .streamlit/secrets.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ HF_PROVIDER = "auto"

# Compatibility: older name for the endpoint base URL. The app will treat this
# as an alias for HF_ENDPOINT_URL if set.
HF_INFERENCE_BASE_URL = ""
HF_INFERENCE_BASE_URL = ""

# Optional: OpenAI fallback settings used when HF Inference requests fail.
OPENAI_API_KEY = ""
OPENAI_FALLBACK_MODEL = "gpt-5-nano"
# When set to a truthy value (\"true\", \"1\", \"yes\"), the app will use structured
# JSON outputs for the fallback ({\"sql\": \"...\"}) and attempt to parse the SQL
# field. If parsing fails, it falls back to plain text handling.
OPENAI_FALLBACK_STRICT_JSON = "false"
272 changes: 264 additions & 8 deletions app/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@

from __future__ import annotations

import json
import logging
import os
from typing import Any, Mapping, NamedTuple, Optional, Tuple

import streamlit as st
from huggingface_hub import InferenceClient # type: ignore[import]
from openai import OpenAI # type: ignore[import]


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,6 +132,72 @@ def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
)


def _get_openai_settings() -> Tuple[str, str]:
"""
Resolve OpenAI fallback settings from Streamlit secrets and environment variables.

Secrets take precedence over environment variables. The model name falls back
to "gpt-5-nano" when not configured explicitly.
"""
try:
secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment]
except Exception: # noqa: BLE001
secrets = {}

def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
try:
value = mapping.get(key) # type: ignore[attr-defined]
except Exception: # noqa: BLE001
value = None
if value is None:
return ""
return str(value).strip()

api_key = _get_from_mapping(secrets, "OPENAI_API_KEY") or os.environ.get(
"OPENAI_API_KEY",
"",
).strip()

model_name = _get_from_mapping(secrets, "OPENAI_FALLBACK_MODEL") or os.environ.get(
"OPENAI_FALLBACK_MODEL",
"",
).strip()

if not model_name:
model_name = "gpt-5-nano"

return api_key, model_name


def _is_openai_strict_json_enabled() -> bool:
"""
Return True if structured JSON mode is enabled for the OpenAI fallback.

Controlled by OPENAI_FALLBACK_STRICT_JSON in Streamlit secrets or environment
variables. Defaults to False.
"""
try:
secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment]
except Exception: # noqa: BLE001
secrets = {}

def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
try:
value = mapping.get(key) # type: ignore[attr-defined]
except Exception: # noqa: BLE001
value = None
if value is None:
return ""
return str(value).strip()

raw_value = _get_from_mapping(secrets, "OPENAI_FALLBACK_STRICT_JSON") or os.environ.get(
"OPENAI_FALLBACK_STRICT_JSON",
"",
).strip()

return raw_value.lower() in {"1", "true", "yes", "on"}


@st.cache_resource(show_spinner=False)
def _get_cached_client(
hf_token: str,
Expand Down Expand Up @@ -162,6 +230,157 @@ def _get_cached_client(
)


def _openai_response_text(resp: Any) -> str:
"""
Extract text content from an OpenAI Responses API response object.

1) Prefer the convenience `output_text` attribute when present and non-empty.
2) Otherwise, iterate over `resp.output` and aggregate any text content blocks.

The Responses API typically exposes text as:
resp.output[0].content[0].text.value
but this helper is defensive and also supports dict-like structures.
"""
# 1) Convenience attribute, when present.
try:
value = getattr(resp, "output_text", None)
except Exception: # noqa: BLE001
value = None

if isinstance(value, str) and value.strip():
return value.strip()

# 2) Walk the output structure.
try:
output = getattr(resp, "output", None)
except Exception: # noqa: BLE001
output = None

if not output:
return ""

chunks: list[str] = []

try:
for item in output:
# Handle both object-style and dict-style access to content.
content_list: Any
if isinstance(item, dict):
content_list = item.get("content")
else:
try:
content_list = getattr(item, "content", None)
except Exception: # noqa: BLE001
content_list = None

if not content_list:
continue

for content in content_list:
# Retrieve the "text" field from the content object or dict.
text_obj: Any = None
if isinstance(content, dict):
text_obj = content.get("text")
else:
try:
text_obj = getattr(content, "text", None)
except Exception: # noqa: BLE001
text_obj = None

if text_obj is None:
continue

# Unwrap the actual string value, handling nested objects/dicts.
text_val: Optional[str] = None
if isinstance(text_obj, str):
text_val = text_obj
else:
# Object-style: content.text.value
try:
inner_val = getattr(text_obj, "value", None)
except Exception: # noqa: BLE001
inner_val = None
if isinstance(inner_val, str):
text_val = inner_val
elif isinstance(text_obj, dict):
# Dict-style: {"text": {"value": "..."}} or similar.
inner_val = text_obj.get("value")
if isinstance(inner_val, str):
text_val = inner_val

if isinstance(text_val, str) and text_val:
chunks.append(text_val)
except Exception: # noqa: BLE001
logger.exception("Failed to extract text from OpenAI response.output.")
return ""

return "".join(chunks).strip()


def _call_openai_fallback(
system_prompt: str,
user_prompt: str,
) -> str:
"""
Call the OpenAI Responses API as a fallback when HF inference fails.

Uses the cheapest nano model (default: gpt-5-nano) and avoids passing
unsupported parameters such as temperature or top_p.
"""
api_key, model_name = _get_openai_settings()
if not api_key:
raise RuntimeError("OpenAI fallback not configured")

client = OpenAI(api_key=api_key)
full_input = f"{system_prompt}\n\n{user_prompt}"

extra_kwargs: dict[str, Any] = {}
strict_json = _is_openai_strict_json_enabled()
if strict_json:
extra_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "sql_fallback",
"schema": {
"type": "object",
"properties": {
"sql": {"type": "string"},
},
"required": ["sql"],
"additionalProperties": False,
},
"strict": True,
},
}

response = client.responses.create(
model=model_name,
input=full_input,
**extra_kwargs,
)

raw_text = _openai_response_text(response)

if strict_json and raw_text:
try:
parsed = json.loads(raw_text)
if isinstance(parsed, dict):
sql_val = parsed.get("sql")
if isinstance(sql_val, str) and sql_val.strip():
return sql_val.strip()
except Exception: # noqa: BLE001
logger.exception(
"Failed to parse structured JSON output from OpenAI fallback; "
"falling back to raw text.",
)

if not raw_text:
logger.warning("OpenAI fallback returned empty text output.")
return ""

return raw_text.strip()


def _build_prompt(schema: str, question: str) -> Tuple[str, str]:
"""
Build the system and user prompt content for text-to-SQL generation.
Expand Down Expand Up @@ -249,6 +468,17 @@ def _create_inference_client(timeout_s: int = 45) -> Tuple[InferenceClient, HFCo
return client, hf_config


def _extract_sql_from_text(raw_text: str) -> str:
"""
Extract SQL from model output text.

Currently implemented as a simple pass-through that returns the
stripped text. Defined as a separate helper so it can be extended
or replaced in tests without changing call sites.
"""
return raw_text.strip()


def _call_model(
client: InferenceClient,
schema: str,
Expand Down Expand Up @@ -283,14 +513,40 @@ def _call_model(
try:
response = client.text_generation(**generation_kwargs)
except Exception as exc: # noqa: BLE001
logger.error("Error while calling Hugging Face Inference API.", exc_info=True)
st.error(
"The Hugging Face Inference endpoint did not respond successfully. "
"This can happen if the endpoint is cold, overloaded, or misconfigured. "
"Please try again, or check your HF endpoint / model settings."
)
st.caption(f"Details: {exc}")
return None, user_prompt
message = str(exc) if exc is not None else ""
lower_message = message.lower()
if "endpoint is paused" in lower_message:
logger.warning(
"Hugging Face Inference endpoint is paused; using OpenAI fallback.",
)
else:
logger.exception(
"Error while calling Hugging Face Inference API. "
"Attempting OpenAI fallback if configured.",
)
try:
raw_text = _call_openai_fallback(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
except Exception: # noqa: BLE001
logger.exception("OpenAI fallback inference failed.")
st.error(
"The service is temporarily unavailable. Please try again in a moment."
)
return None, user_prompt

if not raw_text.strip():
sql_text = "-- Fallback provider returned empty output. Try again."
else:
extracted = _extract_sql_from_text(raw_text)
if extracted and extracted.strip():
sql_text = extracted.strip()
else:
sql_text = raw_text.strip()

st.caption("Using backup inference provider")
return sql_text, user_prompt

# InferenceClient.text_generation may return a string, a dict, or a list.
try:
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ pytest>=8.0.0
huggingface-hub>=0.36.0

# --- Evaluation helpers ---
sqlglot>=28.5.0
sqlglot>=28.5.0

# --- Inference fallbacks ---
openai>=2.15.0
Loading