-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
81 lines (64 loc) · 2.32 KB
/
main.py
File metadata and controls
81 lines (64 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Main entry point for the World Bank SQL Agent application
"""
import os
from config import Config
from utils import setup_nltk, load_indicator_data
from few_shot_selector import FewShotSelector
from indicator_search import IndicatorSearch
from agent import WorldBankAgent
def main(user_query: str):
"""
Main function to process user queries
Args:
user_query: Natural language query from user
"""
# Validate configuration
Config.validate()
# Setup NLTK
setup_nltk()
# Load indicator data
print("Loading indicator data...")
indicator_df = load_indicator_data(Config.DATABASE_PATH)
# Initialize components
print("Initializing components...")
few_shot_selector = FewShotSelector()
indicator_search = IndicatorSearch(indicator_df)
agent = WorldBankAgent()
# Select relevant few-shot examples
print("Selecting few-shot examples...")
selected_examples = few_shot_selector.select_examples(user_query)
formatted_examples = FewShotSelector.format_examples(selected_examples)
# Search for relevant indicators
print("Searching for relevant indicators...")
indicator_ids = indicator_search.search(user_query, top_n=Config.TOP_N_INDICATORS)
# Augment query with indicator IDs if found
augmented_query = user_query
if indicator_ids:
augmented_query += (
f" . **You should use relevant indicator_id's from these based on their description** "
f"{{ indicator_id = {indicator_ids} }}"
)
# Create agent with few-shot examples
print("Creating SQL agent...")
agent.create_agent(formatted_examples)
# Execute query
print("Executing query...")
response = agent.query_with_tokens(augmented_query)
# Generate summary
print("Generating summary...")
summary = agent.generate_summary(user_query, response)
# Print results
print("\n" + "="*80)
print(f"User Query: {user_query}")
print("="*80)
print(f"\nResponse Summary:\n{summary}")
print("="*80)
return summary
if __name__ == "__main__":
# Example query
query = (
"Are citizens in developed countries more aware of their rights under "
"financial consumer protection laws compared to those in developing nations?"
)
main(query)