Skip to content

Commit 1ddc41e

Browse files
authored
Merge pull request #6 from brej-29/cosine/feat/openai-nano-fallback-hf-endpoint
Add OpenAI nano fallback for HF inference outages (silent UI)
2 parents 687ca67 + 37eba61 commit 1ddc41e

6 files changed

Lines changed: 543 additions & 11 deletions

File tree

.streamlit/secrets.toml.example

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,12 @@ HF_PROVIDER = "auto"
2525

2626
# Compatibility: older name for the endpoint base URL. The app will treat this
2727
# as an alias for HF_ENDPOINT_URL if set.
28-
HF_INFERENCE_BASE_URL = ""
28+
HF_INFERENCE_BASE_URL = ""
29+
30+
# Optional: OpenAI fallback settings used when HF Inference requests fail.
31+
OPENAI_API_KEY = ""
32+
OPENAI_FALLBACK_MODEL = "gpt-5-nano"
33+
# When set to a truthy value (\"true\", \"1\", \"yes\"), the app will use structured
34+
# JSON outputs for the fallback ({\"sql\": \"...\"}) and attempt to parse the SQL
35+
# field. If parsing fails, it falls back to plain text handling.
36+
OPENAI_FALLBACK_STRICT_JSON = "false"

app/streamlit_app.py

Lines changed: 264 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@
4242

4343
from __future__ import annotations
4444

45+
import json
4546
import logging
4647
import os
4748
from typing import Any, Mapping, NamedTuple, Optional, Tuple
4849

4950
import streamlit as st
5051
from huggingface_hub import InferenceClient # type: ignore[import]
52+
from openai import OpenAI # type: ignore[import]
5153

5254

5355
logger = logging.getLogger(__name__)
@@ -130,6 +132,72 @@ def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
130132
)
131133

132134

135+
def _get_openai_settings() -> Tuple[str, str]:
136+
"""
137+
Resolve OpenAI fallback settings from Streamlit secrets and environment variables.
138+
139+
Secrets take precedence over environment variables. The model name falls back
140+
to "gpt-5-nano" when not configured explicitly.
141+
"""
142+
try:
143+
secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment]
144+
except Exception: # noqa: BLE001
145+
secrets = {}
146+
147+
def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
148+
try:
149+
value = mapping.get(key) # type: ignore[attr-defined]
150+
except Exception: # noqa: BLE001
151+
value = None
152+
if value is None:
153+
return ""
154+
return str(value).strip()
155+
156+
api_key = _get_from_mapping(secrets, "OPENAI_API_KEY") or os.environ.get(
157+
"OPENAI_API_KEY",
158+
"",
159+
).strip()
160+
161+
model_name = _get_from_mapping(secrets, "OPENAI_FALLBACK_MODEL") or os.environ.get(
162+
"OPENAI_FALLBACK_MODEL",
163+
"",
164+
).strip()
165+
166+
if not model_name:
167+
model_name = "gpt-5-nano"
168+
169+
return api_key, model_name
170+
171+
172+
def _is_openai_strict_json_enabled() -> bool:
173+
"""
174+
Return True if structured JSON mode is enabled for the OpenAI fallback.
175+
176+
Controlled by OPENAI_FALLBACK_STRICT_JSON in Streamlit secrets or environment
177+
variables. Defaults to False.
178+
"""
179+
try:
180+
secrets: Mapping[str, Any] = st.secrets # type: ignore[assignment]
181+
except Exception: # noqa: BLE001
182+
secrets = {}
183+
184+
def _get_from_mapping(mapping: Mapping[str, Any], key: str) -> str:
185+
try:
186+
value = mapping.get(key) # type: ignore[attr-defined]
187+
except Exception: # noqa: BLE001
188+
value = None
189+
if value is None:
190+
return ""
191+
return str(value).strip()
192+
193+
raw_value = _get_from_mapping(secrets, "OPENAI_FALLBACK_STRICT_JSON") or os.environ.get(
194+
"OPENAI_FALLBACK_STRICT_JSON",
195+
"",
196+
).strip()
197+
198+
return raw_value.lower() in {"1", "true", "yes", "on"}
199+
200+
133201
@st.cache_resource(show_spinner=False)
134202
def _get_cached_client(
135203
hf_token: str,
@@ -162,6 +230,157 @@ def _get_cached_client(
162230
)
163231

164232

233+
def _openai_response_text(resp: Any) -> str:
234+
"""
235+
Extract text content from an OpenAI Responses API response object.
236+
237+
1) Prefer the convenience `output_text` attribute when present and non-empty.
238+
2) Otherwise, iterate over `resp.output` and aggregate any text content blocks.
239+
240+
The Responses API typically exposes text as:
241+
resp.output[0].content[0].text.value
242+
but this helper is defensive and also supports dict-like structures.
243+
"""
244+
# 1) Convenience attribute, when present.
245+
try:
246+
value = getattr(resp, "output_text", None)
247+
except Exception: # noqa: BLE001
248+
value = None
249+
250+
if isinstance(value, str) and value.strip():
251+
return value.strip()
252+
253+
# 2) Walk the output structure.
254+
try:
255+
output = getattr(resp, "output", None)
256+
except Exception: # noqa: BLE001
257+
output = None
258+
259+
if not output:
260+
return ""
261+
262+
chunks: list[str] = []
263+
264+
try:
265+
for item in output:
266+
# Handle both object-style and dict-style access to content.
267+
content_list: Any
268+
if isinstance(item, dict):
269+
content_list = item.get("content")
270+
else:
271+
try:
272+
content_list = getattr(item, "content", None)
273+
except Exception: # noqa: BLE001
274+
content_list = None
275+
276+
if not content_list:
277+
continue
278+
279+
for content in content_list:
280+
# Retrieve the "text" field from the content object or dict.
281+
text_obj: Any = None
282+
if isinstance(content, dict):
283+
text_obj = content.get("text")
284+
else:
285+
try:
286+
text_obj = getattr(content, "text", None)
287+
except Exception: # noqa: BLE001
288+
text_obj = None
289+
290+
if text_obj is None:
291+
continue
292+
293+
# Unwrap the actual string value, handling nested objects/dicts.
294+
text_val: Optional[str] = None
295+
if isinstance(text_obj, str):
296+
text_val = text_obj
297+
else:
298+
# Object-style: content.text.value
299+
try:
300+
inner_val = getattr(text_obj, "value", None)
301+
except Exception: # noqa: BLE001
302+
inner_val = None
303+
if isinstance(inner_val, str):
304+
text_val = inner_val
305+
elif isinstance(text_obj, dict):
306+
# Dict-style: {"text": {"value": "..."}} or similar.
307+
inner_val = text_obj.get("value")
308+
if isinstance(inner_val, str):
309+
text_val = inner_val
310+
311+
if isinstance(text_val, str) and text_val:
312+
chunks.append(text_val)
313+
except Exception: # noqa: BLE001
314+
logger.exception("Failed to extract text from OpenAI response.output.")
315+
return ""
316+
317+
return "".join(chunks).strip()
318+
319+
320+
def _call_openai_fallback(
321+
system_prompt: str,
322+
user_prompt: str,
323+
) -> str:
324+
"""
325+
Call the OpenAI Responses API as a fallback when HF inference fails.
326+
327+
Uses the cheapest nano model (default: gpt-5-nano) and avoids passing
328+
unsupported parameters such as temperature or top_p.
329+
"""
330+
api_key, model_name = _get_openai_settings()
331+
if not api_key:
332+
raise RuntimeError("OpenAI fallback not configured")
333+
334+
client = OpenAI(api_key=api_key)
335+
full_input = f"{system_prompt}\n\n{user_prompt}"
336+
337+
extra_kwargs: dict[str, Any] = {}
338+
strict_json = _is_openai_strict_json_enabled()
339+
if strict_json:
340+
extra_kwargs["response_format"] = {
341+
"type": "json_schema",
342+
"json_schema": {
343+
"name": "sql_fallback",
344+
"schema": {
345+
"type": "object",
346+
"properties": {
347+
"sql": {"type": "string"},
348+
},
349+
"required": ["sql"],
350+
"additionalProperties": False,
351+
},
352+
"strict": True,
353+
},
354+
}
355+
356+
response = client.responses.create(
357+
model=model_name,
358+
input=full_input,
359+
**extra_kwargs,
360+
)
361+
362+
raw_text = _openai_response_text(response)
363+
364+
if strict_json and raw_text:
365+
try:
366+
parsed = json.loads(raw_text)
367+
if isinstance(parsed, dict):
368+
sql_val = parsed.get("sql")
369+
if isinstance(sql_val, str) and sql_val.strip():
370+
return sql_val.strip()
371+
except Exception: # noqa: BLE001
372+
logger.exception(
373+
"Failed to parse structured JSON output from OpenAI fallback; "
374+
"falling back to raw text.",
375+
)
376+
377+
if not raw_text:
378+
logger.warning("OpenAI fallback returned empty text output.")
379+
return ""
380+
381+
return raw_text.strip()
382+
383+
165384
def _build_prompt(schema: str, question: str) -> Tuple[str, str]:
166385
"""
167386
Build the system and user prompt content for text-to-SQL generation.
@@ -249,6 +468,17 @@ def _create_inference_client(timeout_s: int = 45) -> Tuple[InferenceClient, HFCo
249468
return client, hf_config
250469

251470

471+
def _extract_sql_from_text(raw_text: str) -> str:
472+
"""
473+
Extract SQL from model output text.
474+
475+
Currently implemented as a simple pass-through that returns the
476+
stripped text. Defined as a separate helper so it can be extended
477+
or replaced in tests without changing call sites.
478+
"""
479+
return raw_text.strip()
480+
481+
252482
def _call_model(
253483
client: InferenceClient,
254484
schema: str,
@@ -283,14 +513,40 @@ def _call_model(
283513
try:
284514
response = client.text_generation(**generation_kwargs)
285515
except Exception as exc: # noqa: BLE001
286-
logger.error("Error while calling Hugging Face Inference API.", exc_info=True)
287-
st.error(
288-
"The Hugging Face Inference endpoint did not respond successfully. "
289-
"This can happen if the endpoint is cold, overloaded, or misconfigured. "
290-
"Please try again, or check your HF endpoint / model settings."
291-
)
292-
st.caption(f"Details: {exc}")
293-
return None, user_prompt
516+
message = str(exc) if exc is not None else ""
517+
lower_message = message.lower()
518+
if "endpoint is paused" in lower_message:
519+
logger.warning(
520+
"Hugging Face Inference endpoint is paused; using OpenAI fallback.",
521+
)
522+
else:
523+
logger.exception(
524+
"Error while calling Hugging Face Inference API. "
525+
"Attempting OpenAI fallback if configured.",
526+
)
527+
try:
528+
raw_text = _call_openai_fallback(
529+
system_prompt=system_prompt,
530+
user_prompt=user_prompt,
531+
)
532+
except Exception: # noqa: BLE001
533+
logger.exception("OpenAI fallback inference failed.")
534+
st.error(
535+
"The service is temporarily unavailable. Please try again in a moment."
536+
)
537+
return None, user_prompt
538+
539+
if not raw_text.strip():
540+
sql_text = "-- Fallback provider returned empty output. Try again."
541+
else:
542+
extracted = _extract_sql_from_text(raw_text)
543+
if extracted and extracted.strip():
544+
sql_text = extracted.strip()
545+
else:
546+
sql_text = raw_text.strip()
547+
548+
st.caption("Using backup inference provider")
549+
return sql_text, user_prompt
294550

295551
# InferenceClient.text_generation may return a string, a dict, or a list.
296552
try:

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@ pytest>=8.0.0
2121
huggingface-hub>=0.36.0
2222

2323
# --- Evaluation helpers ---
24-
sqlglot>=28.5.0
24+
sqlglot>=28.5.0
25+
26+
# --- Inference fallbacks ---
27+
openai>=2.15.0

0 commit comments

Comments
 (0)