From 252479709763e01bd80af7ff4a7133606fe1495c Mon Sep 17 00:00:00 2001 From: Thomas Smith Date: Wed, 6 Aug 2025 16:55:32 +0200 Subject: [PATCH] Add table allow/disallow list --- README.md | 2 +- config_example.yaml | 102 ++++++------------ src/meeseeql/database_manager.py | 39 +++++-- src/meeseeql/sql_transformer.py | 43 ++++++++ src/meeseeql/tools/execute_query.py | 10 +- src/meeseeql/tools/search.py | 64 +++++++---- src/meeseeql/tools/table_summary.py | 58 +++++++--- .../test_config_validation.py | 58 ++++++++++ tests/test_sql_transformer_table_access.py | 65 +++++++++++ tests/tools/test_execute_query.py | 52 ++++++++- tests/tools/test_search.py | 75 ++++++++++++- tests/tools/test_table_summary.py | 48 +++++++++ 12 files changed, 499 insertions(+), 117 deletions(-) create mode 100644 tests/test_sql_transformer_table_access.py diff --git a/README.md b/README.md index 161576a..f7215d9 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Add to your MCP settings in Cursor: ```json { - "meeseeql": { + "sql-explorer": { "command": "uvx", "args": ["meeseeql"] } diff --git a/config_example.yaml b/config_example.yaml index ff0fcdc..9961c8e 100644 --- a/config_example.yaml +++ b/config_example.yaml @@ -1,77 +1,35 @@ databases: - # Option 1: Use connection string (existing format) - chinook: - type: sqlite - connection_string: "sqlite:///tests/Chinook_Sqlite.sqlite" - description: "Chinook sample SQLite database" - - # Option 2: Use individual fields (new format) - my_postgres: - type: postgresql - description: "PostgreSQL database with individual fields" - host: localhost - port: 5432 - database: mydb - username: myuser - password: mypassword - extra_params: - sslmode: require - connect_timeout: "10" - - # Examples using password store (pass) integration - prod_postgres: - type: postgresql - description: "Production PostgreSQL using default pass key" - host: prod-db.example.com + example_db: # the identifier used in all MCP tool calls + type: postgresql # supported: postgresql, mysql, sqlite, sqlserver, snowflake + + description: "Example database showing all configuration options" # Human readable description useful for the LLM + + # Either use connection_string of individual connection parameters + connection_string: "postgresql://app_user:secure_password@db.example.com:5432/myapp_db" + host: db.example.com port: 5432 - database: myapp + database: myapp_db username: app_user - # No password field - will use pass entry: databases/prod_postgres - - dev_postgres: - type: postgresql - description: "Development PostgreSQL using custom pass key" - host: dev-db.example.com - port: 5432 - database: dev_myapp - username: dev_user - password_store_key: "company/dev/postgres" - # Will use pass entry: company/dev/postgres - - memory_sqlite: - type: sqlite - description: "In-memory SQLite for testing" - database: ":memory:" - - snowflake_example: - type: snowflake - description: "Snowflake data warehouse" - host: "myaccount.snowflakecomputing.com" - port: 443 - database: "PROD_DWH" - username: "myuser@example.com" - password: "mypassword" - account: "myaccount" + account: my_account # For Snowflake + password: secure_password # Omit to fetch the credentials from the password manager + + # Password management via 'pass' tool (alternative to password field) + password_store_key: "company/databases/example_db" # Custom pass key, defaults to: databases/example_db + + # Database-specific parameters extra_params: - warehouse: "COMPUTE_WH" - schema: "PUBLIC" - authenticator: "externalbrowser" - - # Schema filtering examples - staging_postgres: - type: postgresql - host: db.example.com - port: 5432 - exclude_schemas: ["archived", "analytics_cache", "temp"] - - dwh_snowflake: - type: snowflake - host: "myaccount.snowflakecomputing.com" - database: "PROD_DWH" - include_schemas: ["SALES", "MARKETING", "FINANCE", "OPERATIONS"] - -# Global settings + sslmode: require + application_name: "mcp-sql-server" + + # Schema filtering (choose one) + include_schemas: ["public", "analytics", "reports"] # Only these schemas + exclude_schemas: ["temp", "staging", "logs"] # All except these schemas + + # Table filtering (choose one) + allowed_tables: ["users", "orders", "products", "analytics_summary"] # Only these tables + disallowed_tables: ["sensitive_data", "audit_logs", "temp_tables"] # All except these tables + +# Global MCP server settings settings: - max_query_timeout: 30 - max_rows_per_query: 500 - enable_write_operations: false + max_query_timeout: 30 # Maximum time in seconds for any single query + max_rows_per_query: 1000 # Maximum rows returned per query (pagination will apply) diff --git a/src/meeseeql/database_manager.py b/src/meeseeql/database_manager.py index 22dc210..46ce3eb 100644 --- a/src/meeseeql/database_manager.py +++ b/src/meeseeql/database_manager.py @@ -27,25 +27,19 @@ class QueryError(Exception): class DatabaseConfig(BaseModel): type: str description: str - - # Option 1: Use connection string directly connection_string: str | None = None - - # Option 2: Use individual fields host: str | None = None port: int | None = None database: str | None = None username: str | None = None password: str | None = None account: str | None = None # For Snowflake - - # Custom password store key (overrides default databases/{db_name}) password_store_key: str | None = None - extra_params: Dict[str, str] | None = None - include_schemas: List[str] | None = None exclude_schemas: List[str] | None = None + allowed_tables: List[str] | None = None + disallowed_tables: List[str] | None = None @model_validator(mode="after") def validate_config(self): @@ -58,6 +52,11 @@ def validate_config(self): "Cannot specify both include_schemas and exclude_schemas" ) + if self.allowed_tables is not None and self.disallowed_tables is not None: + raise ConfigurationError( + "Cannot specify both allowed_tables and disallowed_tables" + ) + # If using connection string, no further validation needed if self.connection_string: return self @@ -238,6 +237,18 @@ def get_filtered_schemas(self, db_name: str) -> List[str] | None: else: return None + def get_filtered_tables(self, db_name: str) -> List[str] | None: + db_config = self.get_database_config(db_name) + if not db_config: + return None + + if db_config.allowed_tables: + return db_config.allowed_tables + elif db_config.disallowed_tables: + return db_config.disallowed_tables + else: + return None + def get_schema_filter_type(self, db_name: str) -> str | None: db_config = self.get_database_config(db_name) if not db_config: @@ -250,6 +261,18 @@ def get_schema_filter_type(self, db_name: str) -> str | None: else: return None + def get_table_filter_type(self, db_name: str) -> str | None: + db_config = self.get_database_config(db_name) + if not db_config: + return None + + if db_config.allowed_tables: + return "allow" + elif db_config.disallowed_tables: + return "deny" + else: + return None + def reload_config(self, new_config: AppConfig, changed_db_names: set[str]): for db_name in changed_db_names: if db_name in self.engines: diff --git a/src/meeseeql/sql_transformer.py b/src/meeseeql/sql_transformer.py index dd0bc6f..7154560 100644 --- a/src/meeseeql/sql_transformer.py +++ b/src/meeseeql/sql_transformer.py @@ -1,5 +1,6 @@ import sqlglot from sqlglot import expressions as exp +from typing import List from typing_extensions import Self @@ -15,6 +16,10 @@ class InvalidPaginationError(Exception): pass +class TableAccessError(Exception): + pass + + class SqlQueryTransformer: def __init__(self, query: str, dialect: str | None = None): self.query = query @@ -120,5 +125,43 @@ def add_where_condition(self, condition: str) -> Self: except Exception as e: raise InvalidSqlError(f"Failed to add WHERE condition: {e}") from e + def _extract_table_names(self) -> List[str]: + table_names = [] + + for table_node in self.ast.find_all(exp.Table): + table_name = table_node.name + if table_name: + table_names.append(table_name.lower()) + + return list(set(table_names)) + + def validate_table_access( + self, + allowed_tables: List[str] | None = None, + disallowed_tables: List[str] | None = None, + ) -> Self: + if not allowed_tables and not disallowed_tables: + return self + + table_names = self._extract_table_names() + + if allowed_tables: + allowed_tables_lower = [t.lower() for t in allowed_tables] + for table_name in table_names: + if table_name not in allowed_tables_lower: + raise TableAccessError( + f"Table '{table_name}' is not in the allowed list" + ) + + if disallowed_tables: + disallowed_tables_lower = [t.lower() for t in disallowed_tables] + for table_name in table_names: + if table_name in disallowed_tables_lower: + raise TableAccessError( + f"Table '{table_name}' is in the excluded list" + ) + + return self + def sql(self) -> str: return self.ast.sql(dialect=self.dialect) diff --git a/src/meeseeql/tools/execute_query.py b/src/meeseeql/tools/execute_query.py index f45850b..b449990 100644 --- a/src/meeseeql/tools/execute_query.py +++ b/src/meeseeql/tools/execute_query.py @@ -2,7 +2,7 @@ from typing import Dict, Any, List from pydantic import BaseModel from meeseeql.database_manager import DatabaseManager -from meeseeql.sql_transformer import SqlQueryTransformer +from meeseeql.sql_transformer import SqlQueryTransformer, TableAccessError class QueryResponse(BaseModel): @@ -86,6 +86,14 @@ async def execute_query( transformer = SqlQueryTransformer(query.strip(), dialect) transformer.validate_read_only() + table_filter_type = db_manager.get_table_filter_type(database) + filtered_tables = db_manager.get_filtered_tables(database) + + if table_filter_type == "allow": + transformer.validate_table_access(allowed_tables=filtered_tables) + elif table_filter_type == "deny": + transformer.validate_table_access(disallowed_tables=filtered_tables) + total_rows = None if accurate_count: count_query = transformer.to_count_query() diff --git a/src/meeseeql/tools/search.py b/src/meeseeql/tools/search.py index eed65ae..02a8161 100644 --- a/src/meeseeql/tools/search.py +++ b/src/meeseeql/tools/search.py @@ -23,7 +23,7 @@ def __str__(self) -> str: result = "" for row in self.rows: - if row.data_type == "table": + if row.object_type == "table": result += f"{row.object_type}: {row.user_friendly_descriptor}\n" elif row.data_type and row.data_type != "null": result += f"{row.object_type}: {row.user_friendly_descriptor} ({row.data_type}) in {row.schema_name}\n" @@ -45,6 +45,50 @@ def _format_value(self, value) -> str: return str(value) +def _apply_search_filters( + transformer: SqlQueryTransformer, + db_manager: DatabaseManager, + database: str, + dialect: str, + schema: str | None, +) -> None: + """Apply schema and table filters to the search query""" + if schema: + transformer.add_where_condition(f"LOWER(schema_name) = LOWER('{schema}')") + else: + schema_filter_type = db_manager.get_schema_filter_type(database) + filtered_schemas = db_manager.get_filtered_schemas(database) + + if schema_filter_type and filtered_schemas: + if schema_filter_type == "include": + schema_list = "', '".join(s.lower() for s in filtered_schemas) + transformer.add_where_condition( + f"LOWER(schema_name) IN ('{schema_list}')" + ) + elif schema_filter_type == "exclude": + schema_list = "', '".join(s.lower() for s in filtered_schemas) + transformer.add_where_condition( + f"LOWER(schema_name) NOT IN ('{schema_list}')" + ) + + table_filter_type = db_manager.get_table_filter_type(database) + filtered_tables = db_manager.get_filtered_tables(database) + + if table_filter_type and filtered_tables: + table_column = "name" if dialect == "sqlite" else "object_name" + + if table_filter_type == "allow": + table_list = "', '".join(t.lower() for t in filtered_tables) + transformer.add_where_condition( + f"LOWER({table_column}) IN ('{table_list}') OR object_type != 'table'" + ) + elif table_filter_type == "deny": + table_list = "', '".join(t.lower() for t in filtered_tables) + transformer.add_where_condition( + f"LOWER({table_column}) NOT IN ('{table_list}') OR object_type != 'table'" + ) + + async def search( db_manager: DatabaseManager, database: str, @@ -61,23 +105,7 @@ async def search( transformer = SqlQueryTransformer(sql_query, dialect) - if schema: - transformer.add_where_condition(f"LOWER(schema_name) = LOWER('{schema}')") - else: - filter_type = db_manager.get_schema_filter_type(database) - filtered_schemas = db_manager.get_filtered_schemas(database) - - if filter_type and filtered_schemas: - if filter_type == "include": - schema_list = "', '".join(s.lower() for s in filtered_schemas) - transformer.add_where_condition( - f"LOWER(schema_name) IN ('{schema_list}')" - ) - elif filter_type == "exclude": - schema_list = "', '".join(s.lower() for s in filtered_schemas) - transformer.add_where_condition( - f"LOWER(schema_name) NOT IN ('{schema_list}')" - ) + _apply_search_filters(transformer, db_manager, database, dialect, schema) paginated_query = transformer.add_pagination(limit).validate_read_only().sql() diff --git a/src/meeseeql/tools/table_summary.py b/src/meeseeql/tools/table_summary.py index b182631..7709d3f 100644 --- a/src/meeseeql/tools/table_summary.py +++ b/src/meeseeql/tools/table_summary.py @@ -349,39 +349,67 @@ async def _check_table_exists( ) from e -async def table_summary( +async def _validate_table_summary_inputs( db_manager: DatabaseManager, database: str, table_name: str, - db_schema: str | None = None, - limit: int = 250, - page: int = 1, -) -> TableSummary: + db_schema: str | None, + limit: int, + page: int, +) -> str: + """Validate inputs and return the schema value to use""" if limit < 1: raise TableSummaryError("Limit must be greater than 0") if page < 1: raise TableSummaryError("Page number must be greater than 0") - max_rows = db_manager.config.settings.get("max_rows_per_query", 1000) - if limit > max_rows: - limit = max_rows - - dialect = db_manager.get_dialect_name(database) + table_filter_type = db_manager.get_table_filter_type(database) + filtered_tables = db_manager.get_filtered_tables(database) - if not db_schema: - schema_value = db_manager.get_default_schema(database) - else: - schema_value = db_schema + if table_filter_type and filtered_tables: + if table_filter_type == "allow": + if table_name.lower() not in [t.lower() for t in filtered_tables]: + raise TableNotFoundError( + f"Table '{table_name}' is not in the allowed list" + ) + elif table_filter_type == "deny": + if table_name.lower() in [t.lower() for t in filtered_tables]: + raise TableNotFoundError(f"Table '{table_name}' is in the exluded list") table_exists = await _check_table_exists( - db_manager, database, table_name, schema_value + db_manager, database, table_name, db_schema ) if not table_exists: raise TableNotFoundError( f"Table '{table_name}' not found in database '{database}'" ) + +async def table_summary( + db_manager: DatabaseManager, + database: str, + table_name: str, + db_schema: str | None = None, + limit: int = 250, + page: int = 1, +) -> TableSummary: + + if not db_schema: + schema_value = db_manager.get_default_schema(database) + else: + schema_value = db_schema + + await _validate_table_summary_inputs( + db_manager, database, table_name, schema_value, limit, page + ) + + max_rows = db_manager.config.settings.get("max_rows_per_query", 1000) + if limit > max_rows: + limit = max_rows + + dialect = db_manager.get_dialect_name(database) + column_count, outgoing_fk_count, incoming_fk_count = await _get_counts( db_manager, database, table_name, schema_value ) diff --git a/tests/database_manager/test_config_validation.py b/tests/database_manager/test_config_validation.py index 8aa3ee9..d711aa0 100644 --- a/tests/database_manager/test_config_validation.py +++ b/tests/database_manager/test_config_validation.py @@ -103,3 +103,61 @@ def test_no_schema_filtering(): ) assert config.include_schemas is None assert config.exclude_schemas is None + + +def test_both_allowed_and_disallowed_tables(): + with raises( + ConfigurationError, + match="Cannot specify both allowed_tables and disallowed_tables", + ): + DatabaseConfig( + type="postgresql", + description="Test DB", + host="localhost", + database="testdb", + username="user", + password="pass", + allowed_tables=["users", "products"], + disallowed_tables=["logs", "temp"], + ) + + +def test_allowed_tables_only(): + config = DatabaseConfig( + type="postgresql", + description="Test DB", + host="localhost", + database="testdb", + username="user", + password="pass", + allowed_tables=["users", "products"], + ) + assert config.allowed_tables == ["users", "products"] + assert config.disallowed_tables is None + + +def test_disallowed_tables_only(): + config = DatabaseConfig( + type="postgresql", + description="Test DB", + host="localhost", + database="testdb", + username="user", + password="pass", + disallowed_tables=["logs", "temp"], + ) + assert config.disallowed_tables == ["logs", "temp"] + assert config.allowed_tables is None + + +def test_no_table_filtering(): + config = DatabaseConfig( + type="postgresql", + description="Test DB", + host="localhost", + database="testdb", + username="user", + password="pass", + ) + assert config.allowed_tables is None + assert config.disallowed_tables is None diff --git a/tests/test_sql_transformer_table_access.py b/tests/test_sql_transformer_table_access.py new file mode 100644 index 0000000..fda4ade --- /dev/null +++ b/tests/test_sql_transformer_table_access.py @@ -0,0 +1,65 @@ +import pytest +from meeseeql.sql_transformer import SqlQueryTransformer, TableAccessError + + +def test_validate_table_access_allow_allowed(): + transformer = SqlQueryTransformer("SELECT * FROM users", "sqlite") + result = transformer.validate_table_access(allowed_tables=["users", "products"]) + assert result is transformer + + +def test_validate_table_access_allow_denied(): + transformer = SqlQueryTransformer("SELECT * FROM logs", "sqlite") + with pytest.raises( + TableAccessError, match="Table 'logs' is not in the allowed list" + ): + transformer.validate_table_access(allowed_tables=["users", "products"]) + + +def test_validate_table_access_disallow_allowed(): + transformer = SqlQueryTransformer("SELECT * FROM users", "sqlite") + result = transformer.validate_table_access(disallowed_tables=["logs", "temp"]) + assert result is transformer + + +def test_validate_table_access_disallow_denied(): + transformer = SqlQueryTransformer("SELECT * FROM logs", "sqlite") + with pytest.raises(TableAccessError, match="Table 'logs' is in the excluded list"): + transformer.validate_table_access(disallowed_tables=["logs", "temp"]) + + +def test_validate_table_access_multiple_tables_allow(): + transformer = SqlQueryTransformer( + "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id", + "postgresql", + ) + result = transformer.validate_table_access( + allowed_tables=["users", "orders", "products"] + ) + assert result is transformer + + +def test_validate_table_access_multiple_tables_allow_partial_denied(): + transformer = SqlQueryTransformer( + "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id", + "postgresql", + ) + with pytest.raises( + TableAccessError, match="Table 'orders' is not in the allowed list" + ): + transformer.validate_table_access(allowed_tables=["users", "products"]) + + +def test_validate_table_access_no_restrictions(): + transformer = SqlQueryTransformer("SELECT * FROM anything", "sqlite") + result = transformer.validate_table_access() + assert result is transformer + + +def test_validate_table_access_case_insensitive(): + transformer = SqlQueryTransformer("SELECT * FROM Users", "sqlite") + result = transformer.validate_table_access(allowed_tables=["USERS", "products"]) + assert result is transformer + + with pytest.raises(TableAccessError): + transformer.validate_table_access(disallowed_tables=["USERS"]) diff --git a/tests/tools/test_execute_query.py b/tests/tools/test_execute_query.py index c571565..2ece508 100644 --- a/tests/tools/test_execute_query.py +++ b/tests/tools/test_execute_query.py @@ -6,9 +6,11 @@ load_config, DatabaseManager, QueryError, + DatabaseConfig, + AppConfig, ) from meeseeql.tools.execute_query import execute_query, QueryResponse -from meeseeql.sql_transformer import ReadOnlyViolationError +from meeseeql.sql_transformer import ReadOnlyViolationError, TableAccessError @pytest.fixture @@ -217,3 +219,51 @@ async def test_execute_query_accurate_count_complex_query(db_manager): assert result.total_pages == math.ceil(result.total_rows / 5) else: assert result.total_pages == 1 + + +async def test_execute_query_respects_allowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table inclusion", + allowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + # Should work for allowed table + result = await execute_query(db_manager, "test_db", "SELECT * FROM Track LIMIT 5") + assert isinstance(result, QueryResponse) + assert len(result.columns) > 0 + + # Should raise error for disallowed table + with pytest.raises(TableAccessError, match="not in the allowed list"): + await execute_query(db_manager, "test_db", "SELECT * FROM Artist LIMIT 5") + + +async def test_execute_query_respects_disallowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table exclusion", + disallowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + # Should work for non-excluded table + result = await execute_query(db_manager, "test_db", "SELECT * FROM Artist LIMIT 5") + assert isinstance(result, QueryResponse) + assert len(result.columns) > 0 + + # Should raise error for excluded table + with pytest.raises(TableAccessError, match="in the excluded list"): + await execute_query(db_manager, "test_db", "SELECT * FROM Track LIMIT 5") diff --git a/tests/tools/test_search.py b/tests/tools/test_search.py index bff2a35..a3154f9 100644 --- a/tests/tools/test_search.py +++ b/tests/tools/test_search.py @@ -1,5 +1,10 @@ import pytest -from meeseeql.database_manager import load_config, DatabaseManager +from meeseeql.database_manager import ( + load_config, + DatabaseManager, + DatabaseConfig, + AppConfig, +) from meeseeql.tools.search import search @@ -67,3 +72,71 @@ async def test_search_case_insensitive(db_manager): assert len(result_lower.rows) == len(result_upper.rows) == len(result_mixed.rows) assert len(result_lower.rows) > 0 + + +async def test_search_respects_allowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table inclusion", + allowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + result = await search(db_manager, "test_db", "a") + + # Should only return results from Track and Album tables, not Artist or other tables + table_results = [r for r in result.rows if r.object_type == "table"] + table_names = [ + ( + r.user_friendly_descriptor.split(".")[1] + if "." in r.user_friendly_descriptor + else r.user_friendly_descriptor + ) + for r in table_results + ] + + for table_name in table_names: + assert table_name.lower() in [ + "track", + "album", + ], f"Found disallowed table: {table_name}" + + +async def test_search_respects_disallowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table exclusion", + disallowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + result = await search(db_manager, "test_db", "a") + + # Should not return Track or Album tables + table_results = [r for r in result.rows if r.object_type == "table"] + table_names = [ + ( + r.user_friendly_descriptor.split(".")[1] + if "." in r.user_friendly_descriptor + else r.user_friendly_descriptor + ) + for r in table_results + ] + + for table_name in table_names: + assert table_name.lower() not in [ + "track", + "album", + ], f"Found excluded table: {table_name}" diff --git a/tests/tools/test_table_summary.py b/tests/tools/test_table_summary.py index a0d9d7d..3cf22c5 100644 --- a/tests/tools/test_table_summary.py +++ b/tests/tools/test_table_summary.py @@ -4,6 +4,8 @@ from meeseeql.database_manager import ( load_config, DatabaseManager, + DatabaseConfig, + AppConfig, ) from meeseeql.tools.table_summary import ( table_summary, @@ -150,3 +152,49 @@ async def test_table_summary_case_insensitive(db_manager): # Test that non-existent table still raises error with pytest.raises(TableNotFoundError): await table_summary(db_manager, "chinook_sqlite", "nonexistenttable") + + +async def test_table_summary_respects_allowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table inclusion", + allowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + # Should work for allowed table + result = await table_summary(db_manager, "test_db", "Track") + assert result.table == "main.Track" + + # Should raise error for disallowed table + with pytest.raises(TableNotFoundError, match="not in the allowed list"): + await table_summary(db_manager, "test_db", "Artist") + + +async def test_table_summary_respects_disallowed_tables(): + config = AppConfig( + databases={ + "test_db": DatabaseConfig( + type="sqlite", + connection_string="sqlite:///tests/Chinook_Sqlite.sqlite", + description="Test DB with table exclusion", + disallowed_tables=["Track", "Album"], + ) + }, + settings={}, + ) + db_manager = DatabaseManager(config) + + # Should work for non-excluded table + result = await table_summary(db_manager, "test_db", "Artist") + assert result.table == "main.Artist" + + # Should raise error for excluded table + with pytest.raises(TableNotFoundError, match="exluded list"): + await table_summary(db_manager, "test_db", "Track")