-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgemini2.py
More file actions
145 lines (130 loc) · 4.73 KB
/
gemini2.py
File metadata and controls
145 lines (130 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import psycopg2
import requests
import os
from decimal import Decimal
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Gemini 1.5 Flash Configuration
API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
headers = {"Content-Type": "application/json"}
# Database configuration
DB_CONFIG = {
"dbname": os.getenv("DB_NAME"),
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"host": os.getenv("DB_HOST"),
"port": os.getenv("DB_PORT"),
}
conn = None
cur = None
def get_connection():
"""Establish the database connection."""
global conn, cur
if conn is None or cur is None:
conn = psycopg2.connect(**DB_CONFIG)
cur = conn.cursor()
return conn, cur
def close_connection():
"""Close the database connection."""
global conn, cur
if cur:
cur.close()
if conn:
conn.close()
def get_schema():
"""Fetch database schema."""
try:
conn, cur = get_connection()
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
""")
tables = cur.fetchall()
schema_text = "CREATE TABLE statements:\n\n"
for table in tables:
table_name = table[0]
cur.execute(f"""
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
CASE
WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'PRIMARY KEY'
ELSE ''
END as key_type
FROM information_schema.columns c
LEFT JOIN information_schema.table_constraints tc
ON c.table_name = tc.table_name
AND tc.constraint_type = 'PRIMARY KEY'
WHERE c.table_name = '{table_name}'
ORDER BY c.ordinal_position;
""")
columns = cur.fetchall()
create_table = f"CREATE TABLE {table_name} (\n"
for col in columns:
name, data_type, nullable, default, key_type = col
create_table += f" {name} {data_type}"
if key_type:
create_table += f" {key_type}"
if nullable == 'NO':
create_table += " NOT NULL"
if default:
create_table += f" DEFAULT {default}"
create_table += ",\n"
create_table = create_table.rstrip(",\n") + "\n);\n\n"
schema_text += create_table
return schema_text
except Exception as e:
print(f"Error fetching schema: {e}")
return None
def get_sql_query(user_query, schema):
"""Get SQL query from Gemini 1.5 Flash."""
try:
prompt = f"""### PostgreSQL database schema:
{schema}
### Instructions:
- Generate a single, correct SQL query that answers the user's question.
- Include appropriate table aliases, JOIN conditions, and WHERE clauses.
- Provide the SQL query only, with no additional explanations or questions.
### User Query: {user_query}
### SQL Query:"""
payload = {
"contents": [{
"parts": [{"text": prompt}]
}]
}
response = requests.post(f"{API_URL}?key={GEMINI_API_KEY}", headers=headers, json=payload)
response_data = response.json()
if 'candidates' not in response_data or not response_data['candidates']:
print("No candidates in response")
return None
sql_query = response_data['candidates'][0]['content']['parts'][0]['text'].strip()
if sql_query.startswith('```sql') and sql_query.endswith('```'):
sql_query = sql_query[7:-3].strip()
if not sql_query.endswith(';'):
sql_query += ';'
return sql_query
except Exception as e:
print(f"Error getting SQL query: {e}")
return None
def execute_query(sql_query):
"""Execute SQL query and return results."""
try:
conn, cur = get_connection()
cur.execute(sql_query)
if sql_query.strip().upper().startswith('SELECT'):
columns = [desc[0] for desc in cur.description]
results = []
for row in cur.fetchall():
row_dict = {columns[i]: float(row[i]) if isinstance(row[i], Decimal) else row[i] for i in range(len(row))}
results.append(row_dict)
else:
results = {"message": "Query executed successfully"}
return results
except Exception as e:
print(f"Error executing query: {e}")
return None