-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsql_assistant.py
More file actions
1236 lines (1042 loc) · 56.4 KB
/
sql_assistant.py
File metadata and controls
1236 lines (1042 loc) · 56.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import logging
import json
from typing import List, Dict, Any, Tuple, Optional
import mysql.connector
import sqlparse
from mysql.connector import FieldType
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from google import genai
from dotenv import load_dotenv
import functools
import uuid
from fastapi_sessions.frontends.implementations import SessionCookie, CookieParameters
from fastapi_sessions.backends.implementations import InMemoryBackend
from fastapi_sessions.session_verifier import SessionVerifier
from contextlib import asynccontextmanager
import re
from starlette.staticfiles import StaticFiles
# --- Configuration ---
load_dotenv() # Load environment variables from .env file
# Logging Configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Database Configuration with defaults
MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_USER = os.getenv("MYSQL_USER", "root")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "root")
# Gemini API Configuration
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
GEMINI_MODEL_NAME = "gemini-2.5-flash-lite-preview-06-17"
SESSION_SECRET_KEY = os.getenv("SESSION_SECRET_KEY", str(uuid.uuid4()))
# Variable to track if Gemini API is initialized
gemini_initialized = False
# Global Gemini Client
gemini_client: Optional[genai.client.Client] = None
# In-memory store for chat history (a list of message dictionaries)
MAX_HISTORY_LENGTH = 20 # Max number of user/model turn pairs to keep
# Path to .env file
ENV_FILE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
# --- Session Management Setup ---
class SessionData(BaseModel):
history: List[Dict[str, Any]] = []
cookie_params = CookieParameters()
# Uses UUID
cookie = SessionCookie(
cookie_name="session_id",
identifier="general_verifier",
auto_error=True,
secret_key=SESSION_SECRET_KEY,
cookie_params=cookie_params,
)
session_backend = InMemoryBackend[uuid.UUID, SessionData]()
class SessionManager(SessionVerifier[uuid.UUID, SessionData]):
def __init__(
self,
*,
identifier: str,
auto_error: bool,
backend: InMemoryBackend[uuid.UUID, SessionData],
auth_http_exception: HTTPException,
):
self._identifier = identifier
self._auto_error = auto_error
self._backend = backend
self._auth_http_exception = auth_http_exception
@property
def identifier(self):
return self._identifier
@property
def backend(self):
return self._backend
@property
def auto_error(self):
return self._auto_error
@property
def auth_http_exception(self):
return self._auth_http_exception
def verify_session(self, model: SessionData) -> bool:
"""If the session exists, it is valid"""
return True
session_verifier = SessionManager(
identifier="general_verifier",
auto_error=True,
backend=session_backend,
auth_http_exception=HTTPException(status_code=403, detail="Invalid session"),
)
# --- Chat History Management (Now operates on a session) ---
def add_to_history(session_data: SessionData, role: str, text: str):
"""
Adds a message to the chat history for a given session and truncates it.
"""
session_data.history.append({"role": role, "parts": [{"text": text}]})
# Truncate history if it becomes too long
if len(session_data.history) > MAX_HISTORY_LENGTH * 2:
session_data.history = session_data.history[2:]
logger.info(f"Chat history truncated. New length: {len(session_data.history)}")
def clear_chat_history():
"""Clears the global chat history."""
global chat_history_store
chat_history_store = []
logger.info("Chat history has been cleared.")
# Function to initialize Gemini API
def initialize_gemini_api():
"""
Initializes the Gemini API and validates the key by making a test call.
"""
global gemini_initialized, gemini_client
try:
if GEMINI_API_KEY:
# The new SDK uses a client object for all interactions.
gemini_client = genai.Client(api_key=GEMINI_API_KEY)
# Test the key by making a lightweight, non-streaming call
gemini_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents="ping",
config={
"max_output_tokens": 1
}
)
gemini_initialized = True
logger.info("Gemini API initialized and validated successfully")
return True
else:
logger.warning("GEMINI_API_KEY not found. Some features will be limited.")
gemini_initialized = False
gemini_client = None
return False
except Exception as e:
logger.error(f"Failed to initialize or validate Gemini API key: {e}")
gemini_initialized = False
gemini_client = None
return False
# Initialize Gemini API on startup
initialize_gemini_api()
# Function to update environment variables and .env file
def update_environment(config_data):
"""Updates environment variables and .env file with new configurations."""
global MYSQL_HOST, MYSQL_USER, MYSQL_PASSWORD, GEMINI_API_KEY
defaults = {
"mysql_host": "localhost",
"mysql_user": "root",
"mysql_password": "root",
"gemini_api_key": "" # Explicitly empty for API key default handling
}
# Update global Python variables and os.environ
# If config_data provides a value, use it. Otherwise, use the default.
MYSQL_HOST = config_data["mysql_host"] if config_data.get("mysql_host") else defaults["mysql_host"]
os.environ["MYSQL_HOST"] = MYSQL_HOST
MYSQL_USER = config_data["mysql_user"] if config_data.get("mysql_user") else defaults["mysql_user"]
os.environ["MYSQL_USER"] = MYSQL_USER
MYSQL_PASSWORD = config_data["mysql_password"] if config_data.get("mysql_password") else defaults["mysql_password"]
os.environ["MYSQL_PASSWORD"] = MYSQL_PASSWORD
# For Gemini API key, allow an empty string from config_data to be set
if "gemini_api_key" in config_data:
GEMINI_API_KEY = config_data["gemini_api_key"]
else:
GEMINI_API_KEY = defaults["gemini_api_key"] # Should not happen if key always in config_data
os.environ["GEMINI_API_KEY"] = GEMINI_API_KEY
if config_data.get("gemini_api_key") or defaults["gemini_api_key"]:
initialize_gemini_api() # Reinitialize if key is set or was previously set and now defaulted
update_env_file() # Call without arguments
logger.info("Environment variables updated with new configuration")
# Function to update .env file
def update_env_file(): # Removed config_data and defaults parameters
"""Updates .env file with the current global configuration values."""
global MYSQL_HOST, MYSQL_USER, MYSQL_PASSWORD, GEMINI_API_KEY, SESSION_SECRET_KEY
try:
env_values_to_write = {
"MYSQL_HOST": MYSQL_HOST,
"MYSQL_USER": MYSQL_USER,
"MYSQL_PASSWORD": MYSQL_PASSWORD,
"GEMINI_API_KEY": GEMINI_API_KEY,
"SESSION_SECRET_KEY": SESSION_SECRET_KEY
}
if not os.path.exists(ENV_FILE_PATH):
with open(ENV_FILE_PATH, "w") as env_file:
for key, value in env_values_to_write.items():
env_file.write(f"{key}={value}\n")
logger.info("Created .env file with new configuration")
else:
try:
with open(ENV_FILE_PATH, "r") as env_file:
lines = env_file.readlines()
except Exception as e:
logger.error(f"Error reading .env file: {e}, recreating file.")
with open(ENV_FILE_PATH, "w") as env_file:
for key, value in env_values_to_write.items():
env_file.write(f"{key}={value}\n")
logger.info("Created new .env file after read failure")
return
updated_lines = []
managed_keys_updated = {key: False for key in env_values_to_write}
for line in lines:
line_strip = line.strip()
if not line_strip or line_strip.startswith("#"): # Preserve comments and blank lines
updated_lines.append(line)
continue
if "=" not in line_strip: # Skip malformed lines
updated_lines.append(line)
continue
key, _ = line_strip.split("=", 1)
key_stripped = key.strip()
if key_stripped in env_values_to_write:
updated_lines.append(f"{key_stripped}={env_values_to_write[key_stripped]}\n")
managed_keys_updated[key_stripped] = True
else:
updated_lines.append(line) # Preserve other unrelated env variables
# Add any of our managed keys that weren't in the file originally
for key, value in env_values_to_write.items():
if not managed_keys_updated[key]:
updated_lines.append(f"{key}={value}\n")
with open(ENV_FILE_PATH, "w") as env_file:
env_file.writelines(updated_lines)
logger.info("Updated .env file with new configuration")
except Exception as e:
logger.error(f"Error updating .env file: {e}")
# --- Database Interaction ---
def get_db_connection(db_name: Optional[str] = None):
"""
Establishes a connection to the MySQL server.
Connects to a specific database if db_name is provided.
Returns the connection object or None if connection fails.
"""
try:
conn_params = {
'host': MYSQL_HOST,
'user': MYSQL_USER,
'password': MYSQL_PASSWORD,
'pool_name': "mypool",
'pool_size': 5,
'auth_plugin': 'mysql_native_password'
}
if db_name:
conn_params['database'] = db_name
conn = mysql.connector.connect(**conn_params)
logger.info(f"DB connection established (Database: {db_name or 'None'})")
return conn
except mysql.connector.Error as err:
logger.error(f"Database connection error (connecting to {db_name or 'server'}): {err}")
return None
def execute_sql_query(query: str) -> Tuple[Optional[List[Any]], Optional[List[str]], Optional[str], int, Optional[str]]:
"""
Executes an SQL query against the database.
Args:
query: The SQL query string to execute.
Returns:
A tuple containing:
- results: List of result tuples (or None).
- column_names: List of column names (or None).
- column_types_str: String describing column names and types (or None).
- status_code: 1 (SELECT/SHOW success), 2 (Other DML/DDL success), 3 (Error).
- error_message: Error details if status_code is 3.
"""
logger.info(f"Executing Query: {query}")
conn = None
cursor = None
results: Optional[List[Any]] = None
column_names: Optional[List[str]] = None
column_types_str: Optional[str] = None
try:
conn = get_db_connection(db_name=None) # Connect WITHOUT specifying a default database
if not conn:
error_message = "SQL Error: Failed to connect to the database server for query execution."
logger.error(error_message)
return None, None, None, 3, error_message
cursor = conn.cursor(buffered=True)
# --- SECURITY WARNING ---
# Executing arbitrary SQL generated by an LLM or user input is a
# significant security risk. In a production environment, you MUST:
# 1. Sanitize and validate the query.
# 2. Use parameterized queries where possible.
# 3. Limit database user permissions (e.g., read-only access).
# 4. Consider query allow-listing or blocking certain commands.
# This example executes the query directly for simplicity, but DO NOT deploy like this.
cursor.execute(query)
query_lower = query.strip().lower()
if query_lower.startswith("select") or query_lower.startswith("show"):
results = cursor.fetchmany(100) # Limit results for display
if cursor.description:
column_names = [i[0] for i in cursor.description]
from mysql.connector.constants import FieldType # Ensure FieldType is imported
col_dtypes = [[i[0], FieldType.get_info(i[1])] for i in cursor.description]
column_types_str = 'Column : Dtype\n' + '\n'.join(f'{k}: {v}' for k, v in col_dtypes)
else:
column_names = ["Result"]
column_types_str = "Column : Dtype\nResult: <unknown>"
if results and isinstance(results[0], (str, int, float, bytes)):
results = [(r,) for r in results] # Wrap single values in tuples
conn.commit() # Necessary even for SELECT with some configurations/engines
result_count = len(results) if results is not None else 0
logger.info(f"Query executed successfully, fetched {result_count} rows.")
return results, column_names, column_types_str, 1, None
else:
conn.commit()
logger.info("Non-SELECT/SHOW query executed successfully.")
return None, None, None, 2, None # Success for non-select queries
except mysql.connector.Error as e:
logger.error(f"SQL Error executing query '{query}': {e}")
error_message = f"SQL Error: {e}"
if conn: # Rollback changes if an error occurs during non-select queries
try:
conn.rollback()
except mysql.connector.Error as rb_err:
logger.error(f"Error during rollback: {rb_err}")
return None, None, None, 3, error_message
except Exception as e:
logger.error(f"Unexpected error executing query '{query}': {e}", exc_info=True)
error_message = f"Unexpected Error: {e}"
if conn: # Rollback changes if an unexpected error occurs
try:
conn.rollback()
except mysql.connector.Error as rb_err:
logger.error(f"Error during rollback: {rb_err}")
return None, None, None, 3, error_message
finally:
if cursor:
cursor.close()
if conn and conn.is_connected(): # Check conn exists and is connected before closing
conn.close()
logger.info("DB connection closed.")
def fetch_all_tables_and_columns() -> Dict[str, Dict[str, Any]]:
"""
Fetches all non-system databases, their tables, and columns.
Returns: Dict[db_name, Dict[table_name, List[column_name]]]
Returns an error structure if connection or queries fail.
"""
schema_info: Dict[str, Dict[str, Any]] = {}
conn = None
cursor = None
system_databases = {'information_schema', 'mysql', 'performance_schema', 'sys'} # Exclude system databases
try:
conn = get_db_connection(db_name=None) # Connect without specifying a database
if not conn:
logger.error("Failed to get DB connection for schema fetching.")
return {"error": {"schema": ["Failed to connect to the database server."]}}
cursor = conn.cursor()
cursor.execute("SHOW DATABASES;") # Get all databases
fetch_result = cursor.fetchall()
if not fetch_result:
logger.warning("No databases returned from SHOW DATABASES query.")
return {}
databases = [str(row[0]) for row in fetch_result if row[0] not in system_databases] # type: ignore
if not databases:
logger.warning("No user databases found.")
return {}
for db_name in databases: # Get tables and columns for each relevant database
schema_info[db_name] = {}
try:
cursor.execute(f"SHOW TABLES FROM `{db_name}`;")
fetch_result = cursor.fetchall()
if fetch_result is None:
logger.warning(f"No tables returned for database {db_name}.")
tables = []
else:
tables = [str(row[0]) for row in fetch_result] # type: ignore
for table_name in tables:
try:
cursor.execute(f"SHOW COLUMNS FROM `{db_name}`.`{table_name}`;")
fetch_result = cursor.fetchall()
if fetch_result is None:
logger.warning(f"No columns returned for table {db_name}.{table_name}.")
columns = []
else:
columns = [str(column[0]) for column in fetch_result] # type: ignore
schema_info[db_name][table_name] = columns
except mysql.connector.Error as e:
logger.warning(f"Could not fetch columns for table {db_name}.{table_name}: {e}")
schema_info[str(db_name)][str(table_name)] = [f"Error fetching columns: {e}"]
except mysql.connector.Error as e:
logger.warning(f"Could not fetch tables for database {db_name}: {e}")
schema_info[str(db_name)] = {"error": [f"Error fetching tables: {e}"]} # Add an error placeholder
logger.info(f"Fetched schema for {len(databases)} databases.")
return schema_info
except mysql.connector.Error as e:
logger.error(f"SQL Error fetching database list: {e}")
return {"error": {"schema": [f"SQL Error fetching databases: {e}"]}}
except Exception as e:
logger.error(f"Error fetching schema: {e}", exc_info=True)
return {"error": {"schema": [f"Unexpected error fetching schema: {str(e)}"]}}
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- Gemini API Interaction ---
# Decorator to check Gemini API initialization
def ensure_gemini_initialized(func):
"""Decorator to ensure Gemini API is initialized before calling the wrapped function."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not gemini_initialized:
logger.warning(f"Gemini API not initialized. Call to {func.__name__} will be skipped.")
# Functions decorated are expected to return a string, so return an error string.
return "Error: Gemini API not configured. Please set up your API key in the configuration."
return func(*args, **kwargs)
return wrapper
@ensure_gemini_initialized
def generate_sql_with_gemini(user_query: str, schema: Dict[str, Dict[str, List[str]]], history: List[Dict[str, Any]]) -> Optional[str]:
"""Generates an SQL query using the Gemini API based on user input and multi-DB schema."""
schema_string = ""
if not schema or "error" in schema:
schema_string = "Could not fetch schema. Please ensure database connection is correct."
else:
for db_name, tables in schema.items():
schema_string += f"\nDatabase: `{db_name}`\n"
if isinstance(tables, dict):
if not tables:
schema_string += " (No tables found or accessible)\n"
elif "error" in tables:
schema_string += f" Error fetching tables: {tables['error']}\n"
else:
for table_name, columns in tables.items():
col_string = ', '.join([f"`{c}`" for c in columns])
schema_string += f" - Table: `{table_name}`: Columns: {col_string}\n"
else:
schema_string += f" Error retrieving table details for this database.\n"
# The system instruction or initial prompt part
system_prompt = f"""You are an expert SQL assistant. Given the following database schema across potentially multiple databases and a user question, generate the most appropriate SQL query to answer the question.
Database Schema:
{schema_string}
User Question: "{user_query}"
Instructions:
- Your **only** task is to generate a single, executable MySQL query to answer the user's question based on the schema.
- **Always** attempt to generate a query. Do not engage in conversation or ask for clarification.
- If the user asks for a relationship between tables, generate a `JOIN` query. When using a `JOIN`, do not use `SELECT *`. Instead, select specific, useful columns from both tables to show the relationship.
- If the user asks to create a table, make reasonable assumptions for column types (e.g., VARCHAR(255) for text, INT for IDs).
- If querying a table, use the fully qualified name (e.g., `database_name`.`table_name`).
- Do NOT include any `USE database_name;` statement. Avoid changing the current database context; always reference tables with the dotted notation as shown above.
- Generate only **one single** SQL statement. Do not include multiple statements or comments.
- Do not include any explanations, introductory text, backticks (```sql), or markdown formatting.
- If the user's request is impossible to answer with a SQL query (e.g., it's a greeting like "hello"), then and only then, respond with the exact text: "Error: This is a conversational query."
SQL Query:"""
# Combine the history with the new system prompt
# The API expects the 'contents' to be a list of these dictionaries.
request_contents = history + [{"role": "user", "parts": [{"text": system_prompt}]}]
try:
# The new SDK uses client.models.generate_content
if not gemini_client:
return "Error: Gemini client not initialized."
response = gemini_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents=request_contents
)
if not hasattr(response, 'text') or not response.text:
logger.warning(f"Gemini returned no text for SQL generation from user query: {user_query}")
return "Error: The AI model did not return a response."
sql_query = response.text.strip() # Clean up potential markdown formatting
if sql_query.startswith("```sql"):
sql_query = sql_query[6:]
if sql_query.endswith("```"):
sql_query = sql_query[:-3]
sql_query = sql_query.strip()
logger.info(f"Gemini generated SQL: {sql_query}")
if sql_query.lower().startswith("error:"):
logger.warning(f"Gemini indicated an error: {sql_query}")
return sql_query
# Basic validation
if not any(kw in sql_query.lower() for kw in ["select", "insert", "update", "delete", "show", "create", "alter", "drop", "use"]):
logger.warning(f"Generated text doesn't look like SQL: {sql_query}")
return "Error: Generated text does not appear to be a valid SQL query."
return sql_query
except Exception as e:
logger.error(f"Error calling Gemini API for SQL generation: {e}", exc_info=True)
return "Error: Failed to communicate with the AI model for SQL generation."
@ensure_gemini_initialized
def get_insights_with_gemini(original_query: str, sql_query: str, results: List[Any], columns: List[str], col_types: str, history: List[Dict[str, Any]]) -> str:
"""Generates insights on the data using the Gemini API."""
if not results:
return "No results to analyze."
results_preview = json.dumps(results[:20], indent=2, default=str) # Limit results sent to Gemini
# The main prompt for this specific task
prompt = f"""You are a data analyst assistant. A user asked the following question:
"{original_query}"
The following SQL query was executed to fetch data:
```sql
{sql_query}
```
The query returned the following data (showing up to 20 rows):
Columns and Types:
{col_types}
Results (JSON format):
{results_preview}
Instructions:
- Provide concise, insightful observations based *only* on the provided data sample.
- Do not invent data or make assumptions beyond what's shown.
- Suggest 1-2 potential follow-up questions or SQL queries the user might be interested in, based on these results and the original question.
- **Crucially, any suggested SQL query MUST be enclosed in its own Markdown fenced code block with the `sql` language identifier.** For example:
```sql
SELECT your_column FROM your_table;
```
- Format your response clearly using Markdown. Start with "### Data Insights" and then "### Suggested Follow-up".
Analysis:"""
# Combine history with the new prompt
request_contents = history + [{"role": "user", "parts": [{"text": prompt}]}]
try:
if not gemini_client:
return "Error: Gemini client not initialized."
response = gemini_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents=request_contents
)
logger.info("Gemini generated insights.")
return response.text if response.text else "No insights could be generated from the data."
except Exception as e:
logger.error(f"Error calling Gemini API for insights: {e}", exc_info=True)
return "Error generating insights from the AI model."
@ensure_gemini_initialized
def get_conversational_response_with_gemini(user_message: str, history: List[Dict[str, Any]]) -> str:
"""Gets a conversational response from Gemini for non-SQL related queries."""
logger.info(f"Getting conversational response for: {user_message}")
# Construct the final prompt for the API call
# NOTE: Updated to make the assistant more context-aware so it can handle
# follow-up questions such as "ok then do it" that implicitly refer to a
# previously suggested SQL query or insight. The assistant is now
# allowed to reference earlier discussion and, when appropriate, suggest
# or execute SQL. This should prevent situations where the model keeps
# asking what the user means by "it".
prompt = f"""You are a helpful assistant specialising in SQL, databases and data analysis. Continue the conversation below, making sure to use the prior context to resolve pronouns or vague references (e.g. what "it" refers to).
Guidelines:
1. If the user implicitly refers to a SQL statement that was suggested earlier (for example saying "ok then do it"), infer that reference from the prior messages and either:
• show the SQL you believe they are referring to and ask for explicit confirmation, **or**
• execute the action if it is clearly safe and previously agreed upon.
2. If the user explicitly asks for or clearly implies a database action, you may include or discuss SQL. Otherwise respond conversationally.
3. If the reference is genuinely ambiguous after considering the context, politely ask for clarification **once** instead of repeatedly.
User Message: "{user_message}"
Assistant Response:"""
# Combine history with the new prompt
request_contents = history + [{"role": "user", "parts": [{"text": prompt}]}]
try:
if not gemini_client:
return "Error: Gemini client not initialized."
response = gemini_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents=request_contents
)
if response.prompt_feedback and response.prompt_feedback.block_reason:
logger.warning(f"Conversational response blocked. Reason: {response.prompt_feedback.block_reason}")
return "I cannot provide a response to that topic."
return response.text.strip() if response.text else "I am unable to provide a response at this time."
except Exception as e:
logger.error(f"Error calling Gemini API for conversational response: {e}", exc_info=True)
return "I'm having trouble responding right now. Please try again later."
@ensure_gemini_initialized
def get_error_explanation_with_gemini(original_user_query: Optional[str], failed_sql_query: str, error_message: str, schema: Optional[Dict[str, Any]] = None, history: List[Dict[str, Any]] = []) -> str:
"""Generates a user-friendly explanation for an SQL error using Gemini."""
prompt_context = f"User's original request (if available): \"{original_user_query}\"\n"
if not original_user_query:
prompt_context = "The user was attempting to execute a specific SQL query.\n"
schema_context = ""
if schema:
schema_string = ""
if not schema or "error" in schema:
schema_string = "Could not fetch schema."
else:
for db_name, tables in schema.items():
schema_string += f"\nDatabase: `{db_name}`\n"
if isinstance(tables, dict):
if not tables:
schema_string += " (No tables found or accessible)\n"
elif "error" in tables:
schema_string += f" Error fetching tables: {tables['error']}\n"
else:
for table_name, columns in tables.items():
col_string = ', '.join([f"`{c}`" for c in columns])
schema_string += f" - Table: `{table_name}`: Columns: {col_string}\n"
else:
schema_string += f" Error retrieving table details for this database.\n"
schema_context = f"""
For context, here is the database schema the query was run against:
{schema_string}
"""
prompt = f"""You are an expert SQL troubleshooting assistant.
The following SQL query failed:
```sql
{failed_sql_query}
```
The database returned this error message:
{error_message}
{prompt_context}
{schema_context}
Instructions:
- Explain the error message in simple, easy-to-understand terms in 5-10 sentences maximum.
- **Use the provided database schema to give a specific, actionable suggestion.** For example, if a column name is wrong, suggest the correct one from the schema.
- What are the common reasons for this specific error?
- What should the user check or try to resolve this issue?
- If the error indicates the table is not insertable (e.g., it's a view), explain what that means.
- Format your response clearly using Markdown. Start with "### AI Troubleshooting Suggestion".
- Do not repeat the SQL query or the raw error message unless it's for specific emphasis within your explanation.
Explanation:"""
# Combine history with the new prompt
request_contents = history + [{"role": "user", "parts": [{"text": prompt}]}]
try:
if not gemini_client:
return "Error: Gemini client not initialized."
response = gemini_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents=request_contents
)
if response.prompt_feedback and response.prompt_feedback.block_reason:
logger.warning(f"Error explanation response blocked. Reason: {response.prompt_feedback.block_reason}")
return "AI explanation could not be generated for this error due to content restrictions."
logger.info("Gemini generated SQL error explanation.")
return response.text.strip() if response.text else "An AI explanation could not be generated for this error."
except Exception as e:
logger.error(f"Error calling Gemini API for SQL error explanation: {e}", exc_info=True)
return "Error generating AI explanation for the SQL error."
def get_query_risk_level(sql_query: str) -> int:
"""
Classifies a query into a risk level based on its type. This is a primary safeguard
to prevent accidental or malicious database structure changes.
Returns:
0: Read-only (safe to execute immediately).
1: Data-modifying (requires user confirmation).
2: Structure-modifying or potentially unsafe (should be blocked).
"""
# Whitelists for different statement types identified by sqlparse
read_only_types = ["SELECT", "USE"]
data_modifying_types = ["INSERT", "UPDATE", "DELETE", "CREATE"]
# Any DDL or DCL is considered structure-modifying and high-risk
structure_modifying_types = ["ALTER", "DROP", "TRUNCATE", "GRANT", "REVOKE", "RENAME"]
# Quick check for common safe commands that sqlparse might label as 'UNKNOWN'
query_upper = sql_query.strip().upper()
if any(query_upper.startswith(kw) for kw in ["SHOW", "DESCRIBE", "EXPLAIN"]):
return 0
try:
parsed_statements = sqlparse.parse(sql_query)
if not parsed_statements:
logger.warning("Could not parse SQL query. Flagging as high-risk for safety.")
return 2 # Fail-closed: If parsing fails, assume it's unsafe.
# If there are multiple statements, we block it for safety to prevent complex attacks.
if len(parsed_statements) > 1:
logger.warning(f"Detected multiple SQL statements. Blocking for security: {sql_query}")
return 2
if not parsed_statements or len(parsed_statements) == 0:
logger.warning("No SQL statements parsed. Flagging as high-risk for safety.")
return 2
try:
stmt = parsed_statements[0]
stmt_type = stmt.get_type()
stmt_upper = str(stmt).strip().upper()
except Exception as e:
logger.error(f"Error accessing parsed SQL statement: {e}. Flagging as high-risk.")
return 2
if stmt_type in structure_modifying_types:
logger.warning(f"Detected structure-modifying query (by type '{stmt_type}'). Blocking.")
return 2
if stmt_type in data_modifying_types:
return 1
if stmt_type in read_only_types:
# Secondary check for dangerous clauses sometimes found in SELECT statements
if "INTO OUTFILE" in stmt_upper or "INTO DUMPFILE" in stmt_upper:
logger.warning(f"Detected potentially harmful 'SELECT...INTO' clause. Blocking.")
return 2
return 0
# If the type is UNKNOWN or something else not on our lists, we block it.
logger.warning(f"Detected query with unclassified type '{stmt_type}'. Blocking for safety.")
return 2
except Exception as e:
logger.error(f"Failed to parse SQL query for security check: {e}. Flagging as high-risk.")
return 2 # If any other exception occurs during parsing, fail-closed.
# --- Pydantic Models ---
# Helper to classify whether a message is conversational-only (does not request new SQL).
# NOTE: Placed here so it is available when the /chat endpoint is defined below.
def _looks_conversational_only(message: str) -> bool:
"""Determine if the user message looks like a conversational or explanatory request.
The goal is to let the assistant stay in conversational mode when the user says
things such as "explain the above" or "don't write SQL".
"""
msg = message.strip().lower()
if msg.startswith(("hello", "hi", "hey", "thanks", "thank you", "ok")):
return True
conversational_markers = [
"explain", "insight", "insights", "what does", "meaning", "dont write sql",
"don't write sql", "not in sql", "just explain", "only explain", "explain above",
"explain the above"
]
if any(marker in msg for marker in conversational_markers):
return True
# NOTE: Previously any message ending with a question-mark and lacking obvious SQL keywords was
# treated as conversational. This heuristic was too aggressive and caused genuine NLQ requests
# (that naturally end with a question-mark) to be mis-classified. It has been removed so that
# only explicit conversational markers trigger conversational handling.
return False
class ChatRequest(BaseModel):
message: str
class ConfigRequest(BaseModel):
mysql_host: str
mysql_user: str
mysql_password: str
gemini_api_key: str
class PublicConfig(BaseModel):
mysql_host: str
mysql_user: str
mysql_password_set: bool
gemini_api_key_set: bool
class ConfirmedExecutionRequest(BaseModel):
query: str
# --- Static Files with Caching ---
class CachingStaticFiles(StaticFiles):
def __init__(self, *args, max_age: int = 31536000, **kwargs):
super().__init__(*args, **kwargs)
self.max_age = max_age
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
if response.status_code == 200:
response.headers.setdefault("Cache-Control", f"public, max-age={self.max_age}, immutable")
return response
# --- FastAPI Application ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages application startup and shutdown events.
On startup, it initializes a new session for the root endpoint,
as it won't have a cookie yet. This is a workaround to ensure
the very first visit gets a session.
"""
# This is a workaround to ensure the very first visit gets a session.
session_id = uuid.uuid4()
initial_data = SessionData()
# Pre-populate the first session with a welcome message - REMOVED
# add_to_history(initial_data, "model", "Hello! I'm your SQL assistant. How can I help you with your databases today?") - REMOVED
await session_backend.create(session_id, initial_data)
# The frontend will receive this session_id via the response from "/"
# We need a way to pass this to the root response. A simple global might suffice for this narrow case.
# A better approach might involve a middleware that creates sessions if they don't exist.
app.state.initial_session_id = session_id
yield
# No shutdown logic needed for now
app = FastAPI(title="SQL Assistant with Gemini", lifespan=lifespan)
# Mount static files using the caching-enabled subclass so that browsers can cache assets effectively.
app.mount("/static", CachingStaticFiles(directory="static", max_age=31536000), name="static")
templates = Jinja2Templates(directory=".") # Expect index.html in the root directory
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
"""
Serves the main HTML page and ensures a session cookie is set
for the first-time visitor.
"""
response = templates.TemplateResponse("index.html", {"request": request})
# Robust session handling:
# 1. If there is no cookie, create a new session.
# 2. If there is a cookie but the session is missing in the backend (e.g., after a server reboot),
# create a new session and overwrite the stale cookie so that subsequent requests are valid.
create_fresh_session = False
existing_sid_str = request.cookies.get("session_id")
if existing_sid_str:
try:
existing_sid = uuid.UUID(existing_sid_str)
# Check if this session actually exists in the backend
existing_session = await session_backend.read(existing_sid)
if existing_session is None:
# Stale / unknown session – we need a fresh one
create_fresh_session = True
except Exception:
# Malformed UUID or other issue – issue a fresh session
create_fresh_session = True
else:
# No cookie at all
create_fresh_session = True
if create_fresh_session:
new_session_id = uuid.uuid4()
await session_backend.create(new_session_id, SessionData())
cookie.attach_to_response(response, new_session_id)
return response
@app.get("/config_status", response_model=PublicConfig)
async def get_config_status():
"""API endpoint to fetch the current, non-sensitive configuration status."""
return PublicConfig(
mysql_host=MYSQL_HOST,
mysql_user=MYSQL_USER,
mysql_password_set=bool(MYSQL_PASSWORD), # True if password is not an empty string
gemini_api_key_set=bool(GEMINI_API_KEY) # True if key is not an empty string
)
@app.get("/schema", response_class=JSONResponse)
async def get_schema():
"""API endpoint to fetch the current database schema."""
schema = fetch_all_tables_and_columns()
if "error" in schema:
# Returning 200 but with error content for client-side handling
return JSONResponse(content={"schema": schema}, status_code=200)
return JSONResponse(content={"schema": schema})
@app.post("/reset_chat", response_class=JSONResponse)
async def reset_chat(session_id: uuid.UUID = Depends(cookie)):
"""API endpoint to clear the server-side chat history for the current session."""
# Create a new, empty session data object
new_session_data = SessionData()
# Replace the old session data with the new empty one
await session_backend.update(session_id, new_session_data)
logger.info(f"Chat history for session {session_id} has been reset.")
return JSONResponse(content={"status": "success", "message": "Chat history has been reset."})
@app.post("/config", response_class=JSONResponse)
async def update_config(config_request: ConfigRequest):
"""Updates application configuration and tests connections."""
try:
mysql_host = config_request.mysql_host if config_request.mysql_host else "localhost"
mysql_user = config_request.mysql_user if config_request.mysql_user else "root"
mysql_password = config_request.mysql_password if config_request.mysql_password else "root"
gemini_api_key = config_request.gemini_api_key if config_request.gemini_api_key else ""
config_data = {
"mysql_host": mysql_host,
"mysql_user": mysql_user,
"mysql_password": mysql_password,
"gemini_api_key": gemini_api_key
}
mysql_status = "failed: unknown"
try: # Test MySQL connection
conn = mysql.connector.connect(
host=mysql_host,
user=mysql_user,
password=mysql_password,
auth_plugin='mysql_native_password'
)
conn.close()
mysql_status = "success"
except mysql.connector.Error as err:
logger.error(f"Failed to connect with new MySQL credentials: {err}")
mysql_status = f"failed: {err}"
gemini_status = "failed: unknown"
original_key = GEMINI_API_KEY
try: # Test Gemini API key
if gemini_api_key:
# Use the new SDK's client for testing
test_client = genai.Client(api_key=gemini_api_key)
test_response = test_client.models.generate_content(
model=GEMINI_MODEL_NAME,
contents="ping",
config={
"max_output_tokens": 1
}
)
if test_response.text:
gemini_status = "success"
else:
gemini_status = "failed: no response"
else:
gemini_status = "not_provided"
except Exception as e:
logger.error(f"Failed to initialize Gemini API with new key: {e}")
gemini_status = f"failed: {e}"
# Restore the original key if test failed