|
42 | 42 |
|
43 | 43 | from __future__ import annotations |
44 | 44 |
|
| 45 | +import json |
45 | 46 | import logging |
46 | 47 | import os |
47 | 48 | from typing import Any, Mapping, NamedTuple, Optional, Tuple |
48 | 49 |
|
49 | 50 | import streamlit as st |
50 | 51 | from huggingface_hub import InferenceClient # type: ignore[import] |
| 52 | +from openai import OpenAI # type: ignore[import] |
51 | 53 |
|
52 | 54 |
|
53 | 55 | logger = logging.getLogger(__name__) |
@@ -130,6 +132,72 @@ def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: |
130 | 132 | ) |
131 | 133 |
|
132 | 134 |
|
| 135 | +def _get_openai_settings() -> Tuple[str, str]: |
| 136 | + """ |
| 137 | + Resolve OpenAI fallback settings from Streamlit secrets and environment variables. |
| 138 | +
|
| 139 | + Secrets take precedence over environment variables. The model name falls back |
| 140 | + to "gpt-5-nano" when not configured explicitly. |
| 141 | + """ |
| 142 | + try: |
| 143 | + secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment] |
| 144 | + except Exception: # noqa: BLE001 |
| 145 | + secrets = {} |
| 146 | + |
| 147 | + def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: |
| 148 | + try: |
| 149 | + value = mapping.get(key) # type: ignore[attr-defined] |
| 150 | + except Exception: # noqa: BLE001 |
| 151 | + value = None |
| 152 | + if value is None: |
| 153 | + return "" |
| 154 | + return str(value).strip() |
| 155 | + |
| 156 | + api_key = _get_from_mapping(secrets, "OPENAI_API_KEY") or os.environ.get( |
| 157 | + "OPENAI_API_KEY", |
| 158 | + "", |
| 159 | + ).strip() |
| 160 | + |
| 161 | + model_name = _get_from_mapping(secrets, "OPENAI_FALLBACK_MODEL") or os.environ.get( |
| 162 | + "OPENAI_FALLBACK_MODEL", |
| 163 | + "", |
| 164 | + ).strip() |
| 165 | + |
| 166 | + if not model_name: |
| 167 | + model_name = "gpt-5-nano" |
| 168 | + |
| 169 | + return api_key, model_name |
| 170 | + |
| 171 | + |
| 172 | +def _is_openai_strict_json_enabled() -> bool: |
| 173 | + """ |
| 174 | + Return True if structured JSON mode is enabled for the OpenAI fallback. |
| 175 | +
|
| 176 | + Controlled by OPENAI_FALLBACK_STRICT_JSON in Streamlit secrets or environment |
| 177 | + variables. Defaults to False. |
| 178 | + """ |
| 179 | + try: |
| 180 | + secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment] |
| 181 | + except Exception: # noqa: BLE001 |
| 182 | + secrets = {} |
| 183 | + |
| 184 | + def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str: |
| 185 | + try: |
| 186 | + value = mapping.get(key) # type: ignore[attr-defined] |
| 187 | + except Exception: # noqa: BLE001 |
| 188 | + value = None |
| 189 | + if value is None: |
| 190 | + return "" |
| 191 | + return str(value).strip() |
| 192 | + |
| 193 | + raw_value = _get_from_mapping(secrets, "OPENAI_FALLBACK_STRICT_JSON") or os.environ.get( |
| 194 | + "OPENAI_FALLBACK_STRICT_JSON", |
| 195 | + "", |
| 196 | + ).strip() |
| 197 | + |
| 198 | + return raw_value.lower() in {"1", "true", "yes", "on"} |
| 199 | + |
| 200 | + |
133 | 201 | @st.cache_resource(show_spinner=False) |
134 | 202 | def _get_cached_client( |
135 | 203 | hf_token: str, |
@@ -162,6 +230,157 @@ def _get_cached_client( |
162 | 230 | ) |
163 | 231 |
|
164 | 232 |
|
| 233 | +def _openai_response_text(resp: Any) -> str: |
| 234 | + """ |
| 235 | + Extract text content from an OpenAI Responses API response object. |
| 236 | +
|
| 237 | + 1) Prefer the convenience `output_text` attribute when present and non-empty. |
| 238 | + 2) Otherwise, iterate over `resp.output` and aggregate any text content blocks. |
| 239 | +
|
| 240 | + The Responses API typically exposes text as: |
| 241 | + resp.output[0].content[0].text.value |
| 242 | + but this helper is defensive and also supports dict-like structures. |
| 243 | + """ |
| 244 | + # 1) Convenience attribute, when present. |
| 245 | + try: |
| 246 | + value = getattr(resp, "output_text", None) |
| 247 | + except Exception: # noqa: BLE001 |
| 248 | + value = None |
| 249 | + |
| 250 | + if isinstance(value, str) and value.strip(): |
| 251 | + return value.strip() |
| 252 | + |
| 253 | + # 2) Walk the output structure. |
| 254 | + try: |
| 255 | + output = getattr(resp, "output", None) |
| 256 | + except Exception: # noqa: BLE001 |
| 257 | + output = None |
| 258 | + |
| 259 | + if not output: |
| 260 | + return "" |
| 261 | + |
| 262 | + chunks: list[str] = [] |
| 263 | + |
| 264 | + try: |
| 265 | + for item in output: |
| 266 | + # Handle both object-style and dict-style access to content. |
| 267 | + content_list: Any |
| 268 | + if isinstance(item, dict): |
| 269 | + content_list = item.get("content") |
| 270 | + else: |
| 271 | + try: |
| 272 | + content_list = getattr(item, "content", None) |
| 273 | + except Exception: # noqa: BLE001 |
| 274 | + content_list = None |
| 275 | + |
| 276 | + if not content_list: |
| 277 | + continue |
| 278 | + |
| 279 | + for content in content_list: |
| 280 | + # Retrieve the "text" field from the content object or dict. |
| 281 | + text_obj: Any = None |
| 282 | + if isinstance(content, dict): |
| 283 | + text_obj = content.get("text") |
| 284 | + else: |
| 285 | + try: |
| 286 | + text_obj = getattr(content, "text", None) |
| 287 | + except Exception: # noqa: BLE001 |
| 288 | + text_obj = None |
| 289 | + |
| 290 | + if text_obj is None: |
| 291 | + continue |
| 292 | + |
| 293 | + # Unwrap the actual string value, handling nested objects/dicts. |
| 294 | + text_val: Optional[str] = None |
| 295 | + if isinstance(text_obj, str): |
| 296 | + text_val = text_obj |
| 297 | + else: |
| 298 | + # Object-style: content.text.value |
| 299 | + try: |
| 300 | + inner_val = getattr(text_obj, "value", None) |
| 301 | + except Exception: # noqa: BLE001 |
| 302 | + inner_val = None |
| 303 | + if isinstance(inner_val, str): |
| 304 | + text_val = inner_val |
| 305 | + elif isinstance(text_obj, dict): |
| 306 | + # Dict-style: {"text": {"value": "..."}} or similar. |
| 307 | + inner_val = text_obj.get("value") |
| 308 | + if isinstance(inner_val, str): |
| 309 | + text_val = inner_val |
| 310 | + |
| 311 | + if isinstance(text_val, str) and text_val: |
| 312 | + chunks.append(text_val) |
| 313 | + except Exception: # noqa: BLE001 |
| 314 | + logger.exception("Failed to extract text from OpenAI response.output.") |
| 315 | + return "" |
| 316 | + |
| 317 | + return "".join(chunks).strip() |
| 318 | + |
| 319 | + |
| 320 | +def _call_openai_fallback( |
| 321 | + system_prompt: str, |
| 322 | + user_prompt: str, |
| 323 | +) -> str: |
| 324 | + """ |
| 325 | + Call the OpenAI Responses API as a fallback when HF inference fails. |
| 326 | +
|
| 327 | + Uses the cheapest nano model (default: gpt-5-nano) and avoids passing |
| 328 | + unsupported parameters such as temperature or top_p. |
| 329 | + """ |
| 330 | + api_key, model_name = _get_openai_settings() |
| 331 | + if not api_key: |
| 332 | + raise RuntimeError("OpenAI fallback not configured") |
| 333 | + |
| 334 | + client = OpenAI(api_key=api_key) |
| 335 | + full_input = f"{system_prompt}\n\n{user_prompt}" |
| 336 | + |
| 337 | + extra_kwargs: dict[str, Any] = {} |
| 338 | + strict_json = _is_openai_strict_json_enabled() |
| 339 | + if strict_json: |
| 340 | + extra_kwargs["response_format"] = { |
| 341 | + "type": "json_schema", |
| 342 | + "json_schema": { |
| 343 | + "name": "sql_fallback", |
| 344 | + "schema": { |
| 345 | + "type": "object", |
| 346 | + "properties": { |
| 347 | + "sql": {"type": "string"}, |
| 348 | + }, |
| 349 | + "required": ["sql"], |
| 350 | + "additionalProperties": False, |
| 351 | + }, |
| 352 | + "strict": True, |
| 353 | + }, |
| 354 | + } |
| 355 | + |
| 356 | + response = client.responses.create( |
| 357 | + model=model_name, |
| 358 | + input=full_input, |
| 359 | + **extra_kwargs, |
| 360 | + ) |
| 361 | + |
| 362 | + raw_text = _openai_response_text(response) |
| 363 | + |
| 364 | + if strict_json and raw_text: |
| 365 | + try: |
| 366 | + parsed = json.loads(raw_text) |
| 367 | + if isinstance(parsed, dict): |
| 368 | + sql_val = parsed.get("sql") |
| 369 | + if isinstance(sql_val, str) and sql_val.strip(): |
| 370 | + return sql_val.strip() |
| 371 | + except Exception: # noqa: BLE001 |
| 372 | + logger.exception( |
| 373 | + "Failed to parse structured JSON output from OpenAI fallback; " |
| 374 | + "falling back to raw text.", |
| 375 | + ) |
| 376 | + |
| 377 | + if not raw_text: |
| 378 | + logger.warning("OpenAI fallback returned empty text output.") |
| 379 | + return "" |
| 380 | + |
| 381 | + return raw_text.strip() |
| 382 | + |
| 383 | + |
165 | 384 | def _build_prompt(schema: str, question: str) -> Tuple[str, str]: |
166 | 385 | """ |
167 | 386 | Build the system and user prompt content for text-to-SQL generation. |
@@ -249,6 +468,17 @@ def _create_inference_client(timeout_s: int = 45) -> Tuple[InferenceClient, HFCo |
249 | 468 | return client, hf_config |
250 | 469 |
|
251 | 470 |
|
| 471 | +def _extract_sql_from_text(raw_text: str) -> str: |
| 472 | + """ |
| 473 | + Extract SQL from model output text. |
| 474 | +
|
| 475 | + Currently implemented as a simple pass-through that returns the |
| 476 | + stripped text. Defined as a separate helper so it can be extended |
| 477 | + or replaced in tests without changing call sites. |
| 478 | + """ |
| 479 | + return raw_text.strip() |
| 480 | + |
| 481 | + |
252 | 482 | def _call_model( |
253 | 483 | client: InferenceClient, |
254 | 484 | schema: str, |
@@ -283,14 +513,40 @@ def _call_model( |
283 | 513 | try: |
284 | 514 | response = client.text_generation(**generation_kwargs) |
285 | 515 | except Exception as exc: # noqa: BLE001 |
286 | | - logger.error("Error while calling Hugging Face Inference API.", exc_info=True) |
287 | | - st.error( |
288 | | - "The Hugging Face Inference endpoint did not respond successfully. " |
289 | | - "This can happen if the endpoint is cold, overloaded, or misconfigured. " |
290 | | - "Please try again, or check your HF endpoint / model settings." |
291 | | - ) |
292 | | - st.caption(f"Details: {exc}") |
293 | | - return None, user_prompt |
| 516 | + message = str(exc) if exc is not None else "" |
| 517 | + lower_message = message.lower() |
| 518 | + if "endpoint is paused" in lower_message: |
| 519 | + logger.warning( |
| 520 | + "Hugging Face Inference endpoint is paused; using OpenAI fallback.", |
| 521 | + ) |
| 522 | + else: |
| 523 | + logger.exception( |
| 524 | + "Error while calling Hugging Face Inference API. " |
| 525 | + "Attempting OpenAI fallback if configured.", |
| 526 | + ) |
| 527 | + try: |
| 528 | + raw_text = _call_openai_fallback( |
| 529 | + system_prompt=system_prompt, |
| 530 | + user_prompt=user_prompt, |
| 531 | + ) |
| 532 | + except Exception: # noqa: BLE001 |
| 533 | + logger.exception("OpenAI fallback inference failed.") |
| 534 | + st.error( |
| 535 | + "The service is temporarily unavailable. Please try again in a moment." |
| 536 | + ) |
| 537 | + return None, user_prompt |
| 538 | + |
| 539 | + if not raw_text.strip(): |
| 540 | + sql_text = "-- Fallback provider returned empty output. Try again." |
| 541 | + else: |
| 542 | + extracted = _extract_sql_from_text(raw_text) |
| 543 | + if extracted and extracted.strip(): |
| 544 | + sql_text = extracted.strip() |
| 545 | + else: |
| 546 | + sql_text = raw_text.strip() |
| 547 | + |
| 548 | + st.caption("Using backup inference provider") |
| 549 | + return sql_text, user_prompt |
294 | 550 |
|
295 | 551 | # InferenceClient.text_generation may return a string, a dict, or a list. |
296 | 552 | try: |
|
0 commit comments