Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Add to your MCP settings in Cursor:

```json
{
"meeseeql": {
"sql-explorer": {
"command": "uvx",
"args": ["meeseeql"]
}
Expand Down
102 changes: 30 additions & 72 deletions config_example.yaml
Original file line number Diff line number Diff line change
@@ -1,77 +1,35 @@
databases:
# Option 1: Use connection string (existing format)
chinook:
type: sqlite
connection_string: "sqlite:///tests/Chinook_Sqlite.sqlite"
description: "Chinook sample SQLite database"

# Option 2: Use individual fields (new format)
my_postgres:
type: postgresql
description: "PostgreSQL database with individual fields"
host: localhost
port: 5432
database: mydb
username: myuser
password: mypassword
extra_params:
sslmode: require
connect_timeout: "10"

# Examples using password store (pass) integration
prod_postgres:
type: postgresql
description: "Production PostgreSQL using default pass key"
host: prod-db.example.com
example_db: # the identifier used in all MCP tool calls
type: postgresql # supported: postgresql, mysql, sqlite, sqlserver, snowflake

description: "Example database showing all configuration options" # Human readable description useful for the LLM

# Either use connection_string of individual connection parameters
connection_string: "postgresql://app_user:secure_password@db.example.com:5432/myapp_db"
host: db.example.com
port: 5432
database: myapp
database: myapp_db
username: app_user
# No password field - will use pass entry: databases/prod_postgres

dev_postgres:
type: postgresql
description: "Development PostgreSQL using custom pass key"
host: dev-db.example.com
port: 5432
database: dev_myapp
username: dev_user
password_store_key: "company/dev/postgres"
# Will use pass entry: company/dev/postgres

memory_sqlite:
type: sqlite
description: "In-memory SQLite for testing"
database: ":memory:"

snowflake_example:
type: snowflake
description: "Snowflake data warehouse"
host: "myaccount.snowflakecomputing.com"
port: 443
database: "PROD_DWH"
username: "myuser@example.com"
password: "mypassword"
account: "myaccount"
account: my_account # For Snowflake
password: secure_password # Omit to fetch the credentials from the password manager

# Password management via 'pass' tool (alternative to password field)
password_store_key: "company/databases/example_db" # Custom pass key, defaults to: databases/example_db

# Database-specific parameters
extra_params:
warehouse: "COMPUTE_WH"
schema: "PUBLIC"
authenticator: "externalbrowser"

# Schema filtering examples
staging_postgres:
type: postgresql
host: db.example.com
port: 5432
exclude_schemas: ["archived", "analytics_cache", "temp"]

dwh_snowflake:
type: snowflake
host: "myaccount.snowflakecomputing.com"
database: "PROD_DWH"
include_schemas: ["SALES", "MARKETING", "FINANCE", "OPERATIONS"]

# Global settings
sslmode: require
application_name: "mcp-sql-server"

# Schema filtering (choose one)
include_schemas: ["public", "analytics", "reports"] # Only these schemas
exclude_schemas: ["temp", "staging", "logs"] # All except these schemas

# Table filtering (choose one)
allowed_tables: ["users", "orders", "products", "analytics_summary"] # Only these tables
disallowed_tables: ["sensitive_data", "audit_logs", "temp_tables"] # All except these tables

# Global MCP server settings
settings:
max_query_timeout: 30
max_rows_per_query: 500
enable_write_operations: false
max_query_timeout: 30 # Maximum time in seconds for any single query
max_rows_per_query: 1000 # Maximum rows returned per query (pagination will apply)
39 changes: 31 additions & 8 deletions src/meeseeql/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,19 @@ class QueryError(Exception):
class DatabaseConfig(BaseModel):
type: str
description: str

# Option 1: Use connection string directly
connection_string: str | None = None

# Option 2: Use individual fields
host: str | None = None
port: int | None = None
database: str | None = None
username: str | None = None
password: str | None = None
account: str | None = None # For Snowflake

# Custom password store key (overrides default databases/{db_name})
password_store_key: str | None = None

extra_params: Dict[str, str] | None = None

include_schemas: List[str] | None = None
exclude_schemas: List[str] | None = None
allowed_tables: List[str] | None = None
disallowed_tables: List[str] | None = None

@model_validator(mode="after")
def validate_config(self):
Expand All @@ -58,6 +52,11 @@ def validate_config(self):
"Cannot specify both include_schemas and exclude_schemas"
)

if self.allowed_tables is not None and self.disallowed_tables is not None:
raise ConfigurationError(
"Cannot specify both allowed_tables and disallowed_tables"
)

# If using connection string, no further validation needed
if self.connection_string:
return self
Expand Down Expand Up @@ -238,6 +237,18 @@ def get_filtered_schemas(self, db_name: str) -> List[str] | None:
else:
return None

def get_filtered_tables(self, db_name: str) -> List[str] | None:
db_config = self.get_database_config(db_name)
if not db_config:
return None

if db_config.allowed_tables:
return db_config.allowed_tables
elif db_config.disallowed_tables:
return db_config.disallowed_tables
else:
return None

def get_schema_filter_type(self, db_name: str) -> str | None:
db_config = self.get_database_config(db_name)
if not db_config:
Expand All @@ -250,6 +261,18 @@ def get_schema_filter_type(self, db_name: str) -> str | None:
else:
return None

def get_table_filter_type(self, db_name: str) -> str | None:
db_config = self.get_database_config(db_name)
if not db_config:
return None

if db_config.allowed_tables:
return "allow"
elif db_config.disallowed_tables:
return "deny"
else:
return None

def reload_config(self, new_config: AppConfig, changed_db_names: set[str]):
for db_name in changed_db_names:
if db_name in self.engines:
Expand Down
43 changes: 43 additions & 0 deletions src/meeseeql/sql_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sqlglot
from sqlglot import expressions as exp
from typing import List
from typing_extensions import Self


Expand All @@ -15,6 +16,10 @@ class InvalidPaginationError(Exception):
pass


class TableAccessError(Exception):
pass


class SqlQueryTransformer:
def __init__(self, query: str, dialect: str | None = None):
self.query = query
Expand Down Expand Up @@ -120,5 +125,43 @@ def add_where_condition(self, condition: str) -> Self:
except Exception as e:
raise InvalidSqlError(f"Failed to add WHERE condition: {e}") from e

def _extract_table_names(self) -> List[str]:
table_names = []

for table_node in self.ast.find_all(exp.Table):
table_name = table_node.name
if table_name:
table_names.append(table_name.lower())

return list(set(table_names))

def validate_table_access(
self,
allowed_tables: List[str] | None = None,
disallowed_tables: List[str] | None = None,
) -> Self:
if not allowed_tables and not disallowed_tables:
return self

table_names = self._extract_table_names()

if allowed_tables:
allowed_tables_lower = [t.lower() for t in allowed_tables]
for table_name in table_names:
if table_name not in allowed_tables_lower:
raise TableAccessError(
f"Table '{table_name}' is not in the allowed list"
)

if disallowed_tables:
disallowed_tables_lower = [t.lower() for t in disallowed_tables]
for table_name in table_names:
if table_name in disallowed_tables_lower:
raise TableAccessError(
f"Table '{table_name}' is in the excluded list"
)

return self

def sql(self) -> str:
return self.ast.sql(dialect=self.dialect)
10 changes: 9 additions & 1 deletion src/meeseeql/tools/execute_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, Any, List
from pydantic import BaseModel
from meeseeql.database_manager import DatabaseManager
from meeseeql.sql_transformer import SqlQueryTransformer
from meeseeql.sql_transformer import SqlQueryTransformer, TableAccessError


class QueryResponse(BaseModel):
Expand Down Expand Up @@ -86,6 +86,14 @@ async def execute_query(
transformer = SqlQueryTransformer(query.strip(), dialect)
transformer.validate_read_only()

table_filter_type = db_manager.get_table_filter_type(database)
filtered_tables = db_manager.get_filtered_tables(database)

if table_filter_type == "allow":
transformer.validate_table_access(allowed_tables=filtered_tables)
elif table_filter_type == "deny":
transformer.validate_table_access(disallowed_tables=filtered_tables)

total_rows = None
if accurate_count:
count_query = transformer.to_count_query()
Expand Down
64 changes: 46 additions & 18 deletions src/meeseeql/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __str__(self) -> str:
result = ""

for row in self.rows:
if row.data_type == "table":
if row.object_type == "table":
result += f"{row.object_type}: {row.user_friendly_descriptor}\n"
elif row.data_type and row.data_type != "null":
result += f"{row.object_type}: {row.user_friendly_descriptor} ({row.data_type}) in {row.schema_name}\n"
Expand All @@ -45,6 +45,50 @@ def _format_value(self, value) -> str:
return str(value)


def _apply_search_filters(
transformer: SqlQueryTransformer,
db_manager: DatabaseManager,
database: str,
dialect: str,
schema: str | None,
) -> None:
"""Apply schema and table filters to the search query"""
if schema:
transformer.add_where_condition(f"LOWER(schema_name) = LOWER('{schema}')")
else:
schema_filter_type = db_manager.get_schema_filter_type(database)
filtered_schemas = db_manager.get_filtered_schemas(database)

if schema_filter_type and filtered_schemas:
if schema_filter_type == "include":
schema_list = "', '".join(s.lower() for s in filtered_schemas)
transformer.add_where_condition(
f"LOWER(schema_name) IN ('{schema_list}')"
)
elif schema_filter_type == "exclude":
schema_list = "', '".join(s.lower() for s in filtered_schemas)
transformer.add_where_condition(
f"LOWER(schema_name) NOT IN ('{schema_list}')"
)

table_filter_type = db_manager.get_table_filter_type(database)
filtered_tables = db_manager.get_filtered_tables(database)

if table_filter_type and filtered_tables:
table_column = "name" if dialect == "sqlite" else "object_name"

if table_filter_type == "allow":
table_list = "', '".join(t.lower() for t in filtered_tables)
transformer.add_where_condition(
f"LOWER({table_column}) IN ('{table_list}') OR object_type != 'table'"
)
elif table_filter_type == "deny":
table_list = "', '".join(t.lower() for t in filtered_tables)
transformer.add_where_condition(
f"LOWER({table_column}) NOT IN ('{table_list}') OR object_type != 'table'"
)


async def search(
db_manager: DatabaseManager,
database: str,
Expand All @@ -61,23 +105,7 @@ async def search(

transformer = SqlQueryTransformer(sql_query, dialect)

if schema:
transformer.add_where_condition(f"LOWER(schema_name) = LOWER('{schema}')")
else:
filter_type = db_manager.get_schema_filter_type(database)
filtered_schemas = db_manager.get_filtered_schemas(database)

if filter_type and filtered_schemas:
if filter_type == "include":
schema_list = "', '".join(s.lower() for s in filtered_schemas)
transformer.add_where_condition(
f"LOWER(schema_name) IN ('{schema_list}')"
)
elif filter_type == "exclude":
schema_list = "', '".join(s.lower() for s in filtered_schemas)
transformer.add_where_condition(
f"LOWER(schema_name) NOT IN ('{schema_list}')"
)
_apply_search_filters(transformer, db_manager, database, dialect, schema)

paginated_query = transformer.add_pagination(limit).validate_read_only().sql()

Expand Down
Loading