-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpurchase_behavior_app.py
More file actions
171 lines (147 loc) · 7.4 KB
/
purchase_behavior_app.py
File metadata and controls
171 lines (147 loc) · 7.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import streamlit as st
import pandas as pd
import os
import subprocess
import sys
import traceback
from dotenv import load_dotenv
from sql_agent import SQLAgent
from llm_client import LLMClient
# Load environment variables
load_dotenv()
# Page configuration
st.set_page_config(page_title="SQL Agent - User Purchase Behavior", layout="wide")
# Sidebar
st.sidebar.title("SQL Agent")
st.sidebar.info(
"This application uses an LLM to analyze user purchase behavior "
"and predict future purchases based on natural language queries."
)
# Database file path
DB_PATH = "sales_database.db"
# Check if database exists, create it if it doesn't
if not os.path.exists(DB_PATH):
st.sidebar.warning("Database not found. Setting up the database...")
try:
subprocess.run([sys.executable, "setup_database.py"], check=True)
subprocess.run([sys.executable, "add_user_purchase_data.py"], check=True)
st.sidebar.success("Database setup complete!")
except subprocess.CalledProcessError as e:
st.sidebar.error(f"Error setting up database: {str(e)}")
# LLM Client configuration
llm_url = st.sidebar.text_input("LLM API URL", value="http://127.0.0.1:1234")
st.sidebar.caption("Make sure your LLM Studio is running at this URL")
# Set number of retries
max_retries = st.sidebar.slider("Max SQL improvement attempts", min_value=1, max_value=5, value=3)
# Check if the LLM endpoint is reachable
import requests
try:
response = requests.get(f"{llm_url}/v1/models", timeout=2)
if response.status_code == 200:
models = response.json().get('data', [])
model_names = [model.get('id') for model in models]
if 'deepseek-r1-distill-qwen-14b' in model_names:
st.sidebar.success(" deepseek-r1-distill-qwen-14b model is available")
else:
st.sidebar.warning(" deepseek-r1-distill-qwen-14b model not found. Available models: " + ", ".join(model_names))
else:
st.sidebar.error(f"LLM endpoint returned status code {response.status_code}")
except requests.exceptions.RequestException:
st.sidebar.error("LLM endpoint is not reachable. Make sure LLM Studio is running.")
# Initialize the LLM client and SQL Agent
try:
llm_client = LLMClient(base_url=llm_url)
sql_agent = SQLAgent(DB_PATH, llm_client, max_retries=max_retries)
except Exception as e:
st.error(f"Error initializing SQL Agent: {str(e)}")
st.stop()
# Main content
st.title("User Purchase Behavior Analysis")
st.subheader("Ask questions about user purchase behavior in natural language")
# Example queries
st.caption("Example queries:")
examples = [
"Show me customers who have purchased products but not services",
"List users with high likelihood of purchasing services",
"Find customers who haven't purchased anything but have high purchase likelihood",
"Show me the top 5 customers most likely to purchase both products and services",
"Which customers had their last interaction in the past 30 days and have a high purchase likelihood?"
]
for ex in examples:
if st.button(ex):
st.session_state.query = ex
# Input box for the query
query = st.text_area("Enter your query:", height=100, key="query")
# Submit button
if st.button("Run Query"):
if not query:
st.warning("Please enter a query.")
else:
with st.spinner("Processing query..."):
try:
# Process the query
response = sql_agent.process_query(query)
# Display the generated SQL
st.subheader("Generated SQL:")
st.code(response.generated_sql, language="sql")
# Display the explanation
st.subheader("Explanation:")
st.write(response.explanation)
# Display improvement history if available
if response.improvement_history:
with st.expander("SQL Improvement Attempts", expanded=True):
for attempt in response.improvement_history:
st.markdown(f"**Attempt {attempt.attempt}**")
st.code(attempt.sql, language="sql")
st.error(f"Error: {attempt.error}")
st.markdown("---")
# Display the results or error
if response.error:
st.error(f"Error executing SQL: {response.error}")
elif response.query_result:
st.subheader("Results:")
if response.query_result.row_count == 0:
st.info("Query returned no results.")
else:
# Convert to DataFrame for better display
df = pd.DataFrame(
response.query_result.rows,
columns=response.query_result.columns
)
st.dataframe(df)
st.caption(f"Returned {response.query_result.row_count} rows")
# Option to download results as CSV
csv = df.to_csv(index=False)
st.download_button(
label="Download results as CSV",
data=csv,
file_name="query_results.csv",
mime="text/csv"
)
# Visualize data if it contains purchase likelihood
if 'purchase_likelihood' in df.columns or 'service_purchase_likelihood' in df.columns:
st.subheader("Visualization:")
if 'purchase_likelihood' in df.columns and 'service_purchase_likelihood' in df.columns:
chart_data = df[['purchase_likelihood', 'service_purchase_likelihood']]
if 'customer_name' in df.columns or 'name' in df.columns:
chart_data.index = df['customer_name'] if 'customer_name' in df.columns else df['name']
st.bar_chart(chart_data)
elif 'purchase_likelihood' in df.columns:
chart_data = df[['purchase_likelihood']]
if 'customer_name' in df.columns or 'name' in df.columns:
chart_data.index = df['customer_name'] if 'customer_name' in df.columns else df['name']
st.bar_chart(chart_data)
elif 'service_purchase_likelihood' in df.columns:
chart_data = df[['service_purchase_likelihood']]
if 'customer_name' in df.columns or 'name' in df.columns:
chart_data.index = df['customer_name'] if 'customer_name' in df.columns else df['name']
st.bar_chart(chart_data)
except Exception as e:
st.error(f"Error processing query: {str(e)}")
st.error(traceback.format_exc())
# Database schema information
with st.expander("View Database Schema"):
st.code(sql_agent.schema_info)
# Footer
st.sidebar.markdown("---")
st.sidebar.caption("SQL Agent powered by LLM Studio using deepseek-r1-distill-qwen-14b")