-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexample.py
More file actions
96 lines (85 loc) · 3.36 KB
/
example.py
File metadata and controls
96 lines (85 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import logging
import os
import sys
import uuid
import pandas as pd
from camel.embeddings import OpenAIEmbedding
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType
from colorama import Fore
from tabulate import tabulate
from camel_database_agent import DatabaseAgent
from camel_database_agent.database.manager import DatabaseManager
from camel_database_agent.database_base import TrainLevel
# Configure logging settings to show errors on stdout
logging.basicConfig(
level=logging.ERROR,
format="%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
# Set specific logging level for the application module
logging.getLogger("camel_database_agent").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
# Configure pandas display options to show complete data
pd.set_option("display.max_rows", None) # Show all rows
pd.set_option("display.max_columns", None) # Show all columns
pd.set_option("display.width", None) # Auto-detect display width
pd.set_option("display.max_colwidth", None) # Show full content of each cell
# Define database connection string
database_url = "sqlite:///database/sqlite/music.sqlite"
# Initialize the database agent with required components
database_agent = DatabaseAgent(
interactive_mode=True,
database_manager=DatabaseManager(db_url=database_url),
# Configure LLM model
model=ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
api_key=os.getenv("OPENAI_API_KEY"),
url=os.getenv("OPENAI_API_BASE_URL"),
),
# Configure embedding model
embedding_model=OpenAIEmbedding(
api_key=os.getenv("OPENAI_API_KEY"),
url=os.getenv("OPENAI_API_BASE_URL"),
),
)
# Train agent's knowledge about the database schema
database_agent.train_knowledge(
# Training level for database knowledge extraction
# MEDIUM level: Balances training time and knowledge depth by:
# - Analyzing schema relationships
# - Extracting representative sample data
# - Generating a moderate number of query examples
level=TrainLevel.MEDIUM,
# Whether to retrain the knowledge base from scratch
# If True: Forces regeneration of all database insights and examples
# If False: Uses existing cached knowledge if available
reset_train=False,
)
# Display database overview information
print(f"{Fore.GREEN}Database Overview")
print("=" * 50)
print(f"{database_agent.get_summary()}\n\n{Fore.RESET}")
# Display recommended example questions
print(f"{Fore.GREEN}Recommendation Question")
print("=" * 50)
print(f"{database_agent.get_recommendation_question()}\n\n{Fore.RESET}")
# Execute a sample query using natural language
response = database_agent.ask(
session_id=str(uuid.uuid4()), question="List all playlists with more than 5 tracks"
)
# Handle and display the query results
if response.success:
if response.dataset is not None:
# Format successful results as a table
data = tabulate(tabular_data=response.dataset, headers='keys', tablefmt='psql')
print(f"{Fore.GREEN}{data}{Fore.RESET}")
else:
print(f"{Fore.GREEN}No results found.{Fore.RESET}")
# Display the SQL that was generated
print(f"{Fore.YELLOW}{response.sql}{Fore.RESET}")
else:
# Display error message if query failed
print(f"{Fore.RED}+ {response.error}{Fore.RESET}")