-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmistral.py
More file actions
152 lines (132 loc) · 4.84 KB
/
mistral.py
File metadata and controls
152 lines (132 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import psycopg2
import requests
import os
from decimal import Decimal
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Mistral API Configuration
API_URL = "https://api-inference.huggingface.co/models/mistral-model-name" # Replace with the specific Mistral model URL
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
headers = {
"Authorization": f"Bearer {HUGGINGFACE_API_KEY}",
"Content-Type": "application/json"
}
# Database configuration
DB_CONFIG = {
"dbname": os.getenv("DB_NAME"),
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"host": os.getenv("DB_HOST"),
"port": os.getenv("DB_PORT"),
}
conn = None
cur = None
def get_connection():
"""Establish the database connection."""
global conn, cur
if conn is None or cur is None:
conn = psycopg2.connect(**DB_CONFIG)
cur = conn.cursor()
return conn, cur
def close_connection():
"""Close the database connection."""
global conn, cur
if cur:
cur.close()
if conn:
conn.close()
def get_schema():
"""Fetch database schema."""
try:
conn, cur = get_connection()
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
""")
tables = cur.fetchall()
schema_text = "CREATE TABLE statements:\n\n"
for table in tables:
table_name = table[0]
cur.execute(f"""
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
CASE
WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'PRIMARY KEY'
ELSE ''
END as key_type
FROM information_schema.columns c
LEFT JOIN information_schema.table_constraints tc
ON c.table_name = tc.table_name
AND tc.constraint_type = 'PRIMARY KEY'
WHERE c.table_name = '{table_name}'
ORDER BY c.ordinal_position;
""")
columns = cur.fetchall()
create_table = f"CREATE TABLE {table_name} (\n"
for col in columns:
name, data_type, nullable, default, key_type = col
create_table += f" {name} {data_type}"
if key_type:
create_table += f" {key_type}"
if nullable == 'NO':
create_table += " NOT NULL"
if default:
create_table += f" DEFAULT {default}"
create_table += ",\n"
create_table = create_table.rstrip(",\n") + "\n);\n\n"
schema_text += create_table
return schema_text
except Exception as e:
print(f"Error fetching schema: {e}")
return None
def get_sql_query(user_query, schema):
"""Get SQL query from Mistral model."""
try:
prompt = f"""### PostgreSQL database schema:
{schema}
### Instructions:
- Generate a single, correct SQL query that answers the user's question.
- Include appropriate table aliases, JOIN conditions, and WHERE clauses.
- Provide the SQL query only, with no additional explanations or questions.
### User Query: {user_query}
### SQL Query:"""
payload = {"inputs": prompt}
# Send the POST request to the Hugging Face Mistral API
response = requests.post(API_URL, headers=headers, json=payload)
response_data = response.json()
if "error" in response_data:
print(f"Error from API: {response_data['error']}")
return None
# Extract the SQL query from the response
sql_query = response_data.get("generated_text", "").strip()
# Remove any surrounding code block formatting (if any)
if sql_query.startswith('```sql') and sql_query.endswith('```'):
sql_query = sql_query[7:-3].strip()
if not sql_query.endswith(';'):
sql_query += ';'
return sql_query
except Exception as e:
print(f"Error getting SQL query: {e}")
return None
def execute_query(sql_query):
"""Execute SQL query and return results."""
try:
conn, cur = get_connection()
cur.execute(sql_query)
if sql_query.strip().upper().startswith('SELECT'):
columns = [desc[0] for desc in cur.description]
results = []
for row in cur.fetchall():
row_dict = {columns[i]: float(row[i]) if isinstance(row[i], Decimal) else row[i] for i in range(len(row))}
results.append(row_dict)
else:
results = {"message": "Query executed successfully"}
return results
except Exception as e:
print(f"Error executing query: {e}")
return None