-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpostgresql_server.py
More file actions
211 lines (184 loc) · 9.29 KB
/
postgresql_server.py
File metadata and controls
211 lines (184 loc) · 9.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import sys
import os
import logging
import psycopg2
import json
import traceback
from psycopg2 import sql
from psycopg2.extras import RealDictCursor
# Import necessary components from the mcp SDK
from mcp.server.fastmcp import FastMCP
# Configure logging
log_dir = '/app/data/logs'
os.makedirs(log_dir, exist_ok=True)
log_file_path = os.path.join(log_dir, 'postgres_server.log')
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
filename=log_file_path,
filemode='w'
)
logger = logging.getLogger(__name__)
mcp_logger = logging.getLogger('mcp')
mcp_logger.setLevel(logging.DEBUG)
if not mcp_logger.handlers:
file_handler = logging.FileHandler(log_file_path, mode='a')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
file_handler.setFormatter(formatter)
mcp_logger.addHandler(file_handler)
logger.info("PostgreSQL MCP Server script started.")
# Global variable to hold the database connection
db_connection = None
# Create a FastMCP server instance
logger.info("Creating FastMCP instance for PostgreSQL server...")
mcp = FastMCP("PostgreSQLServer")
logger.info("FastMCP instance created.")
@mcp.tool()
def connect_to_postgres(connection_string: str) -> str:
"""
Connects to a PostgreSQL database using the provided connection string.
Example connection string: "postgresql://user:password@host:port/dbname"
"""
global db_connection
logger.debug(f"Tool 'connect_to_postgres' called with connection_string: {connection_string}")
try:
if db_connection:
db_connection.close()
logger.info("Closed existing database connection.")
db_connection = psycopg2.connect(connection_string)
logger.info("Successfully connected to PostgreSQL database.")
return json.dumps({"status": "success", "message": "Successfully connected to PostgreSQL database."})
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}", exc_info=True)
db_connection = None
return json.dumps({"status": "error", "message": f"Failed to connect to PostgreSQL. {e}"})
@mcp.tool()
def get_schema_and_sample_data(output_file_path: str) -> str:
"""
Fetches schema and sample data, saves it to a file, and returns status.
Prioritizes 'public' schema to handle large databases.
"""
global db_connection
logger.debug(f"Tool 'get_schema_and_sample_data' called. Output path: {output_file_path}")
if not db_connection:
logger.warning("No active database connection.")
return json.dumps({"status": "error", "message": "Not connected to any database."})
try:
with db_connection.cursor(cursor_factory=RealDictCursor) as cur:
# Get all user tables and views, prioritizing 'public' schema
cur.execute("""
SELECT * FROM (
SELECT schemaname, tablename, 'TABLE' as type
FROM pg_catalog.pg_tables
WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
UNION ALL
SELECT schemaname, viewname as tablename, 'VIEW' as type
FROM pg_catalog.pg_views
WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
) as all_objects
ORDER BY CASE WHEN schemaname = 'public' THEN 0 ELSE 1 END, tablename;
""")
all_objects = cur.fetchall()
logger.debug(f"Found {len(all_objects)} tables and views. Processing all of them.")
result = {}
for row in all_objects:
item_name = row['tablename']
schema_name = row['schemaname']
item_type = row['type']
full_item_name = f"{schema_name}.{item_name}"
item_info = {"type": item_type, "schema": "", "sample_data": []}
if item_type == 'VIEW':
cur.execute("SELECT definition FROM pg_views WHERE schemaname = %s AND viewname = %s;", (schema_name, item_name))
view_definition = cur.fetchone()
if view_definition:
item_info["schema"] = f"CREATE VIEW {full_item_name} AS\n{view_definition['definition']}"
else: # It's a table
cur.execute("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position;
""", (schema_name, item_name))
columns = cur.fetchall()
create_statement_parts = [f"CREATE TABLE {full_item_name} ("]
for col in columns:
part = f" {col['column_name']} {col['data_type']}"
if col['is_nullable'] == 'NO': part += " NOT NULL"
if col['column_default'] is not None: part += f" DEFAULT {col['column_default']}"
create_statement_parts.append(part + ",")
if create_statement_parts[-1].endswith(","):
create_statement_parts[-1] = create_statement_parts[-1][:-1]
create_statement_parts.append(");")
item_info["schema"] = "\n".join(create_statement_parts)
query = sql.SQL("SELECT * FROM {}.{} LIMIT 5;").format(
sql.Identifier(schema_name), sql.Identifier(item_name)
)
cur.execute(query)
sample_data = [dict(row) for row in cur.fetchall()]
item_info["sample_data"] = sample_data
result[full_item_name] = item_info
# Save the result directly to the specified file path
with open(output_file_path, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=4, default=str) # Use default=str for data types like datetime
logger.info(f"Successfully saved schema for {len(all_objects)} tables and views to {output_file_path}.")
return json.dumps({
"status": "success",
"message": f"Schema for {len(all_objects)} tables and views saved successfully.",
"objects_found": len(all_objects),
"objects_processed": len(all_objects)
})
except Exception as e:
logger.error(f"Error fetching schema and data: {e}", exc_info=True)
# Format the full traceback to send back to the client
tb_str = traceback.format_exc()
return json.dumps({"status": "error", "message": f"Server-side exception: {str(e)}\nTraceback:\n{tb_str}"})
@mcp.tool()
def execute_postgres_query(query: str) -> str:
"""
Executes a given SQL query on the currently connected PostgreSQL database.
Requires a prior successful connection using 'connect_to_postgres'.
For SELECT queries, it returns a JSON string of the rows.
For other queries (INSERT, UPDATE, DELETE), it returns a JSON string with a success message and row count.
"""
global db_connection
logger.debug(f"Tool 'execute_postgres_query' called with query: {query[:100]}...") # Log truncated query
if not db_connection:
logger.warning("No active database connection for 'execute_postgres_query'.")
return json.dumps({"status": "error", "message": "Not connected to any database. Please use 'connect_to_postgres' first."})
try:
with db_connection.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query)
if cur.description:
results = cur.fetchall()
logger.info(f"Query executed successfully. Fetched {len(results)} rows.")
db_connection.commit()
return json.dumps({"status": "success", "data": results}, default=str)
else:
rowcount = cur.rowcount
db_connection.commit()
logger.info(f"Query executed successfully. {rowcount} rows affected.")
return json.dumps({"status": "success", "message": f"Query executed successfully. {rowcount} rows affected."})
except Exception as e:
db_connection.rollback()
logger.error(f"Error executing query: {e}", exc_info=True)
return json.dumps({"status": "error", "message": f"Could not execute query. {e}"})
logger.info("PostgreSQL MCP tools defined.")
if __name__ == "__main__":
logger.info("__main__ block started for PostgreSQL server.")
if sys.platform == "win32":
logger.info("Configuring stdout/stdin for win32.")
sys.stdout.reconfigure(encoding='utf-8')
sys.stdin.reconfigure(encoding='utf-8')
logger.info("stdout/stdin configured for win32.")
logger.info("Starting MCP PostgreSQL Server (stdio)...")
print("Starting MCP PostgreSQL Server (stdio)...")
try:
mcp.run(transport='stdio')
except Exception as e:
logger.critical(f"Critical error running MCP PostgreSQL server: {e}", exc_info=True)
finally:
if db_connection:
db_connection.close()
logger.info("Closed database connection on server exit.")
logger.info("MCP PostgreSQL server finished or exited.")
logging.shutdown()