-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
176 lines (141 loc) · 5.71 KB
/
app.py
File metadata and controls
176 lines (141 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import os
import pandas as pd
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from sqlalchemy import create_engine, text, inspect
# Load environment variables
load_dotenv()
# SQL Server Configuration
SERVER_NAME = st.secrets.get("HOST", "SHRABANI")
DATABASE_NAME = st.secrets.get("DATABASE", "demo")
USERNAME = st.secrets.get("USER", "python_user")
PASSWORD = st.secrets.get("PWD", "test")
CSV_PATH = st.secrets.get("path", "")
# Initialize session state for database connection and schema
if 'engine' not in st.session_state:
st.session_state.engine = None
if 'schema_info' not in st.session_state:
st.session_state.schema_info = None
@st.cache_resource
def get_sql_engine():
"""Create SQL Server connection"""
try:
# Connection string for SQL Server
conn_string = f"mssql+pyodbc://{USERNAME}:{PASSWORD}@{SERVER_NAME}/{DATABASE_NAME}?driver=ODBC+Driver+17+for+SQL+Server"
engine = create_engine(conn_string)
# Test connection
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return engine
except Exception as e:
st.error(f"Failed to connect to SQL Server: {e}")
return None
@st.cache_data
def get_database_schema(_engine):
"""Retrieve database schema information"""
if not _engine:
return ""
try:
inspector = inspect(_engine)
schema_info = "Database Tables:\n\n"
for table_name in inspector.get_table_names():
schema_info += f"Table: {table_name}\n"
columns = inspector.get_columns(table_name)
for column in columns:
schema_info += f" - {column['name']} ({column['type']})\n"
schema_info += "\n"
return schema_info
except Exception as e:
st.error(f"Error retrieving schema: {e}")
return ""
# Streamlit UI
st.set_page_config(page_title="AI SQL Chatbot", layout="wide")
st.title("🤖 AI SQL Chatbot")
st.markdown(f"Connected to SQL Server: **{SERVER_NAME}** | Database: **{DATABASE_NAME}**")
# Initialize OpenAI
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
st.error("❌ OPENAI_API_KEY not found in .env file")
st.stop()
# Connect to database
engine = get_sql_engine()
if not engine:
st.error("❌ Could not connect to SQL Server. Check your connection credentials.")
st.stop()
# Get database schema
schema_info = get_database_schema(engine)
if not schema_info:
st.error("❌ Could not retrieve database schema.")
st.stop()
# Display schema in sidebar
with st.sidebar:
st.header("Database Schema")
st.text_area("Available Tables and Columns", schema_info, height=400, disabled=True)
# Initialize LLM
model = ChatOpenAI(model="gpt-4", temperature=0, openai_api_key=openai_key)
# Create system prompt with schema context
system_prompt = f"""You are an expert SQL Server database assistant. Your job is to convert natural language questions into SQL Server T-SQL queries.
Important rules:
1. Only generate SELECT queries (no INSERT, UPDATE, DELETE, DROP, etc.)
2. Always use parameterized queries when possible
3. Include meaningful column aliases
4. Optimize for readability and performance
5. If the question is ambiguous, make reasonable assumptions and add a comment explaining them
Available Database Schema:
{schema_info}
Generate ONLY the SQL query, no explanations."""
prompt_template = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{input}")
])
# Chat interface
st.header("Ask Your Question")
user_question = st.text_area(
"Enter your question in natural language:",
placeholder="E.g., Show me all customers from New York with orders over $1000"
)
col1, col2 = st.columns([1, 1])
with col1:
generate_btn = st.button("🔍 Generate SQL Query", key="generate")
with col2:
clear_btn = st.button("🗑️ Clear", key="clear")
if clear_btn:
st.session_state.user_question = ""
st.rerun()
if generate_btn and user_question:
try:
st.divider()
# Generate SQL query
with st.spinner("Generating SQL query..."):
chain = prompt_template | model | StrOutputParser()
sql_query = chain.invoke({"input": user_question})
# Display generated query
st.subheader("Generated SQL Query:")
st.code(sql_query, language="sql")
# Execute query
st.subheader("Query Results:")
try:
with st.spinner("Executing query..."):
with engine.connect() as conn:
result = conn.execute(text(sql_query))
df = pd.DataFrame(result.fetchall(), columns=result.keys())
if df.empty:
st.info("Query executed successfully but returned no results.")
else:
st.dataframe(df, use_container_width=True)
# Download option
csv = df.to_csv(index=False)
st.download_button(
label="📥 Download Results as CSV",
data=csv,
file_name="query_results.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"❌ Error executing query: {str(e)}")
st.info("Please review the generated SQL query and try again.")
except Exception as e:
st.error(f"❌ Error generating SQL query: {str(e)}")