-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfast_api_demo.py
More file actions
84 lines (65 loc) · 2.65 KB
/
fast_api_demo.py
File metadata and controls
84 lines (65 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pathlib
import requests
from dotenv import load_dotenv
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_agent
# Load environment variables
load_dotenv()
# Download database if it doesn't exist
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("Chinook.db")
if not local_path.exists():
response = requests.get(url)
if response.status_code == 200:
local_path.write_bytes(response.content)
print(f"Database downloaded: {local_path}")
else:
raise Exception(f"Failed to download database (status {response.status_code})")
# Initialize model and database
model = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.7)
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
toolkit = SQLDatabaseToolkit(db=db, llm=model)
tools = toolkit.get_tools()
# Updated system prompt to prevent exposing DB internals
system_prompt = f"""
You are a merchant agent designed to answer questions about the data in a SQL database.
Do NOT reveal table names, column names, database schema, or any internal database details.
You can only provide answers to user questions using the data content.
Always limit query results to at most 5 rows.
Do NOT run any DML statements (INSERT, UPDATE, DELETE, DROP, etc.).
If a question asks for internal structure, respond politely that you cannot reveal it.
"""
# Create agent
agent = create_agent(model, tools, system_prompt=system_prompt)
# FastAPI app
app = FastAPI(title="SQL Agent API")
# Enable CORS if needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Request model
class QueryRequest(BaseModel):
question: str
# API endpoint
@app.post("/ask")
async def ask_question(request: QueryRequest):
question = request.question.strip()
# Simple sanitization: block questions explicitly asking for schema or table names
forbidden_keywords = ["schema", "table", "tables", "columns", "database structure"]
if any(word in question.lower() for word in forbidden_keywords):
return {"answer": "I'm sorry, I cannot provide internal database structure details."}
final_answer = ""
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
final_answer = step["messages"][-1].text
return {"answer": final_answer}