-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
86 lines (66 loc) · 2.77 KB
/
main.py
File metadata and controls
86 lines (66 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys
from db_manager import DBManager
from llm_client import LLMClient
import argparse
def main():
parser = argparse.ArgumentParser(description="NaturalQuery - Text to SQL")
parser.add_argument("--execute", action="store_true", help="Execute the generated SQL query against the database.")
args = parser.parse_args()
print("Initializing NaturalQuery System...")
try:
db = DBManager()
llm = LLMClient()
except Exception as e:
print(f"Initialization Error: {e}")
return
print("Connected to database successfully.")
if not args.execute:
print("Running in DRY-RUN mode. Queries will NOT be executed. Use --execute to run them.")
else:
print("WARNING: Execution mode ENABLED. Queries will be executed against the database.")
print("Type 'exit' to quit.\n")
while True:
user_input = input("Ask a question: ")
if user_input.lower() in ['exit', 'quit']:
break
if not user_input.strip():
continue
print("\nAnalyze schema and generating SQL...")
try:
# 1. Get Schema
schema = db.get_schema_context()
# 2. Generate SQL
sql_query = llm.generate_sql(schema, user_input)
print(f"Generated SQL: {sql_query}")
# Check for validity before executing (basic check)
if "SELECT" not in sql_query.upper():
print("Generated text does not look like a SELECT query.")
if args.execute:
print("Skipping execution due to safety check.")
continue
# 3. Execute SQL (Only if flag is set)
if args.execute:
print("Executing query...")
results = db.execute_query(sql_query)
if "error" in results:
print(f"Database Error: {results['error']}")
elif "message" in results:
print(results['message'])
else:
# Pretty print results
columns = results['columns']
rows = results['data']
print(f"\nResults ({len(rows)} rows):")
print(" | ".join(columns))
print("-" * (len(" | ".join(columns))))
for row in rows:
print(" | ".join(map(str, row)))
else:
print("[Dry Run] Query not executed.")
print("\n" + "="*30 + "\n")
except Exception as e:
print(f"An error occurred: {e}")
db.close()
print("Goodbye.")
if __name__ == "__main__":
main()