-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmerchant.py
More file actions
266 lines (215 loc) · 8.99 KB
/
merchant.py
File metadata and controls
266 lines (215 loc) · 8.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#!/usr/bin/env python3
"""
merchant_cli_option_c.py
Option C (strict): LLM only generates SQL, Python executes SQL, LLM formats answer from SQL results.
Includes forecasting flow.
"""
import os
import re
import sys
import json
import psycopg2
import psycopg2.extras
from dotenv import load_dotenv
from typing import Any, List, Dict
# ---- Replace / adapt this import to your LLM client if necessary ----
from langchain_deepseek import ChatDeepSeek # used earlier in your project
# --------------------------------------------------------------------
load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not DATABASE_URL:
raise Exception("Please set DATABASE_URL in your .env file")
# --------- Initialize LLM (adapt parameters as needed) ----------
model = ChatDeepSeek(
model="x-ai/grok-4.1-fast:free",
api_key=OPENROUTER_API_KEY,
api_base="https://openrouter.ai/api/v1",
temperature=0.0,
extra_body={
"reasoning": {"enabled": True},
"search_parameters": {"mode": "auto", "max_search_results": 2},
},
)
# ---------------------------------------------------------------
# ---------- DB helpers (psycopg2) -----------
def get_db_conn():
"""Return a new psycopg2 connection. Caller must close it."""
return psycopg2.connect(DATABASE_URL)
def set_rls_merchant(conn, merchant_id: int):
"""Set the session variable so RLS policy applies."""
with conn.cursor() as cur:
cur.execute("SET app.current_merchant_id = %s;", (merchant_id,))
conn.commit()
def execute_sql(conn, sql: str) -> List[Dict[str, Any]]:
"""
Execute SQL and return list of rows as dicts.
WARNING: This executes arbitrary SQL generated by the LLM. We rely on query-generation rules.
"""
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
# Try fetchall safely; some queries may be aggregates returning one row
try:
rows = cur.fetchall()
except psycopg2.ProgrammingError:
rows = []
return rows
# ---------- Utility ----------
FORECAST_KEYWORDS = [
"forecast", "predict", "prediction", "will", "trend", "project", "forecasting",
"increase", "decrease", "next month", "next week", "next year", "future"
]
def is_forecast_question(text: str) -> bool:
t = text.lower()
return any(k in t for k in FORECAST_KEYWORDS)
def sanitize_merchant_id(value: str) -> int:
try:
mid = int(value)
if mid < 0:
raise ValueError
return mid
except Exception:
raise ValueError("merchant_id must be a non-negative integer")
# ---------- Prompt templates ----------
SQL_GEN_PROMPT = """
You are a strict SQL generator. DO NOT output any natural language—ONLY output a single valid SQL SELECT query, and nothing else.
Rules:
- Use the `orders` table and any other merchant tables only if necessary.
- ALWAYS filter by merchant_id using the RLS session variable (do NOT hardcode merchant_id in the SQL).
*Example of correct filtering: WHERE merchant_id = current_setting('app.current_merchant_id')::int
- Do not attempt to show schema or column names unless the user explicitly asked for that exact column.
- If the user question asks for aggregates (counts, sums, averages), produce the appropriate SQL aggregate.
- For date ranges, use PostgreSQL functions (e.g., DATE_TRUNC).
- Limit results to at most 1000 rows.
- If the user's question is ambiguous about time ranges, generate a conservative SQL (e.g., last 30 days) not centuries of data.
- Do NOT include multiple statements; output exactly one SELECT query ending with a semicolon.
User question:
{question}
If this is a forecasting request, still generate SQL that returns the historical data needed for forecasting (e.g., daily totals for the last N months).
"""
FORMATTER_PROMPT = """
You are an assistant that converts SQL execution results into a user-friendly answer.
You MUST use ONLY the SQL results provided below — no external knowledge, no hallucination.
If the SQL result set is empty, reply exactly:
"No data available in your database."
If this was a forecasting request, produce a short forecast (1-3 sentences) based on the historical numbers provided and describe assumptions briefly.
Do NOT mention RLS, schema, or how the data was fetched.
SQL Results (JSON):
{results_json}
Original user question:
{question}
Provide a concise, factual answer in plain language.
"""
# ---------- LLM call wrappers (adjust parsing per SDK) ----------
def llm_invoke(prompt: str) -> str:
"""
Call the LLM and return its textual output.
Adjust this if your model SDK uses a different return shape.
"""
# Many SDKs return an object; adapt as needed.
# For ChatDeepSeek the method used earlier in your code was `invoke`.
resp = model.invoke(prompt)
# Try several common fields for returned text:
text = None
if hasattr(resp, "content"):
# content might be a string or list
text = resp.content if isinstance(resp.content, str) else str(resp.content)
elif hasattr(resp, "text"):
text = resp.text
else:
# fallback
text = str(resp)
return text.strip()
def extract_sql_only(llm_output: str) -> str:
"""
Extract SQL from the LLM output. We enforce that output contains a SELECT ...;
This will try to extract the first SQL-looking block.
"""
# Try to find text between triple backticks or code fences
m = re.search(r"```(?:sql)?\s*(SELECT[\s\S]*?);?\s*```", llm_output, re.IGNORECASE)
if m:
return m.group(1).strip() + ";"
# Otherwise try to extract first SELECT ...;
m = re.search(r"(SELECT[\s\S]*?;)", llm_output, re.IGNORECASE)
if m:
return m.group(1).strip()
# If nothing found, but the whole output looks like SQL starting with SELECT:
if llm_output.strip().upper().startswith("SELECT"):
if not llm_output.strip().endswith(";"):
return llm_output.strip() + ";"
return llm_output.strip()
raise ValueError("LLM did not return a valid SELECT query. Got:\n" + llm_output[:1000])
# ---------- Main logic ----------
def handle_user_question(conn, question: str):
# Step 1: Force LLM to generate SQL only
sql_prompt = SQL_GEN_PROMPT.format(question=question)
llm_sql_out = llm_invoke(sql_prompt)
try:
sql = extract_sql_only(llm_sql_out)
except ValueError as e:
return "Error: LLM did not return a valid SQL query. Please rephrase your question."
# Extra safety: disallow destructive statements (just in case)
if re.search(r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE|ALTER|CREATE)\b", sql, re.IGNORECASE):
return "Refusing to run non-SELECT statements."
# Step 2: Execute SQL
try:
rows = execute_sql(conn, sql)
except Exception as e:
return f"SQL execution error: {e}"
# Step 3: Format results via LLM (only to beautify results / forecast)
results_json = json.dumps(rows, default=str)
formatter_prompt = FORMATTER_PROMPT.format(results_json=results_json, question=question)
formatted_answer = llm_invoke(formatter_prompt)
# If SQL returned no rows, we must obey the rule and return exact phrase
if not rows:
return "No data available in your database."
return formatted_answer
def interactive_mode():
print("Merchant Intelligence Agent CLI — Option C (strict SQL pipeline)")
print("--------------------------------------------------------------")
merchant_input = input("Enter Merchant ID: ").strip()
try:
merchant_id = sanitize_merchant_id(merchant_input)
except ValueError as e:
print("Invalid merchant id:", e)
return
# Open a DB connection for the session
conn = get_db_conn()
try:
set_rls_merchant(conn, merchant_id)
except Exception as e:
conn.close()
print("Failed to set merchant scope:", e)
return
print(f"Merchant scope set to {merchant_id}. Type 'exit' to quit.\n")
try:
while True:
q = input("> ").strip()
if not q:
continue
if q.lower() in ("exit", "quit"):
break
# Quick heuristic: if it's a forecasting question, ask LLM to return historical series
# The SQL generator prompt already asks for historical series if forecasting was requested.
answer = handle_user_question(conn, q)
print("\n" + answer + "\n")
finally:
conn.close()
if __name__ == "__main__":
# Allow passing merchant_id + question via CLI
if len(sys.argv) > 2:
try:
merchant_id = sanitize_merchant_id(sys.argv[1])
except ValueError as e:
print("Invalid merchant id:", e)
sys.exit(1)
question = " ".join(sys.argv[2:])
conn = get_db_conn()
try:
set_rls_merchant(conn, merchant_id)
out = handle_user_question(conn, question)
print(out)
finally:
conn.close()
else:
interactive_mode()