diff --git a/.gitignore b/.gitignore index 64de0dd87..9e41a2dd4 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,5 @@ tools/scripts/profiles/*.prof .beads-credential-key .codex .mypy_worker* +data/ +.antigravitycli/ diff --git a/pyproject.toml b/pyproject.toml index 097e9a51e..60b0033a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -587,6 +587,7 @@ ignore = [ "PLW0603", # Using the global statement to update is discouraged "PLW0108", # Replace lambda expression with a def "RUF067", # ruff - init should only have import statements + "PLW0717", # too many statements in try block ] select = ["ALL"] diff --git a/sqlspec/__init__.py b/sqlspec/__init__.py index 7fa64e824..47bf7fa97 100644 --- a/sqlspec/__init__.py +++ b/sqlspec/__init__.py @@ -89,8 +89,6 @@ from sqlspec.typing import ConnectionT, PoolT, SchemaT, StatementParameters, SupportedSchemaModel from sqlspec.utils.logging import suppress_erroneous_sqlglot_log_messages -suppress_erroneous_sqlglot_log_messages() - __all__ = ( "SQL", "ArrowResult", @@ -165,3 +163,5 @@ "typing", "utils", ) + +suppress_erroneous_sqlglot_log_messages() diff --git a/sqlspec/_typing.py b/sqlspec/_typing.py index bfd920acf..4a1419e98 100644 --- a/sqlspec/_typing.py +++ b/sqlspec/_typing.py @@ -16,6 +16,89 @@ from sqlspec.utils.module_loader import dependency_flag, module_available +__all__ = ( + "ALLOYDB_CONNECTOR_INSTALLED", + "ATTRS_INSTALLED", + "CATTRS_INSTALLED", + "CLOUD_SQL_CONNECTOR_INSTALLED", + "FSSPEC_INSTALLED", + "LITESTAR_INSTALLED", + "MSGSPEC_INSTALLED", + "NANOID_INSTALLED", + "NUMPY_INSTALLED", + "OBSTORE_INSTALLED", + "OPENTELEMETRY_INSTALLED", + "ORJSON_INSTALLED", + "PANDAS_INSTALLED", + "PGVECTOR_INSTALLED", + "POLARS_INSTALLED", + "PROMETHEUS_INSTALLED", + "PYARROW_INSTALLED", + "PYDANTIC_INSTALLED", + "UNSET", + "UNSET_STUB", + "UUID_UTILS_INSTALLED", + "ArrowRecordBatch", + "ArrowRecordBatchReader", + "ArrowRecordBatchReaderProtocol", + "ArrowRecordBatchResult", + "ArrowSchema", + "ArrowSchemaProtocol", + "ArrowTable", + "ArrowTableResult", + "AttrsInstance", + "AttrsInstanceStub", + "BaseModel", + "BaseModelStub", + "Counter", + "DTOData", + "DTODataStub", + "DataclassProtocol", + "Empty", + "EmptyEnum", + "EmptyType", + "FailFast", + "FailFastStub", + "Gauge", + "Histogram", + "NumpyArray", + "NumpyArrayStub", + "PandasDataFrame", + "PandasDataFrameProtocol", + "PolarsDataFrame", + "PolarsDataFrameProtocol", + "Span", + "Status", + "StatusCode", + "Struct", + "StructStub", + "T", + "T_co", + "Tracer", + "TypeAdapter", + "TypeAdapterStub", + "UnsetType", + "UnsetTypeStub", + "attrs_asdict", + "attrs_asdict_stub", + "attrs_define", + "attrs_define_stub", + "attrs_field", + "attrs_field_stub", + "attrs_fields", + "attrs_fields_stub", + "attrs_has", + "attrs_has_stub", + "cattrs_structure", + "cattrs_unstructure", + "convert", + "convert_stub", + "module_available", + "msgspec_fields", + "msgspec_fields_stub", + "trace", +) + @runtime_checkable class DataclassProtocol(Protocol): @@ -630,85 +713,3 @@ def labels(self, *labelvalues: str, **labelkwargs: str) -> _MetricInstance: ALLOYDB_CONNECTOR_INSTALLED = dependency_flag("google.cloud.alloydb.connector") NANOID_INSTALLED = dependency_flag("fastnanoid") UUID_UTILS_INSTALLED = dependency_flag("uuid_utils") -__all__ = ( - "ALLOYDB_CONNECTOR_INSTALLED", - "ATTRS_INSTALLED", - "CATTRS_INSTALLED", - "CLOUD_SQL_CONNECTOR_INSTALLED", - "FSSPEC_INSTALLED", - "LITESTAR_INSTALLED", - "MSGSPEC_INSTALLED", - "NANOID_INSTALLED", - "NUMPY_INSTALLED", - "OBSTORE_INSTALLED", - "OPENTELEMETRY_INSTALLED", - "ORJSON_INSTALLED", - "PANDAS_INSTALLED", - "PGVECTOR_INSTALLED", - "POLARS_INSTALLED", - "PROMETHEUS_INSTALLED", - "PYARROW_INSTALLED", - "PYDANTIC_INSTALLED", - "UNSET", - "UNSET_STUB", - "UUID_UTILS_INSTALLED", - "ArrowRecordBatch", - "ArrowRecordBatchReader", - "ArrowRecordBatchReaderProtocol", - "ArrowRecordBatchResult", - "ArrowSchema", - "ArrowSchemaProtocol", - "ArrowTable", - "ArrowTableResult", - "AttrsInstance", - "AttrsInstanceStub", - "BaseModel", - "BaseModelStub", - "Counter", - "DTOData", - "DTODataStub", - "DataclassProtocol", - "Empty", - "EmptyEnum", - "EmptyType", - "FailFast", - "FailFastStub", - "Gauge", - "Histogram", - "NumpyArray", - "NumpyArrayStub", - "PandasDataFrame", - "PandasDataFrameProtocol", - "PolarsDataFrame", - "PolarsDataFrameProtocol", - "Span", - "Status", - "StatusCode", - "Struct", - "StructStub", - "T", - "T_co", - "Tracer", - "TypeAdapter", - "TypeAdapterStub", - "UnsetType", - "UnsetTypeStub", - "attrs_asdict", - "attrs_asdict_stub", - "attrs_define", - "attrs_define_stub", - "attrs_field", - "attrs_field_stub", - "attrs_fields", - "attrs_fields_stub", - "attrs_has", - "attrs_has_stub", - "cattrs_structure", - "cattrs_unstructure", - "convert", - "convert_stub", - "module_available", - "msgspec_fields", - "msgspec_fields_stub", - "trace", -) diff --git a/sqlspec/adapters/adbc/_typing.py b/sqlspec/adapters/adbc/_typing.py index df9982cfc..3880f82d2 100644 --- a/sqlspec/adapters/adbc/_typing.py +++ b/sqlspec/adapters/adbc/_typing.py @@ -28,6 +28,8 @@ AdbcConnection = _AdbcConnection AdbcRawCursor = _AdbcRawCursor +__all__ = ("AdbcConnection", "AdbcCursor", "AdbcRawCursor", "AdbcSessionContext") + class AdbcCursor: """Context manager for cursor management.""" @@ -106,6 +108,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("AdbcConnection", "AdbcCursor", "AdbcRawCursor", "AdbcSessionContext") diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 1b1166f29..c5d5041a9 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -14,9 +14,10 @@ from sqlspec.adapters.adbc.config import AdbcConfig from sqlspec.extensions.adk import MemoryRecord +__all__ = ("AdbcADKMemoryStore", "AdbcADKStore") + logger = get_logger("sqlspec.adapters.adbc.adk.store") -__all__ = ("AdbcADKMemoryStore", "AdbcADKStore") DIALECT_POSTGRESQL: Final = "postgresql" DIALECT_SQLITE: Final = "sqlite" diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index e1bf88c88..4d120b034 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -437,7 +437,7 @@ def handle_postgres_rollback(dialect: str, cursor: Any, logger: Any | None = Non cursor: Database cursor to execute rollback. logger: Optional logger for diagnostics. """ - if dialect != "postgres": + if not is_postgres_dialect(dialect): return try: cursor.execute("ROLLBACK") @@ -457,7 +457,7 @@ def normalize_postgres_empty_parameters(dialect: str, parameters: Any) -> Any: Returns: Normalized parameter payload. """ - if dialect == "postgres" and isinstance(parameters, dict) and not parameters: + if is_postgres_dialect(dialect) and isinstance(parameters, dict) and not parameters: return None return parameters diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index 4124f1ea4..7e9442195 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -23,12 +23,12 @@ from sqlspec.exceptions import SQLFileNotFoundError from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -__all__ = ("AdbcDataDictionary",) - if TYPE_CHECKING: from sqlspec.adapters.adbc.driver import AdbcDriver from sqlspec.core import SQL +__all__ = ("AdbcDataDictionary",) + @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class AdbcDataDictionary(SyncDataDictionaryBase): diff --git a/sqlspec/adapters/aiomysql/_typing.py b/sqlspec/adapters/aiomysql/_typing.py index 88979bab9..cc153f69a 100644 --- a/sqlspec/adapters/aiomysql/_typing.py +++ b/sqlspec/adapters/aiomysql/_typing.py @@ -39,6 +39,15 @@ def close(self) -> Any: ... AiomysqlDictCursor = _AiomysqlDictCursor AiomysqlPool = _AiomysqlPool +__all__ = ( + "AiomysqlConnection", + "AiomysqlCursor", + "AiomysqlDictCursor", + "AiomysqlPool", + "AiomysqlRawCursor", + "AiomysqlSessionContext", +) + class AiomysqlCursor: """Context manager for aiomysql cursor operations. @@ -125,13 +134,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "AiomysqlConnection", - "AiomysqlCursor", - "AiomysqlDictCursor", - "AiomysqlPool", - "AiomysqlRawCursor", - "AiomysqlSessionContext", -) diff --git a/sqlspec/adapters/aiomysql/driver.py b/sqlspec/adapters/aiomysql/driver.py index 9da2da084..a72167a1d 100644 --- a/sqlspec/adapters/aiomysql/driver.py +++ b/sqlspec/adapters/aiomysql/driver.py @@ -184,6 +184,10 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters is not None and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + successful_count = 0 last_cursor = cursor diff --git a/sqlspec/adapters/aiomysql/litestar/store.py b/sqlspec/adapters/aiomysql/litestar/store.py index 8aeb26725..b595f2c99 100644 --- a/sqlspec/adapters/aiomysql/litestar/store.py +++ b/sqlspec/adapters/aiomysql/litestar/store.py @@ -12,9 +12,10 @@ if TYPE_CHECKING: from sqlspec.adapters.aiomysql.config import AiomysqlConfig +__all__ = ("AiomysqlStore",) + logger = get_logger("sqlspec.adapters.aiomysql.litestar.store") -__all__ = ("AiomysqlStore",) MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 diff --git a/sqlspec/adapters/aiosqlite/_typing.py b/sqlspec/adapters/aiosqlite/_typing.py index 8a4635e3e..53d4f40ce 100644 --- a/sqlspec/adapters/aiosqlite/_typing.py +++ b/sqlspec/adapters/aiosqlite/_typing.py @@ -27,6 +27,8 @@ AiosqliteConnection = _AiosqliteConnection AiosqliteRawCursor = aiosqlite.Cursor +__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteRawCursor", "AiosqliteSessionContext") + class AiosqliteCursor: """Async context manager for AIOSQLite cursors.""" @@ -44,8 +46,6 @@ async def __aenter__(self) -> "AiosqliteRawCursor": async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> None: - if exc_type is not None: - return if self.cursor is not None: with contextlib.suppress(Exception): await self.cursor.close() @@ -104,6 +104,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteRawCursor", "AiosqliteSessionContext") diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index d8012f258..c361e2087 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -12,13 +12,13 @@ from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.extensions.adk import MemoryRecord +__all__ = ("AiosqliteADKMemoryStore", "AiosqliteADKStore") + SECONDS_PER_DAY = 86400.0 JULIAN_EPOCH = 2440587.5 SQLITE_TABLE_NOT_FOUND_ERROR: Final = "no such table" -__all__ = ("AiosqliteADKMemoryStore", "AiosqliteADKStore") - def _datetime_to_julian(dt: datetime) -> float: """Convert datetime to Julian Day number for SQLite storage. diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 25f7fd6fc..1c5c84761 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -14,7 +14,7 @@ AiosqlitePoolConnection, AiosqlitePoolConnectionContext, ) -from sqlspec.adapters.sqlite.type_converter import register_type_handlers +from sqlspec.adapters.aiosqlite.type_converter import register_type_handlers from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory from sqlspec.utils.config_tools import normalize_connection_config diff --git a/sqlspec/adapters/aiosqlite/core.py b/sqlspec/adapters/aiosqlite/core.py index 17c22828b..9b35b7dd4 100644 --- a/sqlspec/adapters/aiosqlite/core.py +++ b/sqlspec/adapters/aiosqlite/core.py @@ -9,10 +9,13 @@ CheckViolationError, DatabaseConnectionError, DataError, + DeadlockError, ForeignKeyViolationError, IntegrityError, NotNullViolationError, OperationalError, + PermissionDeniedError, + QueryTimeoutError, SQLParsingError, SQLSpecError, UniqueViolationError, @@ -53,6 +56,11 @@ SQLITE_CANTOPEN_CODE = 14 SQLITE_IOERR_CODE = 10 SQLITE_MISMATCH_CODE = 20 +SQLITE_BUSY_CODE = 5 +SQLITE_LOCKED_CODE = 6 +SQLITE_INTERRUPT_CODE = 9 +SQLITE_PERM_CODE = 3 +SQLITE_READONLY_CODE = 8 def _bool_to_int(value: bool) -> int: @@ -209,11 +217,29 @@ def create_mapped_exception(error: BaseException) -> SQLSpecError: error_name = None error_msg = str(error).lower() - if "locked" in error_msg: - msg = f"AIOSQLite database locked: {error}. Consider enabling WAL mode or reducing concurrency." - exc = SQLSpecError(msg) - exc.__cause__ = error # pyright: ignore[reportAttributeAccessIssue] - return exc + # Check for busy/locked conditions first (deadlock-like scenarios in SQLite) + # SQLITE_BUSY means another process has the database locked + # SQLITE_LOCKED means another connection has the table/rows locked + if error_code == SQLITE_BUSY_CODE or error_name == "SQLITE_BUSY": + return _create_aiosqlite_error(error, error_code, DeadlockError, "database busy") + if error_code == SQLITE_LOCKED_CODE or error_name == "SQLITE_LOCKED": + return _create_aiosqlite_error(error, error_code, DeadlockError, "database locked") + if "locked" in error_msg or "busy" in error_msg: + return _create_aiosqlite_error(error, error_code or 0, DeadlockError, "database locked") + + # Query interruption (timeout-like behavior) + if error_code == SQLITE_INTERRUPT_CODE or error_name == "SQLITE_INTERRUPT": + return _create_aiosqlite_error(error, error_code, QueryTimeoutError, "query interrupted") + if "interrupt" in error_msg: + return _create_aiosqlite_error(error, error_code or 0, QueryTimeoutError, "query interrupted") + + # Permission errors + if error_code == SQLITE_PERM_CODE or error_name == "SQLITE_PERM": + return _create_aiosqlite_error(error, error_code, PermissionDeniedError, "permission denied") + if error_code == SQLITE_READONLY_CODE or error_name == "SQLITE_READONLY": + return _create_aiosqlite_error(error, error_code, PermissionDeniedError, "database is read-only") + if "permission denied" in error_msg or "readonly" in error_msg: + return _create_aiosqlite_error(error, error_code or 0, PermissionDeniedError, "permission denied") if not error_code: if "unique constraint" in error_msg: @@ -255,9 +281,9 @@ def build_profile() -> "DriverParameterProfile": return DriverParameterProfile( name="AIOSQLite", default_style=ParameterStyle.QMARK, - supported_styles={ParameterStyle.QMARK}, + supported_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, default_execution_style=ParameterStyle.QMARK, - supported_execution_styles={ParameterStyle.QMARK}, + supported_execution_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, has_native_list_expansion=False, preserve_parameter_format=True, needs_static_script_compilation=False, @@ -286,7 +312,10 @@ def build_statement_config( deserializer = json_deserializer or from_json profile = driver_profile return build_statement_config_from_profile( - profile, statement_overrides={"dialect": "sqlite"}, json_serializer=serializer, json_deserializer=deserializer + profile, + statement_overrides={"dialect": "sqlite", "enable_parameter_type_wrapping": False}, + json_serializer=serializer, + json_deserializer=deserializer, ) diff --git a/sqlspec/adapters/aiosqlite/litestar/store.py b/sqlspec/adapters/aiosqlite/litestar/store.py index bb0b21933..421f586bd 100644 --- a/sqlspec/adapters/aiosqlite/litestar/store.py +++ b/sqlspec/adapters/aiosqlite/litestar/store.py @@ -8,12 +8,12 @@ if TYPE_CHECKING: from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +__all__ = ("AiosqliteStore",) + SECONDS_PER_DAY = 86400.0 JULIAN_EPOCH = 2440587.5 -__all__ = ("AiosqliteStore",) - class AiosqliteStore(BaseSQLSpecStore["AiosqliteConfig"]): """SQLite session store using AioSQLite driver. diff --git a/sqlspec/adapters/aiosqlite/pool.py b/sqlspec/adapters/aiosqlite/pool.py index fae8ccf1a..61bbd1d36 100644 --- a/sqlspec/adapters/aiosqlite/pool.py +++ b/sqlspec/adapters/aiosqlite/pool.py @@ -187,6 +187,7 @@ class AiosqliteConnectionPool: "_pool_size", "_queue_instance", "_wal_initialized", + "_wal_initialized_event", "_warmed", ) @@ -230,6 +231,7 @@ def __init__( self._queue_instance: asyncio.Queue[AiosqlitePoolConnection] | None = None self._lock_instance: asyncio.Lock | None = None self._closed_event_instance: asyncio.Event | None = None + self._wal_initialized_event: asyncio.Event | None = None @property def _queue(self) -> "asyncio.Queue[AiosqlitePoolConnection]": @@ -252,6 +254,13 @@ def _closed_event(self) -> asyncio.Event: self._closed_event_instance = asyncio.Event() return self._closed_event_instance + @property + def _wal_ready_event(self) -> asyncio.Event: + """Lazy initialization of WAL-ready event for Python 3.9 compatibility.""" + if self._wal_initialized_event is None: + self._wal_initialized_event = asyncio.Event() + return self._wal_initialized_event + @property def is_closed(self) -> bool: """Check if pool is closed. @@ -395,6 +404,7 @@ async def _create_connection(self) -> AiosqlitePoolConnection: if is_shared_cache: self._wal_initialized = True + self._wal_ready_event.set() except Exception: log_with_context( @@ -652,8 +662,6 @@ async def acquire(self) -> AiosqlitePoolConnection: msg = f"Connection acquisition timed out after {self._connect_timeout}s" raise AiosqliteConnectTimeoutError(msg) from e - if not self._wal_initialized and "cache=shared" in str(self._connection_parameters.get("database", "")): - await asyncio.sleep(0.01) return connection async def release(self, connection: AiosqlitePoolConnection) -> None: diff --git a/sqlspec/adapters/aiosqlite/type_converter.py b/sqlspec/adapters/aiosqlite/type_converter.py new file mode 100644 index 000000000..02c80abd6 --- /dev/null +++ b/sqlspec/adapters/aiosqlite/type_converter.py @@ -0,0 +1,115 @@ +# Keep in sync with sqlspec/adapters/sqlite/type_converter.py +"""SQLite custom type handlers for optional JSON and type conversion support. + +Provides registration functions for SQLite's adapter/converter system to enable +custom type handling. All handlers are optional and must be explicitly enabled +via SqliteDriverFeatures configuration. + +All functions are designed for mypyc compilation using functools.partial +instead of lambdas for adapter registration. +""" + +import json +import sqlite3 +from functools import partial +from typing import TYPE_CHECKING, Any + +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = ("json_adapter", "json_converter", "register_type_handlers", "unregister_type_handlers") + +logger = get_logger(__name__) + +DEFAULT_JSON_TYPE = "JSON" + + +def json_adapter(value: Any, serializer: "Callable[[Any], str] | None" = None) -> str: + """Convert Python dict/list to JSON string for SQLite storage. + + Args: + value: Python dict or list to serialize. + serializer: Optional JSON serializer callable. Defaults to standard json.dumps. + + Returns: + JSON string representation. + """ + if serializer is None: + return json.dumps(value, ensure_ascii=False) + return serializer(value) + + +def json_converter(value: bytes, deserializer: "Callable[[str], Any] | None" = None) -> Any: + """Convert JSON string from SQLite to Python dict/list. + + Args: + value: UTF-8 encoded JSON bytes from SQLite. + deserializer: Optional JSON deserializer callable. Defaults to standard json.loads. + + Returns: + Deserialized Python object (dict or list). + """ + if deserializer is None: + return json.loads(value.decode("utf-8")) + return deserializer(value.decode("utf-8")) + + +def _make_json_adapter(serializer: "Callable[[Any], str] | None") -> "Callable[[Any], str]": + """Create a JSON adapter function with bound serializer. + + This is a module-level factory to avoid lambda closures which are + problematic for mypyc compilation. + + Args: + serializer: Optional JSON serializer callable. + + Returns: + Adapter function ready for sqlite3.register_adapter. + """ + return partial(json_adapter, serializer=serializer) + + +def _make_json_converter(deserializer: "Callable[[str], Any] | None") -> "Callable[[bytes], Any]": + """Create a JSON converter function with bound deserializer. + + This is a module-level factory to avoid lambda closures which are + problematic for mypyc compilation. + + Args: + deserializer: Optional JSON deserializer callable. + + Returns: + Converter function ready for sqlite3.register_converter. + """ + return partial(json_converter, deserializer=deserializer) + + +def register_type_handlers( + json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None +) -> None: + """Register custom type adapters and converters with sqlite3 module. + + This function registers handlers globally for the sqlite3 module. It should be + called once during application initialization if custom type handling is needed. + + Args: + json_serializer: Optional custom JSON serializer (e.g., orjson.dumps). + json_deserializer: Optional custom JSON deserializer (e.g., orjson.loads). + """ + dict_adapter = _make_json_adapter(json_serializer) + list_adapter = _make_json_adapter(json_serializer) + converter = _make_json_converter(json_deserializer) + + sqlite3.register_adapter(dict, dict_adapter) + sqlite3.register_adapter(list, list_adapter) + sqlite3.register_converter(DEFAULT_JSON_TYPE, converter) + + +def unregister_type_handlers() -> None: + """Unregister custom type handlers from sqlite3 module. + + Note: sqlite3 module does not provide an official unregister API, so this + function is a no-op placeholder for API consistency with other adapters. + """ diff --git a/sqlspec/adapters/arrow_odbc/_typing.py b/sqlspec/adapters/arrow_odbc/_typing.py index ee974f4c4..a0d9d4a31 100644 --- a/sqlspec/adapters/arrow_odbc/_typing.py +++ b/sqlspec/adapters/arrow_odbc/_typing.py @@ -21,6 +21,15 @@ ArrowOdbcConnection = _arrow_odbc.Connection ArrowOdbcRawCursor = _arrow_odbc.Connection +__all__ = ( + "ArrowOdbcConnection", + "ArrowOdbcCursor", + "ArrowOdbcError", + "ArrowOdbcRawCursor", + "ArrowOdbcSessionContext", + "arrow_odbc_connect", +) + ArrowOdbcError: "type[Exception]" = getattr(_arrow_odbc, "Error", Exception) @@ -84,13 +93,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "ArrowOdbcConnection", - "ArrowOdbcCursor", - "ArrowOdbcError", - "ArrowOdbcRawCursor", - "ArrowOdbcSessionContext", - "arrow_odbc_connect", -) diff --git a/sqlspec/adapters/arrow_odbc/config.py b/sqlspec/adapters/arrow_odbc/config.py index cad61ff2c..7258fbb7a 100644 --- a/sqlspec/adapters/arrow_odbc/config.py +++ b/sqlspec/adapters/arrow_odbc/config.py @@ -127,16 +127,16 @@ def __init__( **kwargs: Any, ) -> None: """Initialize arrow-odbc configuration.""" - self.connection_config = normalize_connection_config(connection_config) + normalized = normalize_connection_config(connection_config) features = apply_driver_features(dict(driver_features or {})) - connection_string = self.connection_config.get("connection_string") + connection_string = normalized.get("connection_string") if connection_string is not None: features.setdefault("connection_string", str(connection_string)) - elif self.connection_config.get("driver") is not None: - features.setdefault("dbms_name", str(self.connection_config["driver"])) + elif normalized.get("driver") is not None: + features.setdefault("dbms_name", str(normalized["driver"])) super().__init__( - connection_config=self.connection_config, + connection_config=normalized, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config or default_statement_config, diff --git a/sqlspec/adapters/asyncmy/_typing.py b/sqlspec/adapters/asyncmy/_typing.py index f1e601c66..eda10c513 100644 --- a/sqlspec/adapters/asyncmy/_typing.py +++ b/sqlspec/adapters/asyncmy/_typing.py @@ -8,6 +8,8 @@ from asyncmy import Connection # pyright: ignore from asyncmy.cursors import Cursor as _AsyncmyCursor # pyright: ignore +from asyncmy.cursors import DictCursor as _AsyncmyDictCursor # pyright: ignore +from asyncmy.pool import Pool as _AsyncmyPool # pyright: ignore if TYPE_CHECKING: from collections.abc import Callable @@ -27,12 +29,25 @@ async def rollback(self) -> Any: ... async def close(self) -> Any: ... AsyncmyConnection: TypeAlias = AsyncmyConnectionProtocol + AsyncmyDictCursor: TypeAlias = _AsyncmyDictCursor + AsyncmyPool: TypeAlias = _AsyncmyPool AsyncmyRawCursor: TypeAlias = _AsyncmyCursor if not TYPE_CHECKING: AsyncmyConnection = Connection + AsyncmyDictCursor = _AsyncmyDictCursor + AsyncmyPool = _AsyncmyPool AsyncmyRawCursor = _AsyncmyCursor +__all__ = ( + "AsyncmyConnection", + "AsyncmyCursor", + "AsyncmyDictCursor", + "AsyncmyPool", + "AsyncmyRawCursor", + "AsyncmySessionContext", +) + class AsyncmyCursor: """Context manager for AsyncMy cursor operations. @@ -108,6 +123,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("AsyncmyConnection", "AsyncmyCursor", "AsyncmyRawCursor", "AsyncmySessionContext") diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 6a1d40f4b..63c12fbd4 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -4,12 +4,17 @@ from weakref import WeakSet import asyncmy -from asyncmy.cursors import Cursor, DictCursor # pyright: ignore -from asyncmy.pool import Pool as AsyncmyPool # pyright: ignore from mypy_extensions import mypyc_attr from typing_extensions import NotRequired -from sqlspec.adapters.asyncmy._typing import AsyncmyConnection, AsyncmyCursor, AsyncmySessionContext +from sqlspec.adapters.asyncmy._typing import ( + AsyncmyConnection, + AsyncmyCursor, + AsyncmyDictCursor, + AsyncmyPool, + AsyncmyRawCursor, + AsyncmySessionContext, +) from sqlspec.adapters.asyncmy.core import apply_driver_features, default_statement_config from sqlspec.adapters.asyncmy.driver import AsyncmyDriver, AsyncmyExceptionHandler from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs @@ -21,9 +26,6 @@ from collections.abc import Awaitable, Callable from types import TracebackType - from asyncmy.cursors import Cursor, DictCursor # pyright: ignore - from asyncmy.pool import Pool # pyright: ignore - from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -49,7 +51,7 @@ class AsyncmyConnectionParams(TypedDict): ssl: NotRequired[Any] sql_mode: NotRequired[str] init_command: NotRequired[str] - cursor_class: NotRequired[type["Cursor"] | type["DictCursor"]] + cursor_class: NotRequired[type["AsyncmyRawCursor"] | type["AsyncmyDictCursor"]] extra: NotRequired["dict[str, Any]"] @@ -269,7 +271,7 @@ async def create_connection(self) -> AsyncmyConnection: await self._ensure_connection_initialized(connection) return connection - async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": + async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncmyPool": """Provide async pool instance. Returns: @@ -292,11 +294,13 @@ def get_signature_namespace(self) -> "dict[str, Any]": "AsyncmyConnection": AsyncmyConnection, "AsyncmyConnectionParams": AsyncmyConnectionParams, "AsyncmyCursor": AsyncmyCursor, + "AsyncmyDictCursor": AsyncmyDictCursor, "AsyncmyDriver": AsyncmyDriver, "AsyncmyDriverFeatures": AsyncmyDriverFeatures, "AsyncmyExceptionHandler": AsyncmyExceptionHandler, "AsyncmyPool": AsyncmyPool, "AsyncmyPoolParams": AsyncmyPoolParams, + "AsyncmyRawCursor": AsyncmyRawCursor, "AsyncmySessionContext": AsyncmySessionContext, }) return namespace diff --git a/sqlspec/adapters/asyncmy/core.py b/sqlspec/adapters/asyncmy/core.py index 7fd77c206..593cd8d11 100644 --- a/sqlspec/adapters/asyncmy/core.py +++ b/sqlspec/adapters/asyncmy/core.py @@ -168,7 +168,7 @@ def build_profile() -> "DriverParameterProfile": """Create the AsyncMy driver parameter profile.""" coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int, **build_uuid_coercions()} return DriverParameterProfile( - name="AsyncMy", + name="asyncmy", default_style=ParameterStyle.QMARK, supported_styles={ParameterStyle.QMARK}, default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 375e55a7f..fb9573cd4 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -187,6 +187,10 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters is not None and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + successful_count = 0 last_cursor = cursor diff --git a/sqlspec/adapters/asyncmy/litestar/store.py b/sqlspec/adapters/asyncmy/litestar/store.py index 85636ca37..5b4e86ad5 100644 --- a/sqlspec/adapters/asyncmy/litestar/store.py +++ b/sqlspec/adapters/asyncmy/litestar/store.py @@ -9,9 +9,10 @@ if TYPE_CHECKING: from sqlspec.adapters.asyncmy.config import AsyncmyConfig +__all__ = ("AsyncmyStore",) + logger = get_logger("sqlspec.adapters.asyncmy.litestar.store") -__all__ = ("AsyncmyStore",) MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 diff --git a/sqlspec/adapters/asyncpg/_typing.py b/sqlspec/adapters/asyncpg/_typing.py index d2d8b2000..443297842 100644 --- a/sqlspec/adapters/asyncpg/_typing.py +++ b/sqlspec/adapters/asyncpg/_typing.py @@ -29,6 +29,8 @@ AsyncpgPool = Pool AsyncpgPreparedStatement = PreparedStatement +__all__ = ("AsyncpgConnection", "AsyncpgCursor", "AsyncpgPool", "AsyncpgPreparedStatement", "AsyncpgSessionContext") + class AsyncpgCursor: """Context manager for AsyncPG cursor management.""" @@ -98,6 +100,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("AsyncpgConnection", "AsyncpgCursor", "AsyncpgPool", "AsyncpgPreparedStatement", "AsyncpgSessionContext") diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index cbd48a060..27247da9f 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -258,7 +258,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async """ driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver - connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment] + connection_type: "ClassVar[type[AsyncpgConnection]]" = AsyncpgConnection # type: ignore[assignment] supports_transactional_ddl: "ClassVar[bool]" = True supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 2dca9efe2..a17f3e23c 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -1,8 +1,9 @@ """AsyncPG PostgreSQL driver implementation for async PostgreSQL operations.""" +import re from collections import OrderedDict from io import BytesIO -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Final, cast import asyncpg @@ -51,9 +52,21 @@ __all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext") +_COPY_FROM_STDIN_RE: re.Pattern[str] = re.compile( + r'COPY\s+((?:"[^"]+"|\w+)(?:\.(?:"[^"]+"|\w+))?)(?:\s*\([^)]*\))?\s+FROM\s+STDIN', re.IGNORECASE +) +_QUALIFIED_TABLE_NAME_PARTS: Final = 2 + logger = get_logger("sqlspec.adapters.asyncpg") +def _split_copy_table_name(raw_name: str) -> "tuple[str | None, str]": + parts = raw_name.split(".", 1) + if len(parts) == _QUALIFIED_TABLE_NAME_PARTS: + return parts[0].strip('"'), parts[1].strip('"') + return None, raw_name.strip('"') + + class AsyncpgExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling AsyncPG database exceptions. @@ -473,7 +486,8 @@ async def _get_prepared_statement(self, sql: str) -> "AsyncpgPreparedStatement": async def _handle_copy_operation(self, cursor: "AsyncpgConnection", statement: "SQL") -> None: """Handle PostgreSQL COPY operations. - Supports both COPY FROM STDIN and COPY TO STDOUT operations. + Supports COPY FROM STDIN data import and falls through to cursor.execute + for COPY TO STDOUT and file-based COPY variants. Args: cursor: AsyncPG connection object @@ -483,10 +497,9 @@ async def _handle_copy_operation(self, cursor: "AsyncpgConnection", statement: " execution_args = statement.statement_config.execution_args metadata: dict[str, Any] = dict(execution_args) if execution_args else {} sql_text, _ = self._get_compiled_sql(statement, statement.statement_config) - sql_upper = sql_text.upper() copy_data = metadata.get("postgres_copy_data") - if copy_data and is_copy_from_operation(statement.operation_type) and "FROM STDIN" in sql_upper: + if copy_data is not None and is_copy_from_operation(statement.operation_type): if isinstance(copy_data, dict): data_str = ( str(next(iter(copy_data.values()))) @@ -499,7 +512,30 @@ async def _handle_copy_operation(self, cursor: "AsyncpgConnection", statement: " data_str = str(copy_data) data_io = BytesIO(data_str.encode("utf-8")) - await cursor.copy_from_query(sql_text, output=data_io) + table_name = cast("str | None", metadata.get("postgres_copy_table")) + schema_name: str | None = None + + if table_name is None: + match = _COPY_FROM_STDIN_RE.search(sql_text) + if match: + table_name = match.group(1) + + if table_name is None: + msg = "COPY FROM STDIN requires a table name or postgres_copy_table execution argument" + raise SQLSpecError(msg) + + schema_name, table_name = _split_copy_table_name(table_name) + copy_kwargs: dict[str, Any] = {"source": data_io} + if schema_name is not None: + copy_kwargs["schema_name"] = schema_name + if "postgres_copy_columns" in metadata: + copy_kwargs["columns"] = metadata["postgres_copy_columns"] + if "postgres_copy_format" in metadata: + copy_kwargs["format"] = metadata["postgres_copy_format"] + if "postgres_copy_delimiter" in metadata: + copy_kwargs["delimiter"] = metadata["postgres_copy_delimiter"] + + await cursor.copy_to_table(table_name, **copy_kwargs) return await cursor.execute(sql_text) diff --git a/sqlspec/adapters/asyncpg/events/_hub.py b/sqlspec/adapters/asyncpg/events/_hub.py index 613a01c84..95c33a60b 100644 --- a/sqlspec/adapters/asyncpg/events/_hub.py +++ b/sqlspec/adapters/asyncpg/events/_hub.py @@ -24,10 +24,10 @@ from sqlspec.adapters.asyncpg.config import AsyncpgConfig -logger = get_logger("sqlspec.adapters.asyncpg.events.hub") - __all__ = ("AsyncpgListenerHub",) +logger = get_logger("sqlspec.adapters.asyncpg.events.hub") + class AsyncpgListenerHub: """Per-channel persistent listener for asyncpg-based event backends.""" diff --git a/sqlspec/adapters/asyncpg/events/backend.py b/sqlspec/adapters/asyncpg/events/backend.py index 2ccf8d533..a2c16e328 100644 --- a/sqlspec/adapters/asyncpg/events/backend.py +++ b/sqlspec/adapters/asyncpg/events/backend.py @@ -22,10 +22,10 @@ if TYPE_CHECKING: from sqlspec.adapters.asyncpg.config import AsyncpgConfig -logger = get_logger("sqlspec.events.postgres") - __all__ = ("AsyncpgEventsBackend", "AsyncpgHybridEventsBackend", "create_event_backend") +logger = get_logger("sqlspec.events.postgres") + def _extract_event_id(payload: "str | None") -> "str | None": if not payload: diff --git a/sqlspec/adapters/bigquery/_typing.py b/sqlspec/adapters/bigquery/_typing.py index c660e58c7..54db18cdc 100644 --- a/sqlspec/adapters/bigquery/_typing.py +++ b/sqlspec/adapters/bigquery/_typing.py @@ -24,6 +24,8 @@ BigQueryConnection = Client BigQueryParam = ArrayQueryParameter | ScalarQueryParameter +__all__ = ("BigQueryConnection", "BigQueryCursor", "BigQueryParam", "BigQuerySessionContext") + class BigQueryCursor: """BigQuery cursor with resource management.""" @@ -103,6 +105,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("BigQueryConnection", "BigQueryCursor", "BigQueryParam", "BigQuerySessionContext") diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 23d28e591..1bc509d21 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -26,6 +26,8 @@ from sqlspec.core import StatementConfig +__all__ = ("BigQueryConfig", "BigQueryConnectionParams", "BigQueryDriverFeatures") + class BigQueryConnectionParams(TypedDict): """Standard BigQuery connection parameters. @@ -144,9 +146,6 @@ def release_connection(self, _conn: "BigQueryConnection", **kwargs: Any) -> None return None -__all__ = ("BigQueryConfig", "BigQueryConnectionParams", "BigQueryDriverFeatures") - - class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]): """BigQuery configuration. @@ -263,6 +262,7 @@ def _setup_default_job_config(self) -> None: if query_timeout_ms is not None: job_config.job_timeout_ms = query_timeout_ms + # job_timeout_ms intentionally wins when both timeout aliases are configured. job_timeout_ms = self.connection_config.get("job_timeout_ms") if job_timeout_ms is not None: job_config.job_timeout_ms = job_timeout_ms diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index 20bc30dd9..29425ff7b 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -597,7 +597,7 @@ def build_profile() -> "DriverParameterProfile": return DriverParameterProfile( name="BigQuery", default_style=ParameterStyle.NAMED_AT, - supported_styles={ParameterStyle.NAMED_AT, ParameterStyle.QMARK}, + supported_styles={ParameterStyle.NAMED_AT}, default_execution_style=ParameterStyle.NAMED_AT, supported_execution_styles={ParameterStyle.NAMED_AT}, has_native_list_expansion=True, diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index a37d736ff..dd3f1d1ed 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -11,11 +11,11 @@ from sqlspec.driver import SyncDataDictionaryBase from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -__all__ = ("BigQueryDataDictionary",) - if TYPE_CHECKING: from sqlspec.adapters.bigquery.driver import BigQueryDriver +__all__ = ("BigQueryDataDictionary",) + @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class BigQueryDataDictionary(SyncDataDictionaryBase): diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 4aea8bcad..e3a3f542c 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -59,10 +59,10 @@ ) from sqlspec.typing import ArrowReturnFormat, StatementParameters -logger = get_logger(__name__) - __all__ = ("BigQueryCursor", "BigQueryDriver", "BigQueryExceptionHandler", "BigQuerySessionContext") +logger = get_logger(__name__) + class BigQueryExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling BigQuery API exceptions. diff --git a/sqlspec/adapters/cockroach_asyncpg/_typing.py b/sqlspec/adapters/cockroach_asyncpg/_typing.py index ff3944a6d..8a1d8345e 100644 --- a/sqlspec/adapters/cockroach_asyncpg/_typing.py +++ b/sqlspec/adapters/cockroach_asyncpg/_typing.py @@ -22,6 +22,8 @@ CockroachAsyncpgConnection = PoolConnectionProxy CockroachAsyncpgPool = Pool +__all__ = ("CockroachAsyncpgConnection", "CockroachAsyncpgPool", "CockroachAsyncpgSessionContext") + class CockroachAsyncpgSessionContext: """Async context manager for CockroachDB AsyncPG sessions.""" @@ -40,7 +42,7 @@ def __init__( self, acquire_connection: "Callable[[], Any]", release_connection: "Callable[[Any], Any]", - statement_config: "StatementConfig", + statement_config: "StatementConfig | Callable[[], StatementConfig]", driver_features: "dict[str, Any]", prepare_driver: "Callable[[CockroachAsyncpgDriver], CockroachAsyncpgDriver]", ) -> None: @@ -56,8 +58,9 @@ async def __aenter__(self) -> "CockroachAsyncpgDriver": from sqlspec.adapters.cockroach_asyncpg.driver import CockroachAsyncpgDriver self._connection = await self._acquire_connection() + statement_config = self._statement_config() if callable(self._statement_config) else self._statement_config self._driver = CockroachAsyncpgDriver( - connection=self._connection, statement_config=self._statement_config, driver_features=self._driver_features + connection=self._connection, statement_config=statement_config, driver_features=self._driver_features ) return self._prepare_driver(self._driver) @@ -68,6 +71,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("CockroachAsyncpgConnection", "CockroachAsyncpgPool", "CockroachAsyncpgSessionContext") diff --git a/sqlspec/adapters/cockroach_asyncpg/config.py b/sqlspec/adapters/cockroach_asyncpg/config.py index 4547be033..423260604 100644 --- a/sqlspec/adapters/cockroach_asyncpg/config.py +++ b/sqlspec/adapters/cockroach_asyncpg/config.py @@ -10,6 +10,8 @@ build_connection_config, default_statement_config, register_json_codecs, + register_pgvector_support, + resolve_runtime_statement_config, ) from sqlspec.adapters.cockroach_asyncpg._typing import ( CockroachAsyncpgConnection, @@ -154,9 +156,11 @@ def __init__( observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: + raw_enable_pgvector = bool(driver_features and driver_features.get("enable_pgvector") is True) connection_config = normalize_connection_config(connection_config) statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) + driver_features["enable_pgvector"] = raw_enable_pgvector driver_features.setdefault("enable_auto_retry", True) @@ -192,6 +196,9 @@ async def _init_connection(self, connection: "CockroachAsyncpgConnection") -> No decoder=self.driver_features.get("json_deserializer", from_json), ) + if self.driver_features.get("enable_pgvector") is True: + await register_pgvector_support(connection) + if self._user_connection_hook is not None: await self._user_connection_hook(connection) @@ -224,7 +231,8 @@ def provide_session( return CockroachAsyncpgSessionContext( acquire_connection=factory.acquire_connection, release_connection=factory.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, + statement_config=statement_config + or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=driver_features, prepare_driver=self._prepare_driver, ) diff --git a/sqlspec/adapters/cockroach_asyncpg/driver.py b/sqlspec/adapters/cockroach_asyncpg/driver.py index 149687158..e52706f00 100644 --- a/sqlspec/adapters/cockroach_asyncpg/driver.py +++ b/sqlspec/adapters/cockroach_asyncpg/driver.py @@ -106,12 +106,12 @@ async def _dispatch_execute_impl(self, cursor: "CockroachAsyncpgConnection", sta async def _dispatch_execute_many_impl( self, cursor: "CockroachAsyncpgConnection", statement: SQL ) -> "ExecutionResult": - return await super().dispatch_execute_many(cursor, statement) + return await AsyncpgDriver.dispatch_execute_many(self, cursor, statement) async def _dispatch_execute_script_impl( self, cursor: "CockroachAsyncpgConnection", statement: SQL ) -> "ExecutionResult": - return await super().dispatch_execute_script(cursor, statement) + return await AsyncpgDriver.dispatch_execute_script(self, cursor, statement) async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: @@ -120,12 +120,12 @@ async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResul async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return await super().dispatch_execute_many(cursor, statement) + return await self._dispatch_execute_many_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return await super().dispatch_execute_script(cursor, statement) + return await self._dispatch_execute_script_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) def handle_database_exceptions(self) -> "CockroachAsyncpgExceptionHandler": # type: ignore[override] diff --git a/sqlspec/adapters/cockroach_psycopg/_typing.py b/sqlspec/adapters/cockroach_psycopg/_typing.py index 263bc6c07..3d2c714f2 100644 --- a/sqlspec/adapters/cockroach_psycopg/_typing.py +++ b/sqlspec/adapters/cockroach_psycopg/_typing.py @@ -31,6 +31,16 @@ CockroachSyncCursor = Cursor CockroachAsyncCursor = AsyncCursor +__all__ = ( + "CockroachAsyncConnection", + "CockroachAsyncCursor", + "CockroachPsycopgAsyncSessionContext", + "CockroachPsycopgSyncSessionContext", + "CockroachSyncConnection", + "CockroachSyncCursor", + "PsycopgDictRow", +) + class CockroachPsycopgSyncSessionContext: """Sync context manager for CockroachDB psycopg sessions.""" @@ -124,14 +134,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "CockroachAsyncConnection", - "CockroachAsyncCursor", - "CockroachPsycopgAsyncSessionContext", - "CockroachPsycopgSyncSessionContext", - "CockroachSyncConnection", - "CockroachSyncCursor", - "PsycopgDictRow", -) diff --git a/sqlspec/adapters/cockroach_psycopg/config.py b/sqlspec/adapters/cockroach_psycopg/config.py index 3b6644187..1a56289af 100644 --- a/sqlspec/adapters/cockroach_psycopg/config.py +++ b/sqlspec/adapters/cockroach_psycopg/config.py @@ -487,7 +487,7 @@ def provide_session( return CockroachPsycopgAsyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or build_statement_config(), + statement_config=statement_config or self.statement_config or default_statement_config, driver_features=driver_features, prepare_driver=self._prepare_driver, ) diff --git a/sqlspec/adapters/cockroach_psycopg/driver.py b/sqlspec/adapters/cockroach_psycopg/driver.py index 8a9f09f6b..509b34d83 100644 --- a/sqlspec/adapters/cockroach_psycopg/driver.py +++ b/sqlspec/adapters/cockroach_psycopg/driver.py @@ -146,10 +146,10 @@ def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResul return super().dispatch_execute(cursor, statement) def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": - return super().dispatch_execute_many(cursor, statement) + return PsycopgSyncDriver.dispatch_execute_many(self, cursor, statement) def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": - return super().dispatch_execute_script(cursor, statement) + return PsycopgSyncDriver.dispatch_execute_script(self, cursor, statement) def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: @@ -158,12 +158,12 @@ def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return super().dispatch_execute_many(cursor, statement) + return self._dispatch_execute_many_impl(cursor, statement) return self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return super().dispatch_execute_script(cursor, statement) + return self._dispatch_execute_script_impl(cursor, statement) return self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) def handle_database_exceptions(self) -> "CockroachPsycopgSyncExceptionHandler": # type: ignore[override] @@ -173,7 +173,7 @@ def handle_database_exceptions(self) -> "CockroachPsycopgSyncExceptionHandler": def data_dictionary(self) -> "CockroachPsycopgSyncDataDictionary": # type: ignore[override] if self._data_dictionary is None: # Intentionally assign CockroachDB-specific data dictionary to parent slot - object.__setattr__(self, "_data_dictionary", CockroachPsycopgSyncDataDictionary()) + self._data_dictionary = CockroachPsycopgSyncDataDictionary() return cast("CockroachPsycopgSyncDataDictionary", self._data_dictionary) @@ -239,10 +239,10 @@ async def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "Executio return await super().dispatch_execute(cursor, statement) async def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": - return await super().dispatch_execute_many(cursor, statement) + return await PsycopgAsyncDriver.dispatch_execute_many(self, cursor, statement) async def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": - return await super().dispatch_execute_script(cursor, statement) + return await PsycopgAsyncDriver.dispatch_execute_script(self, cursor, statement) async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: @@ -251,12 +251,12 @@ async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResul async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return await super().dispatch_execute_many(cursor, statement) + return await self._dispatch_execute_many_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: - return await super().dispatch_execute_script(cursor, statement) + return await self._dispatch_execute_script_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) def handle_database_exceptions(self) -> "CockroachPsycopgAsyncExceptionHandler": # type: ignore[override] @@ -266,7 +266,7 @@ def handle_database_exceptions(self) -> "CockroachPsycopgAsyncExceptionHandler": def data_dictionary(self) -> "CockroachPsycopgAsyncDataDictionary": # type: ignore[override] if self._data_dictionary is None: # Intentionally assign CockroachDB-specific data dictionary to parent slot - object.__setattr__(self, "_data_dictionary", CockroachPsycopgAsyncDataDictionary()) + self._data_dictionary = CockroachPsycopgAsyncDataDictionary() return cast("CockroachPsycopgAsyncDataDictionary", self._data_dictionary) diff --git a/sqlspec/adapters/duckdb/_typing.py b/sqlspec/adapters/duckdb/_typing.py index 78816debd..33d84d09f 100644 --- a/sqlspec/adapters/duckdb/_typing.py +++ b/sqlspec/adapters/duckdb/_typing.py @@ -23,6 +23,8 @@ if not TYPE_CHECKING: DuckDBConnection = _DuckDBConnection +__all__ = ("DuckDBConnection", "DuckDBCursor", "DuckDBSessionContext") + class DuckDBCursor: """Context manager for DuckDB connection-as-cursor. @@ -99,6 +101,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("DuckDBConnection", "DuckDBCursor", "DuckDBSessionContext") diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index 969e41d41..644372390 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -7,11 +7,11 @@ from sqlspec.driver import SyncDataDictionaryBase from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -__all__ = ("DuckDBDataDictionary",) - if TYPE_CHECKING: from sqlspec.adapters.duckdb.driver import DuckDBDriver +__all__ = ("DuckDBDataDictionary",) + @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class DuckDBDataDictionary(SyncDataDictionaryBase): diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 626c9355c..ad16c3292 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -85,8 +85,7 @@ def __init__( statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) - - statement_config = apply_driver_features(statement_config, driver_features) + statement_config = apply_driver_features(statement_config, driver_features) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: DuckDBDataDictionary | None = None diff --git a/sqlspec/adapters/duckdb/pool.py b/sqlspec/adapters/duckdb/pool.py index 5ec288fee..a28babbbc 100644 --- a/sqlspec/adapters/duckdb/pool.py +++ b/sqlspec/adapters/duckdb/pool.py @@ -1,6 +1,7 @@ """DuckDB connection pool with thread-local connections.""" import logging +import re import threading import time import uuid @@ -15,6 +16,21 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator +__all__ = ("DuckDBConnectionPool",) + + +_SQL_IDENTIFIER_RE: "Final[re.Pattern[str]]" = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") + + +def _validate_sql_identifier(value: str, field_name: str) -> None: + """Raise ValueError if value is not safe to interpolate as a SQL identifier.""" + if not _SQL_IDENTIFIER_RE.fullmatch(value): + msg = ( + f"Invalid SQL identifier for {field_name!r}: {value!r}. " + "Must start with a letter and contain only letters, digits, and underscores." + ) + raise ValueError(msg) + logger = get_logger(POOL_LOGGER_NAME) _ADAPTER_NAME = "duckdb" @@ -25,8 +41,6 @@ POOL_RECYCLE: Final[int] = 86400 HEALTH_CHECK_INTERVAL: Final[float] = 30.0 -__all__ = ("DuckDBConnectionPool",) - class DuckDBConnectionPool: """Thread-local connection manager for DuckDB. @@ -41,8 +55,6 @@ class DuckDBConnectionPool: __slots__ = ( "_connection_config", - "_connection_times", - "_created_connections", "_extension_flags", "_extensions", "_health_check_interval", @@ -85,8 +97,6 @@ def __init__( self._on_connection_create = on_connection_create self._thread_local = threading.local() self._lock = threading.RLock() - self._created_connections = 0 - self._connection_times: dict[int, float] = {} self._pool_id = str(uuid.uuid4())[:8] # Track if this pool uses an in-memory database # In-memory databases require connections to stay alive to preserve data @@ -156,6 +166,9 @@ def _create_connection(self) -> DuckDBConnection: if not (secret_type and secret_name and secret_value): continue + _validate_sql_identifier(secret_name, "secret_name") + _validate_sql_identifier(secret_type, "secret_type") + value_pairs = [] for key, value in secret_value.items(): escaped_value = str(value).replace("'", "''") @@ -178,11 +191,6 @@ def _create_connection(self) -> DuckDBConnection: with suppress(Exception): self._on_connection_create(connection) - conn_id = id(connection) - with self._lock: - self._created_connections += 1 - self._connection_times[conn_id] = time.time() - return connection def _apply_extension_flags(self, connection: DuckDBConnection) -> None: diff --git a/sqlspec/adapters/mssql_python/_typing.py b/sqlspec/adapters/mssql_python/_typing.py index c4b230d2d..457b668f3 100644 --- a/sqlspec/adapters/mssql_python/_typing.py +++ b/sqlspec/adapters/mssql_python/_typing.py @@ -24,6 +24,16 @@ MssqlPythonConnection = Connection MssqlPythonRawCursor = Cursor +__all__ = ( + "MSSQL_PYTHON_MODULE", + "MssqlPythonAsyncCursor", + "MssqlPythonAsyncSessionContext", + "MssqlPythonConnection", + "MssqlPythonCursor", + "MssqlPythonRawCursor", + "MssqlPythonSessionContext", +) + class MssqlPythonCursor: """Sync context manager for mssql-python cursor management.""" @@ -153,14 +163,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "MSSQL_PYTHON_MODULE", - "MssqlPythonAsyncCursor", - "MssqlPythonAsyncSessionContext", - "MssqlPythonConnection", - "MssqlPythonCursor", - "MssqlPythonRawCursor", - "MssqlPythonSessionContext", -) diff --git a/sqlspec/adapters/mssql_python/config.py b/sqlspec/adapters/mssql_python/config.py index f4f2b62ac..d460c4895 100644 --- a/sqlspec/adapters/mssql_python/config.py +++ b/sqlspec/adapters/mssql_python/config.py @@ -1,6 +1,7 @@ """mssql-python database configuration.""" import asyncio +import warnings from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from typing_extensions import NotRequired @@ -35,6 +36,8 @@ "MssqlPythonPoolParams", ) +_POOLING_PARAMS: "tuple[int, int, bool] | None" = None + class MssqlPythonConnectionParams(TypedDict): """mssql-python connection parameters.""" @@ -111,7 +114,16 @@ def __init__( self.enabled = enabled self.on_connection_create = on_connection_create self._closed = False + global _POOLING_PARAMS + new_params = (max_size, idle_timeout, enabled) + if _POOLING_PARAMS is not None and new_params != _POOLING_PARAMS: + warnings.warn( + f"mssql-python pooling config already set to {_POOLING_PARAMS}; " + f"overwriting with {new_params}. Only one pool config per process is supported.", + stacklevel=2, + ) MSSQL_PYTHON_MODULE.pooling(max_size=max_size, idle_timeout=idle_timeout, enabled=enabled) + _POOLING_PARAMS = new_params def acquire(self) -> "MssqlPythonConnection": if self._closed: diff --git a/sqlspec/adapters/mssql_python/core.py b/sqlspec/adapters/mssql_python/core.py index 88f085542..693b6eb7c 100644 --- a/sqlspec/adapters/mssql_python/core.py +++ b/sqlspec/adapters/mssql_python/core.py @@ -124,14 +124,19 @@ def build_connection_config(params: dict[str, Any]) -> tuple[str, dict[str, Any] parts: list[str] = [] consumed: set[str] = set() + seen_odbc_keys: set[str] = set() for key, option_name in _CONNECTION_STRING_KEYS: if key not in config: continue value = config[key] if value is None: continue + if option_name in seen_odbc_keys: + consumed.add(key) + continue parts.append(f"{option_name}={_format_connection_value(value)}") consumed.add(key) + seen_odbc_keys.add(option_name) extra = config.get("extra") if isinstance(extra, dict): diff --git a/sqlspec/adapters/mssql_python/data_dictionary.py b/sqlspec/adapters/mssql_python/data_dictionary.py index 9041a6b52..ff8d54134 100644 --- a/sqlspec/adapters/mssql_python/data_dictionary.py +++ b/sqlspec/adapters/mssql_python/data_dictionary.py @@ -22,10 +22,10 @@ if TYPE_CHECKING: from sqlspec.data_dictionary._types import DialectConfig -logger = get_logger("sqlspec.adapters.mssql_python.data_dictionary") - __all__ = ("MssqlPythonAsyncDataDictionary", "MssqlPythonSyncDataDictionary", "MssqlVersionInfo") +logger = get_logger("sqlspec.adapters.mssql_python.data_dictionary") + class MssqlVersionInfo(VersionInfo): """MSSQL database version info with build, revision, and Azure SQL detection.""" diff --git a/sqlspec/adapters/mysqlconnector/_typing.py b/sqlspec/adapters/mysqlconnector/_typing.py index d90a6723f..3f570fdd0 100644 --- a/sqlspec/adapters/mysqlconnector/_typing.py +++ b/sqlspec/adapters/mysqlconnector/_typing.py @@ -43,6 +43,17 @@ async def close(self) -> Any: ... MysqlConnectorSyncRawCursor = _MysqlConnectorSyncRawCursor MysqlConnectorAsyncRawCursor = _MysqlConnectorAsyncRawCursor +__all__ = ( + "MysqlConnectorAsyncConnection", + "MysqlConnectorAsyncCursor", + "MysqlConnectorAsyncRawCursor", + "MysqlConnectorAsyncSessionContext", + "MysqlConnectorSyncConnection", + "MysqlConnectorSyncCursor", + "MysqlConnectorSyncRawCursor", + "MysqlConnectorSyncSessionContext", +) + class MysqlConnectorSyncCursor: """Context manager for mysql-connector sync cursor operations.""" @@ -172,15 +183,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "MysqlConnectorAsyncConnection", - "MysqlConnectorAsyncCursor", - "MysqlConnectorAsyncRawCursor", - "MysqlConnectorAsyncSessionContext", - "MysqlConnectorSyncConnection", - "MysqlConnectorSyncCursor", - "MysqlConnectorSyncRawCursor", - "MysqlConnectorSyncSessionContext", -) diff --git a/sqlspec/adapters/mysqlconnector/driver.py b/sqlspec/adapters/mysqlconnector/driver.py index f08178690..a377d3041 100644 --- a/sqlspec/adapters/mysqlconnector/driver.py +++ b/sqlspec/adapters/mysqlconnector/driver.py @@ -151,6 +151,10 @@ def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionRe sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters is not None and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + successful_count = 0 last_cursor = cursor @@ -370,6 +374,10 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters is not None and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + successful_count = 0 last_cursor = cursor diff --git a/sqlspec/adapters/mysqlconnector/litestar/store.py b/sqlspec/adapters/mysqlconnector/litestar/store.py index f7f8fedf6..f9a0db8ee 100644 --- a/sqlspec/adapters/mysqlconnector/litestar/store.py +++ b/sqlspec/adapters/mysqlconnector/litestar/store.py @@ -12,9 +12,10 @@ if TYPE_CHECKING: from sqlspec.adapters.mysqlconnector.config import MysqlConnectorAsyncConfig, MysqlConnectorSyncConfig +__all__ = ("MysqlConnectorAsyncStore", "MysqlConnectorSyncStore") + logger = get_logger("sqlspec.adapters.mysqlconnector.litestar.store") -__all__ = ("MysqlConnectorAsyncStore", "MysqlConnectorSyncStore") MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 diff --git a/sqlspec/adapters/oracledb/_typing.py b/sqlspec/adapters/oracledb/_typing.py index b9adff08e..239f18621 100644 --- a/sqlspec/adapters/oracledb/_typing.py +++ b/sqlspec/adapters/oracledb/_typing.py @@ -56,6 +56,31 @@ OracleSyncRawCursor = Cursor OracleAsyncRawCursor = AsyncCursor +__all__ = ( + "AQMSG_INVISIBLE", + "AQMSG_PAYLOAD_TYPE_JSON", + "AQMSG_VISIBLE", + "DB_TYPE_BLOB", + "DB_TYPE_CLOB", + "DB_TYPE_JSON", + "DB_TYPE_RAW", + "DB_TYPE_VECTOR", + "AQDequeueOptions", + "DatabaseError", + "OracleAsyncConnection", + "OracleAsyncConnectionPool", + "OracleAsyncCursor", + "OracleAsyncRawCursor", + "OracleAsyncSessionContext", + "OraclePipelineDriver", + "OracleSyncConnection", + "OracleSyncConnectionPool", + "OracleSyncCursor", + "OracleSyncRawCursor", + "OracleSyncSessionContext", + "OracleVectorType", +) + AQDequeueOptions: Any | None = getattr(_oracledb, "AQDequeueOptions", None) AQMSG_VISIBLE: int | None = getattr(_oracledb, "AQMSG_VISIBLE", None) AQMSG_INVISIBLE: int | None = getattr(_oracledb, "AQMSG_INVISIBLE", None) @@ -230,29 +255,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "AQMSG_INVISIBLE", - "AQMSG_PAYLOAD_TYPE_JSON", - "AQMSG_VISIBLE", - "DB_TYPE_BLOB", - "DB_TYPE_CLOB", - "DB_TYPE_JSON", - "DB_TYPE_RAW", - "DB_TYPE_VECTOR", - "AQDequeueOptions", - "DatabaseError", - "OracleAsyncConnection", - "OracleAsyncConnectionPool", - "OracleAsyncCursor", - "OracleAsyncRawCursor", - "OracleAsyncSessionContext", - "OraclePipelineDriver", - "OracleSyncConnection", - "OracleSyncConnectionPool", - "OracleSyncCursor", - "OracleSyncRawCursor", - "OracleSyncSessionContext", - "OracleVectorType", -) diff --git a/sqlspec/adapters/oracledb/_vector_handlers.py b/sqlspec/adapters/oracledb/_vector_handlers.py index f4c8eec6e..55b6e3e65 100644 --- a/sqlspec/adapters/oracledb/_vector_handlers.py +++ b/sqlspec/adapters/oracledb/_vector_handlers.py @@ -144,6 +144,10 @@ def _pack_python_sequence(value: "list[Any] | tuple[Any, ...]") -> "array.array[ return array.array(_TYPECODE_FLOAT32, [float(v) for v in value]) +def _pack_sequence_converter(v: Any) -> "array.array[Any]": + return _pack_python_sequence(v) + + def _input_type_handler(cursor: "Cursor | AsyncCursor", value: Any, arraysize: int) -> Any: """Oracle input type handler for vector payloads. @@ -168,8 +172,7 @@ def _input_type_handler(cursor: "Cursor | AsyncCursor", value: Any, arraysize: i if isinstance(value, array.array): return cursor.var(DB_TYPE_VECTOR, arraysize=arraysize) - packed = _pack_python_sequence(value) - return cursor.var(DB_TYPE_VECTOR, arraysize=arraysize, inconverter=lambda _v: packed) + return cursor.var(DB_TYPE_VECTOR, arraysize=arraysize, inconverter=_pack_sequence_converter) def _output_type_handler(cursor: "Cursor | AsyncCursor", metadata: Any) -> Any: diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 3d1dd7243..bd7603c7a 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -25,8 +25,6 @@ from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig from sqlspec.extensions.adk import MemoryRecord -logger = get_logger("sqlspec.adapters.oracledb.adk.store") - __all__ = ( "JSONStorageType", "OracleAsyncADKMemoryStore", @@ -37,6 +35,9 @@ "storage_type_from_version", ) +logger = get_logger("sqlspec.adapters.oracledb.adk.store") + + ORACLE_TABLE_NOT_FOUND_ERROR: Final = 942 ORACLE_MIN_JSON_NATIVE_VERSION: Final = 21 ORACLE_MIN_JSON_NATIVE_COMPATIBLE: Final = 20 diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 509f39c6e..e71c07f1e 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -25,10 +25,10 @@ from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver from sqlspec.data_dictionary._types import DialectConfig -logger = get_logger("sqlspec.adapters.oracledb.data_dictionary") - __all__ = ("OracleVersionInfo", "OracledbAsyncDataDictionary", "OracledbSyncDataDictionary") +logger = get_logger("sqlspec.adapters.oracledb.data_dictionary") + class OracleVersionInfo(VersionInfo): """Oracle database version information.""" diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index fe3a56261..6aa59dc2f 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -67,13 +67,6 @@ from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry from sqlspec.typing import ArrowReturnFormat, StatementParameters, VersionInfo - -logger = get_logger(__name__) - -# Oracle SQL-context byte thresholds (4000 / 2000) live in driver_features so users -# on MAX_STRING_SIZE=EXTENDED databases can override them; defaults are wired in -# core.apply_driver_features and read at the dispatch_execute call sites below. - __all__ = ( "OracleAsyncDriver", "OracleAsyncExceptionHandler", @@ -83,6 +76,14 @@ "OracleSyncSessionContext", ) + +logger = get_logger(__name__) + +# Oracle SQL-context byte thresholds (4000 / 2000) live in driver_features so users +# on MAX_STRING_SIZE=EXTENDED databases can override them; defaults are wired in +# core.apply_driver_features and read at the dispatch_execute call sites below. + + PIPELINE_MIN_DRIVER_VERSION: "tuple[int, int, int]" = (2, 4, 0) PIPELINE_MIN_DATABASE_MAJOR: int = 23 @@ -335,7 +336,9 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": cursor.execute(sql, prepared_parameters or {}) # SELECT result processing for Oracle - if statement.returns_rows(): + is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) + + if is_select_like: fetched_data = cursor.fetchall() column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( diff --git a/sqlspec/adapters/oracledb/events/_hub.py b/sqlspec/adapters/oracledb/events/_hub.py index a7b1918ce..743d0e991 100644 --- a/sqlspec/adapters/oracledb/events/_hub.py +++ b/sqlspec/adapters/oracledb/events/_hub.py @@ -21,10 +21,10 @@ if TYPE_CHECKING: from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig -logger = get_logger("sqlspec.adapters.oracledb.events.hub") - __all__ = ("OracleAsyncAQHub", "OracleSyncAQHub") +logger = get_logger("sqlspec.adapters.oracledb.events.hub") + _AQ_AVAILABLE = False OracleDatabaseError: Any diff --git a/sqlspec/adapters/oracledb/events/backend.py b/sqlspec/adapters/oracledb/events/backend.py index b4658475f..13c78b15f 100644 --- a/sqlspec/adapters/oracledb/events/backend.py +++ b/sqlspec/adapters/oracledb/events/backend.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig +__all__ = ("OracleAsyncAQEventBackend", "OracleSyncAQEventBackend", "create_event_backend") + _ORACLEDB_AVAILABLE = False try: # pragma: no cover - optional dependency path @@ -39,7 +41,6 @@ logger = get_logger("sqlspec.events.oracle") -__all__ = ("OracleAsyncAQEventBackend", "OracleSyncAQEventBackend", "create_event_backend") _DEFAULT_QUEUE_NAME = "SQLSPEC_EVENTS_QUEUE" _DEFAULT_VISIBILITY: "int | None" diff --git a/sqlspec/adapters/oracledb/litestar/store.py b/sqlspec/adapters/oracledb/litestar/store.py index 92d6a7ad0..7629f00f8 100644 --- a/sqlspec/adapters/oracledb/litestar/store.py +++ b/sqlspec/adapters/oracledb/litestar/store.py @@ -10,10 +10,10 @@ if TYPE_CHECKING: from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig +__all__ = ("OracleAsyncStore", "OracleSyncStore") -ORACLE_SMALL_BLOB_LIMIT = 32000 -__all__ = ("OracleAsyncStore", "OracleSyncStore") +ORACLE_SMALL_BLOB_LIMIT = 32000 def _coerce_bytes_payload(value: object) -> bytes: diff --git a/sqlspec/adapters/psqlpy/_typing.py b/sqlspec/adapters/psqlpy/_typing.py index fd272ec9c..70424a7eb 100644 --- a/sqlspec/adapters/psqlpy/_typing.py +++ b/sqlspec/adapters/psqlpy/_typing.py @@ -24,6 +24,8 @@ PsqlpyConnection = _PsqlpyConnection PsqlpyListener = _PsqlpyListener +__all__ = ("PsqlpyConnection", "PsqlpyCursor", "PsqlpyListener", "PsqlpySessionContext") + class PsqlpyCursor: """Context manager for psqlpy cursor management.""" @@ -116,6 +118,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("PsqlpyConnection", "PsqlpyCursor", "PsqlpyListener", "PsqlpySessionContext") diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 2c3caa440..e9fce7b1a 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -28,6 +28,8 @@ from sqlspec.core import StatementConfig +__all__ = ("PsqlpyConfig", "PsqlpyConnectionParams", "PsqlpyCursor", "PsqlpyDriverFeatures", "PsqlpyPoolParams") + class PsqlpyConnectionParams(TypedDict): """Psqlpy connection parameters.""" @@ -144,9 +146,6 @@ async def release_connection(self, _conn: "PsqlpyConnection", **kwargs: Any) -> self._ctx = None -__all__ = ("PsqlpyConfig", "PsqlpyConnectionParams", "PsqlpyCursor", "PsqlpyDriverFeatures", "PsqlpyPoolParams") - - class PsqlpyConnectionContext(AsyncPoolConnectionContext): """Async context manager for Psqlpy connections.""" diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 8d3d30dae..131d2b1e8 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -180,12 +180,17 @@ async def dispatch_execute_script(self, cursor: "PsqlpyConnection", statement: S prepared_parameters = cast("Sequence[Any] | Mapping[str, Any] | None", prepared_parameters) statement_config = statement.statement_config statements = self.split_script_statements(sql, statement_config, strip_trailing_semicolon=True) + if len(statements) > 1 and prepared_parameters: + msg = ( + "Parameterized multi-statement scripts are not supported; use execute_many or individual execute calls" + ) + raise SQLSpecError(msg) successful_count = 0 last_result = None for stmt in statements: - last_result = await cursor.execute(stmt, prepared_parameters or []) + last_result = await cursor.execute(stmt, []) successful_count += 1 return self.create_execution_result( diff --git a/sqlspec/adapters/psqlpy/events/_hub.py b/sqlspec/adapters/psqlpy/events/_hub.py index fb31a4f87..7df57892f 100644 --- a/sqlspec/adapters/psqlpy/events/_hub.py +++ b/sqlspec/adapters/psqlpy/events/_hub.py @@ -25,9 +25,10 @@ from sqlspec.adapters.psqlpy._typing import PsqlpyListener from sqlspec.adapters.psqlpy.config import PsqlpyConfig +__all__ = ("PsqlpyListenerHub",) + logger = get_logger("sqlspec.adapters.psqlpy.events.hub") -__all__ = ("PsqlpyListenerHub",) _PSQLPY_CALLBACK_MIN_ARGS = 3 diff --git a/sqlspec/adapters/psqlpy/events/backend.py b/sqlspec/adapters/psqlpy/events/backend.py index 7c74ab80b..206e6704a 100644 --- a/sqlspec/adapters/psqlpy/events/backend.py +++ b/sqlspec/adapters/psqlpy/events/backend.py @@ -23,10 +23,10 @@ if TYPE_CHECKING: from sqlspec.adapters.psqlpy.config import PsqlpyConfig +__all__ = ("PsqlpyEventsBackend", "PsqlpyHybridEventsBackend", "create_event_backend") -logger = get_logger("sqlspec.events.psqlpy") -__all__ = ("PsqlpyEventsBackend", "PsqlpyHybridEventsBackend", "create_event_backend") +logger = get_logger("sqlspec.events.psqlpy") def _extract_event_id(payload: "str | None") -> "str | None": diff --git a/sqlspec/adapters/psycopg/_typing.py b/sqlspec/adapters/psycopg/_typing.py index 508f59eb9..59a235b90 100644 --- a/sqlspec/adapters/psycopg/_typing.py +++ b/sqlspec/adapters/psycopg/_typing.py @@ -32,6 +32,22 @@ PsycopgSyncRawCursor = Cursor PsycopgAsyncRawCursor = AsyncCursor +__all__ = ( + "PsycopgAsyncConnection", + "PsycopgAsyncCursor", + "PsycopgAsyncRawCursor", + "PsycopgAsyncSessionContext", + "PsycopgComposed", + "PsycopgDictRow", + "PsycopgIdentifier", + "PsycopgPipelineDriver", + "PsycopgSQL", + "PsycopgSyncConnection", + "PsycopgSyncCursor", + "PsycopgSyncRawCursor", + "PsycopgSyncSessionContext", +) + class PsycopgSyncCursor: """Context manager for PostgreSQL psycopg cursor management.""" @@ -199,20 +215,3 @@ async def __aexit__( await self._release_connection(self._connection) self._connection = None return None - - -__all__ = ( - "PsycopgAsyncConnection", - "PsycopgAsyncCursor", - "PsycopgAsyncRawCursor", - "PsycopgAsyncSessionContext", - "PsycopgComposed", - "PsycopgDictRow", - "PsycopgIdentifier", - "PsycopgPipelineDriver", - "PsycopgSQL", - "PsycopgSyncConnection", - "PsycopgSyncCursor", - "PsycopgSyncRawCursor", - "PsycopgSyncSessionContext", -) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 82edee4a5..b94433f0d 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -41,6 +41,16 @@ from sqlspec.core import StatementConfig +__all__ = ( + "PsycopgAsyncConfig", + "PsycopgAsyncCursor", + "PsycopgConnectionParams", + "PsycopgDriverFeatures", + "PsycopgPoolParams", + "PsycopgSyncConfig", + "PsycopgSyncCursor", +) + class PsycopgConnectionParams(TypedDict): """Psycopg connection parameters.""" @@ -120,17 +130,6 @@ class PsycopgDriverFeatures(TypedDict): events_backend: NotRequired[str] -__all__ = ( - "PsycopgAsyncConfig", - "PsycopgAsyncCursor", - "PsycopgConnectionParams", - "PsycopgDriverFeatures", - "PsycopgPoolParams", - "PsycopgSyncConfig", - "PsycopgSyncCursor", -) - - class PsycopgSyncConnectionContext(SyncPoolConnectionContext): """Context manager for Psycopg connections.""" diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 62366348b..c4d392045 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -238,6 +238,11 @@ def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionRe """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if len(statements) > 1 and prepared_parameters: + msg = ( + "Parameterized multi-statement scripts are not supported; use execute_many or individual execute calls" + ) + raise SQLSpecError(msg) successful_count = 0 last_cursor = cursor @@ -715,6 +720,11 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if len(statements) > 1 and prepared_parameters: + msg = ( + "Parameterized multi-statement scripts are not supported; use execute_many or individual execute calls" + ) + raise SQLSpecError(msg) successful_count = 0 last_cursor = cursor @@ -742,13 +752,12 @@ async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQL return None sql, _ = self._get_compiled_sql(statement, statement.statement_config) - sql_upper = sql.upper() operation_type = statement.operation_type copy_data = statement.parameters if isinstance(copy_data, list) and len(copy_data) == 1: copy_data = copy_data[0] - if is_copy_from_operation(operation_type) and "FROM STDIN" in sql_upper: + if is_copy_from_operation(operation_type): if isinstance(copy_data, (str, bytes)): data_to_write = copy_data elif is_readable(copy_data): @@ -768,7 +777,7 @@ async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQL data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"} ) - if is_copy_to_operation(operation_type) and "TO STDOUT" in sql_upper: + if is_copy_to_operation(operation_type): output_data: list[str] = [] async with cursor.copy(sql) as copy_ctx: output_data.extend([row.decode() if isinstance(row, bytes) else str(row) async for row in copy_ctx]) diff --git a/sqlspec/adapters/psycopg/events/_hub.py b/sqlspec/adapters/psycopg/events/_hub.py index 0a8f69c70..35d154d9c 100644 --- a/sqlspec/adapters/psycopg/events/_hub.py +++ b/sqlspec/adapters/psycopg/events/_hub.py @@ -23,9 +23,10 @@ if TYPE_CHECKING: from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig +__all__ = ("PsycopgAsyncListenerHub", "PsycopgSyncListenerHub") + logger = get_logger("sqlspec.adapters.psycopg.events.hub") -__all__ = ("PsycopgAsyncListenerHub", "PsycopgSyncListenerHub") _PUMP_TIMEOUT = 0.05 _PUMP_BATCH = 32 diff --git a/sqlspec/adapters/psycopg/events/backend.py b/sqlspec/adapters/psycopg/events/backend.py index a257bc99e..f07ff170e 100644 --- a/sqlspec/adapters/psycopg/events/backend.py +++ b/sqlspec/adapters/psycopg/events/backend.py @@ -23,8 +23,6 @@ if TYPE_CHECKING: from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig -logger = get_logger("sqlspec.events.psycopg") - __all__ = ( "PsycopgAsyncEventsBackend", "PsycopgAsyncHybridEventsBackend", @@ -33,6 +31,8 @@ "create_event_backend", ) +logger = get_logger("sqlspec.events.psycopg") + def _extract_event_id(payload: "str | None") -> "str | None": if not payload: diff --git a/sqlspec/adapters/pymysql/_typing.py b/sqlspec/adapters/pymysql/_typing.py index 7aae2125b..beed40407 100644 --- a/sqlspec/adapters/pymysql/_typing.py +++ b/sqlspec/adapters/pymysql/_typing.py @@ -23,6 +23,8 @@ PyMysqlConnection = pymysql.connections.Connection PyMysqlRawCursor = pymysql.cursors.Cursor +__all__ = ("PyMysqlConnection", "PyMysqlCursor", "PyMysqlRawCursor", "PyMysqlSessionContext") + class PyMysqlCursor: """Context manager for PyMySQL cursor operations.""" @@ -87,6 +89,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("PyMysqlConnection", "PyMysqlCursor", "PyMysqlRawCursor", "PyMysqlSessionContext") diff --git a/sqlspec/adapters/pymysql/driver.py b/sqlspec/adapters/pymysql/driver.py index 812b69b34..ecb656d77 100644 --- a/sqlspec/adapters/pymysql/driver.py +++ b/sqlspec/adapters/pymysql/driver.py @@ -124,6 +124,10 @@ def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionRe sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters is not None and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + successful_count = 0 last_cursor = cursor diff --git a/sqlspec/adapters/pymysql/litestar/store.py b/sqlspec/adapters/pymysql/litestar/store.py index aa1d3ce74..11c2e773a 100644 --- a/sqlspec/adapters/pymysql/litestar/store.py +++ b/sqlspec/adapters/pymysql/litestar/store.py @@ -12,9 +12,10 @@ if TYPE_CHECKING: from sqlspec.adapters.pymysql.config import PyMysqlConfig +__all__ = ("PyMysqlStore",) + logger = get_logger("sqlspec.adapters.pymysql.litestar.store") -__all__ = ("PyMysqlStore",) MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 diff --git a/sqlspec/adapters/pymysql/pool.py b/sqlspec/adapters/pymysql/pool.py index d538139f6..ae9985185 100644 --- a/sqlspec/adapters/pymysql/pool.py +++ b/sqlspec/adapters/pymysql/pool.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator +__all__ = ("PyMysqlConnectionPool",) + class PyMysqlConnectionParams(TypedDict): """PyMySQL connection parameters.""" @@ -40,8 +42,6 @@ class PyMysqlConnectionParams(TypedDict): extra: NotRequired["dict[str, Any]"] -__all__ = ("PyMysqlConnectionPool",) - logger = get_logger(POOL_LOGGER_NAME) _ADAPTER_NAME = "pymysql" diff --git a/sqlspec/adapters/spanner/_typing.py b/sqlspec/adapters/spanner/_typing.py index 853dd2537..676ec4a83 100644 --- a/sqlspec/adapters/spanner/_typing.py +++ b/sqlspec/adapters/spanner/_typing.py @@ -22,6 +22,8 @@ if not TYPE_CHECKING: SpannerConnection = Any +__all__ = ("SpannerConnection", "SpannerSessionContext", "SpannerSyncCursor") + class SpannerSyncCursor: """Context manager that yields the active Spanner connection.""" @@ -97,6 +99,3 @@ def __exit__( self._release_connection(self._connection, exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb) self._connection = None return None - - -__all__ = ("SpannerConnection", "SpannerSessionContext", "SpannerSyncCursor") diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index a389ef26d..b3b8e030b 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -19,10 +19,11 @@ from sqlspec.config import ADKConfig from sqlspec.extensions.adk import MemoryRecord -SPANNER_PARAM_TYPES: SpannerParamTypesProtocol = cast("SpannerParamTypesProtocol", param_types) __all__ = ("SpannerSyncADKMemoryStore", "SpannerSyncADKStore") +SPANNER_PARAM_TYPES: SpannerParamTypesProtocol = cast("SpannerParamTypesProtocol", param_types) + class SpannerSyncADKStore(BaseAsyncADKStore[SpannerSyncConfig]): """Spanner ADK store backed by synchronous Spanner client.""" diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 535d62e13..070c2489a 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -239,18 +239,7 @@ def get_database(self) -> "Database": return self._database def create_connection(self) -> SpannerConnection: - instance_id = self.connection_config.get("instance_id") - database_id = self.connection_config.get("database_id") - if not instance_id or not database_id: - msg = "instance_id and database_id are required." - raise ImproperConfigurationError(msg) - - if self.connection_instance is None: - self.connection_instance = self.provide_pool() - - client = self._get_client() - database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call] - return cast("SpannerConnection", database.snapshot()) + return cast("SpannerConnection", self.get_database().snapshot()) def _create_pool(self) -> AbstractSessionPool: instance_id = self.connection_config.get("instance_id") diff --git a/sqlspec/adapters/spanner/data_dictionary.py b/sqlspec/adapters/spanner/data_dictionary.py index 9e13cb2c0..5b7b438a1 100644 --- a/sqlspec/adapters/spanner/data_dictionary.py +++ b/sqlspec/adapters/spanner/data_dictionary.py @@ -7,11 +7,11 @@ from sqlspec.driver import SyncDataDictionaryBase from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -__all__ = ("SpannerDataDictionary",) - if TYPE_CHECKING: from sqlspec.adapters.spanner.driver import SpannerSyncDriver +__all__ = ("SpannerDataDictionary",) + @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class SpannerDataDictionary(SyncDataDictionaryBase): diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index 9a461ad1f..553ac567a 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -3,8 +3,10 @@ from collections.abc import Iterator from typing import TYPE_CHECKING, Any, Protocol, cast +import sqlglot as _sqlglot from google.api_core import exceptions as api_exceptions from google.cloud.spanner_v1.transaction import Transaction +from sqlglot import exp as _sqlglot_exp from sqlspec.adapters.spanner._typing import SpannerSessionContext, SpannerSyncCursor from sqlspec.adapters.spanner.core import ( @@ -179,21 +181,21 @@ def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") - msg = "execute_many requires at least one parameter set" raise SQLConversionError(msg) - coerce_params = self._coerce_params - infer_param_types = self._infer_param_types + _coerce = self._coerce_params + _infer = self._infer_param_types param_types_cache: dict[tuple[tuple[str, type[Any]], ...], dict[str, Any]] = {} empty_param_types: dict[str, Any] = {} batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = [] append_batch_arg = batch_args.append for params in prepared_parameters: - coerced_params = coerce_params(cast("dict[str, Any] | None", params)) + coerced_params = _coerce(cast("dict[str, Any] | None", params)) if not coerced_params: append_batch_arg((sql, {}, empty_param_types)) continue signature = build_param_type_signature(coerced_params) param_types = param_types_cache.get(signature) if param_types is None: - param_types = infer_param_types(coerced_params) + param_types = _infer(coerced_params) param_types_cache[signature] = param_types append_batch_arg((sql, coerced_params, param_types)) @@ -212,7 +214,11 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") count = 0 script_params = cast("dict[str, Any] | None", params) for stmt in statements: - is_select = stmt.upper().strip().startswith("SELECT") + try: + parsed = _sqlglot.parse_one(stmt) + is_select = isinstance(parsed, _sqlglot_exp.Select) + except Exception: + is_select = stmt.upper().strip().startswith("SELECT") coerced_params = self._coerce_params(script_params) if not is_select and not is_transaction: raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE) diff --git a/sqlspec/adapters/sqlite/_typing.py b/sqlspec/adapters/sqlite/_typing.py index 7ef3baa3c..c7ea27c8e 100644 --- a/sqlspec/adapters/sqlite/_typing.py +++ b/sqlspec/adapters/sqlite/_typing.py @@ -25,6 +25,8 @@ SqliteConnection = _SqliteConnection SqliteRawCursor = sqlite3.Cursor +__all__ = ("SqliteConnection", "SqliteCursor", "SqliteRawCursor", "SqliteSessionContext") + class SqliteCursor: """Context manager for SQLite cursor management. @@ -118,6 +120,3 @@ def __exit__( self._release_connection(self._connection) self._connection = None return None - - -__all__ = ("SqliteConnection", "SqliteCursor", "SqliteRawCursor", "SqliteSessionContext") diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 137beba74..2194b328e 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -16,12 +16,13 @@ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.extensions.adk import MemoryRecord +__all__ = ("SqliteADKMemoryStore", "SqliteADKStore") + SECONDS_PER_DAY = 86400.0 JULIAN_EPOCH = 2440587.5 SQLITE_TABLE_NOT_FOUND_ERROR: Final = "no such table" -__all__ = ("SqliteADKMemoryStore", "SqliteADKStore") logger: "logging.Logger" = get_logger("sqlspec.adapters.sqlite.adk.store") diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index e1bff9ad7..01d3dfd52 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -22,6 +22,8 @@ from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig +__all__ = ("SqliteConfig", "SqliteConnectionParams", "SqliteDriverFeatures") + class SqliteConnectionParams(TypedDict): """SQLite connection parameters.""" @@ -69,9 +71,6 @@ class SqliteDriverFeatures(TypedDict): events_backend: NotRequired[str] -__all__ = ("SqliteConfig", "SqliteConnectionParams", "SqliteDriverFeatures") - - class SqliteConnectionContext(SyncPoolConnectionContext): """Context manager for Sqlite connections.""" diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index 61b427565..efd77b2e4 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -10,11 +10,11 @@ from sqlspec.driver import SyncDataDictionaryBase from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -__all__ = ("SqliteDataDictionary",) - if TYPE_CHECKING: from sqlspec.adapters.sqlite.driver import SqliteDriver +__all__ = ("SqliteDataDictionary",) + @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class SqliteDataDictionary(SyncDataDictionaryBase): diff --git a/sqlspec/adapters/sqlite/litestar/store.py b/sqlspec/adapters/sqlite/litestar/store.py index c91c00ab6..a43b3c368 100644 --- a/sqlspec/adapters/sqlite/litestar/store.py +++ b/sqlspec/adapters/sqlite/litestar/store.py @@ -9,12 +9,12 @@ if TYPE_CHECKING: from sqlspec.adapters.sqlite.config import SqliteConfig +__all__ = ("SQLiteStore",) + SECONDS_PER_DAY = 86400.0 JULIAN_EPOCH = 2440587.5 -__all__ = ("SQLiteStore",) - class SQLiteStore(BaseSQLSpecStore["SqliteConfig"]): """SQLite session store using synchronous SQLite driver. diff --git a/sqlspec/adapters/sqlite/pool.py b/sqlspec/adapters/sqlite/pool.py index 6692726bc..1155ab1e5 100644 --- a/sqlspec/adapters/sqlite/pool.py +++ b/sqlspec/adapters/sqlite/pool.py @@ -7,9 +7,7 @@ import time import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, TypedDict, cast - -from typing_extensions import NotRequired +from typing import TYPE_CHECKING, Any, cast from sqlspec.adapters.sqlite._typing import SqliteConnection from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context @@ -17,20 +15,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator - -class SqliteConnectionParams(TypedDict): - """SQLite connection parameters.""" - - database: NotRequired[str] - timeout: NotRequired[float] - detect_types: NotRequired[int] - isolation_level: "NotRequired[str | None]" - check_same_thread: NotRequired[bool] - factory: "NotRequired[type[SqliteConnection] | None]" - cached_statements: NotRequired[int] - uri: NotRequired[bool] - - __all__ = ("SqliteConnectionPool",) logger = get_logger(POOL_LOGGER_NAME) diff --git a/sqlspec/adapters/sqlite/type_converter.py b/sqlspec/adapters/sqlite/type_converter.py index 51a15fd85..d8dbfa251 100644 --- a/sqlspec/adapters/sqlite/type_converter.py +++ b/sqlspec/adapters/sqlite/type_converter.py @@ -1,3 +1,4 @@ +# Keep in sync with sqlspec/adapters/aiosqlite/type_converter.py """SQLite custom type handlers for optional JSON and type conversion support. Provides registration functions for SQLite's adapter/converter system to enable diff --git a/sqlspec/base.py b/sqlspec/base.py index 42fc1591d..d1f5d6ccd 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -23,6 +23,7 @@ get_cache_statistics, log_cache_stats, reset_cache_stats, + reset_stats_only, update_cache_config, ) from sqlspec.exceptions import ImproperConfigurationError @@ -168,6 +169,10 @@ def get_cache_stats() -> "dict[str, Any]": def reset_cache_stats() -> None: reset_cache_stats() + @staticmethod + def reset_stats_only() -> None: + reset_stats_only() + @staticmethod def log_cache_stats() -> None: log_cache_stats() @@ -728,6 +733,11 @@ def reset_cache_stats() -> None: """Reset all cache statistics to zero.""" _CACHE_MANAGER.reset_cache_stats() + @staticmethod + def reset_stats_only() -> None: + """Reset cache statistics without clearing cached data.""" + _CACHE_MANAGER.reset_stats_only() + @staticmethod def log_cache_stats() -> None: """Log current cache statistics using the configured logger.""" diff --git a/sqlspec/builder/__init__.py b/sqlspec/builder/__init__.py index 31fc0a827..e0707988a 100644 --- a/sqlspec/builder/__init__.py +++ b/sqlspec/builder/__init__.py @@ -94,9 +94,6 @@ from sqlspec.builder._vector_distance import VectorDistance from sqlspec.exceptions import SQLBuilderError -# Register temporal query SQL generators on module import -register_version_generators() - __all__ = ( "AggregateExpression", "AlterTable", @@ -179,3 +176,6 @@ "sql", "to_expression", ) + +# Register temporal query SQL generators on module import +register_version_generators() diff --git a/sqlspec/builder/_ddl.py b/sqlspec/builder/_ddl.py index c33e6f59d..9ae75f791 100644 --- a/sqlspec/builder/_ddl.py +++ b/sqlspec/builder/_ddl.py @@ -4,7 +4,7 @@ TRUNCATE, and other schema manipulation statements. """ -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from sqlglot import exp from typing_extensions import Self @@ -59,7 +59,6 @@ FOREIGN_KEY_ACTION_SET_DEFAULT, FOREIGN_KEY_ACTION_RESTRICT, FOREIGN_KEY_ACTION_NO_ACTION, - None, } VALID_CONSTRAINT_TYPES = { @@ -89,6 +88,9 @@ def build_column_expression(col: "ColumnDefinition") -> "exp.Expr": if col.unique: constraints.append(exp.ColumnConstraint(kind=exp.UniqueColumnConstraint())) + if col.auto_increment: + constraints.append(exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint())) + if col.default is not None: default_expr: exp.Expr | None = None if isinstance(col.default, str): @@ -138,6 +140,13 @@ def build_constraint_expression(constraint: "ConstraintDefinition") -> "exp.Expr return pk_constraint if constraint.constraint_type == CONSTRAINT_TYPE_FOREIGN_KEY: + fk_options: list[str] = [] + if constraint.deferrable: + if constraint.initially_deferred: + fk_options.append("DEFERRABLE INITIALLY DEFERRED") + else: + fk_options.append("DEFERRABLE INITIALLY IMMEDIATE") + fk_constraint = exp.ForeignKey( expressions=[exp.to_identifier(col) for col in constraint.columns], reference=exp.Reference( @@ -146,6 +155,7 @@ def build_constraint_expression(constraint: "ConstraintDefinition") -> "exp.Expr on_delete=constraint.on_delete, on_update=constraint.on_update, ), + options=fk_options or None, ) if constraint.name: @@ -471,7 +481,7 @@ def unique_constraint(self, columns: "str | list[str]", name: "str | None" = Non self._constraints.append(constraint) return self - def check_constraint(self, condition: Union[str, "ColumnExpression"], name: "str | None" = None) -> "Self": + def check_constraint(self, condition: "str | ColumnExpression", name: "str | None" = None) -> "Self": """Add a check constraint.""" if not condition: self._raise_sql_builder_error("Check constraint must have a condition") @@ -887,6 +897,11 @@ def _create_base_expression(self) -> exp.Expr: index_params = exp.IndexParameters(columns=cols) if cols else None + if self._using: + if index_params is None: + index_params = exp.IndexParameters() + index_params.set("using", exp.Var(this=self._using)) + index_expr = exp.Index( this=exp.to_identifier(self._index_name), table=exp.to_table(self._table_name), params=index_params ) @@ -1517,6 +1532,32 @@ def drop_not_null(self, column: str) -> "Self": self._operations.append(operation) return self + def in_schema(self, schema_name: str) -> "Self": + """Set the schema for the table.""" + self._schema = schema_name + return self + + def set_column_default(self, column: str, value: "Any") -> "Self": + """Set the default value for a column.""" + if not column: + self._raise_sql_builder_error("Column name must be a non-empty string") + + column_def = ColumnDefinition(name=column, dtype="", default=value) + operation = AlterOperation( + operation_type="ALTER COLUMN SET DEFAULT", column_name=column, column_definition=column_def + ) + self._operations.append(operation) + return self + + def drop_column_default(self, column: str) -> "Self": + """Remove the default value from a column.""" + if not column: + self._raise_sql_builder_error("Column name must be a non-empty string") + + operation = AlterOperation(operation_type="ALTER COLUMN DROP DEFAULT", column_name=column) + self._operations.append(operation) + return self + def _create_base_expression(self) -> "exp.Expr": """Create the SQLGlot expression for ALTER TABLE.""" if not self._operations: diff --git a/sqlspec/builder/_factory.py b/sqlspec/builder/_factory.py index 82bece085..a68823f7a 100644 --- a/sqlspec/builder/_factory.py +++ b/sqlspec/builder/_factory.py @@ -1432,11 +1432,11 @@ def not_any_(values: list[Any] | exp.Expr | str) -> FunctionExpression: sql .select("*") .from_("users") - .where(sql.id.neq(sql.not_any(subquery))) + .where(sql.id != sql.not_any_(subquery)) ) ``` """ - return SQLFactory.any(values) + return FunctionExpression(exp.Not(this=SQLFactory.any(values).expression)) @staticmethod def concat(*expressions: str | exp.Expr) -> StringExpression: diff --git a/sqlspec/builder/_parsing_utils.py b/sqlspec/builder/_parsing_utils.py index 92abe79b5..be355399c 100644 --- a/sqlspec/builder/_parsing_utils.py +++ b/sqlspec/builder/_parsing_utils.py @@ -19,6 +19,16 @@ has_parameter_builder, ) +__all__ = ( + "extract_expression", + "extract_sql_object_expression", + "parse_column_expression", + "parse_condition_expression", + "parse_order_expression", + "parse_table_expression", + "to_expression", +) + ALIAS_PARTS_EXPECTED_COUNT = 2 @@ -282,14 +292,3 @@ def to_expression(value: Any) -> exp.Expr: if isinstance(value, exp.Expr): return value return exp.convert(value) - - -__all__ = ( - "extract_expression", - "extract_sql_object_expression", - "parse_column_expression", - "parse_condition_expression", - "parse_order_expression", - "parse_table_expression", - "to_expression", -) diff --git a/sqlspec/builder/_select.py b/sqlspec/builder/_select.py index b17a9b1e3..2554a3869 100644 --- a/sqlspec/builder/_select.py +++ b/sqlspec/builder/_select.py @@ -35,11 +35,6 @@ is_iterable_parameters, ) -BETWEEN_BOUND_COUNT = 2 -PAIR_LENGTH = 2 -TRIPLE_LENGTH = 3 - - if TYPE_CHECKING: from collections.abc import Callable @@ -67,6 +62,10 @@ "WindowFunctionBuilder", ) +BETWEEN_BOUND_COUNT = 2 +PAIR_LENGTH = 2 +TRIPLE_LENGTH = 3 + def is_explicitly_quoted(identifier: Any) -> bool: """Detect if identifier was provided with explicit quotes.""" diff --git a/sqlspec/builder/_vector_distance.py b/sqlspec/builder/_vector_distance.py index 3ea2b0e43..3f2b46529 100644 --- a/sqlspec/builder/_vector_distance.py +++ b/sqlspec/builder/_vector_distance.py @@ -17,8 +17,6 @@ from sqlspec.dialects.spanner import Spangres, Spanner -SupportedVectorDistanceDialect: TypeAlias = "BigQuery | DuckDB | MySQL | Oracle | Postgres | Spangres | Spanner" - __all__ = ( "VectorDistance", "has_vector_distance_ancestor", @@ -32,6 +30,8 @@ "vector_distance_metric", ) +SupportedVectorDistanceDialect: TypeAlias = "BigQuery | DuckDB | MySQL | Oracle | Postgres | Spangres | Spanner" + _VECTOR_DISTANCE_META_KEY: Final[str] = "sqlspec_vector_distance_metric" _OperatorTransform = Callable[[Any, exp.Operator], str] _SQLGLOT_VECTOR_DISTANCE_REGISTERED = False diff --git a/sqlspec/config.py b/sqlspec/config.py index 4b6ffcaa5..61a444fd7 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1401,7 +1401,9 @@ def get_observability_runtime(self) -> "ObservabilityRuntime": if self._observability_runtime is None: self.attach_observability(None) - assert self._observability_runtime is not None + if self._observability_runtime is None: + msg = "ObservabilityRuntime was not set by attach_observability; this is a bug" + raise RuntimeError(msg) return self._observability_runtime def _prepare_driver(self, driver: DriverT) -> DriverT: @@ -1779,7 +1781,6 @@ def init_migrations(self, directory: "str | None" = None, package: bool = True) directory = str(migration_config.get("script_location") or "migrations") commands = self._ensure_migration_commands() - assert directory is not None commands.init(directory, package) def stamp_migration(self, revision: str) -> None: @@ -1943,7 +1944,6 @@ async def init_migrations(self, directory: "str | None" = None, package: bool = directory = str(migration_config.get("script_location") or "migrations") commands = cast("AsyncMigrationCommands[Any]", self._ensure_migration_commands()) - assert directory is not None await commands.init(directory, package) async def stamp_migration(self, revision: str) -> None: @@ -2155,7 +2155,6 @@ def init_migrations(self, directory: "str | None" = None, package: bool = True) directory = str(migration_config.get("script_location") or "migrations") commands = self._ensure_migration_commands() - assert directory is not None commands.init(directory, package) def stamp_migration(self, revision: str) -> None: @@ -2367,7 +2366,6 @@ async def init_migrations(self, directory: "str | None" = None, package: bool = directory = str(migration_config.get("script_location") or "migrations") commands = cast("AsyncMigrationCommands[Any]", self._ensure_migration_commands()) - assert directory is not None await commands.init(directory, package) async def stamp_migration(self, revision: str) -> None: diff --git a/sqlspec/core/__init__.py b/sqlspec/core/__init__.py index ed57e5cdb..b40673477 100644 --- a/sqlspec/core/__init__.py +++ b/sqlspec/core/__init__.py @@ -107,6 +107,7 @@ log_cache_stats, reset_cache_stats, reset_pipeline_registry, + reset_stats_only, update_cache_config, ) from sqlspec.core.compiler import ( @@ -364,6 +365,7 @@ "replace_placeholders_with_literals", "reset_cache_stats", "reset_pipeline_registry", + "reset_stats_only", "safe_modify_with_cte", "split_sql_script", "update_cache_config", diff --git a/sqlspec/core/cache.py b/sqlspec/core/cache.py index ccfa6e1d6..8ecc84a20 100644 --- a/sqlspec/core/cache.py +++ b/sqlspec/core/cache.py @@ -12,6 +12,7 @@ import logging import threading import time +import warnings from typing import TYPE_CHECKING, Any, Final from mypy_extensions import mypyc_attr @@ -435,6 +436,15 @@ def clear_all_caches() -> None: reset_statement_pipeline_cache() +def reset_stats_only() -> None: + """Reset cache statistics without evicting cached entries.""" + if _default_cache is not None: + _default_cache.get_stats().reset() + if _namespaced_cache is not None: + for cache in _namespaced_cache._caches.values(): + cache.get_stats().reset() + + def get_cache_statistics() -> "dict[str, CacheStats]": """Get statistics from all cache instances. @@ -556,7 +566,13 @@ def update_cache_config(config: CacheConfig) -> None: def reset_cache_stats() -> None: - """Reset all cache statistics.""" + """Deprecated alias for clearing all cache entries.""" + warnings.warn( + "reset_cache_stats() clears cached data and is deprecated; use clear_all_caches() to evict entries or " + "reset_stats_only() to reset counters without eviction.", + DeprecationWarning, + stacklevel=2, + ) clear_all_caches() diff --git a/sqlspec/core/compiler.py b/sqlspec/core/compiler.py index 4200a27e9..b20203f34 100644 --- a/sqlspec/core/compiler.py +++ b/sqlspec/core/compiler.py @@ -9,7 +9,7 @@ import logging from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, final import sqlglot from mypy_extensions import mypyc_attr @@ -107,7 +107,7 @@ COPY_TO_OPERATION_TYPES: "tuple[OperationType, ...]" = ("COPY_TO",) -ParseCacheEntry = tuple[exp.Expr | None, OperationType, dict[int, str], tuple[bool, bool]] +ParseCacheEntry = tuple[exp.Expr | None, OperationType, tuple[bool, bool]] def is_copy_operation(operation_type: "OperationType") -> bool: @@ -169,6 +169,7 @@ def _assign_placeholder_position( return placeholder_positions[placeholder_key] +@final @mypyc_attr(allow_interpreted_subclasses=False) class OperationProfile: """Semantic characteristics derived from the parsed SQL expression.""" @@ -202,6 +203,7 @@ def _is_effectively_empty_parameters(value: Any) -> bool: return False +@final @mypyc_attr(allow_interpreted_subclasses=False) class CompiledSQL: """Compiled SQL result. @@ -297,6 +299,7 @@ def __repr__(self) -> str: ) +@final @mypyc_attr(allow_interpreted_subclasses=False) class SQLProcessor: """SQL processor with compilation and caching. @@ -366,9 +369,7 @@ def __init__( ) self._cache_hits = 0 self._cache_misses = 0 - self._parse_cache: OrderedDict[ - Any, tuple[exp.Expr | None, OperationType, dict[int, str], tuple[bool, bool]] - ] = OrderedDict() + self._parse_cache: OrderedDict[Any, tuple[exp.Expr | None, OperationType, tuple[bool, bool]]] = OrderedDict() self._parse_cache_hits = 0 self._parse_cache_misses = 0 self._last_cache_key: Any | None = None @@ -384,7 +385,12 @@ def __init__( ) def compile( - self, sql: str, parameters: Any = None, is_many: bool = False, expression: "exp.Expr | None" = None + self, + sql: str, + parameters: Any = None, + is_many: bool = False, + expression: "exp.Expr | None" = None, + param_fingerprint: "Any | None" = None, ) -> CompiledSQL: """Compile SQL statement. @@ -393,6 +399,7 @@ def compile( parameters: Parameter values for substitution is_many: Whether this is for execute_many operation expression: Pre-parsed SQLGlot expression to reuse + param_fingerprint: Pre-computed parameter fingerprint. Returns: CompiledSQL with execution information @@ -400,7 +407,8 @@ def compile( if not self._config.enable_caching or not self._cache_enabled: return self._compile_uncached(sql, parameters, is_many, expression, param_fingerprint=None) - param_fingerprint = self._get_param_fingerprint(parameters, is_many) + if param_fingerprint is None: + param_fingerprint = self._get_param_fingerprint(parameters, is_many) cache_key = self._make_cache_key(sql, param_fingerprint, is_many) # MICRO-CACHE: Fast path for repeating statements @@ -498,8 +506,9 @@ def _prepare_parameters( process_result.applied_wrap_types, ) + @staticmethod def _normalize_expression_override( - self, expression_override: "exp.Expr | None", sqlglot_sql: str, sql: str + expression_override: "exp.Expr | None", sqlglot_sql: str, sql: str ) -> "exp.Expr | None": """Validate expression overrides against the input SQL. @@ -519,7 +528,7 @@ def _normalize_expression_override( def _parse_expression_uncached( self, sqlglot_sql: str, dialect_str: "str | None", expression_override: "exp.Expr | None" - ) -> "tuple[exp.Expr | None, OperationType, dict[int, str], OperationProfile]": + ) -> "tuple[exp.Expr | None, OperationType, OperationProfile]": """Parse SQL into an expression without cache. Args: @@ -536,19 +545,17 @@ def _parse_expression_uncached( else: expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str) except ParseError: - return None, "COMMAND", {}, OperationProfile.empty() + return None, "COMMAND", OperationProfile.empty() else: - operation_type = self._detect_operation_type(expression) - parameter_casts = self._detect_parameter_casts(expression) - operation_profile = self._build_operation_profile(expression, operation_type) - return expression, operation_type, parameter_casts, operation_profile + operation_type = SQLProcessor._detect_operation_type(expression) + operation_profile = SQLProcessor._build_operation_profile(expression, operation_type) + return expression, operation_type, operation_profile def _store_parse_cache( self, parse_cache_key: Any, expression: "exp.Expr | None", operation_type: "OperationType", - parameter_casts: "dict[int, str]", operation_profile: "OperationProfile", ) -> None: """Store parsed expression details in cache. @@ -557,7 +564,6 @@ def _store_parse_cache( parse_cache_key: Cache key for the parsed SQL. expression: Parsed SQLGlot expression. operation_type: Detected operation type. - parameter_casts: Parameter cast mappings. operation_profile: Operation metadata. """ if len(self._parse_cache) >= self._parse_cache_max_size: @@ -567,13 +573,13 @@ def _store_parse_cache( self._parse_cache[parse_cache_key] = ( expression, operation_type, - parameter_casts, (operation_profile.returns_rows, operation_profile.modifies_rows), ) + @staticmethod def _unpack_parse_cache_entry( - self, parse_cache_entry: "ParseCacheEntry" - ) -> "tuple[exp.Expr | None, OperationType, dict[int, str], OperationProfile]": + parse_cache_entry: "ParseCacheEntry", + ) -> "tuple[exp.Expr | None, OperationType, OperationProfile]": """Expand cached parse results into runtime objects. Args: @@ -582,17 +588,16 @@ def _unpack_parse_cache_entry( Returns: Parsed expression metadata. """ - cached_expression, cached_operation, cached_casts, cached_profile = parse_cache_entry + cached_expression, cached_operation, cached_profile = parse_cache_entry # Return expression reference without copying - _apply_ast_transformers will copy # if transformers are configured and will modify it. This avoids unnecessary copies # when no transformers are active (common case). operation_profile = OperationProfile(returns_rows=cached_profile[0], modifies_rows=cached_profile[1]) - # cached_casts is already a dict, no need to copy - it's not mutated by callers - return cached_expression, cached_operation, cached_casts, operation_profile + return cached_expression, cached_operation, operation_profile def _resolve_expression( self, sqlglot_sql: str, dialect_str: "str | None", expression_override: "exp.Expr | None" - ) -> "tuple[exp.Expr | None, OperationType, dict[int, str], OperationProfile, Any | None, ParseCacheEntry | None]": + ) -> "tuple[exp.Expr | None, OperationType, OperationProfile, Any | None, ParseCacheEntry | None]": """Resolve an SQLGlot expression with caching. Args: @@ -606,23 +611,21 @@ def _resolve_expression( parse_cache_key = None parse_cache_entry = None if self._config.enable_caching and self._parse_cache_max_size > 0: - parse_cache_key = self._make_parse_cache_key(sqlglot_sql, dialect_str) + parse_cache_key = SQLProcessor._make_parse_cache_key(sqlglot_sql, dialect_str) parse_cache_entry = self._parse_cache.get(parse_cache_key) if parse_cache_entry is not None: self._parse_cache_hits += 1 self._parse_cache.move_to_end(parse_cache_key) if parse_cache_entry is None: self._parse_cache_misses += 1 - expression, operation_type, parameter_casts, operation_profile = self._parse_expression_uncached( + expression, operation_type, operation_profile = self._parse_expression_uncached( sqlglot_sql, dialect_str, expression_override ) if parse_cache_key is not None: - self._store_parse_cache(parse_cache_key, expression, operation_type, parameter_casts, operation_profile) + self._store_parse_cache(parse_cache_key, expression, operation_type, operation_profile) else: - expression, operation_type, parameter_casts, operation_profile = self._unpack_parse_cache_entry( - parse_cache_entry - ) - return expression, operation_type, parameter_casts, operation_profile, parse_cache_key, parse_cache_entry + expression, operation_type, operation_profile = SQLProcessor._unpack_parse_cache_entry(parse_cache_entry) + return expression, operation_type, operation_profile, parse_cache_key, parse_cache_entry def _apply_ast_transformers( self, @@ -630,12 +633,11 @@ def _apply_ast_transformers( parameters: Any, parameter_profile: "ParameterProfile", operation_type: "OperationType", - parameter_casts: "dict[int, str]", operation_profile: "OperationProfile", parse_cache_key: "Any | None", parse_cache_entry: "ParseCacheEntry | None", expression_override: "exp.Expr | None", - ) -> "tuple[exp.Expr | None, Any, bool, OperationType, dict[int, str], OperationProfile]": + ) -> "tuple[exp.Expr | None, Any, bool, OperationType, OperationProfile]": """Apply AST transformers and update metadata. Args: @@ -643,7 +645,6 @@ def _apply_ast_transformers( parameters: Execution parameters. parameter_profile: Parameter profile metadata. operation_type: Current operation type. - parameter_casts: Current parameter cast mapping. operation_profile: Current operation profile. parse_cache_key: Parse cache key when used. parse_cache_entry: Cached parse entry when available. @@ -655,7 +656,7 @@ def _apply_ast_transformers( statement_transformers = self._config.statement_transformers ast_transformer = self._config.parameter_config.ast_transformer if expression is None or (not statement_transformers and not ast_transformer): - return expression, parameters, False, operation_type, parameter_casts, operation_profile + return expression, parameters, False, operation_type, operation_profile # Must copy the expression before transformers modify it to avoid corrupting: # 1. Cache entries (for both cache hits and misses that will be cached) @@ -679,12 +680,11 @@ def _apply_ast_transformers( ast_was_transformed = True if ast_was_transformed: if expression is None: - return expression, parameters, ast_was_transformed, operation_type, parameter_casts, operation_profile - operation_type = self._detect_operation_type(expression) - parameter_casts = self._detect_parameter_casts(expression) - operation_profile = self._build_operation_profile(expression, operation_type) + return expression, parameters, ast_was_transformed, operation_type, operation_profile + operation_type = SQLProcessor._detect_operation_type(expression) + operation_profile = SQLProcessor._build_operation_profile(expression, operation_type) - return expression, parameters, ast_was_transformed, operation_type, parameter_casts, operation_profile + return expression, parameters, ast_was_transformed, operation_type, operation_profile def _finalize_compilation( self, @@ -799,7 +799,6 @@ def _compile_uncached( operation_profile = OperationProfile.empty() try: - dialect_str = str(self._config.dialect) if self._config.dialect else None ( processed_sql, processed_params, @@ -807,8 +806,10 @@ def _compile_uncached( sqlglot_sql, input_named_parameters, applied_wrap_types, - ) = self._prepare_parameters(sql, parameters, is_many, dialect_str, param_fingerprint=param_fingerprint) - expression_override = self._normalize_expression_override(expression_override, sqlglot_sql, sql) + ) = self._prepare_parameters( + sql, parameters, is_many, self._dialect_str, param_fingerprint=param_fingerprint + ) + expression_override = SQLProcessor._normalize_expression_override(expression_override, sqlglot_sql, sql) final_parameters = processed_params ast_was_transformed = False @@ -819,27 +820,23 @@ def _compile_uncached( parse_cache_entry = None if self._config.enable_parsing: - (expression, operation_type, parameter_casts, operation_profile, parse_cache_key, parse_cache_entry) = ( - self._resolve_expression(sqlglot_sql, dialect_str, expression_override) + (expression, operation_type, operation_profile, parse_cache_key, parse_cache_entry) = ( + self._resolve_expression(sqlglot_sql, self._dialect_str, expression_override) ) - ( - expression, - final_parameters, - ast_was_transformed, - operation_type, - parameter_casts, - operation_profile, - ) = self._apply_ast_transformers( - expression, - final_parameters, - parameter_profile, - operation_type, - parameter_casts, - operation_profile, - parse_cache_key, - parse_cache_entry, - expression_override, + (expression, final_parameters, ast_was_transformed, operation_type, operation_profile) = ( + self._apply_ast_transformers( + expression, + final_parameters, + parameter_profile, + operation_type, + operation_profile, + parse_cache_key, + parse_cache_entry, + expression_override, + ) ) + if expression is not None: + parameter_casts = SQLProcessor._detect_parameter_casts(expression) final_sql, final_params, parameter_profile, input_named_params, applied_wrap = self._finalize_compilation( processed_sql, @@ -848,7 +845,7 @@ def _compile_uncached( final_parameters, parameter_profile, is_many, - dialect_str, + self._dialect_str, ast_was_transformed, ) @@ -910,7 +907,8 @@ def _make_cache_key(self, sql: str, param_fingerprint: Any, is_many: bool = Fals sql, param_fingerprint, self._input_style, self._exec_style, self._dialect_str, is_many ) - def _detect_operation_type(self, expression: "exp.Expr") -> "OperationType": + @staticmethod + def _detect_operation_type(expression: "exp.Expr") -> "OperationType": """Detect operation type from AST. Args: @@ -936,7 +934,8 @@ def _detect_operation_type(self, expression: "exp.Expr") -> "OperationType": # that doesn't fit specific categories return "COMMAND" - def _detect_parameter_casts(self, expression: "exp.Expr | None") -> "dict[int, str]": + @staticmethod + def _detect_parameter_casts(expression: "exp.Expr | None") -> "dict[int, str]": """Detect explicit type casts on parameters in the AST. Args: @@ -1007,9 +1006,8 @@ def _apply_final_transformations( return sql, parameters - def _build_operation_profile( - self, expression: "exp.Expr | None", operation_type: "OperationType" - ) -> "OperationProfile": + @staticmethod + def _build_operation_profile(expression: "exp.Expr | None", operation_type: "OperationType") -> "OperationProfile": if expression is None: return OperationProfile.empty() @@ -1050,7 +1048,8 @@ def clear_cache(self) -> None: self._parse_cache_misses = 0 self._parameter_processor.clear_cache() - def _make_parse_cache_key(self, sql: str, dialect: "str | None") -> Any: + @staticmethod + def _make_parse_cache_key(sql: str, dialect: "str | None") -> Any: dialect_marker = dialect or "default" # Use tuple as cache key instead of hashing. Python handles tuple keys efficiently. return (dialect_marker, sql) diff --git a/sqlspec/core/filters.py b/sqlspec/core/filters.py index f674202d8..46b608658 100644 --- a/sqlspec/core/filters.py +++ b/sqlspec/core/filters.py @@ -242,7 +242,6 @@ def extract_parameters(self) -> "tuple[list[Any], dict[str, Any]]": def append_to_statement(self, statement: "SQL") -> "SQL": """Apply filter to SQL expression only.""" conditions: list[Condition] = [] - col_expr = self._get_column_expression(self.field_name) proposed_names = self.get_param_names() if not proposed_names: @@ -255,12 +254,21 @@ def append_to_statement(self, statement: "SQL") -> "SQL": if self.before: before_param_name = resolved_names[param_idx] param_idx += 1 - conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=before_param_name))) + conditions.append( + exp.LT( + this=self._get_column_expression(self.field_name), + expression=exp.Placeholder(this=before_param_name), + ) + ) result = result.add_named_parameter(before_param_name, self.before) if self.after: after_param_name = resolved_names[param_idx] - conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=after_param_name))) + conditions.append( + exp.GT( + this=self._get_column_expression(self.field_name), expression=exp.Placeholder(this=after_param_name) + ) + ) result = result.add_named_parameter(after_param_name, self.after) final_condition = conditions[0] @@ -330,7 +338,6 @@ def extract_parameters(self) -> "tuple[list[Any], dict[str, Any]]": def append_to_statement(self, statement: "SQL") -> "SQL": conditions: list[Condition] = [] - col_expr = self._get_column_expression(self.field_name) proposed_names = self.get_param_names() if not proposed_names: @@ -343,12 +350,21 @@ def append_to_statement(self, statement: "SQL") -> "SQL": if self.on_or_before: before_param_name = resolved_names[param_idx] param_idx += 1 - conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=before_param_name))) + conditions.append( + exp.LTE( + this=self._get_column_expression(self.field_name), + expression=exp.Placeholder(this=before_param_name), + ) + ) result = result.add_named_parameter(before_param_name, self.on_or_before) if self.on_or_after: after_param_name = resolved_names[param_idx] - conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=after_param_name))) + conditions.append( + exp.GTE( + this=self._get_column_expression(self.field_name), expression=exp.Placeholder(this=after_param_name) + ) + ) result = result.add_named_parameter(after_param_name, self.on_or_after) final_condition = conditions[0] @@ -720,6 +736,7 @@ def append_to_statement(self, statement: "SQL") -> "SQL": elif isinstance(self.field_name, exp.Expression): result = statement.where(like_op(this=self.field_name, expression=exp.Placeholder(this=param_name))) elif isinstance(self.field_name, set) and self.field_name: + # do not hoist Placeholder outside this comprehension; sqlglot nodes carry parent pointers. field_conditions: list[Condition] = [ like_op(this=self._get_column_expression(field), expression=exp.Placeholder(this=param_name)) for field in self.field_name @@ -874,6 +891,7 @@ def append_to_statement(self, statement: "SQL") -> "SQL": exp.Not(this=like_op(this=self.field_name, expression=exp.Placeholder(this=param_name))) ) elif isinstance(self.field_name, set) and self.field_name: + # do not hoist Placeholder outside this comprehension; sqlglot nodes carry parent pointers. field_conditions: list[Condition] = [ exp.Not( this=like_op(this=self._get_column_expression(field), expression=exp.Placeholder(this=param_name)) diff --git a/sqlspec/core/parameters/_alignment.py b/sqlspec/core/parameters/_alignment.py index 987106dd9..c2c5978b6 100644 --- a/sqlspec/core/parameters/_alignment.py +++ b/sqlspec/core/parameters/_alignment.py @@ -4,7 +4,7 @@ from typing import Any, cast import sqlspec.exceptions -from sqlspec.core.parameters._types import ParameterProfile, ParameterStyle +from sqlspec.core.parameters._types import _NAMED_STYLES, ParameterProfile, ParameterStyle __all__ = ( "EXECUTE_MANY_MIN_ROWS", @@ -119,12 +119,7 @@ def _collect_expected_identifiers(parameter_profile: "ParameterProfile") -> "set for parameter in parameters: style = parameter.style name = parameter.name - if style in { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - }: + if style in _NAMED_STYLES: identifiers.add(("named", name or f"param_{parameter.ordinal}")) elif style in {ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_COLON}: # When mixed with ordinal styles (like QMARK), use ordinal instead of explicit index diff --git a/sqlspec/core/parameters/_converter.py b/sqlspec/core/parameters/_converter.py index 1e0386ee9..7be223a01 100644 --- a/sqlspec/core/parameters/_converter.py +++ b/sqlspec/core/parameters/_converter.py @@ -6,6 +6,7 @@ from mypy_extensions import mypyc_attr from sqlspec.core.parameters._types import ( + _NAMED_STYLES, ConvertedParameters, NamedParameterOutput, ParameterInfo, @@ -129,6 +130,7 @@ def convert_placeholder_style( *, strict_named_parameters: bool = True, param_info: "list[ParameterInfo] | None" = None, + precomputed_plan: "tuple[list[ParameterInfo], dict[str, int]] | None" = None, ) -> "tuple[str, ConvertedParameters]": extracted_param_info = param_info if param_info is not None else self.validator.extract_parameters(sql) @@ -148,7 +150,7 @@ def convert_placeholder_style( ) return sql, converted_parameters - converted_sql = self._convert_placeholders_to_style(sql, extracted_param_info, target_style) + converted_sql = self._convert_placeholders_to_style(sql, extracted_param_info, target_style, precomputed_plan) converted_parameters = self._convert_parameter_format( parameters, extracted_param_info, @@ -174,14 +176,21 @@ def _build_conversion_plan( return ordered_params, unique_params def _convert_placeholders_to_style( - self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + self, + sql: str, + param_info: "list[ParameterInfo]", + target_style: "ParameterStyle", + precomputed_plan: "tuple[list[ParameterInfo], dict[str, int]] | None" = None, ) -> str: generator = self._placeholder_generators.get(target_style) if generator is None: msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) - ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) + if precomputed_plan is not None: + ordered_params, unique_params = precomputed_plan + else: + ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) placeholder_text_len_cache: dict[str, int] = {} # Build SQL using forward iteration with list join (O(n) vs O(n^2) string slicing) @@ -214,14 +223,20 @@ def _convert_placeholders_to_style( return "".join(segments) def convert_parameter_info_style( - self, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + self, + param_info: "list[ParameterInfo]", + target_style: "ParameterStyle", + precomputed_plan: "tuple[list[ParameterInfo], dict[str, int]] | None" = None, ) -> "list[ParameterInfo]": generator = self._placeholder_generators.get(target_style) if generator is None: msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) - ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) + if precomputed_plan is not None: + ordered_params, unique_params = precomputed_plan + else: + ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) is_positional_style = _is_positional_style(target_style) converted_param_info: list[ParameterInfo] = [] delta = 0 @@ -333,15 +348,9 @@ def _extract_param_value_single_style( def _collect_missing_named_parameters( self, param_info: "list[ParameterInfo]", parameters: "ParameterMapping" ) -> "list[str]": - named_styles = { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - } missing: list[str] = [] for param in param_info: - if param.style not in named_styles or not param.name: + if param.style not in _NAMED_STYLES or not param.name: continue if param.name in parameters or param.placeholder_text in parameters: continue diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index 3146c229a..b0007079c 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -9,6 +9,9 @@ from sqlspec.core.parameters._alignment import looks_like_execute_many from sqlspec.core.parameters._converter import ParameterConverter from sqlspec.core.parameters._types import ( + _NAMED_STYLE_VALUES, + _NAMED_STYLES, + _POSITIONAL_STYLE_VALUES, ConvertedParameters, ParameterInfo, ParameterPayload, @@ -565,7 +568,7 @@ def _coerce_parameter_types( return cast("ConvertedParameters", result) if result_type is tuple: return cast("ConvertedParameters", result) - return None + return cast("ConvertedParameters", result) def _store_cached_result( self, cache_key: Any | None, result: "ParameterProcessingResult" @@ -617,24 +620,12 @@ def _transform_cached_parameters( processed = self._coerce_parameter_types(processed, config.type_coercion_map, is_many) if processed: - positional_styles = { - ParameterStyle.QMARK.value, - ParameterStyle.NUMERIC.value, - ParameterStyle.POSITIONAL_COLON.value, - ParameterStyle.POSITIONAL_PYFORMAT.value, - } - named_styles = { - ParameterStyle.NAMED_COLON.value, - ParameterStyle.NAMED_AT.value, - ParameterStyle.NAMED_DOLLAR.value, - ParameterStyle.NAMED_PYFORMAT.value, - } cached_styles = cached_profile.styles - if input_named_parameters and any(style in positional_styles for style in cached_styles): + if input_named_parameters and any(style in _POSITIONAL_STYLE_VALUES for style in cached_styles): processed = self._map_named_to_positional( processed, input_named_parameters, is_many, strict=config.strict_named_parameters ) - elif any(style in named_styles for style in cached_styles): + elif any(style in _NAMED_STYLE_VALUES for style in cached_styles): processed = self._map_positional_to_named(processed, cached_profile, is_many) return processed @@ -769,16 +760,7 @@ def _needs_mapping_normalization( if not payload or not param_info: return False - has_named_placeholders = any( - param.style - in { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - } - for param in param_info - ) + has_named_placeholders = any(param.style in _NAMED_STYLES for param in param_info) if has_named_placeholders: return False @@ -1028,6 +1010,9 @@ def _process_internal( if requires_mapping: target_style = self._select_execution_style(original_styles, config) + mapping_plan = self._converter._build_conversion_plan( # pyright: ignore[reportPrivateUsage] + param_info, target_style + ) processed_sql, processed_parameters = self._converter.convert_placeholder_style( processed_sql, processed_parameters, @@ -1035,8 +1020,9 @@ def _process_internal( is_many, strict_named_parameters=config.strict_named_parameters, param_info=param_info, + precomputed_plan=mapping_plan, ) - param_info = self._converter.convert_parameter_info_style(param_info, target_style) + param_info = self._converter.convert_parameter_info_style(param_info, target_style, mapping_plan) original_styles = {target_style} needs_execution_conversion = False @@ -1155,10 +1141,15 @@ def _convert_placeholders_for_execution( return sql, dict(parameters), None if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): return sql, list(parameters), None + if len(param_info) == 1: + return sql, [parameters], None return sql, None, None target_style = self._select_execution_style(original_styles, config) - converted_param_info = self._converter.convert_parameter_info_style(param_info, target_style) + execution_plan = self._converter._build_conversion_plan( # pyright: ignore[reportPrivateUsage] + param_info, target_style + ) + converted_param_info = self._converter.convert_parameter_info_style(param_info, target_style, execution_plan) if is_many and config.preserve_original_params_for_many and isinstance(parameters, (list, tuple)): processed_sql, _ = self._converter.convert_placeholder_style( @@ -1168,6 +1159,7 @@ def _convert_placeholders_for_execution( is_many, strict_named_parameters=config.strict_named_parameters, param_info=param_info, + precomputed_plan=execution_plan, ) return processed_sql, parameters, converted_param_info @@ -1178,5 +1170,6 @@ def _convert_placeholders_for_execution( is_many, strict_named_parameters=config.strict_named_parameters, param_info=param_info, + precomputed_plan=execution_plan, ) return processed_sql, processed_parameters, converted_param_info diff --git a/sqlspec/core/parameters/_transformers.py b/sqlspec/core/parameters/_transformers.py index 4fa51a8ef..d1d02fa89 100644 --- a/sqlspec/core/parameters/_transformers.py +++ b/sqlspec/core/parameters/_transformers.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Mapping, Sequence from typing import Any, cast +from mypy_extensions import mypyc_attr from sqlglot import exp as _exp from sqlspec.core.parameters._alignment import ( @@ -13,13 +14,15 @@ validate_parameter_alignment, ) from sqlspec.core.parameters._types import ( + _NAMED_STYLES, + _POSITIONAL_STYLES, ConvertedParameters, ParameterMapping, ParameterPayload, ParameterProfile, - ParameterStyle, ) from sqlspec.core.parameters._validator import ParameterValidator +from sqlspec.utils.deprecation import warn_deprecation from sqlspec.utils.type_guards import get_value_attribute __all__ = ( @@ -29,9 +32,12 @@ "replace_placeholders_with_literals", ) +# Flyweight validator used exclusively by the deprecated no-profile fallback +# path. Callers that pass validator= bypass this singleton. _AST_TRANSFORMER_VALIDATOR: "ParameterValidator" = ParameterValidator() +@mypyc_attr(allow_interpreted_subclasses=False) class _NullPruningTransform: __slots__ = ("_dialect", "_validator") @@ -51,6 +57,7 @@ def __call__( ) +@mypyc_attr(allow_interpreted_subclasses=False) class _LiteralInliningTransform: __slots__ = ("_json_serializer",) @@ -66,6 +73,7 @@ def __call__( return literal_expression, parameters +@mypyc_attr(allow_interpreted_subclasses=False) class _NullPlaceholderTransformer: __slots__ = ("_null_names", "_null_positions", "_qmark_position", "_sorted_null_positions") @@ -112,6 +120,7 @@ def __call__(self, node: Any) -> Any: return node +@mypyc_attr(allow_interpreted_subclasses=False) class _PlaceholderLiteralTransformer: __slots__ = ("_json_serializer", "_parameters", "_placeholder_index") @@ -226,6 +235,17 @@ def replace_null_parameters_with_literals( validator_instance = validator or _AST_TRANSFORMER_VALIDATOR profile = parameter_profile if profile is None: + warn_deprecation( + "0.49", + "replace_null_parameters_with_literals", + "function", + info=( + "Calling replace_null_parameters_with_literals without parameter_profile triggers a " + "round-trip AST serialization. Pass parameter_profile for efficiency. This fallback will " + "be removed in a future version." + ), + stacklevel=3, + ) parameter_info = validator_instance.extract_parameters(expression.sql(dialect=dialect)) profile = ParameterProfile(parameter_info) validate_parameter_alignment(profile, parameters) @@ -243,26 +263,14 @@ def replace_null_parameters_with_literals( return expression, list(parameters) return expression, None - named_styles = { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - } - positional_styles = { - ParameterStyle.QMARK, - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.POSITIONAL_COLON, - ParameterStyle.NUMERIC, - } null_names: set[str] = set() positional_null_positions: set[int] = set() for parameter in profile.parameters: if parameter.ordinal not in null_positions: continue - if parameter.style in named_styles and parameter.name: + if parameter.style in _NAMED_STYLES and parameter.name: null_names.add(parameter.name) - if parameter.style in positional_styles: + if parameter.style in _POSITIONAL_STYLES: positional_null_positions.add(parameter.ordinal) sorted_null_positions = sorted(positional_null_positions) diff --git a/sqlspec/core/parameters/_types.py b/sqlspec/core/parameters/_types.py index 226898408..4a73bcf58 100644 --- a/sqlspec/core/parameters/_types.py +++ b/sqlspec/core/parameters/_types.py @@ -4,7 +4,6 @@ from datetime import date, datetime, time from decimal import Decimal from enum import Enum -from functools import singledispatch from types import MappingProxyType from typing import Any, Literal, TypeAlias @@ -41,19 +40,20 @@ """Type alias for parameter payloads accepted by the processing pipeline.""" -ConvertedParameters: TypeAlias = "dict[str, Any] | list[Any] | tuple[Any, ...] | None" +ConvertedParameters: TypeAlias = "dict[str, Any] | list[Any] | tuple[Any, ...] | object | None" """Type alias for parameters after conversion to driver-consumable format. This type represents the concrete output of parameter conversion functions. Unlike :data:`ParameterPayload` (which represents inputs and can include abstract -Mapping/Sequence types), :data:`ConvertedParameters` only includes concrete types -that database drivers can directly consume. +Mapping/Sequence types), :data:`ConvertedParameters` includes concrete container +types and scalar objects that database drivers can directly consume. The union includes: - ``dict[str, Any]``: Named parameters (e.g., ``{"name": "Alice", "age": 30}``) - ``list[Any]``: Positional parameters as list (e.g., ``["Alice", 30]``) - ``tuple[Any, ...]``: Positional parameters as tuple (e.g., ``("Alice", 30)``) +- ``object``: Scalar parameter payloads after type coercion (e.g., ``3.14``) - ``None``: When parameters are statically embedded in SQL string """ @@ -141,6 +141,22 @@ class ParameterStyle(str, Enum): POSITIONAL_PYFORMAT = "pyformat_positional" +_NAMED_STYLES: frozenset["ParameterStyle"] = frozenset({ + ParameterStyle.NAMED_COLON, + ParameterStyle.NAMED_AT, + ParameterStyle.NAMED_DOLLAR, + ParameterStyle.NAMED_PYFORMAT, +}) +_POSITIONAL_STYLES: frozenset["ParameterStyle"] = frozenset({ + ParameterStyle.QMARK, + ParameterStyle.NUMERIC, + ParameterStyle.POSITIONAL_COLON, + ParameterStyle.POSITIONAL_PYFORMAT, +}) +_NAMED_STYLE_VALUES: frozenset[str] = frozenset(style.value for style in _NAMED_STYLES) +_POSITIONAL_STYLE_VALUES: frozenset[str] = frozenset(style.value for style in _POSITIONAL_STYLES) + + @mypyc_attr(allow_interpreted_subclasses=False) class TypedParameter: """Wrapper that preserves original parameter type information.""" @@ -190,41 +206,22 @@ def __call__(self, value: Any) -> "Any": return self._serializer(value) -@singledispatch def _wrap_parameter_by_type(value: Any, semantic_name: "str | None" = None) -> Any: + if isinstance(value, bool): + return TypedParameter(value, bool, semantic_name) + if isinstance(value, Decimal): + return TypedParameter(value, Decimal, semantic_name) + if isinstance(value, datetime): + return TypedParameter(value, datetime, semantic_name) + if isinstance(value, date): + return TypedParameter(value, date, semantic_name) + if isinstance(value, time): + return TypedParameter(value, time, semantic_name) + if isinstance(value, bytes): + return TypedParameter(value, bytes, semantic_name) return value -@_wrap_parameter_by_type.register -def _(value: bool, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, bool, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: Decimal, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, Decimal, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: datetime, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, datetime, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: date, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, date, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: time, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, time, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: bytes, semantic_name: "str | None" = None) -> "TypedParameter": - return TypedParameter(value, bytes, semantic_name) - - @mypyc_attr(allow_interpreted_subclasses=False) class ParameterInfo: """Metadata describing a single detected SQL parameter.""" @@ -310,27 +307,18 @@ def __hash__(self) -> int: tuple(sorted(self.type_coercion_map.keys(), key=str)) if self.type_coercion_map else None, self.has_native_list_expansion, self.preserve_original_params_for_many, - bool(self.output_transformer), + id(self.output_transformer) if self.output_transformer else None, self.needs_static_script_compilation, self.allow_mixed_parameter_styles, self.preserve_parameter_format, self.strict_named_parameters, - bool(self.ast_transformer), + id(self.ast_transformer) if self.ast_transformer else None, self.json_serializer, self.json_deserializer, ) self._hash_cache = hash(hash_components) return self._hash_cache - def hash(self) -> int: - """Return the hash value for caching compatibility. - - Returns: - Hash value matching :func:`hash` output for this config. - """ - - return hash(self) - def __reduce__(self) -> "tuple[Any, ...]": """Reconstruct via the public ctor so copy/pickle work on mypyc native classes.""" return ( diff --git a/sqlspec/core/pipeline.py b/sqlspec/core/pipeline.py index 1f0d11761..f795b9c72 100644 --- a/sqlspec/core/pipeline.py +++ b/sqlspec/core/pipeline.py @@ -15,6 +15,13 @@ from sqlspec.core.statement import StatementConfig +__all__ = ( + "StatementPipelineRegistry", + "compile_with_pipeline", + "get_statement_pipeline_metrics", + "reset_statement_pipeline_cache", +) + DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" DEFAULT_PIPELINE_CACHE_SIZE: Final[int] = 1000 DEFAULT_PIPELINE_PARSE_CACHE_SIZE: Final[int] = 5000 @@ -25,6 +32,9 @@ def _is_truthy(value: "str | None") -> bool: return bool(value and value.strip().lower() in {"1", "true", "yes", "on"}) +_RECORD_PIPELINE_METRICS: Final[bool] = _is_truthy(os.environ.get(DEBUG_ENV_FLAG)) + + @mypyc_attr(allow_interpreted_subclasses=False) class _PipelineMetrics: __slots__ = ( @@ -147,9 +157,17 @@ def __init__( self._metrics = _PipelineMetrics() if record_metrics else None def compile( - self, sql: str, parameters: Any, is_many: bool, record_metrics: bool, expression: "exp.Expr | None" = None + self, + sql: str, + parameters: Any, + is_many: bool, + record_metrics: bool, + expression: "exp.Expr | None" = None, + param_fingerprint: "Any | None" = None, ) -> "CompiledSQL": - result = self._processor.compile(sql, parameters, is_many=is_many, expression=expression) + result = self._processor.compile( + sql, parameters, is_many=is_many, expression=expression, param_fingerprint=param_fingerprint + ) if record_metrics and self._metrics is not None: self._metrics.update(self._processor.cache_stats) return result @@ -189,10 +207,11 @@ def compile( parameters: Any, is_many: bool = False, expression: "exp.Expr | None" = None, + param_fingerprint: "Any | None" = None, ) -> "CompiledSQL": key = self._fingerprint_config(config) pipeline = self._pipelines.get(key) - record_metrics = _is_truthy(os.getenv(DEBUG_ENV_FLAG)) + record_metrics = _RECORD_PIPELINE_METRICS if pipeline is not None: self._pipelines.move_to_end(key) @@ -204,7 +223,9 @@ def compile( self._pipelines.popitem(last=False) self._pipelines[key] = pipeline - return pipeline.compile(sql, parameters, is_many, record_metrics, expression=expression) + return pipeline.compile( + sql, parameters, is_many, record_metrics, expression=expression, param_fingerprint=param_fingerprint + ) def reset(self) -> None: for pipeline in self._pipelines.values(): @@ -218,7 +239,7 @@ def configure_cache(self, cache_size: int, parse_cache_size: int, cache_enabled: self.reset() def metrics(self) -> "list[dict[str, Any]]": - if not _is_truthy(os.getenv(DEBUG_ENV_FLAG)): + if not _RECORD_PIPELINE_METRICS: return [] snapshots: list[dict[str, Any]] = [] @@ -253,12 +274,15 @@ def metrics(self) -> "list[dict[str, Any]]": def _fingerprint_config(self, config: "Any") -> str: # Optimization: Use cached fingerprint if available # Configs are effectively immutable after creation, so caching is safe - cached = getattr(config, "_fingerprint_cache", None) - if isinstance(cached, str): - return cached + try: + cached = config._fingerprint_cache # pyright: ignore[reportPrivateUsage] + if isinstance(cached, str): + return cached + except AttributeError: + pass param_config = config.parameter_config - param_config_hash = param_config.hash() + param_config_hash = hash(param_config) converter_type = type(config.parameter_converter) if config.parameter_converter else None validator_type = type(config.parameter_validator) if config.parameter_validator else None output_transformer_id = id(config.output_transformer) if config.output_transformer else None @@ -298,8 +322,8 @@ def _fingerprint_config(self, config: "Any") -> str: full_fingerprint = f"pipeline::{fingerprint}" # Cache the fingerprint for future calls - configs are immutable in practice - with contextlib.suppress(AttributeError, TypeError): - config._fingerprint_cache = full_fingerprint + with contextlib.suppress(AttributeError): + config._fingerprint_cache = full_fingerprint # pyright: ignore[reportPrivateUsage] return full_fingerprint @@ -308,9 +332,16 @@ def _fingerprint_config(self, config: "Any") -> str: def compile_with_pipeline( - config: "Any", sql: str, parameters: Any, is_many: bool = False, expression: "exp.Expr | None" = None + config: "Any", + sql: str, + parameters: Any, + is_many: bool = False, + expression: "exp.Expr | None" = None, + param_fingerprint: "Any | None" = None, ) -> "CompiledSQL": - return _PIPELINE_REGISTRY.compile(config, sql, parameters, is_many=is_many, expression=expression) + return _PIPELINE_REGISTRY.compile( + config, sql, parameters, is_many=is_many, expression=expression, param_fingerprint=param_fingerprint + ) def reset_statement_pipeline_cache() -> None: @@ -323,11 +354,3 @@ def configure_statement_pipeline_cache(cache_size: int, parse_cache_size: int, c def get_statement_pipeline_metrics() -> "list[dict[str, Any]]": return _PIPELINE_REGISTRY.metrics() - - -__all__ = ( - "StatementPipelineRegistry", - "compile_with_pipeline", - "get_statement_pipeline_metrics", - "reset_statement_pipeline_cache", -) diff --git a/sqlspec/core/result/_base.py b/sqlspec/core/result/_base.py index d1a80bcf9..2474becf2 100644 --- a/sqlspec/core/result/_base.py +++ b/sqlspec/core/result/_base.py @@ -384,11 +384,6 @@ def _get_schema_rows(self, schema_type: "type[SchemaT]", rows: "list[dict[str, A return converted_rows def _get_schema_row(self, schema_type: "type[SchemaT]", row: "dict[str, Any]") -> "SchemaT": - rows_cache = self._schema_rows_cache - if rows_cache is not None: - cached_rows = rows_cache.get(schema_type) - if cached_rows: - return cast("SchemaT", cached_rows[0]) row_cache = self._schema_row_cache if row_cache is not None: cached_row = row_cache.get(schema_type) diff --git a/sqlspec/core/statement.py b/sqlspec/core/statement.py index 8b6a40d43..30e56b0fc 100644 --- a/sqlspec/core/statement.py +++ b/sqlspec/core/statement.py @@ -106,6 +106,29 @@ "_is_frozen", ) +_PUBLIC_CONFIG_FIELDS: Final = frozenset(( + "dialect", + "enable_analysis", + "enable_caching", + "enable_column_pruning", + "enable_expression_simplification", + "enable_parameter_type_wrapping", + "enable_parsing", + "enable_sqlcommenter", + "enable_transformations", + "enable_validation", + "execution_mode", + "execution_args", + "output_transformer", + "sqlcommenter_attributes", + "sqlcommenter_enable_context", + "sqlcommenter_enable_traceparent", + "statement_transformers", + "parameter_config", + "parameter_converter", + "parameter_validator", +)) + PROCESSED_STATE_SLOTS: Final = ( "compiled_sql", "execution_parameters", @@ -296,8 +319,7 @@ def __init__( if isinstance(statement, str): self._raw_sql = statement else: - dialect = self._dialect - self._raw_sql = statement.sql(dialect=str(dialect) if dialect else None) + self._raw_sql = "" self._raw_expression = statement self._is_many = is_many if is_many is not None else self._should_auto_detect_many(parameters) @@ -325,7 +347,7 @@ def __reduce__(self) -> "tuple[Any, ...]": return ( _rebuild_sql, ( - self._raw_sql, + self._get_raw_sql(), self._original_parameters, tuple(self._filters), self._statement_config, @@ -370,6 +392,13 @@ def _normalize_dialect(self, dialect: "DialectType") -> "str | None": return dialect return dialect.__class__.__name__.lower() + def _get_raw_sql(self) -> str: + """Return raw SQL, materializing deferred expression SQL when needed.""" + if self._raw_sql == "" and self._raw_expression is not None: + dialect = self._dialect + self._raw_sql = self._raw_expression.sql(dialect=str(dialect) if dialect else None) + return self._raw_sql + def _init_from_sql_object(self, sql_obj: "SQL") -> None: """Initialize instance attributes from existing SQL object. @@ -482,7 +511,7 @@ def _normalize_parameters(self, parameters: "tuple[Any, ...]") -> None: @property def sql(self) -> str: """Get the raw SQL string.""" - return self._raw_sql + return self._get_raw_sql() @property def raw_sql(self) -> str: @@ -491,7 +520,7 @@ def raw_sql(self) -> str: Returns: The raw SQL string """ - return self._raw_sql + return self._get_raw_sql() @property def parameters(self) -> Any: @@ -597,7 +626,9 @@ def validation_errors(self) -> "list[str]": @property def has_errors(self) -> bool: """Check if there are validation errors.""" - return len(self.validation_errors) > 0 + if self._processed_state is Empty: + return False + return bool(self._processed_state.validation_errors) def returns_rows(self) -> bool: """Check if statement returns rows. @@ -673,12 +704,20 @@ def compile(self) -> "tuple[str, Any]": if self._processed_state is Empty: try: config = self._statement_config - raw_sql = self._raw_sql + raw_sql = self._get_raw_sql() params = self._named_parameters or self._positional_parameters is_many = self._is_many param_fingerprint = structural_fingerprint(params, is_many=is_many) + pipeline_fingerprint: Any | None = ( + None if config.parameter_config.needs_static_script_compilation else param_fingerprint + ) compiled_result = pipeline.compile_with_pipeline( - config, raw_sql, params, is_many=is_many, expression=self._raw_expression + config, + raw_sql, + params, + is_many=is_many, + expression=self._raw_expression, + param_fingerprint=pipeline_fingerprint, ) self._processed_state = self._build_processed_state( @@ -860,13 +899,10 @@ def _build_processed_state( return state def _handle_compile_failure(self, error: Exception) -> ProcessedState: - import traceback - - traceback.print_exc() - logger.debug("Processing failed, using fallback: %s", error) + logger.debug("Processing failed, using fallback: %s", error, exc_info=(type(error), error, error.__traceback__)) params = self._named_parameters or self._positional_parameters return self._build_processed_state( - compiled_sql=self._raw_sql, + compiled_sql=self._get_raw_sql(), execution_parameters=self._named_parameters or self._positional_parameters, parsed_expression=None, operation_type="COMMAND", @@ -1533,11 +1569,16 @@ def builder(self, dialect: "DialectType | None" = None) -> "QueryBuilder": self._statement_config.parameter_validator ) raw_params = self.parameters + raw_sql_for_builder = self._get_raw_sql() converted_sql, converted_params = converter.convert_placeholder_style( - self._raw_sql, raw_params, ParameterStyle.NAMED_COLON, is_many=False + raw_sql_for_builder, raw_params, ParameterStyle.NAMED_COLON, is_many=False ) - if self._raw_expression is not None and converted_sql == self._raw_sql and (builder_dialect == self._dialect): + if ( + self._raw_expression is not None + and converted_sql == raw_sql_for_builder + and (builder_dialect == self._dialect) + ): expression = self._raw_expression.copy() else: try: @@ -1606,7 +1647,7 @@ def __hash__(self) -> int: if self._hash is None: positional_tuple = tuple(self._positional_parameters) named_tuple = tuple(sorted(self._named_parameters.items())) if self._named_parameters else () - raw_sql = self._raw_sql + raw_sql = self._get_raw_sql() is_many = self._is_many is_script = self._is_script self._hash = hash((raw_sql, positional_tuple, named_tuple, is_many, is_script)) @@ -1617,7 +1658,7 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, SQL): return False return ( - self._raw_sql == other._raw_sql + self._get_raw_sql() == other._get_raw_sql() and self._positional_parameters == other._positional_parameters and self._named_parameters == other._named_parameters and self._is_many == other._is_many @@ -1640,7 +1681,7 @@ def __repr__(self) -> str: flags.append("is_script") flags_str = f", {', '.join(flags)}" if flags else "" - return f"SQL({self._raw_sql!r}{params_str}{flags_str})" + return f"SQL({self._get_raw_sql()!r}{params_str}{flags_str})" @mypyc_attr(allow_interpreted_subclasses=False) @@ -1769,7 +1810,7 @@ def replace(self, **kwargs: Any) -> "StatementConfig": New StatementConfig instance with updated attributes """ for key in kwargs: - if key not in SQL_CONFIG_SLOTS: + if key not in _PUBLIC_CONFIG_FIELDS: msg = f"{key!r} is not a field in {type(self).__name__}" raise TypeError(msg) @@ -1811,7 +1852,7 @@ def __hash__(self) -> int: self.enable_parameter_type_wrapping, self.enable_caching, str(self.dialect), - self.parameter_config.hash(), + hash(self.parameter_config), self.execution_mode, self.output_transformer, self.statement_transformers, diff --git a/sqlspec/data_dictionary/_registry.py b/sqlspec/data_dictionary/_registry.py index 6c822fae2..2d437db91 100644 --- a/sqlspec/data_dictionary/_registry.py +++ b/sqlspec/data_dictionary/_registry.py @@ -8,7 +8,7 @@ _DIALECT_CONFIGS: dict[str, "DialectConfig"] = {} -_DIALECTS_LOADED = False +_DIALECTS_LOADED: bool = False DIALECT_ALIASES: dict[str, str] = { "postgresql": "postgres", @@ -38,7 +38,7 @@ def _load_default_dialects() -> None: if _DIALECTS_LOADED: return importlib.import_module("sqlspec.data_dictionary.dialects") - _DIALECTS_LOADED = True # pyright: ignore + _DIALECTS_LOADED = True def register_dialect(config: "DialectConfig") -> None: diff --git a/sqlspec/dialects/__init__.py b/sqlspec/dialects/__init__.py index f4b012222..121bafddb 100644 --- a/sqlspec/dialects/__init__.py +++ b/sqlspec/dialects/__init__.py @@ -14,9 +14,9 @@ from sqlspec.dialects.postgres import ParadeDB, PGVector from sqlspec.dialects.spanner import Spangres, Spanner +__all__ = ("PGVector", "ParadeDB", "Spangres", "Spanner") + Dialect.classes["pgvector"] = PGVector Dialect.classes["paradedb"] = ParadeDB Dialect.classes["spanner"] = Spanner Dialect.classes["spangres"] = Spangres - -__all__ = ("PGVector", "ParadeDB", "Spangres", "Spanner") diff --git a/sqlspec/dialects/spanner/_generators.py b/sqlspec/dialects/spanner/_generators.py index b4a071bab..78e141a53 100644 --- a/sqlspec/dialects/spanner/_generators.py +++ b/sqlspec/dialects/spanner/_generators.py @@ -161,15 +161,9 @@ def _bq_create_transform(self: Any, expression: exp.Create) -> str: if dialect_name == "Spanner" and expression.this and expression.kind == "TABLE": properties = expression.args.get("properties") if properties: - # Re-order properties so Spanner ones stay at the schema boundary - new_expressions = [] - for p in properties.expressions: - if _is_post_schema_spanner_property(p): - # Force to POST_SCHEMA if it's a Spanner property - new_expressions.append(p) - else: - new_expressions.append(p) - properties.set("expressions", new_expressions) + spanner_props = [p for p in properties.expressions if _is_post_schema_spanner_property(p)] + other_props = [p for p in properties.expressions if not _is_post_schema_spanner_property(p)] + properties.set("expressions", other_props + spanner_props) if _original_bq_create_transform is not None: return str(_original_bq_create_transform(self, expression)) diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index c3807647d..7f9a0f85f 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -1105,6 +1105,8 @@ def _stmt_cache_prepare_direct(self, statement: str, params: "tuple[Any, ...] | needs_rebind = type(params[0]) in coercion_types else: needs_rebind = any(type(p) in coercion_types for p in params) + if not needs_rebind and _CACHED_NAMED_STYLES.intersection(cached.parameter_profile.styles): + needs_rebind = True if needs_rebind: rebound_params = self.stmt_cache_rebind(params, cached) params_are_simple = False @@ -1985,7 +1987,7 @@ def _generate_compilation_cache_key( else () ) context_hash = hash(( - config.parameter_config.hash(), + hash(config.parameter_config), config.dialect, statement.is_script, statement.is_many, diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index 36a81a4ef..b2ca8c5b6 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -50,11 +50,11 @@ VersionInfo, ) +__all__ = ("SyncDataDictionaryBase", "SyncDriverAdapterBase", "SyncPoolConnectionContext", "SyncPoolSessionFactory") + _LOGGER_NAME: Final[str] = "sqlspec.driver" logger = get_logger(_LOGGER_NAME) -__all__ = ("SyncDataDictionaryBase", "SyncDriverAdapterBase", "SyncPoolConnectionContext", "SyncPoolSessionFactory") - EMPTY_FILTERS: Final["list[StatementFilter]"] = [] @@ -214,19 +214,18 @@ def dispatch_statement_execution(self, statement: "SQL", connection: "Any") -> " exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(connection) as cursor: # Logic mirrors the instrumentation path below but without telemetry - if statement.is_script: + special_result = self.dispatch_special_handling(cursor, statement) + if special_result is not None: + result = special_result + elif statement.is_script: execution_result = self.dispatch_execute_script(cursor, statement) result = self.build_statement_result(statement, execution_result) elif statement.is_many: execution_result = self.dispatch_execute_many(cursor, statement) result = self.build_statement_result(statement, execution_result) else: - special_result = self.dispatch_special_handling(cursor, statement) - if special_result is not None: - result = special_result - else: - execution_result = self.dispatch_execute(cursor, statement) - result = self.build_statement_result(statement, execution_result) + execution_result = self.dispatch_execute(cursor, statement) + result = self.build_statement_result(statement, execution_result) self._check_pending_exception(exc_handler) assert result is not None return result diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index 37457a5d1..c5afcae0f 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -2,8 +2,6 @@ from contextlib import contextmanager from typing import Any, Final -STACK_SQL_PREVIEW_LIMIT: Final[int] = 120 - __all__ = ( "SQLSTATE_EXCEPTION_MAP", "CheckViolationError", @@ -48,6 +46,8 @@ "map_sqlstate_to_exception", ) +STACK_SQL_PREVIEW_LIMIT: Final[int] = 120 + class SQLSpecError(Exception): """Base exception class for SQLSpec exceptions.""" @@ -465,10 +465,7 @@ def wrap_exceptions( yield except Exception as exc: - if suppress is not None and ( - (isinstance(suppress, type) and isinstance(exc, suppress)) - or (isinstance(suppress, tuple) and isinstance(exc, suppress)) - ): + if suppress is not None and isinstance(exc, suppress): return if isinstance(exc, SQLSpecError): diff --git a/sqlspec/extensions/adk/artifact/service.py b/sqlspec/extensions/adk/artifact/service.py index e8a277ec6..0ca096efe 100644 --- a/sqlspec/extensions/adk/artifact/service.py +++ b/sqlspec/extensions/adk/artifact/service.py @@ -27,9 +27,10 @@ from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore +__all__ = ("SQLSpecArtifactService",) + logger = get_logger("sqlspec.extensions.adk.artifact.service") -__all__ = ("SQLSpecArtifactService",) # Matches path traversal and absolute path components _UNSAFE_PATH_CHARS = re.compile(r"(?:^|/)\.\.(?:/|$)|[\x00]") diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py index 0ad05724f..9be54b862 100644 --- a/sqlspec/extensions/adk/artifact/store.py +++ b/sqlspec/extensions/adk/artifact/store.py @@ -21,11 +21,12 @@ from sqlspec.config import DatabaseConfigProtocol from sqlspec.extensions.adk.artifact._types import ArtifactRecord +__all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore") + ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") logger = get_logger("sqlspec.extensions.adk.artifact.store") -__all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore") VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") MAX_TABLE_NAME_LENGTH: Final = 63 diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index d25161904..58bbb56b7 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -49,6 +49,7 @@ def session_to_record(session: "Session") -> SessionRecord: app_name=session.app_name, user_id=session.user_id, state=filter_temp_state(session.state), + # create_time is not exposed by ADK Session; a re-upsert of a restored session will reset this timestamp. create_time=datetime.now(timezone.utc), update_time=datetime.fromtimestamp(session.last_update_time, tz=timezone.utc), ) @@ -68,6 +69,8 @@ def compute_update_marker(update_time: "datetime") -> str: """ if update_time.tzinfo is not None: update_time = update_time.astimezone(timezone.utc) + else: + update_time = update_time.replace(tzinfo=timezone.utc) return update_time.isoformat(timespec="microseconds") @@ -205,8 +208,8 @@ def merge_scoped_state( Merged state dict. """ merged = dict(session_state) - if app_state: + if app_state is not None: merged.update(app_state) - if user_state: + if user_state is not None: merged.update(user_state) return merged diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index fafea6d12..d11ccb891 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -17,8 +17,6 @@ from google.adk.sessions import Session from google.genai import types -logger = get_logger("sqlspec.extensions.adk.memory.converters") - __all__ = ( "event_to_memory_record", "extract_content_text", @@ -28,6 +26,8 @@ "session_to_memory_records", ) +logger = get_logger("sqlspec.extensions.adk.memory.converters") + def extract_content_text(content: "types.Content") -> str: """Extract plain text from ADK Content for search indexing. diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index 4e977992f..bb43e01ff 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -20,10 +20,10 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore -logger = get_logger("sqlspec.extensions.adk.memory.service") - __all__ = ("SQLSpecMemoryService", "SQLSpecSyncMemoryService") +logger = get_logger("sqlspec.extensions.adk.memory.service") + class SQLSpecMemoryService(BaseMemoryService): """SQLSpec-backed implementation of BaseMemoryService. diff --git a/sqlspec/extensions/adk/memory/store.py b/sqlspec/extensions/adk/memory/store.py index 8301df2f7..3bb8b96e6 100644 --- a/sqlspec/extensions/adk/memory/store.py +++ b/sqlspec/extensions/adk/memory/store.py @@ -13,11 +13,12 @@ from sqlspec.config import DatabaseConfigProtocol from sqlspec.extensions.adk.memory._types import MemoryRecord +__all__ = ("BaseAsyncADKMemoryStore", "BaseSyncADKMemoryStore") + ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") logger = get_logger("sqlspec.extensions.adk.memory.store") -__all__ = ("BaseAsyncADKMemoryStore", "BaseSyncADKMemoryStore") VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") diff --git a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py index 1dd61a676..1c3233d36 100644 --- a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py @@ -17,10 +17,10 @@ from sqlspec.extensions.adk.store import BaseAsyncADKStore from sqlspec.migrations.context import MigrationContext -logger = get_logger("sqlspec.migrations.adk.tables") - __all__ = ("down", "up") +logger = get_logger("sqlspec.migrations.adk.tables") + def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKStore]": """Get the appropriate store class based on the config's module path. diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 9f7d2a36e..928dec605 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -21,10 +21,10 @@ from sqlspec.extensions.adk.store import BaseAsyncADKStore -logger = get_logger("sqlspec.extensions.adk.service") - __all__ = ("SQLSpecSessionService",) +logger = get_logger("sqlspec.extensions.adk.service") + class SQLSpecSessionService(BaseSessionService): """SQLSpec-backed implementation of BaseSessionService. diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 04ae2d874..1b473e82c 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -15,11 +15,12 @@ from sqlspec.config import DatabaseConfigProtocol from sqlspec.extensions.adk._types import EventRecord, SessionRecord +__all__ = ("BaseAsyncADKStore", "BaseSyncADKStore") + ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") logger = get_logger("sqlspec.extensions.adk.store") -__all__ = ("BaseAsyncADKStore", "BaseSyncADKStore") VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") diff --git a/sqlspec/extensions/events/_channel.py b/sqlspec/extensions/events/_channel.py index 4f08b55cc..b77ac203c 100644 --- a/sqlspec/extensions/events/_channel.py +++ b/sqlspec/extensions/events/_channel.py @@ -24,8 +24,6 @@ from sqlspec.extensions.events._protocols import AsyncEventHandler, SyncEventHandler from sqlspec.observability import ObservabilityRuntime -logger = get_logger("sqlspec.events.channel") - __all__ = ( "AsyncEventChannel", "AsyncEventListener", @@ -36,6 +34,9 @@ "resolve_poll_interval", ) +logger = get_logger("sqlspec.events.channel") + + _ADAPTER_MODULE_PARTS = 3 @@ -385,6 +386,9 @@ def ack(self, event_id: str) -> None: def nack(self, event_id: str) -> None: """Return an event to the queue for redelivery.""" + if not self._backend.supports_sync: + msg = "Current events backend does not support sync nack" + raise ImproperConfigurationError(msg) span = _start_event_span(self._runtime, "nack", self._backend_name, self._adapter_name, mode="sync") try: self._backend.nack(event_id) @@ -633,6 +637,9 @@ async def ack(self, event_id: str) -> None: async def nack(self, event_id: str) -> None: """Return an event to the queue for redelivery.""" + if not self._backend.supports_async: + msg = "Current events backend does not support async nack" + raise ImproperConfigurationError(msg) span = _start_event_span(self._runtime, "nack", self._backend_name, self._adapter_name, mode="async") try: await self._backend.nack(event_id) diff --git a/sqlspec/extensions/events/_queue.py b/sqlspec/extensions/events/_queue.py index f1c9efce1..adb566075 100644 --- a/sqlspec/extensions/events/_queue.py +++ b/sqlspec/extensions/events/_queue.py @@ -20,9 +20,10 @@ from sqlspec.config import DatabaseConfigProtocol from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase +__all__ = ("AsyncTableEventQueue", "SyncTableEventQueue", "build_queue_backend") + logger = get_logger("sqlspec.events.queue") -__all__ = ("AsyncTableEventQueue", "SyncTableEventQueue", "build_queue_backend") _PENDING_STATUS = "pending" _LEASED_STATUS = "leased" diff --git a/sqlspec/extensions/events/_store.py b/sqlspec/extensions/events/_store.py index f2df503fe..022b58d99 100644 --- a/sqlspec/extensions/events/_store.py +++ b/sqlspec/extensions/events/_store.py @@ -9,9 +9,10 @@ if TYPE_CHECKING: from sqlspec.config import DatabaseConfigProtocol +__all__ = ("BaseEventQueueStore", "normalize_event_channel_name", "normalize_queue_table_name") + ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") -__all__ = ("BaseEventQueueStore", "normalize_event_channel_name", "normalize_queue_table_name") _IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") diff --git a/sqlspec/extensions/events/migrations/0001_create_event_queue.py b/sqlspec/extensions/events/migrations/0001_create_event_queue.py index 61fdc8cef..7259279d6 100644 --- a/sqlspec/extensions/events/migrations/0001_create_event_queue.py +++ b/sqlspec/extensions/events/migrations/0001_create_event_queue.py @@ -11,10 +11,10 @@ from sqlspec.extensions.events._store import BaseEventQueueStore from sqlspec.migrations.context import MigrationContext -logger = get_logger("sqlspec.events.migrations.queue") - __all__ = ("down", "up") +logger = get_logger("sqlspec.events.migrations.queue") + async def up(context: "MigrationContext | None" = None) -> "list[str]": """Return SQL statements that provision the queue table and indexes.""" diff --git a/sqlspec/extensions/flask/_utils.py b/sqlspec/extensions/flask/_utils.py index d6f8613f9..40d110713 100644 --- a/sqlspec/extensions/flask/_utils.py +++ b/sqlspec/extensions/flask/_utils.py @@ -64,7 +64,9 @@ def get_or_create_session(config_state: "FlaskConfigState", portal: "Portal | No connection = get_context_value(g, config_state.connection_key) session = config_state.config.driver_type( - connection=connection, statement_config=config_state.config.statement_config + connection=connection, + statement_config=config_state.config.statement_config, + driver_features=config_state.config.driver_features, ) set_context_value(g, cache_key, session) diff --git a/sqlspec/extensions/litestar/handlers.py b/sqlspec/extensions/litestar/handlers.py index 93114d8de..dbd5327c4 100644 --- a/sqlspec/extensions/litestar/handlers.py +++ b/sqlspec/extensions/litestar/handlers.py @@ -25,8 +25,6 @@ from sqlspec.config import DatabaseConfigProtocol, DriverT from sqlspec.typing import ConnectionT, PoolT -SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE} - __all__ = ( "SESSION_TERMINUS_ASGI_EVENTS", "autocommit_handler_maker", @@ -37,6 +35,8 @@ "session_provider_maker", ) +SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE} + def manual_handler_maker( connection_scope_key: str, is_async: bool = False diff --git a/sqlspec/extensions/litestar/migrations/0001_create_session_table.py b/sqlspec/extensions/litestar/migrations/0001_create_session_table.py index 0e3466491..a988e9cd1 100644 --- a/sqlspec/extensions/litestar/migrations/0001_create_session_table.py +++ b/sqlspec/extensions/litestar/migrations/0001_create_session_table.py @@ -10,10 +10,10 @@ from sqlspec.extensions.litestar.store import BaseSQLSpecStore from sqlspec.migrations.context import MigrationContext -logger = get_logger("sqlspec.migrations.litestar.session") - __all__ = ("down", "up") +logger = get_logger("sqlspec.migrations.litestar.session") + def _get_store_class(context: "MigrationContext | None") -> "type[BaseSQLSpecStore]": """Get the appropriate store class based on the config's module path. diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 9188d7615..56c34349b 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -20,6 +20,7 @@ SyncConfigT, SyncDatabaseConfig, ) +from sqlspec.core._correlation import CorrelationExtractor from sqlspec.core._pagination import OffsetPagination from sqlspec.core.sqlcommenter import SQLCommenterContext from sqlspec.exceptions import ImproperConfigurationError, NotFoundError @@ -57,6 +58,22 @@ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.loader import SQLFileLoader +__all__ = ( + "CORRELATION_STATE_KEY", + "DEFAULT_COMMIT_MODE", + "DEFAULT_CONNECTION_KEY", + "DEFAULT_CORRELATION_HEADER", + "DEFAULT_POOL_KEY", + "DEFAULT_SESSION_KEY", + "TRACE_CONTEXT_FALLBACK_HEADERS", + "CommitMode", + "CorrelationMiddleware", + "PluginConfigState", + "SQLSpecPlugin", + "_OffsetPaginationSchemaPlugin", + "not_found_error_handler", +) + logger = get_logger("sqlspec.extensions.litestar") CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"] @@ -78,22 +95,6 @@ CORRELATION_STATE_KEY = "sqlspec_correlation_id" _LITESTAR_NUMPY_ARRAY_TYPE: type[Any] | None = None -__all__ = ( - "CORRELATION_STATE_KEY", - "DEFAULT_COMMIT_MODE", - "DEFAULT_CONNECTION_KEY", - "DEFAULT_CORRELATION_HEADER", - "DEFAULT_POOL_KEY", - "DEFAULT_SESSION_KEY", - "TRACE_CONTEXT_FALLBACK_HEADERS", - "CommitMode", - "CorrelationMiddleware", - "PluginConfigState", - "SQLSpecPlugin", - "_OffsetPaginationSchemaPlugin", - "not_found_error_handler", -) - def not_found_error_handler(_request: "Request[Any, Any, Any]", exc: NotFoundError) -> NoReturn: """Translate :class:`sqlspec.exceptions.NotFoundError` into Litestar's HTTP 404. @@ -107,29 +108,30 @@ def not_found_error_handler(_request: "Request[Any, Any, Any]", exc: NotFoundErr class CorrelationMiddleware: - __slots__ = ("_app", "_headers") + __slots__ = ("_app", "_extractor", "_headers") def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None: self._app = app self._headers = headers + self._extractor = ( + CorrelationExtractor( + primary_header=headers[0], + additional_headers=headers[1:] if len(headers) > 1 else None, + auto_trace_headers=False, + ) + if headers + else None + ) async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: scope_type = scope.get("type") - if str(scope_type) != "http" or not self._headers: + if str(scope_type) != "http" or self._extractor is None: await self._app(scope, receive, send) return - header_value: str | None = None raw_headers = scope.get("headers") or [] - for header in self._headers: - for name, value in raw_headers: - if name.decode().lower() == header: - header_value = value.decode() - break - if header_value: - break - if not header_value: - header_value = CorrelationContext.generate() + header_dict = {name.decode().lower(): value.decode() for name, value in raw_headers} + header_value = self._extractor.extract(lambda header: header_dict.get(header)) previous_correlation_id = CorrelationContext.get() CorrelationContext.set(header_value) @@ -417,17 +419,13 @@ def store_sqlspec_in_state() -> None: if sqlspec_decoders: app_config.type_decoders = [*(app_config.type_decoders or []), *sqlspec_decoders] + new_middlewares: list[DefineMiddleware] = [] if self._correlation_headers: - middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers) - existing_middleware = list(app_config.middleware or []) - existing_middleware.append(middleware) - app_config.middleware = existing_middleware - + new_middlewares.append(DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)) if self._enable_sqlcommenter_middleware: - sc_middleware = DefineMiddleware(SQLCommenterMiddleware) - existing_middleware = list(app_config.middleware or []) - existing_middleware.append(sc_middleware) - app_config.middleware = existing_middleware + new_middlewares.append(DefineMiddleware(SQLCommenterMiddleware)) + if new_middlewares: + app_config.middleware = [*(app_config.middleware or []), *new_middlewares] log_with_context( logger, @@ -853,7 +851,8 @@ def _get_plugin_state( return state self._raise_config_not_found(key) - return None + msg = "unreachable" + raise AssertionError(msg) def _get_available_keys(self) -> "list[str]": """Get a list of all available configuration keys for error messages.""" @@ -873,7 +872,7 @@ def _validate_dependency_keys(self) -> None: if len(set(pool_keys)) != len(pool_keys): self._raise_duplicate_pool_keys() - def _raise_missing_connection(self, connection_key: str) -> None: + def _raise_missing_connection(self, connection_key: str) -> NoReturn: """Raise error when connection is not found in scope.""" msg = f"No database connection found in scope for key '{connection_key}'. " msg += "Ensure the connection dependency is properly configured and available." diff --git a/sqlspec/extensions/litestar/store.py b/sqlspec/extensions/litestar/store.py index a57ec6477..a737df1f0 100644 --- a/sqlspec/extensions/litestar/store.py +++ b/sqlspec/extensions/litestar/store.py @@ -14,13 +14,14 @@ from typing_extensions import Self +__all__ = ("BaseSQLSpecStore",) + ConfigT = TypeVar("ConfigT") logger = get_logger("sqlspec.extensions.litestar.store") -__all__ = ("BaseSQLSpecStore",) VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") MAX_TABLE_NAME_LENGTH: Final = 63 diff --git a/sqlspec/extensions/sanic/extension.py b/sqlspec/extensions/sanic/extension.py index fa8257ff1..6b433faa5 100644 --- a/sqlspec/extensions/sanic/extension.py +++ b/sqlspec/extensions/sanic/extension.py @@ -62,7 +62,7 @@ class SQLSpecPlugin: db_ext = SQLSpecPlugin(sqlspec, app) """ - __slots__ = ("_config_states", "_lifecycle_listeners_added", "_request_middleware_added", "_sqlspec") + __slots__ = ("_config_states", "_extractor", "_lifecycle_listeners_added", "_request_middleware_added", "_sqlspec") def __init__(self, sqlspec: SQLSpec, app: "Sanic[Any, Any] | None" = None) -> None: """Initialize SQLSpec Sanic extension. @@ -81,6 +81,17 @@ def __init__(self, sqlspec: SQLSpec, app: "Sanic[Any, Any] | None" = None) -> No state = self._create_config_state(cfg, settings) self._config_states.append(state) + correlation_state = self._first_correlation_state() + self._extractor = ( + CorrelationExtractor( + primary_header=correlation_state.correlation_header, + additional_headers=correlation_state.correlation_headers, + auto_trace_headers=correlation_state.auto_trace_headers, + ) + if correlation_state is not None + else None + ) + if app is not None: self.init_app(app) @@ -267,16 +278,10 @@ def _set_correlation_context(self, request: Any) -> None: Args: request: Sanic request instance. """ - config_state = self._first_correlation_state() - if config_state is None: + if self._extractor is None: return - extractor = CorrelationExtractor( - primary_header=config_state.correlation_header, - additional_headers=config_state.correlation_headers, - auto_trace_headers=config_state.auto_trace_headers, - ) - correlation_id = extractor.extract(lambda header: request.headers.get(header)) + correlation_id = self._extractor.extract(lambda header: request.headers.get(header)) set_context_value(request.ctx, "_sqlspec_previous_correlation_id", CorrelationContext.get()) set_context_value(request.ctx, "_sqlspec_correlation_id", correlation_id) set_context_value(request.ctx, "correlation_id", correlation_id) diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 0521508ad..5120f61ca 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -163,7 +163,15 @@ class SQLFileLoader: and retrieves them by name. """ - __slots__ = ("_files", "_queries", "_query_to_file", "_runtime", "encoding", "storage_registry") + __slots__ = ( + "_compiled_statements", + "_files", + "_queries", + "_query_to_file", + "_runtime", + "encoding", + "storage_registry", + ) def __init__( self, @@ -182,6 +190,7 @@ def __init__( self.encoding = encoding self.storage_registry = storage_registry or default_storage_registry + self._compiled_statements: dict[str, SQL] = {} self._queries: dict[str, NamedStatement] = {} self._files: dict[str, SQLFile] = {} self._query_to_file: dict[str, str] = {} @@ -232,6 +241,11 @@ def _generate_file_cache_key(self, path: str | Path) -> str: path_hash = hashlib.md5(path_str.encode(), usedforsecurity=False).hexdigest() return f"file:{path_hash[:16]}" + @staticmethod + def _compute_checksum(content: str) -> str: + """Compute MD5 checksum from already-read file content.""" + return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest() + def _calculate_file_checksum(self, path: str | Path) -> str: """Calculate checksum for file content validation. @@ -245,7 +259,7 @@ def _calculate_file_checksum(self, path: str | Path) -> str: SQLFileParseError: If file cannot be read. """ try: - return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest() + return self._compute_checksum(self._read_file_content(path)) except Exception as e: raise SQLFileParseError(str(path), str(path), e) from e @@ -266,6 +280,10 @@ def _is_file_unchanged(self, path: str | Path, cached_file: SQLFileCacheEntry) - else: return current_checksum == cached_file.sql_file.checksum + def _is_file_content_unchanged(self, content: str, cached_file: SQLFileCacheEntry) -> bool: + """Check if already-read file content matches cached checksum.""" + return self._compute_checksum(content) == cached_file.sql_file.checksum + def _read_file_content(self, path: str | Path) -> str: """Read file content using storage backend. @@ -498,29 +516,33 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> boo cache = get_cache() cached_file = cache.get_file(cache_key_str) - if ( - cached_file is not None - and isinstance(cached_file, SQLFileCacheEntry) - and self._is_file_unchanged(file_path, cached_file) - ): - self._files[path_str] = cached_file.sql_file - for name, statement in cached_file.parsed_statements.items(): - namespaced_name = f"{namespace}.{name}" if namespace else name - if namespaced_name in self._queries: - existing_file = self._query_to_file.get(namespaced_name, "unknown") - if existing_file != path_str: - raise SQLFileParseError( - path_str, - path_str, - ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"), - ) - self._queries[namespaced_name] = statement - self._query_to_file[namespaced_name] = path_str - if runtime is not None: - runtime.increment_metric("loader.cache.hit") - return True - - self._load_file_without_cache(file_path, namespace) + if cached_file is not None and isinstance(cached_file, SQLFileCacheEntry): + try: + file_content = self._read_file_content(file_path) + except Exception: + file_content = None + + if file_content is not None and self._is_file_content_unchanged(file_content, cached_file): + self._files[path_str] = cached_file.sql_file + for name, statement in cached_file.parsed_statements.items(): + namespaced_name = f"{namespace}.{name}" if namespace else name + if namespaced_name in self._queries: + existing_file = self._query_to_file.get(namespaced_name, "unknown") + if existing_file != path_str: + raise SQLFileParseError( + path_str, + path_str, + ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"), + ) + self._queries[namespaced_name] = statement + self._query_to_file[namespaced_name] = path_str + if runtime is not None: + runtime.increment_metric("loader.cache.hit") + return True + + self._load_file_without_cache(file_path, namespace, content=file_content) + else: + self._load_file_without_cache(file_path, namespace) if path_str in self._files: sql_file = self._files[path_str] @@ -541,16 +563,20 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> boo return True - def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) -> None: + def _load_file_without_cache( + self, file_path: str | Path, namespace: "str | None", content: "str | None" = None + ) -> None: """Load a single SQL file without using cache. Args: file_path: Path to the SQL file. namespace: Optional namespace prefix for queries. + content: Pre-read file content. If provided, skips the disk read. """ path_str = str(file_path) runtime = self._runtime - content = self._read_file_content(file_path) + if content is None: + content = self._read_file_content(file_path) statements = self._parse_sql_content(content, path_str) if not statements: @@ -663,6 +689,7 @@ def has_query(self, name: str) -> bool: def clear_cache(self) -> None: """Clear all cached files and queries.""" + self._compiled_statements.clear() self._files.clear() self._queries.clear() self._query_to_file.clear() @@ -713,6 +740,8 @@ def get_sql(self, name: str) -> "SQL": if safe_name not in self._queries: self._raise_statement_not_found(name, safe_name) + if safe_name in self._compiled_statements: + return self._compiled_statements[safe_name] parsed_statement = self._queries[safe_name] sqlglot_dialect = None @@ -724,4 +753,5 @@ def get_sql(self, name: str) -> "SQL": sql.compile() except Exception as exc: raise SQLFileParseError(name=name, path="", original_error=exc) from exc + self._compiled_statements[safe_name] = sql return sql diff --git a/sqlspec/migrations/base.py b/sqlspec/migrations/base.py index d71178b86..d03b1e897 100644 --- a/sqlspec/migrations/base.py +++ b/sqlspec/migrations/base.py @@ -927,6 +927,6 @@ def stamp(self, revision: str) -> "None | Awaitable[None]": ... @abstractmethod - def revision(self, message: str, file_type: str = "sql") -> "None | Awaitable[None]": + def revision(self, message: str, file_type: str | None = None) -> "None | Awaitable[None]": """Create a new migration file.""" ... diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index cca76fe74..a1693b0dc 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -722,6 +722,7 @@ def upgrade( self._validate_migration_schema(driver) self.tracker.ensure_tracking_table(driver) + # config auto_sync=False cannot be overridden by the call-site flag. if auto_sync and self.config.migration_config.get("auto_sync", True): self._synchronize_version_records( driver, use_logger=ul, echo=echo_value, summary_only=summary_value @@ -1661,6 +1662,7 @@ async def upgrade( await self._validate_migration_schema(driver) await self.tracker.ensure_tracking_table(driver) + # config auto_sync=False cannot be overridden by the call-site flag. if auto_sync and self.config.migration_config.get("auto_sync", True): await self._synchronize_version_records( driver, use_logger=ul, echo=echo_value, summary_only=summary_value diff --git a/sqlspec/migrations/context.py b/sqlspec/migrations/context.py index 90d8ab0b8..31a34bbc8 100644 --- a/sqlspec/migrations/context.py +++ b/sqlspec/migrations/context.py @@ -14,10 +14,10 @@ if TYPE_CHECKING: from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase -logger = get_logger("sqlspec.migrations.context") - __all__ = ("MigrationContext",) +logger = get_logger("sqlspec.migrations.context") + def _normalize_dialect_name(dialect: Any | None) -> "str | None": if dialect is None: diff --git a/sqlspec/observability/_common.py b/sqlspec/observability/_common.py index ec9bb3957..535cc3fe5 100644 --- a/sqlspec/observability/_common.py +++ b/sqlspec/observability/_common.py @@ -74,4 +74,4 @@ def compute_sql_hash(sql: str) -> str: Returns: SHA256 hash prefix. """ - return hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16] + return hashlib.sha256(sql.encode("utf-8"), usedforsecurity=False).hexdigest()[:16] diff --git a/sqlspec/observability/_config.py b/sqlspec/observability/_config.py index e7c638b95..c8387757e 100644 --- a/sqlspec/observability/_config.py +++ b/sqlspec/observability/_config.py @@ -3,6 +3,8 @@ from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any, Protocol, cast +from sqlspec.observability._dispatcher import LifecycleHook + if TYPE_CHECKING: # pragma: no cover - import cycle guard from sqlspec.config import LifecycleConfig from sqlspec.observability._formatters._base import CloudLogFormatter @@ -18,8 +20,6 @@ "TelemetryConfig", ) -LifecycleHook = Callable[[dict[str, Any]], None] - class StatementObserver(Protocol): """Protocol for callbacks that receive SQL execution events.""" @@ -328,12 +328,9 @@ def _merge_sampling(base: "SamplingConfig | None", override: "SamplingConfig | N if base is None: return override.copy() merged = base.copy() - if override.sample_rate != 1.0: - merged.sample_rate = override.sample_rate - if override.deterministic: - merged.deterministic = override.deterministic - if override.force_sample_on_error: - merged.force_sample_on_error = override.force_sample_on_error + merged.sample_rate = override.sample_rate + merged.deterministic = override.deterministic + merged.force_sample_on_error = override.force_sample_on_error if override.force_sample_slow_queries_ms is not None: merged.force_sample_slow_queries_ms = override.force_sample_slow_queries_ms return merged diff --git a/sqlspec/observability/_runtime.py b/sqlspec/observability/_runtime.py index 6934ec22f..6c7aa927e 100644 --- a/sqlspec/observability/_runtime.py +++ b/sqlspec/observability/_runtime.py @@ -265,9 +265,14 @@ def emit_statement_event( if not self._statement_observers: return + correlation_id = CorrelationContext.get() + sampled = True + if self.config.sampling is not None: + sampled = self.config.sampling.should_sample(correlation_id=correlation_id, duration_ms=duration_s * 1000) + if not sampled: + return sanitized_sql = self._redact_sql(sql) sanitized_params = self._redact_parameters(parameters) - correlation_id = CorrelationContext.get() logging_config = self.config.logging db_system = resolve_db_system(self.config_name) sql_hash = None @@ -301,6 +306,7 @@ def emit_statement_event( sql_original_length=sql_original_length, trace_id=trace_id, span_id=span_id, + sampled=sampled, ) for observer in self._statement_observers: observer(event) diff --git a/sqlspec/protocols.py b/sqlspec/protocols.py index 0e1363f45..5648955d2 100644 --- a/sqlspec/protocols.py +++ b/sqlspec/protocols.py @@ -626,8 +626,9 @@ async def write_arrow_async(self, path: "str | Path", table: "ArrowTable", **kwa msg = "Async arrow writing not implemented" raise NotImplementedError(msg) + # NOTE: Returns AsyncIterator directly; this is intentionally not async def. def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]": - """Async stream Arrow record batches from matching objects.""" + """Stream Arrow record batches from matching objects.""" msg = "Async arrow streaming not implemented" raise NotImplementedError(msg) diff --git a/sqlspec/storage/backends/base.py b/sqlspec/storage/backends/base.py index 6bed3fb86..bb7d9cfa2 100644 --- a/sqlspec/storage/backends/base.py +++ b/sqlspec/storage/backends/base.py @@ -368,6 +368,7 @@ async def write_arrow_async(self, path: str, table: ArrowTable, **kwargs: Any) - """Write Arrow table to storage asynchronously.""" raise NotImplementedError + # NOTE: Returns AsyncIterator directly; keep in sync with ObjectStoreProtocol. @abstractmethod def stream_arrow_async(self, pattern: str, **kwargs: Any) -> AsyncIterator[ArrowRecordBatch]: """Stream Arrow record batches from storage asynchronously.""" diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 5d5fa41d3..69131d8fc 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -203,10 +203,7 @@ def _resolve_path_for_local_store(self, path: "str | Path") -> str: return str(path) - def read_bytes_sync(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ignore[reportUnusedParameter] - """Read bytes using obstore synchronously.""" - resolved_path = self._resolve_path(path) - + def _read_bytes_resolved_sync(self, resolved_path: str) -> bytes: result = execute_sync_storage_operation( partial(_read_obstore_bytes, self.store, resolved_path), backend=self.backend_type, @@ -222,10 +219,12 @@ def read_bytes_sync(self, path: "str | Path", **kwargs: Any) -> bytes: # pyrigh ) return result - def write_bytes_sync(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] - """Write bytes using obstore synchronously.""" + def read_bytes_sync(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ignore[reportUnusedParameter] + """Read bytes using obstore synchronously.""" resolved_path = self._resolve_path(path) + return self._read_bytes_resolved_sync(resolved_path) + def _write_bytes_resolved_sync(self, resolved_path: str, data: bytes) -> None: execute_sync_storage_operation( partial(self.store.put, resolved_path, data), backend=self.backend_type, @@ -240,6 +239,11 @@ def write_bytes_sync(self, path: "str | Path", data: bytes, **kwargs: Any) -> No path=resolved_path, ) + def write_bytes_sync(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] + """Write bytes using obstore synchronously.""" + resolved_path = self._resolve_path(path) + self._write_bytes_resolved_sync(resolved_path, data) + def read_text_sync(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str: """Read text using obstore synchronously.""" return self.read_bytes_sync(path, **kwargs).decode(encoding) @@ -258,8 +262,11 @@ def list_objects_sync(self, prefix: str = "", recursive: bool = True, **kwargs: resolved_prefix = "" else: resolved_prefix = self.base_path or "" - items = self.store.list_with_delimiter(resolved_prefix) if not recursive else self.store.list(resolved_prefix) - paths = sorted(item["path"] for batch in items for item in batch) + if not recursive: + result = self.store.list_with_delimiter(resolved_prefix) + paths = sorted(item["path"] for item in result["objects"]) + else: + paths = sorted(item["path"] for batch in self.store.list(resolved_prefix) for item in batch) _log_storage_event( "storage.list", backend_type=self.backend_type, @@ -386,38 +393,24 @@ def glob_sync(self, pattern: str, **kwargs: Any) -> "list[str]": def get_metadata_sync(self, path: "str | Path", **kwargs: Any) -> "dict[str, object]": # pyright: ignore[reportUnusedParameter] """Get object metadata using obstore synchronously.""" - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) + resolved_path = self._resolve_path(path) + # Keep in sync with get_metadata_async. try: metadata = self.store.head(resolved_path) except Exception: return {"path": resolved_path, "exists": False} else: - if isinstance(metadata, dict): - result = { - "path": resolved_path, - "exists": True, - "size": metadata.get("size"), - "last_modified": metadata.get("last_modified"), - "e_tag": metadata.get("e_tag"), - "version": metadata.get("version"), - } - if metadata.get("metadata"): - result["custom_metadata"] = metadata["metadata"] - return result - result = { "path": resolved_path, "exists": True, - "size": metadata.size, - "last_modified": metadata.last_modified, - "e_tag": metadata.e_tag, - "version": metadata.version, + "size": metadata.get("size"), + "last_modified": metadata.get("last_modified"), + "e_tag": metadata.get("e_tag"), + "version": metadata.get("version"), } - - if metadata.metadata: - result["custom_metadata"] = metadata.metadata - + if metadata.get("metadata"): + result["custom_metadata"] = metadata["metadata"] return result def is_object_sync(self, path: "str | Path") -> bool: @@ -441,8 +434,8 @@ def is_path_sync(self, path: "str | Path") -> bool: def read_arrow_sync(self, path: "str | Path", **kwargs: Any) -> ArrowTable: """Read Arrow table using obstore synchronously.""" pq = import_pyarrow_parquet() - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - data = self.read_bytes_sync(resolved_path) + resolved_path = self._resolve_path(path) + data = self._read_bytes_resolved_sync(resolved_path) result = cast( "ArrowTable", execute_sync_storage_operation( @@ -465,7 +458,7 @@ def write_arrow_sync(self, path: "str | Path", table: ArrowTable, **kwargs: Any) """Write Arrow table using obstore synchronously.""" pa = import_pyarrow() pq = import_pyarrow_parquet() - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) + resolved_path = self._resolve_path(path) schema = table.schema if any(str(f.type).startswith("decimal64") for f in schema): @@ -490,7 +483,7 @@ def write_arrow_sync(self, path: "str | Path", table: ArrowTable, **kwargs: Any) path=resolved_path, ) buffer.seek(0) - self.write_bytes_sync(resolved_path, buffer.read()) + self._write_bytes_resolved_sync(resolved_path, buffer.read()) _log_storage_event( "storage.write", backend_type=self.backend_type, @@ -625,13 +618,7 @@ def sign_sync( msg = f"Failed to generate signed URL(s) for {resolved_paths}" raise StorageOperationFailedError(msg) from exc - async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ignore[reportUnusedParameter] - """Read bytes from storage asynchronously.""" - if self._is_local_store: - resolved_path = self._resolve_path_for_local_store(path) - else: - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - + async def _read_bytes_resolved_async(self, resolved_path: str) -> bytes: result = await self.store.get_async(resolved_path) bytes_obj = await result.bytes_async() # pyright: ignore[reportAttributeAccessIssue] data = cast("bytes", bytes_obj.to_bytes()) @@ -645,13 +632,12 @@ async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes: # ) return data - async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] - """Write bytes to storage asynchronously.""" - if self._is_local_store: - resolved_path = self._resolve_path_for_local_store(path) - else: - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) + async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ignore[reportUnusedParameter] + """Read bytes from storage asynchronously.""" + resolved_path = self._resolve_path(path) + return await self._read_bytes_resolved_async(resolved_path) + async def _write_bytes_resolved_async(self, resolved_path: str, data: bytes) -> None: await self.store.put_async(resolved_path, data) _log_storage_event( "storage.write", @@ -662,6 +648,11 @@ async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any path=resolved_path, ) + async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] + """Write bytes to storage asynchronously.""" + resolved_path = self._resolve_path(path) + await self._write_bytes_resolved_async(resolved_path, data) + async def stream_read_async( self, path: "str | Path", chunk_size: "int | None" = None, **kwargs: Any ) -> AsyncIterator[bytes]: @@ -815,6 +806,7 @@ async def get_metadata_async(self, path: "str | Path", **kwargs: Any) -> "dict[s resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) result: dict[str, object] = {} + # Keep in sync with get_metadata_sync. try: metadata = await self.store.head_async(resolved_path) result.update({ @@ -839,8 +831,8 @@ async def read_arrow_async(self, path: "str | Path", **kwargs: Any) -> ArrowTabl Uses async_() with storage limiter to offload blocking PyArrow I/O to thread pool. """ pq = import_pyarrow_parquet() - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - data = await self.read_bytes_async(resolved_path) + resolved_path = self._resolve_path(path) + data = await self._read_bytes_resolved_async(resolved_path) # Offload PyArrow parsing to thread pool result = await async_(pq.read_table)(io.BytesIO(data), **kwargs) @@ -862,7 +854,7 @@ async def write_arrow_async(self, path: "str | Path", table: ArrowTable, **kwarg to thread pool, preventing event loop blocking. """ pq = import_pyarrow_parquet() - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) + resolved_path = self._resolve_path(path) def _serialize() -> bytes: buffer = io.BytesIO() @@ -871,7 +863,7 @@ def _serialize() -> bytes: return buffer.read() data = await async_(_serialize)() - await self.write_bytes_async(resolved_path, data) + await self._write_bytes_resolved_async(resolved_path, data) _log_storage_event( "storage.write", diff --git a/sqlspec/storage/registry.py b/sqlspec/storage/registry.py index 62efc8509..f73a84740 100644 --- a/sqlspec/storage/registry.py +++ b/sqlspec/storage/registry.py @@ -85,6 +85,7 @@ def register_alias( if base_path: backend_config["base_path"] = base_path self._alias_configs[alias] = (backend_cls, uri, backend_config) + self.clear_cache(alias) log_with_context( logger, logging.DEBUG, diff --git a/sqlspec/typing.py b/sqlspec/typing.py index 38e5819cf..3ea4f7125 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -1,4 +1,3 @@ -# pyright: ignore[reportAttributeAccessIssue] """Public typing helpers and optional dependency aliases. This module is the supported import surface for SQLSpec typing utilities. @@ -7,9 +6,9 @@ API. """ -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from functools import lru_cache -from typing import Annotated, Any, Literal, Protocol, TypeAlias, TypedDict, _TypedDict # pyright: ignore +from typing import Annotated, Any, Literal, Protocol, TypeAlias, TypedDict from typing_extensions import TypeVar @@ -77,6 +76,82 @@ trace, ) +__all__ = ( + "ALLOYDB_CONNECTOR_INSTALLED", + "ATTRS_INSTALLED", + "CATTRS_INSTALLED", + "CLOUD_SQL_CONNECTOR_INSTALLED", + "FSSPEC_INSTALLED", + "LITESTAR_INSTALLED", + "MSGSPEC_INSTALLED", + "NANOID_INSTALLED", + "NUMPY_INSTALLED", + "OBSTORE_INSTALLED", + "OPENTELEMETRY_INSTALLED", + "ORJSON_INSTALLED", + "PANDAS_INSTALLED", + "PGVECTOR_INSTALLED", + "POLARS_INSTALLED", + "PROMETHEUS_INSTALLED", + "PYARROW_INSTALLED", + "PYDANTIC_INSTALLED", + "PYDANTIC_USE_FAILFAST", + "UNSET", + "UUID_UTILS_INSTALLED", + "ArrowRecordBatch", + "ArrowRecordBatchReader", + "ArrowRecordBatchReaderProtocol", + "ArrowReturnFormat", + "ArrowSchema", + "ArrowSchemaProtocol", + "ArrowTable", + "AttrsInstance", + "BaseModel", + "ColumnMetadata", + "ConnectionT", + "Counter", + "DTOData", + "DataclassProtocol", + "DictLike", + "Empty", + "EmptyEnum", + "EmptyType", + "FailFast", + "ForeignKeyMetadata", + "Gauge", + "Histogram", + "IndexMetadata", + "NumpyArray", + "PandasDataFrame", + "PolarsDataFrame", + "PoolT", + "SchemaT", + "Span", + "StatementParameters", + "Status", + "StatusCode", + "Struct", + "SupportedSchemaModel", + "TableMetadata", + "Tracer", + "TypeAdapter", + "UnsetType", + "VersionCacheResult", + "VersionInfo", + "attrs_asdict", + "attrs_define", + "attrs_field", + "attrs_fields", + "attrs_has", + "cattrs_structure", + "cattrs_unstructure", + "convert", + "get_type_adapter", + "module_available", + "msgspec_fields", + "trace", +) + class DictLike(Protocol): """A protocol for objects that behave like a dictionary for reading.""" @@ -286,7 +361,7 @@ def __hash__(self) -> int: SupportedSchemaModel: TypeAlias = ( - DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub | _TypedDict + DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub | Mapping[str, Any] ) """Type alias for pydantic or msgspec models. @@ -326,80 +401,3 @@ def get_type_adapter(f: "type[T]") -> Any: if PYDANTIC_USE_FAILFAST: return TypeAdapter(Annotated[f, FailFast()]) return TypeAdapter(f) - - -__all__ = ( - "ALLOYDB_CONNECTOR_INSTALLED", - "ATTRS_INSTALLED", - "CATTRS_INSTALLED", - "CLOUD_SQL_CONNECTOR_INSTALLED", - "FSSPEC_INSTALLED", - "LITESTAR_INSTALLED", - "MSGSPEC_INSTALLED", - "NANOID_INSTALLED", - "NUMPY_INSTALLED", - "OBSTORE_INSTALLED", - "OPENTELEMETRY_INSTALLED", - "ORJSON_INSTALLED", - "PANDAS_INSTALLED", - "PGVECTOR_INSTALLED", - "POLARS_INSTALLED", - "PROMETHEUS_INSTALLED", - "PYARROW_INSTALLED", - "PYDANTIC_INSTALLED", - "PYDANTIC_USE_FAILFAST", - "UNSET", - "UUID_UTILS_INSTALLED", - "ArrowRecordBatch", - "ArrowRecordBatchReader", - "ArrowRecordBatchReaderProtocol", - "ArrowReturnFormat", - "ArrowSchema", - "ArrowSchemaProtocol", - "ArrowTable", - "AttrsInstance", - "BaseModel", - "ColumnMetadata", - "ConnectionT", - "Counter", - "DTOData", - "DataclassProtocol", - "DictLike", - "Empty", - "EmptyEnum", - "EmptyType", - "FailFast", - "ForeignKeyMetadata", - "Gauge", - "Histogram", - "IndexMetadata", - "NumpyArray", - "PandasDataFrame", - "PolarsDataFrame", - "PoolT", - "SchemaT", - "Span", - "StatementParameters", - "Status", - "StatusCode", - "Struct", - "SupportedSchemaModel", - "TableMetadata", - "Tracer", - "TypeAdapter", - "UnsetType", - "VersionCacheResult", - "VersionInfo", - "attrs_asdict", - "attrs_define", - "attrs_field", - "attrs_fields", - "attrs_has", - "cattrs_structure", - "cattrs_unstructure", - "convert", - "get_type_adapter", - "module_available", - "msgspec_fields", - "trace", -) diff --git a/sqlspec/utils/fixtures.py b/sqlspec/utils/fixtures.py index a2f8aeabd..1538b1892 100644 --- a/sqlspec/utils/fixtures.py +++ b/sqlspec/utils/fixtures.py @@ -9,8 +9,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from anyio import Path as AsyncPath - from sqlspec.storage import storage_registry from sqlspec.utils.serializers import from_json as decode_json from sqlspec.utils.serializers import schema_dump @@ -27,6 +25,10 @@ def _read_text_sync(path: "Path") -> str: return path.read_text(encoding="utf-8") +def _resolve_path_str(path: str) -> str: + return str(Path(path).resolve()) + + def _compress_text(content: str) -> bytes: return gzip.compress(content.encode("utf-8")) @@ -193,7 +195,7 @@ def write_fixture( """ if storage_backend == "local": uri = "file://" - storage_kwargs["base_path"] = str(Path(fixtures_path).resolve()) + storage_kwargs["base_path"] = _resolve_path_str(fixtures_path) else: uri = storage_backend @@ -237,7 +239,8 @@ async def write_fixture_async( """ if storage_backend == "local": uri = "file://" - storage_kwargs["base_path"] = str(await AsyncPath(fixtures_path).resolve()) + resolve_path = async_(_resolve_path_str) + storage_kwargs["base_path"] = await resolve_path(fixtures_path) else: uri = storage_backend diff --git a/sqlspec/utils/schema.py b/sqlspec/utils/schema.py index a226fc3a0..58e7ac4b0 100644 --- a/sqlspec/utils/schema.py +++ b/sqlspec/utils/schema.py @@ -15,6 +15,7 @@ from sqlspec.typing import ( CATTRS_INSTALLED, NUMPY_INSTALLED, + ForeignKeyMetadata, SchemaT, attrs_asdict, cattrs_structure, @@ -230,21 +231,6 @@ def _detect_schema_type(schema_type: type) -> "str | None": ) -def _is_foreign_key_metadata_type(schema_type: type) -> bool: - if schema_type.__name__ != "ForeignKeyMetadata": - return False - - # Check module for stronger guarantee without importing - module = getattr(schema_type, "__module__", "") - if "sqlspec" in module and ("driver" in module or "data_dictionary" in module): - return True - - slots = getattr(schema_type, "__slots__", None) - if not slots: - return False - return {"table_name", "column_name", "referenced_table", "referenced_column"}.issubset(set(slots)) - - def _convert_foreign_key_metadata(data: Any, schema_type: Any) -> Any: """Convert data to ForeignKeyMetadata schema type. @@ -517,7 +503,7 @@ def _get_schema_converter(schema_type: type) -> "Callable[[Any, Any], Any] | Non conv = _convert_pydantic elif is_attrs_schema(schema_type): conv = _convert_attrs - elif _is_foreign_key_metadata_type(schema_type): + elif isinstance(schema_type, type) and issubclass(schema_type, ForeignKeyMetadata): conv = _convert_foreign_key_metadata else: conv = None diff --git a/sqlspec/utils/serializers/_json.py b/sqlspec/utils/serializers/_json.py index ed324819f..72c7c144a 100644 --- a/sqlspec/utils/serializers/_json.py +++ b/sqlspec/utils/serializers/_json.py @@ -393,7 +393,9 @@ def get_default_serializer() -> JSONSerializer: if _default_serializer is None: _default_serializer = StandardLibSerializer() - assert _default_serializer is not None + if _default_serializer is None: + msg = "No JSON serializer available; this should be unreachable" + raise RuntimeError(msg) return _default_serializer diff --git a/sqlspec/utils/type_guards.py b/sqlspec/utils/type_guards.py index 6f115b5c2..d8c16f840 100644 --- a/sqlspec/utils/type_guards.py +++ b/sqlspec/utils/type_guards.py @@ -4,6 +4,7 @@ understand type narrowing, replacing defensive hasattr() and duck typing patterns. """ +import inspect from collections.abc import Sequence from collections.abc import Set as AbstractSet from dataclasses import Field @@ -203,7 +204,7 @@ def is_readable(obj: Any) -> "TypeGuard[ReadableProtocol]": def is_async_readable(obj: Any) -> "TypeGuard[AsyncReadableProtocol]": """Check if an object exposes an async read method.""" try: - return callable(obj.read) + return callable(obj.read) and inspect.iscoroutinefunction(obj.read) except AttributeError: return False diff --git a/tests/integration/adapters/aiosqlite/test_parameter_styles.py b/tests/integration/adapters/aiosqlite/test_parameter_styles.py index d8f8fa51f..6a5a34b34 100644 --- a/tests/integration/adapters/aiosqlite/test_parameter_styles.py +++ b/tests/integration/adapters/aiosqlite/test_parameter_styles.py @@ -377,3 +377,45 @@ async def test_aiosqlite_parameter_count_mismatch_with_none() -> None: result = await driver.execute("INSERT INTO test_param_count (col1, col2) VALUES (?, ?)", ("value1", None)) assert isinstance(result, SQLResult) assert result.rows_affected == 1 + + +async def test_aiosqlite_named_colon_style_insert_and_select() -> None: + """Named-colon parameters must work end-to-end via the sqlspec pipeline.""" + from sqlspec.core import SQL + + config = AiosqliteConfig(connection_config={"database": ":memory:"}) + + async with config.provide_session() as driver: + await driver.execute("CREATE TABLE test_named_colon (id INTEGER PRIMARY KEY, name TEXT, value INTEGER)") + + result = await driver.execute( + "INSERT INTO test_named_colon (id, name, value) VALUES (:id, :name, :value)", + {"id": 1, "name": "alice", "value": 42}, + ) + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + select_result = await driver.execute( + SQL("SELECT name, value FROM test_named_colon WHERE name = :target", target="alice") + ) + assert isinstance(select_result, SQLResult) + row = select_result.get_data()[0] + assert row["name"] == "alice" + assert row["value"] == 42 + + +async def test_aiosqlite_named_colon_style_execute_many() -> None: + """Named-colon parameters must work with execute_many.""" + config = AiosqliteConfig(connection_config={"database": ":memory:"}) + + async with config.provide_session() as driver: + await driver.execute("CREATE TABLE test_named_colon_many (id INTEGER PRIMARY KEY, name TEXT)") + + params = [{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}, {"id": 3, "name": "gamma"}] + result = await driver.execute_many("INSERT INTO test_named_colon_many (id, name) VALUES (:id, :name)", params) + assert isinstance(result, SQLResult) + assert result.rows_affected == 3 + + count_result = await driver.execute("SELECT COUNT(*) as cnt FROM test_named_colon_many") + assert isinstance(count_result, SQLResult) + assert count_result.get_data()[0]["cnt"] == 3 diff --git a/tests/unit/adapters/test_adbc/test_core.py b/tests/unit/adapters/test_adbc/test_core.py index a28f582b9..d6357dd69 100644 --- a/tests/unit/adapters/test_adbc/test_core.py +++ b/tests/unit/adapters/test_adbc/test_core.py @@ -154,3 +154,33 @@ def test_detect_dialect_introspection_match_beats_fallback() -> None: conn = SimpleNamespace(adbc_get_info=lambda: {"vendor_name": "duckdb", "driver_name": ""}) result = adbc_core.detect_dialect(conn, fallback_dialect="sqlite") assert result == "duckdb" + + +def test_handle_postgres_rollback_issues_rollback_for_pgvector() -> None: + """pgvector is PostgreSQL-compatible and must receive rollback handling.""" + executed: list[str] = [] + cursor = SimpleNamespace(execute=lambda sql: executed.append(sql)) + + adbc_core.handle_postgres_rollback("pgvector", cursor) + + assert executed == ["ROLLBACK"] + + +def test_handle_postgres_rollback_issues_rollback_for_paradedb() -> None: + """paradedb is PostgreSQL-compatible and must receive rollback handling.""" + executed: list[str] = [] + cursor = SimpleNamespace(execute=lambda sql: executed.append(sql)) + + adbc_core.handle_postgres_rollback("paradedb", cursor) + + assert executed == ["ROLLBACK"] + + +def test_normalize_postgres_empty_parameters_returns_none_for_pgvector() -> None: + """pgvector empty dict parameters should use the PostgreSQL empty-params path.""" + assert adbc_core.normalize_postgres_empty_parameters("pgvector", {}) is None + + +def test_normalize_postgres_empty_parameters_returns_none_for_paradedb() -> None: + """paradedb empty dict parameters should use the PostgreSQL empty-params path.""" + assert adbc_core.normalize_postgres_empty_parameters("paradedb", {}) is None diff --git a/tests/unit/adapters/test_aiosqlite/test_cursor_lifecycle.py b/tests/unit/adapters/test_aiosqlite/test_cursor_lifecycle.py new file mode 100644 index 000000000..5ce531d63 --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_cursor_lifecycle.py @@ -0,0 +1,62 @@ +"""Unit tests for AiosqliteCursor cleanup on exception paths.""" + +import sqlite3 + +import aiosqlite +import pytest + +from sqlspec.adapters.aiosqlite._typing import AiosqliteCursor + + +async def test_cursor_closed_after_normal_exit() -> None: + conn = await aiosqlite.connect(":memory:") + try: + async with AiosqliteCursor(conn) as cursor: + await cursor.execute("SELECT 1") + + with pytest.raises(sqlite3.ProgrammingError, match="closed cursor"): + cursor._cursor.execute("SELECT 1") + finally: + await conn.close() + + +async def test_cursor_closed_after_exception_in_block() -> None: + conn = await aiosqlite.connect(":memory:") + try: + captured_cursor = None + with pytest.raises(aiosqlite.OperationalError): + async with AiosqliteCursor(conn) as cursor: + captured_cursor = cursor + await cursor.execute("SELECT * FROM nonexistent_table_xyz") + + assert captured_cursor is not None + with pytest.raises(sqlite3.ProgrammingError, match="closed cursor"): + captured_cursor._cursor.execute("SELECT 1") + finally: + await conn.close() + + +async def test_cursor_closed_after_generic_exception_in_block() -> None: + conn = await aiosqlite.connect(":memory:") + try: + captured_cursor = None + with pytest.raises(ValueError, match="simulated error"): + async with AiosqliteCursor(conn) as cursor: + captured_cursor = cursor + raise ValueError("simulated error") + + assert captured_cursor is not None + with pytest.raises(sqlite3.ProgrammingError, match="closed cursor"): + captured_cursor._cursor.execute("SELECT 1") + finally: + await conn.close() + + +async def test_aexit_does_not_suppress_exception() -> None: + conn = await aiosqlite.connect(":memory:") + try: + with pytest.raises(aiosqlite.OperationalError): + async with AiosqliteCursor(conn) as cursor: + await cursor.execute("SELECT * FROM nonexistent_table_xyz") + finally: + await conn.close() diff --git a/tests/unit/adapters/test_aiosqlite/test_exception_mapping.py b/tests/unit/adapters/test_aiosqlite/test_exception_mapping.py new file mode 100644 index 000000000..25acce52b --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_exception_mapping.py @@ -0,0 +1,98 @@ +"""Unit tests for aiosqlite exception mapping parity with sqlite.""" + +import sqlite3 + +from sqlspec.adapters.aiosqlite.core import create_mapped_exception +from sqlspec.exceptions import DeadlockError, PermissionDeniedError, QueryTimeoutError + + +class _SqliteError(sqlite3.OperationalError): + def __init__(self, message: str, code: int | None = None, name: str | None = None) -> None: + super().__init__(message) + self.sqlite_errorcode = code + self.sqlite_errorname = name + + +def test_busy_error_code_maps_to_deadlock() -> None: + err = _SqliteError("database is locked", 5, "SQLITE_BUSY") + result = create_mapped_exception(err) + assert isinstance(result, DeadlockError) + assert result.__cause__ is err + + +def test_busy_error_name_maps_to_deadlock() -> None: + result = create_mapped_exception(_SqliteError("database is busy", None, "SQLITE_BUSY")) + assert isinstance(result, DeadlockError) + + +def test_locked_error_code_maps_to_deadlock() -> None: + err = _SqliteError("database table is locked", 6, "SQLITE_LOCKED") + result = create_mapped_exception(err) + assert isinstance(result, DeadlockError) + assert result.__cause__ is err + + +def test_locked_error_name_maps_to_deadlock() -> None: + result = create_mapped_exception(_SqliteError("table is locked", None, "SQLITE_LOCKED")) + assert isinstance(result, DeadlockError) + + +def test_locked_text_heuristic_maps_to_deadlock() -> None: + result = create_mapped_exception(sqlite3.OperationalError("database locked")) + assert isinstance(result, DeadlockError) + + +def test_busy_text_heuristic_maps_to_deadlock() -> None: + result = create_mapped_exception(sqlite3.OperationalError("database is busy, please retry")) + assert isinstance(result, DeadlockError) + + +def test_interrupt_error_code_maps_to_query_timeout() -> None: + err = _SqliteError("interrupted", 9, "SQLITE_INTERRUPT") + result = create_mapped_exception(err) + assert isinstance(result, QueryTimeoutError) + assert result.__cause__ is err + + +def test_interrupt_error_name_maps_to_query_timeout() -> None: + result = create_mapped_exception(_SqliteError("query was interrupted", None, "SQLITE_INTERRUPT")) + assert isinstance(result, QueryTimeoutError) + + +def test_interrupt_text_heuristic_maps_to_query_timeout() -> None: + result = create_mapped_exception(sqlite3.OperationalError("query was interrupted by application")) + assert isinstance(result, QueryTimeoutError) + + +def test_perm_error_code_maps_to_permission_denied() -> None: + err = _SqliteError("access permission denied", 3, "SQLITE_PERM") + result = create_mapped_exception(err) + assert isinstance(result, PermissionDeniedError) + assert result.__cause__ is err + + +def test_perm_error_name_maps_to_permission_denied() -> None: + result = create_mapped_exception(_SqliteError("access permission denied", None, "SQLITE_PERM")) + assert isinstance(result, PermissionDeniedError) + + +def test_readonly_error_code_maps_to_permission_denied() -> None: + err = _SqliteError("attempt to write a readonly database", 8, "SQLITE_READONLY") + result = create_mapped_exception(err) + assert isinstance(result, PermissionDeniedError) + assert result.__cause__ is err + + +def test_readonly_error_name_maps_to_permission_denied() -> None: + result = create_mapped_exception(_SqliteError("readonly database", None, "SQLITE_READONLY")) + assert isinstance(result, PermissionDeniedError) + + +def test_readonly_text_heuristic_maps_to_permission_denied() -> None: + result = create_mapped_exception(sqlite3.OperationalError("attempt to write a readonly database")) + assert isinstance(result, PermissionDeniedError) + + +def test_permission_denied_text_heuristic_maps_to_permission_denied() -> None: + result = create_mapped_exception(sqlite3.OperationalError("permission denied: cannot open /etc/passwd")) + assert isinstance(result, PermissionDeniedError) diff --git a/tests/unit/adapters/test_aiosqlite/test_pool_acquire.py b/tests/unit/adapters/test_aiosqlite/test_pool_acquire.py new file mode 100644 index 000000000..d12ff2ddb --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_pool_acquire.py @@ -0,0 +1,123 @@ +"""Tests for aiosqlite pool acquire hot path.""" + +import asyncio +import time +from collections.abc import Iterator +from threading import Thread +from typing import Any + +import pytest + +from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool + + +class _FakeConnection: + def __init__(self) -> None: + self.executed: list[str] = [] + + async def execute(self, sql: str) -> None: + self.executed.append(sql) + + async def commit(self) -> None: + return None + + async def rollback(self) -> None: + return None + + async def close(self) -> None: + return None + + +class _FakeConnectProxy: + def __init__(self, connection: _FakeConnection) -> None: + self._thread = Thread(target=lambda: None) + self._connection = connection + + def __await__(self) -> Any: + async def _resolve() -> _FakeConnection: + return self._connection + + return _resolve().__await__() + + +async def test_acquire_does_not_sleep_when_wal_not_initialized(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.adapters.aiosqlite import pool as pool_module + + proxy = _FakeConnectProxy(_FakeConnection()) + pool = AiosqliteConnectionPool({"database": "file:test_acquire?mode=memory&cache=shared", "uri": True}, pool_size=3) + + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: proxy) + + pool_conn = await pool._create_connection() # pyright: ignore[reportPrivateUsage] + pool._queue.put_nowait(pool_conn) # pyright: ignore[reportPrivateUsage] + pool._wal_initialized = False # pyright: ignore[reportPrivateUsage] + + start = time.monotonic() + connection = await pool.acquire() + elapsed = time.monotonic() - start + + assert elapsed < 0.005 + assert connection is pool_conn + + await pool._retire_connection(pool_conn, reason="test_cleanup") # pyright: ignore[reportPrivateUsage] + + +async def test_acquire_multiple_concurrent_no_sleep_serialization(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.adapters.aiosqlite import pool as pool_module + + pool_size = 5 + pool = AiosqliteConnectionPool( + {"database": "file:concurrent_test?mode=memory&cache=shared", "uri": True}, pool_size=pool_size + ) + + proxies: Iterator[_FakeConnectProxy] = iter(_FakeConnectProxy(_FakeConnection()) for _ in range(pool_size)) + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: next(proxies)) + + for _ in range(pool_size): + pool_conn = await pool._create_connection() # pyright: ignore[reportPrivateUsage] + pool._queue.put_nowait(pool_conn) # pyright: ignore[reportPrivateUsage] + + pool._wal_initialized = False # pyright: ignore[reportPrivateUsage] + + async def _acquire_and_release() -> float: + t0 = time.monotonic() + connection = await pool.acquire() + t1 = time.monotonic() + await pool.release(connection) + return t1 - t0 + + start = time.monotonic() + times = await asyncio.gather(*[_acquire_and_release() for _ in range(pool_size)]) + total = time.monotonic() - start + + assert total < pool_size * 0.005, f"total={total:.4f}s per-call={times}" + + +async def test_wal_ready_event_set_after_shared_cache_create_connection(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.adapters.aiosqlite import pool as pool_module + + pool = AiosqliteConnectionPool({"database": "file:wal_event_test?mode=memory&cache=shared", "uri": True}) + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: _FakeConnectProxy(_FakeConnection())) + + assert not pool._wal_ready_event.is_set() # pyright: ignore[reportPrivateUsage] + + pool_conn = await pool._create_connection() # pyright: ignore[reportPrivateUsage] + try: + assert pool._wal_initialized is True # pyright: ignore[reportPrivateUsage] + assert pool._wal_ready_event.is_set() # pyright: ignore[reportPrivateUsage] + finally: + await pool._retire_connection(pool_conn, reason="test_cleanup") # pyright: ignore[reportPrivateUsage] + + +async def test_wal_ready_event_not_set_for_non_shared_cache(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.adapters.aiosqlite import pool as pool_module + + pool = AiosqliteConnectionPool({"database": ":memory:"}) + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: _FakeConnectProxy(_FakeConnection())) + + pool_conn = await pool._create_connection() # pyright: ignore[reportPrivateUsage] + try: + assert pool._wal_initialized is False # pyright: ignore[reportPrivateUsage] + assert not pool._wal_ready_event.is_set() # pyright: ignore[reportPrivateUsage] + finally: + await pool._retire_connection(pool_conn, reason="test_cleanup") # pyright: ignore[reportPrivateUsage] diff --git a/tests/unit/adapters/test_aiosqlite/test_profile.py b/tests/unit/adapters/test_aiosqlite/test_profile.py new file mode 100644 index 000000000..bdaf6212b --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_profile.py @@ -0,0 +1,66 @@ +"""Unit tests for aiosqlite profile parity with sqlite.""" + +from sqlspec.adapters.aiosqlite.core import ( + build_profile, + build_statement_config, + default_statement_config, + driver_profile, +) +from sqlspec.core import ParameterStyle + + +def test_build_profile_includes_named_colon_in_supported_styles() -> None: + profile = build_profile() + assert ParameterStyle.NAMED_COLON in profile.supported_styles + + +def test_build_profile_includes_named_colon_in_supported_execution_styles() -> None: + profile = build_profile() + assert ParameterStyle.NAMED_COLON in profile.supported_execution_styles + + +def test_build_profile_default_style_remains_qmark() -> None: + profile = build_profile() + assert profile.default_style == ParameterStyle.QMARK + assert profile.default_execution_style == ParameterStyle.QMARK + + +def test_build_profile_qmark_still_supported() -> None: + profile = build_profile() + assert ParameterStyle.QMARK in profile.supported_styles + assert ParameterStyle.QMARK in profile.supported_execution_styles + + +def test_build_statement_config_disables_parameter_type_wrapping() -> None: + config = build_statement_config() + assert config.enable_parameter_type_wrapping is False + + +def test_default_statement_config_disables_parameter_type_wrapping() -> None: + assert default_statement_config.enable_parameter_type_wrapping is False + + +def test_driver_profile_module_singleton_has_named_colon() -> None: + assert ParameterStyle.NAMED_COLON in driver_profile.supported_styles + assert ParameterStyle.NAMED_COLON in driver_profile.supported_execution_styles + + +def test_aiosqlite_profile_parity_with_sqlite_profile() -> None: + from sqlspec.adapters.sqlite.core import build_profile as sqlite_build_profile + + aio_profile = build_profile() + sqlite_profile = sqlite_build_profile() + + assert aio_profile.supported_styles == sqlite_profile.supported_styles + assert aio_profile.supported_execution_styles == sqlite_profile.supported_execution_styles + assert aio_profile.default_style == sqlite_profile.default_style + assert aio_profile.default_execution_style == sqlite_profile.default_execution_style + + +def test_aiosqlite_statement_config_parity_with_sqlite() -> None: + from sqlspec.adapters.sqlite.core import build_statement_config as sqlite_build_statement_config + + aio_config = build_statement_config() + sqlite_config = sqlite_build_statement_config() + + assert aio_config.enable_parameter_type_wrapping == sqlite_config.enable_parameter_type_wrapping diff --git a/tests/unit/adapters/test_aiosqlite/test_type_converter.py b/tests/unit/adapters/test_aiosqlite/test_type_converter.py new file mode 100644 index 000000000..2018d5d76 --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_type_converter.py @@ -0,0 +1,109 @@ +"""Unit tests for the aiosqlite-local SQLite type converter copy.""" + +import json +from unittest.mock import MagicMock, patch + + +def test_import_is_from_aiosqlite_not_sqlite() -> None: + from sqlspec.adapters.aiosqlite.type_converter import register_type_handlers + + assert callable(register_type_handlers) + + +def test_aiosqlite_module_is_independent_of_sqlite_module() -> None: + import sqlspec.adapters.aiosqlite.type_converter as aio_mod + import sqlspec.adapters.sqlite.type_converter as sql_mod + + assert aio_mod is not sql_mod + + +def test_json_adapter_dict_default_serializer() -> None: + from sqlspec.adapters.aiosqlite.type_converter import json_adapter + + data = {"key": "value", "count": 42} + result = json_adapter(data) + + assert isinstance(result, str) + assert json.loads(result) == data + + +def test_json_adapter_list_default_serializer() -> None: + from sqlspec.adapters.aiosqlite.type_converter import json_adapter + + data = [1, 2, 3, "four"] + result = json_adapter(data) + + assert isinstance(result, str) + assert json.loads(result) == data + + +def test_json_adapter_custom_serializer() -> None: + from sqlspec.adapters.aiosqlite.type_converter import json_adapter + + called_with = [] + + def custom_serializer(value: object) -> str: + called_with.append(value) + return json.dumps(value) + + data = {"x": 1} + result = json_adapter(data, serializer=custom_serializer) + + assert result == json.dumps(data) + assert called_with == [data] + + +def test_json_converter_default_deserializer() -> None: + from sqlspec.adapters.aiosqlite.type_converter import json_converter + + data = {"key": "value"} + result = json_converter(json.dumps(data).encode("utf-8")) + + assert result == data + + +def test_register_type_handlers_calls_sqlite3_apis() -> None: + from sqlspec.adapters.aiosqlite.type_converter import register_type_handlers + + with patch("sqlite3.register_adapter") as mock_adapter, patch("sqlite3.register_converter") as mock_converter: + register_type_handlers() + + assert mock_adapter.call_count == 2 + mock_converter.assert_called_once() + + +def test_register_type_handlers_with_custom_serializers() -> None: + from sqlspec.adapters.aiosqlite.type_converter import register_type_handlers + + def custom_serializer(value: object) -> str: + return json.dumps(value) + + def custom_deserializer(value: str) -> object: + return json.loads(value) + + with patch("sqlite3.register_adapter"), patch("sqlite3.register_converter"): + register_type_handlers(json_serializer=custom_serializer, json_deserializer=custom_deserializer) + + +def test_unregister_type_handlers_is_noop() -> None: + from sqlspec.adapters.aiosqlite.type_converter import unregister_type_handlers + + unregister_type_handlers() + + +def test_aiosqlite_config_uses_aiosqlite_type_converter() -> None: + import sqlspec.adapters.aiosqlite.config as config_mod + import sqlspec.adapters.sqlite.type_converter as sqlite_type_converter + + aio_mock = MagicMock() + sqlite_mock = MagicMock() + + with ( + patch.object(config_mod, "register_type_handlers", aio_mock), + patch.object(sqlite_type_converter, "register_type_handlers", sqlite_mock), + ): + config = config_mod.AiosqliteConfig(driver_features={"enable_custom_adapters": True}) + config._register_type_adapters() # pyright: ignore[reportPrivateUsage] + + aio_mock.assert_called_once() + sqlite_mock.assert_not_called() diff --git a/tests/unit/adapters/test_arrow_odbc.py b/tests/unit/adapters/test_arrow_odbc.py index c1b3c9b42..6d7e1bfce 100644 --- a/tests/unit/adapters/test_arrow_odbc.py +++ b/tests/unit/adapters/test_arrow_odbc.py @@ -190,6 +190,33 @@ def test_arrow_odbc_field_config_uses_driver_name_for_dialect(monkeypatch: Any) assert session.dialect == "tsql" +def test_arrow_odbc_config_init_no_pre_super_assign_connection_string() -> None: + """Connection-string input should normalize and populate driver features.""" + connection_string = "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=db;" + + config = ArrowOdbcConfig(connection_config={"connection_string": connection_string}) + + assert config.connection_config == {"connection_string": connection_string} + assert config.driver_features.get("connection_string") == connection_string + + +def test_arrow_odbc_config_init_no_pre_super_assign_driver_key() -> None: + """Driver-key input should populate dbms_name without early slot writes.""" + config = ArrowOdbcConfig( + connection_config={"driver": "ODBC Driver 17 for SQL Server", "server": "myhost", "database": "mydb"} + ) + + assert config.connection_config["driver"] == "ODBC Driver 17 for SQL Server" + assert config.driver_features.get("dbms_name") == "ODBC Driver 17 for SQL Server" + + +def test_arrow_odbc_config_init_no_pre_super_assign_none_input() -> None: + """None input should normalize to an empty connection_config dict.""" + config = ArrowOdbcConfig(connection_config=None) + + assert config.connection_config == {} + + def test_arrow_odbc_mssql_driver_uses_tsql_statement_dialect() -> None: """SQL Server ODBC connections should compile with sqlglot's tsql dialect.""" connection = FakeConnection() diff --git a/tests/unit/adapters/test_asyncmy/test_config.py b/tests/unit/adapters/test_asyncmy/test_config.py index 40ce24e44..1ca6dc1b4 100644 --- a/tests/unit/adapters/test_asyncmy/test_config.py +++ b/tests/unit/adapters/test_asyncmy/test_config.py @@ -34,3 +34,56 @@ def deserializer(_: str) -> object: parameter_config = config.statement_config.parameter_config assert parameter_config.json_serializer is serializer assert parameter_config.json_deserializer is deserializer + + +def test_asyncmy_typing_all_exports_pool_and_dict_cursor() -> None: + """AsyncmyPool and AsyncmyDictCursor must be present in _typing.__all__.""" + from sqlspec.adapters.asyncmy import _typing + + assert "AsyncmyPool" in _typing.__all__ + assert "AsyncmyDictCursor" in _typing.__all__ + + +def test_asyncmy_typing_pool_is_importable() -> None: + """AsyncmyPool must be importable from sqlspec.adapters.asyncmy._typing.""" + from sqlspec.adapters.asyncmy._typing import AsyncmyPool + + assert AsyncmyPool is not None + + +def test_asyncmy_typing_dict_cursor_is_importable() -> None: + """AsyncmyDictCursor must be importable from sqlspec.adapters.asyncmy._typing.""" + from sqlspec.adapters.asyncmy._typing import AsyncmyDictCursor + + assert AsyncmyDictCursor is not None + + +def test_asyncmy_config_imports_pool_from_typing_not_vendor() -> None: + """AsyncmyPool in config module must be the same object as _typing.AsyncmyPool.""" + import sqlspec.adapters.asyncmy.config as config_mod + from sqlspec.adapters.asyncmy._typing import AsyncmyPool + + assert config_mod.AsyncmyPool is AsyncmyPool + + +def test_asyncmy_config_no_direct_vendor_type_imports() -> None: + """config.py must not expose raw vendor type names at module level.""" + import sqlspec.adapters.asyncmy.config as config_mod + + assert not hasattr(config_mod, "Pool") + assert not hasattr(config_mod, "Cursor") + assert not hasattr(config_mod, "DictCursor") + + +def test_asyncmy_provide_pool_return_annotation_is_asyncmy_pool() -> None: + """provide_pool must be annotated with AsyncmyPool, not the raw Pool alias.""" + from sqlspec.adapters.asyncmy.config import AsyncmyConfig + + assert AsyncmyConfig.provide_pool.__annotations__.get("return") == "AsyncmyPool" + + +def test_driver_profile_name_matches_registry_key() -> None: + """driver_profile.name must equal the registry key 'asyncmy'.""" + from sqlspec.adapters.asyncmy.core import driver_profile + + assert driver_profile.name == "asyncmy" diff --git a/tests/unit/adapters/test_asyncpg/test_config.py b/tests/unit/adapters/test_asyncpg/test_config.py index 073ec397d..20dce7666 100644 --- a/tests/unit/adapters/test_asyncpg/test_config.py +++ b/tests/unit/adapters/test_asyncpg/test_config.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock import pytest +from asyncpg.pool import PoolConnectionProxy, PoolConnectionProxyMeta from sqlspec.adapters.asyncpg._typing import AsyncpgSessionContext from sqlspec.adapters.asyncpg.config import AsyncpgConfig @@ -46,6 +47,11 @@ def deserializer(_: str) -> object: assert parameter_config.json_deserializer is deserializer +def test_asyncpg_config_connection_type_is_not_metaclass() -> None: + assert AsyncpgConfig.connection_type is PoolConnectionProxy + assert AsyncpgConfig.connection_type is not PoolConnectionProxyMeta + + def test_asyncpg_build_postgres_extension_probe_names_filters_disabled_features() -> None: """Only enabled extension probes should be returned.""" assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] diff --git a/tests/unit/adapters/test_asyncpg/test_driver.py b/tests/unit/adapters/test_asyncpg/test_driver.py new file mode 100644 index 000000000..ae29b3480 --- /dev/null +++ b/tests/unit/adapters/test_asyncpg/test_driver.py @@ -0,0 +1,101 @@ +"""AsyncPG driver regression tests.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from sqlspec.adapters.asyncpg.core import default_statement_config +from sqlspec.adapters.asyncpg.driver import AsyncpgDriver +from sqlspec.core import SQL +from sqlspec.exceptions import SQLSpecError + + +class _CopyCursor: + def __init__(self) -> None: + self.copy_to_table = AsyncMock() + self.copy_from_query = AsyncMock() + self.execute = AsyncMock() + + +class _CompiledCopyDriver(AsyncpgDriver): + def _get_compiled_sql(self, *_args: object, **_kwargs: object) -> tuple[str, object]: + return "COPY (SELECT 1) FROM STDIN", None + + +@pytest.mark.anyio +async def test_asyncpg_copy_from_stdin_uses_copy_to_table() -> None: + config = default_statement_config.replace(execution_args={"postgres_copy_data": "1\tAlice"}) + statement = SQL("COPY users FROM STDIN", statement_config=config) + cursor = _CopyCursor() + driver = AsyncpgDriver(connection=object()) + + await driver._handle_copy_operation(cursor, statement) # pyright: ignore[reportPrivateUsage] + + cursor.copy_to_table.assert_awaited_once() + assert cursor.copy_to_table.await_args.args == ("users",) + assert cursor.copy_to_table.await_args.kwargs["source"].read() == b"1\tAlice" + cursor.copy_from_query.assert_not_awaited() + cursor.execute.assert_not_awaited() + + +@pytest.mark.anyio +async def test_asyncpg_copy_from_stdin_splits_schema_table() -> None: + config = default_statement_config.replace(execution_args={"postgres_copy_data": "1"}) + statement = SQL("COPY public.users FROM STDIN", statement_config=config) + cursor = _CopyCursor() + driver = AsyncpgDriver(connection=object()) + + await driver._handle_copy_operation(cursor, statement) # pyright: ignore[reportPrivateUsage] + + assert cursor.copy_to_table.await_args.args == ("users",) + assert cursor.copy_to_table.await_args.kwargs["schema_name"] == "public" + + +@pytest.mark.anyio +async def test_asyncpg_copy_from_stdin_forwards_copy_options() -> None: + config = default_statement_config.replace( + execution_args={ + "postgres_copy_data": "1,Alice", + "postgres_copy_columns": ("id", "name"), + "postgres_copy_format": "csv", + "postgres_copy_delimiter": ",", + } + ) + statement = SQL("COPY users (id, name) FROM STDIN", statement_config=config) + cursor = _CopyCursor() + driver = AsyncpgDriver(connection=object()) + + await driver._handle_copy_operation(cursor, statement) # pyright: ignore[reportPrivateUsage] + + assert cursor.copy_to_table.await_args.kwargs["columns"] == ("id", "name") + assert cursor.copy_to_table.await_args.kwargs["format"] == "csv" + assert cursor.copy_to_table.await_args.kwargs["delimiter"] == "," + + +@pytest.mark.anyio +async def test_asyncpg_copy_from_stdin_uses_metadata_table_fallback() -> None: + config = default_statement_config.replace( + execution_args={"postgres_copy_data": "1", "postgres_copy_table": "public.users"} + ) + statement = SimpleNamespace(operation_type="COPY_FROM", statement_config=config) + cursor = _CopyCursor() + driver = _CompiledCopyDriver(connection=object()) + + await driver._handle_copy_operation(cursor, statement) # pyright: ignore[reportPrivateUsage] + + assert cursor.copy_to_table.await_args.args == ("users",) + assert cursor.copy_to_table.await_args.kwargs["schema_name"] == "public" + + +@pytest.mark.anyio +async def test_asyncpg_copy_from_stdin_requires_table_name() -> None: + config = default_statement_config.replace(execution_args={"postgres_copy_data": "1"}) + statement = SimpleNamespace(operation_type="COPY_FROM", statement_config=config) + cursor = _CopyCursor() + driver = _CompiledCopyDriver(connection=object()) + + with pytest.raises(SQLSpecError, match="postgres_copy_table"): + await driver._handle_copy_operation(cursor, statement) # pyright: ignore[reportPrivateUsage] + + cursor.copy_to_table.assert_not_awaited() diff --git a/tests/unit/adapters/test_bigquery/test_core.py b/tests/unit/adapters/test_bigquery/test_core.py index 4311d87e5..d8d787a4d 100644 --- a/tests/unit/adapters/test_bigquery/test_core.py +++ b/tests/unit/adapters/test_bigquery/test_core.py @@ -2,7 +2,8 @@ from types import SimpleNamespace -from sqlspec.adapters.bigquery.core import collect_rows, resolve_column_names +from sqlspec.adapters.bigquery.core import build_profile, collect_rows, driver_profile, resolve_column_names +from sqlspec.core import ParameterStyle def _schema_field(name: str) -> SimpleNamespace: @@ -57,3 +58,22 @@ def test_collect_rows_uses_cache_when_column_names_not_precomputed() -> None: assert column_names_one == ["id", "name"] assert column_names_two is column_names_one assert len(cache) == 1 + + +def test_build_profile_does_not_advertise_qmark() -> None: + """BigQuery rejects positional params, so QMARK must not be advertised.""" + profile = build_profile() + + assert ParameterStyle.QMARK not in profile.supported_styles + + +def test_module_level_driver_profile_does_not_advertise_qmark() -> None: + """Module-level driver_profile should match the supported BigQuery style set.""" + assert ParameterStyle.QMARK not in driver_profile.supported_styles + + +def test_build_profile_named_at_is_only_supported_style() -> None: + """NAMED_AT is the only supported BigQuery parameter style.""" + profile = build_profile() + + assert profile.supported_styles == frozenset({ParameterStyle.NAMED_AT}) diff --git a/tests/unit/adapters/test_cockroach_asyncpg/test_config.py b/tests/unit/adapters/test_cockroach_asyncpg/test_config.py index 61e359fed..e72f5d0e3 100644 --- a/tests/unit/adapters/test_cockroach_asyncpg/test_config.py +++ b/tests/unit/adapters/test_cockroach_asyncpg/test_config.py @@ -8,11 +8,14 @@ from unittest.mock import AsyncMock, patch +import pytest + from sqlspec.adapters.cockroach_asyncpg import ( CockroachAsyncpgConfig, CockroachAsyncpgDriverFeatures, CockroachAsyncpgRetryConfig, ) +from sqlspec.core import StatementConfig class TestCockroachAsyncpgConfig: @@ -125,6 +128,57 @@ async def user_hook(connection: object) -> None: assert events == ["user"] +@pytest.mark.anyio +async def test_cockroach_asyncpg_init_connection_registers_pgvector_when_enabled() -> None: + config = CockroachAsyncpgConfig(driver_features={"enable_pgvector": True}) + connection = AsyncMock() + + with ( + patch("sqlspec.adapters.cockroach_asyncpg.config.register_json_codecs", new_callable=AsyncMock, create=True), + patch( + "sqlspec.adapters.cockroach_asyncpg.config.register_pgvector_support", new_callable=AsyncMock, create=True + ) as register_pgvector, + ): + await config._init_connection(connection) + + register_pgvector.assert_awaited_once_with(connection) + + +@pytest.mark.anyio +async def test_cockroach_asyncpg_init_connection_skips_pgvector_by_default() -> None: + config = CockroachAsyncpgConfig() + connection = AsyncMock() + + with ( + patch("sqlspec.adapters.cockroach_asyncpg.config.register_json_codecs", new_callable=AsyncMock, create=True), + patch( + "sqlspec.adapters.cockroach_asyncpg.config.register_pgvector_support", new_callable=AsyncMock, create=True + ) as register_pgvector, + ): + await config._init_connection(connection) + + register_pgvector.assert_not_awaited() + + +def test_cockroach_asyncpg_provide_session_uses_lazy_lambda() -> None: + config = CockroachAsyncpgConfig() + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + + +def test_cockroach_asyncpg_provide_session_reflects_post_construction_dialect_change() -> None: + config = CockroachAsyncpgConfig() + context = config.provide_session() + config.statement_config = StatementConfig(dialect="pgvector") + + session_config = context._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + assert session_config().dialect == "pgvector" + + class TestCockroachAsyncpgDriverFeatures: """Tests for CockroachAsyncpgDriverFeatures TypedDict structure.""" diff --git a/tests/unit/adapters/test_cockroach_asyncpg/test_driver.py b/tests/unit/adapters/test_cockroach_asyncpg/test_driver.py new file mode 100644 index 000000000..92d34e437 --- /dev/null +++ b/tests/unit/adapters/test_cockroach_asyncpg/test_driver.py @@ -0,0 +1,39 @@ +"""CockroachDB AsyncPG driver regression tests.""" + +import pytest + +from sqlspec.adapters.cockroach_asyncpg.driver import CockroachAsyncpgDriver + + +class _RecordingCockroachAsyncpgDriver(CockroachAsyncpgDriver): + def __init__(self) -> None: + super().__init__(connection=object(), driver_features={"enable_auto_retry": False}) + self.calls: list[str] = [] + + async def _dispatch_execute_many_impl(self, cursor: object, statement: object) -> str: + self.calls.append("many") + return "many" + + async def _dispatch_execute_script_impl(self, cursor: object, statement: object) -> str: + self.calls.append("script") + return "script" + + +@pytest.mark.anyio +async def test_cockroach_asyncpg_non_retry_execute_many_uses_impl_wrapper() -> None: + driver = _RecordingCockroachAsyncpgDriver() + + result = await driver.dispatch_execute_many(object(), object()) # type: ignore[arg-type] + + assert result == "many" + assert driver.calls == ["many"] + + +@pytest.mark.anyio +async def test_cockroach_asyncpg_non_retry_execute_script_uses_impl_wrapper() -> None: + driver = _RecordingCockroachAsyncpgDriver() + + result = await driver.dispatch_execute_script(object(), object()) # type: ignore[arg-type] + + assert result == "script" + assert driver.calls == ["script"] diff --git a/tests/unit/adapters/test_cockroach_psycopg/test_config.py b/tests/unit/adapters/test_cockroach_psycopg/test_config.py index cf821c3a1..56e46db43 100644 --- a/tests/unit/adapters/test_cockroach_psycopg/test_config.py +++ b/tests/unit/adapters/test_cockroach_psycopg/test_config.py @@ -13,6 +13,7 @@ CockroachPsycopgRetryConfig, CockroachPsycopgSyncConfig, ) +from sqlspec.adapters.cockroach_psycopg.config import default_statement_config class TestCockroachPsycopgSyncConfig: @@ -159,6 +160,15 @@ def test_class_attributes(self) -> None: assert CockroachPsycopgAsyncConfig.supports_native_arrow_export is True assert CockroachPsycopgAsyncConfig.supports_native_arrow_import is True + def test_provide_session_uses_default_statement_config_constant_when_config_missing(self) -> None: + """Async session fallback should reuse the module-level default config.""" + config = CockroachPsycopgAsyncConfig() + config.statement_config = None # type: ignore[assignment] + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert session_config is default_statement_config + class TestCockroachPsycopgDriverFeatures: """Tests for CockroachPsycopgDriverFeatures TypedDict structure.""" diff --git a/tests/unit/adapters/test_cockroach_psycopg/test_driver.py b/tests/unit/adapters/test_cockroach_psycopg/test_driver.py new file mode 100644 index 000000000..1646b1792 --- /dev/null +++ b/tests/unit/adapters/test_cockroach_psycopg/test_driver.py @@ -0,0 +1,87 @@ +"""CockroachDB psycopg driver regression tests.""" + +import pytest + +from sqlspec.adapters.cockroach_psycopg.data_dictionary import ( + CockroachPsycopgAsyncDataDictionary, + CockroachPsycopgSyncDataDictionary, +) +from sqlspec.adapters.cockroach_psycopg.driver import CockroachPsycopgAsyncDriver, CockroachPsycopgSyncDriver + + +class _RecordingCockroachPsycopgSyncDriver(CockroachPsycopgSyncDriver): + def __init__(self) -> None: + super().__init__(connection=object(), driver_features={"enable_auto_retry": False}) + self.calls: list[str] = [] + + def _dispatch_execute_many_impl(self, cursor: object, statement: object) -> str: + self.calls.append("many") + return "many" + + def _dispatch_execute_script_impl(self, cursor: object, statement: object) -> str: + self.calls.append("script") + return "script" + + +class _RecordingCockroachPsycopgAsyncDriver(CockroachPsycopgAsyncDriver): + def __init__(self) -> None: + super().__init__(connection=object(), driver_features={"enable_auto_retry": False}) + self.calls: list[str] = [] + + async def _dispatch_execute_many_impl(self, cursor: object, statement: object) -> str: + self.calls.append("many") + return "many" + + async def _dispatch_execute_script_impl(self, cursor: object, statement: object) -> str: + self.calls.append("script") + return "script" + + +def test_cockroach_psycopg_sync_non_retry_execute_many_uses_impl_wrapper() -> None: + driver = _RecordingCockroachPsycopgSyncDriver() + + result = driver.dispatch_execute_many(object(), object()) # type: ignore[arg-type] + + assert result == "many" + assert driver.calls == ["many"] + + +def test_cockroach_psycopg_sync_non_retry_execute_script_uses_impl_wrapper() -> None: + driver = _RecordingCockroachPsycopgSyncDriver() + + result = driver.dispatch_execute_script(object(), object()) # type: ignore[arg-type] + + assert result == "script" + assert driver.calls == ["script"] + + +@pytest.mark.anyio +async def test_cockroach_psycopg_async_non_retry_execute_many_uses_impl_wrapper() -> None: + driver = _RecordingCockroachPsycopgAsyncDriver() + + result = await driver.dispatch_execute_many(object(), object()) # type: ignore[arg-type] + + assert result == "many" + assert driver.calls == ["many"] + + +@pytest.mark.anyio +async def test_cockroach_psycopg_async_non_retry_execute_script_uses_impl_wrapper() -> None: + driver = _RecordingCockroachPsycopgAsyncDriver() + + result = await driver.dispatch_execute_script(object(), object()) # type: ignore[arg-type] + + assert result == "script" + assert driver.calls == ["script"] + + +def test_cockroach_psycopg_sync_data_dictionary_uses_parent_slot() -> None: + driver = CockroachPsycopgSyncDriver(connection=object()) + + assert isinstance(driver.data_dictionary, CockroachPsycopgSyncDataDictionary) + + +def test_cockroach_psycopg_async_data_dictionary_uses_parent_slot() -> None: + driver = CockroachPsycopgAsyncDriver(connection=object()) + + assert isinstance(driver.data_dictionary, CockroachPsycopgAsyncDataDictionary) diff --git a/tests/unit/adapters/test_duckdb/test_driver_init.py b/tests/unit/adapters/test_duckdb/test_driver_init.py new file mode 100644 index 000000000..2cd40e28f --- /dev/null +++ b/tests/unit/adapters/test_duckdb/test_driver_init.py @@ -0,0 +1,58 @@ +"""Regression tests for DuckDBDriver statement config setup.""" + +from typing import Any +from unittest.mock import patch + +import duckdb + +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.duckdb.core import apply_driver_features, default_statement_config +from sqlspec.adapters.duckdb.driver import DuckDBDriver + + +def test_apply_driver_features_called_once_when_statement_config_is_none() -> None: + connection = duckdb.connect(database=":memory:") + driver_features: dict[str, Any] = {"json_serializer": lambda value: "[]"} + + with patch("sqlspec.adapters.duckdb.driver.apply_driver_features", wraps=apply_driver_features) as mock_apply: + DuckDBDriver(connection=connection, statement_config=None, driver_features=driver_features) + + assert mock_apply.call_count == 1 + + +def test_apply_driver_features_not_called_when_statement_config_provided() -> None: + connection = duckdb.connect(database=":memory:") + driver_features: dict[str, Any] = {"json_serializer": lambda value: "[]"} + + with patch("sqlspec.adapters.duckdb.driver.apply_driver_features", wraps=apply_driver_features) as mock_apply: + DuckDBDriver(connection=connection, statement_config=default_statement_config, driver_features=driver_features) + + assert mock_apply.call_count == 0 + + +def test_custom_json_serializer_identity_preserved_through_session() -> None: + def custom_serializer(_: object) -> str: + return "custom" + + config = DuckDBConfig( + connection_config={"database": ":memory:"}, driver_features={"json_serializer": custom_serializer} + ) + + with config.provide_session() as driver: + serializer_in_driver = driver.statement_config.parameter_config.json_serializer + + assert serializer_in_driver is custom_serializer + + +def test_apply_driver_features_not_called_per_session_via_config() -> None: + def custom_serializer(_: object) -> str: + return "custom" + + with patch("sqlspec.adapters.duckdb.driver.apply_driver_features", wraps=apply_driver_features) as mock_apply: + config = DuckDBConfig( + connection_config={"database": ":memory:"}, driver_features={"json_serializer": custom_serializer} + ) + with config.provide_session(): + pass + + assert mock_apply.call_count == 0 diff --git a/tests/unit/adapters/test_duckdb/test_pool.py b/tests/unit/adapters/test_duckdb/test_pool.py new file mode 100644 index 000000000..49944bc5c --- /dev/null +++ b/tests/unit/adapters/test_duckdb/test_pool.py @@ -0,0 +1,46 @@ +"""Unit tests for DuckDB connection pool helpers.""" + +import pytest + +pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package") + +from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool, _validate_sql_identifier + + +@pytest.mark.parametrize("identifier", ["my_openai_secret", "openai", "S3", "s3", "r2", "secret_1"]) +def test_validate_sql_identifier_accepts_safe_identifiers(identifier: str) -> None: + _validate_sql_identifier(identifier, "secret_name") + + +@pytest.mark.parametrize("identifier", ["evil; DROP TABLE secrets--", "bad name", "S3); DROP TABLE--", "1bad", ""]) +def test_validate_sql_identifier_rejects_unsafe_identifiers(identifier: str) -> None: + with pytest.raises(ValueError, match="secret_name"): + _validate_sql_identifier(identifier, "secret_name") + + +def test_create_connection_raises_for_malicious_secret_name() -> None: + pool = DuckDBConnectionPool( + connection_config={"database": ":memory:"}, + secrets=[ + {"name": "evil; DROP TABLE secrets--", "secret_type": "s3", "value": {"key_id": "abc", "secret": "xyz"}} + ], + ) + + with pytest.raises(ValueError, match="secret_name"): + pool._create_connection() + + +def test_create_connection_raises_for_malicious_secret_type() -> None: + pool = DuckDBConnectionPool( + connection_config={"database": ":memory:"}, + secrets=[ + { + "name": "safe_secret", + "secret_type": "S3); DROP TABLE secrets--", + "value": {"key_id": "abc", "secret": "xyz"}, + } + ], + ) + + with pytest.raises(ValueError, match="secret_type"): + pool._create_connection() diff --git a/tests/unit/adapters/test_duckdb/test_pool_memory_leak.py b/tests/unit/adapters/test_duckdb/test_pool_memory_leak.py new file mode 100644 index 000000000..d50a8844c --- /dev/null +++ b/tests/unit/adapters/test_duckdb/test_pool_memory_leak.py @@ -0,0 +1,37 @@ +"""Regression tests for DuckDB pool write-only attribute removal.""" + +import pytest + +pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package") + +from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool + + +def test_pool_has_no_connection_times_attribute() -> None: + pool = DuckDBConnectionPool({"database": ":memory:"}) + + assert not hasattr(pool, "_connection_times") + + +def test_pool_has_no_created_connections_attribute() -> None: + pool = DuckDBConnectionPool({"database": ":memory:"}) + + assert not hasattr(pool, "_created_connections") + + +def test_pool_slots_do_not_contain_removed_attrs() -> None: + slots = DuckDBConnectionPool.__slots__ + + assert "_connection_times" not in slots + assert "_created_connections" not in slots + + +def test_pool_creates_connection_after_attribute_removal() -> None: + pool = DuckDBConnectionPool({"database": ":memory:"}) + try: + conn = pool.acquire() + row = conn.execute("SELECT 42").fetchone() + finally: + pool.close() + + assert row == (42,) diff --git a/tests/unit/adapters/test_mssql_python/test_config.py b/tests/unit/adapters/test_mssql_python/test_config.py index 062062878..fd0cad017 100644 --- a/tests/unit/adapters/test_mssql_python/test_config.py +++ b/tests/unit/adapters/test_mssql_python/test_config.py @@ -1,9 +1,11 @@ """Unit tests for mssql_python adapter configuration.""" +import warnings from typing import Any import pytest +import sqlspec.adapters.mssql_python.config as _mssql_config from sqlspec.adapters.mssql_python.config import MssqlPythonConfig, MssqlPythonConnectionPool @@ -21,6 +23,7 @@ def test_connection_pool_configures_mssql_python_pooling(monkeypatch: pytest.Mon """The SQLSpec pool wrapper should configure driver-level pooling before connect.""" calls: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] connection = DummyConnection() + monkeypatch.setattr(_mssql_config, "_POOLING_PARAMS", None) def fake_pooling(*args: Any, **kwargs: Any) -> None: calls.append(("pooling", args, kwargs)) @@ -48,6 +51,7 @@ def fake_connect(connection_string: str, **kwargs: Any) -> DummyConnection: def test_config_create_pool_splits_connection_and_pool_options(monkeypatch: pytest.MonkeyPatch) -> None: """MssqlPythonConfig should pass connection options and pool options to the right APIs.""" pooling_calls: list[dict[str, Any]] = [] + monkeypatch.setattr(_mssql_config, "_POOLING_PARAMS", None) def fake_pooling(**kwargs: Any) -> None: pooling_calls.append(kwargs) @@ -75,6 +79,7 @@ def test_config_connection_hook_runs_for_session_connections(monkeypatch: pytest """on_connection_create should run for connections acquired by provide_session().""" connection = DummyConnection() seen: list[DummyConnection] = [] + monkeypatch.setattr(_mssql_config, "_POOLING_PARAMS", None) monkeypatch.setattr("sqlspec.adapters.mssql_python.config.MSSQL_PYTHON_MODULE.pooling", lambda **_: None) monkeypatch.setattr( @@ -89,3 +94,41 @@ def test_config_connection_hook_runs_for_session_connections(monkeypatch: pytest pass assert seen == [connection] + + +def test_second_pool_warns_on_different_params(monkeypatch: pytest.MonkeyPatch) -> None: + """A second pool with different process-wide pooling params emits one warning.""" + monkeypatch.setattr(_mssql_config, "_POOLING_PARAMS", None) + monkeypatch.setattr("sqlspec.adapters.mssql_python.config.MSSQL_PYTHON_MODULE.pooling", lambda **_: None) + monkeypatch.setattr( + "sqlspec.adapters.mssql_python.config.MSSQL_PYTHON_MODULE.connect", lambda *_args, **_kwargs: DummyConnection() + ) + + MssqlPythonConnectionPool(connection_string="Server=srv1;", max_size=10, idle_timeout=60, enabled=True) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + MssqlPythonConnectionPool(connection_string="Server=srv2;", max_size=20, idle_timeout=60, enabled=True) + + assert len(caught) == 1 + message = str(caught[0].message) + assert "overwriting" in message + assert "(10, 60, True)" in message + assert "(20, 60, True)" in message + + +def test_second_pool_same_params_no_warn(monkeypatch: pytest.MonkeyPatch) -> None: + """A second pool with identical process-wide pooling params emits no warning.""" + monkeypatch.setattr(_mssql_config, "_POOLING_PARAMS", None) + monkeypatch.setattr("sqlspec.adapters.mssql_python.config.MSSQL_PYTHON_MODULE.pooling", lambda **_: None) + monkeypatch.setattr( + "sqlspec.adapters.mssql_python.config.MSSQL_PYTHON_MODULE.connect", lambda *_args, **_kwargs: DummyConnection() + ) + + MssqlPythonConnectionPool(connection_string="Server=srv1;", max_size=10, idle_timeout=60, enabled=True) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + MssqlPythonConnectionPool(connection_string="Server=srv2;", max_size=10, idle_timeout=60, enabled=True) + + assert len(caught) == 0 diff --git a/tests/unit/adapters/test_mssql_python/test_core.py b/tests/unit/adapters/test_mssql_python/test_core.py index e1e12791b..093cf6483 100644 --- a/tests/unit/adapters/test_mssql_python/test_core.py +++ b/tests/unit/adapters/test_mssql_python/test_core.py @@ -38,6 +38,24 @@ def test_build_connection_config_from_parts_formats_odbc_options() -> None: assert kwargs == {"autocommit": True} +def test_build_connection_config_no_duplicate_uid() -> None: + """Passing both 'uid' and 'user' produces exactly one UID= option.""" + connection_string, _ = build_connection_config({"server": "srv", "uid": "user1", "user": "user2"}) + + assert connection_string.count("UID=") == 1 + assert "UID=user1" in connection_string + assert "user=user2" not in connection_string + + +def test_build_connection_config_no_duplicate_pwd() -> None: + """Passing both 'pwd' and 'password' produces exactly one PWD= option.""" + connection_string, _ = build_connection_config({"server": "srv", "pwd": "secret1", "password": "secret2"}) + + assert connection_string.count("PWD=") == 1 + assert "PWD=secret1" in connection_string + assert "password=secret2" not in connection_string + + def test_create_mapped_exception_extracts_sql_server_error_number() -> None: """SQL Server native error numbers should map to specific SQLSpec exceptions.""" exc = MSSQL_PYTHON_MODULE.IntegrityError( diff --git a/tests/unit/adapters/test_mysql_execute_script.py b/tests/unit/adapters/test_mysql_execute_script.py new file mode 100644 index 000000000..d6ac61ae7 --- /dev/null +++ b/tests/unit/adapters/test_mysql_execute_script.py @@ -0,0 +1,109 @@ +"""Unit tests for MySQL-family execute_script parameter handling.""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from sqlspec.adapters.aiomysql.driver import AiomysqlDriver +from sqlspec.adapters.asyncmy.driver import AsyncmyDriver +from sqlspec.adapters.mysqlconnector.driver import MysqlConnectorAsyncDriver, MysqlConnectorSyncDriver +from sqlspec.adapters.pymysql.driver import PyMysqlDriver +from sqlspec.exceptions import SQLSpecError + +_ERROR_MESSAGE = "execute_script with parameters is not supported for multi-statement scripts" +_MULTI_STMT_SQL = "SELECT 1; SELECT 2" + + +def _make_statement(driver: Any) -> MagicMock: + statement = MagicMock() + statement.statement_config = driver.statement_config + return statement + + +def _make_sync_cursor() -> MagicMock: + cursor = MagicMock() + cursor.rowcount = 1 + cursor.description = None + cursor.lastrowid = None + return cursor + + +def _make_async_cursor() -> AsyncMock: + cursor = AsyncMock() + cursor.rowcount = 1 + cursor.description = None + cursor.lastrowid = None + return cursor + + +@pytest.mark.parametrize("driver_factory", [PyMysqlDriver, MysqlConnectorSyncDriver]) +def test_sync_mysql_execute_script_multi_statement_with_params_raises(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=MagicMock()) + statement = _make_statement(driver) + + with patch.object(driver_factory, "_get_compiled_sql", return_value=(_MULTI_STMT_SQL, (42,))): + with pytest.raises(SQLSpecError, match=_ERROR_MESSAGE): + driver.dispatch_execute_script(_make_sync_cursor(), statement) + + +@pytest.mark.parametrize("driver_factory", [PyMysqlDriver, MysqlConnectorSyncDriver]) +def test_sync_mysql_execute_script_single_statement_with_params_executes(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=MagicMock()) + statement = _make_statement(driver) + cursor = _make_sync_cursor() + + with patch.object(driver_factory, "_get_compiled_sql", return_value=("SELECT 1", (42,))): + result = driver.dispatch_execute_script(cursor, statement) + + assert result.statement_count == 1 + cursor.execute.assert_called_once() + + +@pytest.mark.parametrize("driver_factory", [PyMysqlDriver, MysqlConnectorSyncDriver]) +def test_sync_mysql_execute_script_multi_statement_without_params_executes(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=MagicMock()) + statement = _make_statement(driver) + cursor = _make_sync_cursor() + + with patch.object(driver_factory, "_get_compiled_sql", return_value=(_MULTI_STMT_SQL, None)): + result = driver.dispatch_execute_script(cursor, statement) + + assert result.statement_count == 2 + assert result.successful_statements == 2 + + +@pytest.mark.parametrize("driver_factory", [AiomysqlDriver, AsyncmyDriver, MysqlConnectorAsyncDriver]) +async def test_async_mysql_execute_script_multi_statement_with_params_raises(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=AsyncMock()) + statement = _make_statement(driver) + + with patch.object(driver_factory, "_get_compiled_sql", return_value=(_MULTI_STMT_SQL, (42,))): + with pytest.raises(SQLSpecError, match=_ERROR_MESSAGE): + await driver.dispatch_execute_script(_make_async_cursor(), statement) + + +@pytest.mark.parametrize("driver_factory", [AiomysqlDriver, AsyncmyDriver, MysqlConnectorAsyncDriver]) +async def test_async_mysql_execute_script_single_statement_with_params_executes(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=AsyncMock()) + statement = _make_statement(driver) + cursor = _make_async_cursor() + + with patch.object(driver_factory, "_get_compiled_sql", return_value=("SELECT 1", (42,))): + result = await driver.dispatch_execute_script(cursor, statement) + + assert result.statement_count == 1 + cursor.execute.assert_awaited_once() + + +@pytest.mark.parametrize("driver_factory", [AiomysqlDriver, AsyncmyDriver, MysqlConnectorAsyncDriver]) +async def test_async_mysql_execute_script_multi_statement_without_params_executes(driver_factory: type[Any]) -> None: + driver = driver_factory(connection=AsyncMock()) + statement = _make_statement(driver) + cursor = _make_async_cursor() + + with patch.object(driver_factory, "_get_compiled_sql", return_value=(_MULTI_STMT_SQL, None)): + result = await driver.dispatch_execute_script(cursor, statement) + + assert result.statement_count == 2 + assert result.successful_statements == 2 diff --git a/tests/unit/adapters/test_oracledb/test_dispatch_execute_force_select.py b/tests/unit/adapters/test_oracledb/test_dispatch_execute_force_select.py new file mode 100644 index 000000000..3545977f3 --- /dev/null +++ b/tests/unit/adapters/test_oracledb/test_dispatch_execute_force_select.py @@ -0,0 +1,78 @@ +"""Unit tests for OracleSyncDriver.dispatch_execute force-select parity.""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("oracledb") + +from sqlspec.adapters.oracledb.driver import OracleSyncDriver + + +def _make_sync_driver() -> OracleSyncDriver: + connection = MagicMock() + connection.in_transaction = False + return OracleSyncDriver(connection=connection) + + +def test_sync_dispatch_execute_force_select_fallback() -> None: + """The sync driver must fetch rows when cursor metadata proves rows exist.""" + driver = _make_sync_driver() + statement = MagicMock() + statement.returns_rows.return_value = False + statement.operation_type = "COMMAND" + + cursor = MagicMock() + cursor.description = [("id", None, None, None, None, None, None)] + cursor.fetchall.return_value = [(42,)] + del cursor.statement_type + + with ( + patch.object(OracleSyncDriver, "_get_compiled_sql", return_value=("SELECT 1 FROM DUAL", {})), + patch("sqlspec.adapters.oracledb.driver.coerce_large_parameters_sync", return_value={}), + patch.object(OracleSyncDriver, "_resolve_row_metadata", return_value=(["id"], False)), + patch("sqlspec.adapters.oracledb.driver.collect_sync_rows", return_value=([(42,)], ["id"])), + ): + result = driver.dispatch_execute(cursor, statement) + + assert result.is_select_result is True + assert result.data_row_count == 1 + cursor.fetchall.assert_called_once() + + +def test_sync_dispatch_execute_no_force_select_when_returns_rows_false_and_no_description() -> None: + """The sync driver must keep DML behavior when no row metadata exists.""" + driver = _make_sync_driver() + statement = MagicMock() + statement.returns_rows.return_value = False + statement.operation_type = "COMMAND" + + cursor = MagicMock() + cursor.description = None + cursor.rowcount = 3 + del cursor.statement_type + + with ( + patch.object(OracleSyncDriver, "_get_compiled_sql", return_value=("DELETE FROM t", {})), + patch("sqlspec.adapters.oracledb.driver.coerce_large_parameters_sync", return_value={}), + patch("sqlspec.adapters.oracledb.driver.resolve_rowcount", return_value=3), + ): + result = driver.dispatch_execute(cursor, statement) + + assert result.is_select_result is False + assert result.rowcount_override == 3 + cursor.fetchall.assert_not_called() + + +def test_sync_and_async_dispatch_execute_both_call_should_force_select() -> None: + """Both Oracle dispatch implementations should use the force-select fallback.""" + import inspect + + from sqlspec.adapters.oracledb.driver import OracleAsyncDriver + + sync_src = inspect.getsource(OracleSyncDriver.dispatch_execute) + async_src = inspect.getsource(OracleAsyncDriver.dispatch_execute) + + assert "_should_force_select" in sync_src + assert "_should_force_select" in async_src + assert "is_select_like" in sync_src diff --git a/tests/unit/adapters/test_oracledb/test_vector_handlers.py b/tests/unit/adapters/test_oracledb/test_vector_handlers.py index f21437a4d..d53c5d5b4 100644 --- a/tests/unit/adapters/test_oracledb/test_vector_handlers.py +++ b/tests/unit/adapters/test_oracledb/test_vector_handlers.py @@ -55,7 +55,7 @@ def test_input_handler_claims_list_of_float() -> None: kwargs = cursor.var.call_args.kwargs assert kwargs["arraysize"] == 1 assert callable(kwargs["inconverter"]) - packed = kwargs["inconverter"](None) + packed = kwargs["inconverter"]([1.0, 2.0, 3.0]) assert isinstance(packed, array.array) assert packed.typecode == "f" assert list(packed) == [1.0, 2.0, 3.0] @@ -71,7 +71,7 @@ def test_input_handler_claims_tuple_of_float() -> None: result = _input_type_handler(cursor, (0.5, 0.25, 0.125), 1) assert result == "var-token" - packed = cursor.var.call_args.kwargs["inconverter"](None) + packed = cursor.var.call_args.kwargs["inconverter"]((0.5, 0.25, 0.125)) assert packed.typecode == "f" assert list(packed) == [0.5, 0.25, 0.125] @@ -85,7 +85,7 @@ def test_input_handler_packs_int_in_int8_range_as_int8() -> None: _input_type_handler(cursor, [-128, 0, 64, 127], 1) - packed = cursor.var.call_args.kwargs["inconverter"](None) + packed = cursor.var.call_args.kwargs["inconverter"]([-128, 0, 64, 127]) assert packed.typecode == "b" assert list(packed) == [-128, 0, 64, 127] @@ -99,7 +99,7 @@ def test_input_handler_packs_int_outside_int8_range_as_float32() -> None: _input_type_handler(cursor, [0, 200], 1) - packed = cursor.var.call_args.kwargs["inconverter"](None) + packed = cursor.var.call_args.kwargs["inconverter"]([0, 200]) assert packed.typecode == "f" assert list(packed) == [0.0, 200.0] @@ -294,3 +294,71 @@ def test_output_handler_default_format_when_attr_missing() -> None: assert result == "var-token" cursor.var.assert_called_once() + + +def test_vector_executemany_each_row_packed() -> None: + """Each executemany row must receive its own packed vector array.""" + from sqlspec.adapters.oracledb._vector_handlers import ( + _input_type_handler, # pyright: ignore[reportPrivateUsage] + _pack_sequence_converter, # pyright: ignore[reportPrivateUsage] + ) + + rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + cursor = _mock_cursor() + cursor.var.return_value = "var-token" + + result = _input_type_handler(cursor, rows[0], 1) + + assert result == "var-token" + inconverter = cursor.var.call_args.kwargs["inconverter"] + assert inconverter is _pack_sequence_converter + + packed_results = [inconverter(row) for row in rows] + + assert len(packed_results) == 3 + for row, packed in zip(rows, packed_results): + assert isinstance(packed, array.array) + assert packed.typecode == "f" + assert list(packed) == [float(value) for value in row] + assert packed_results[0] is not packed_results[1] + assert packed_results[1] is not packed_results[2] + assert packed_results[0] is not packed_results[2] + + +def test_pack_sequence_converter_is_module_level_named_function() -> None: + """The vector inconverter must be a named module-level function.""" + import inspect + + from sqlspec.adapters.oracledb._vector_handlers import ( + _pack_sequence_converter, # pyright: ignore[reportPrivateUsage] + ) + + assert inspect.isfunction(_pack_sequence_converter) + assert _pack_sequence_converter.__name__ == "_pack_sequence_converter" + assert "" not in repr(_pack_sequence_converter) + + +def test_pack_sequence_converter_float_list() -> None: + """_pack_sequence_converter packs float lists as float32.""" + from sqlspec.adapters.oracledb._vector_handlers import ( + _pack_sequence_converter, # pyright: ignore[reportPrivateUsage] + ) + + result = _pack_sequence_converter([1.5, 2.5, 3.5]) + + assert isinstance(result, array.array) + assert result.typecode == "f" + assert list(result) == [1.5, 2.5, 3.5] + + +def test_pack_sequence_converter_int8_range() -> None: + """_pack_sequence_converter packs small int lists as int8.""" + from sqlspec.adapters.oracledb._vector_handlers import ( + _pack_sequence_converter, # pyright: ignore[reportPrivateUsage] + ) + + result = _pack_sequence_converter([-128, 0, 127]) + + assert isinstance(result, array.array) + assert result.typecode == "b" + assert list(result) == [-128, 0, 127] diff --git a/tests/unit/adapters/test_psqlpy/test_driver.py b/tests/unit/adapters/test_psqlpy/test_driver.py new file mode 100644 index 000000000..27d391979 --- /dev/null +++ b/tests/unit/adapters/test_psqlpy/test_driver.py @@ -0,0 +1,60 @@ +"""Psqlpy driver regression tests.""" + +from types import SimpleNamespace +from typing import Any + +import pytest + +from sqlspec.adapters.psqlpy.core import default_statement_config +from sqlspec.adapters.psqlpy.driver import PsqlpyDriver +from sqlspec.exceptions import SQLSpecError + + +class _Cursor: + def __init__(self) -> None: + self.execute_calls: list[tuple[Any, ...]] = [] + + async def execute(self, *args: Any) -> str: + self.execute_calls.append(args) + return "OK" + + +class _Driver(PsqlpyDriver): + def __init__(self, compiled_sql: str, parameters: object = None) -> None: + super().__init__(connection=object()) + self.compiled_sql = compiled_sql + self.compiled_parameters = parameters + + def _get_compiled_sql(self, *_args: object, **_kwargs: object) -> tuple[str, object]: + return self.compiled_sql, self.compiled_parameters + + +@pytest.mark.anyio +async def test_psqlpy_execute_script_rejects_multi_statement_parameters() -> None: + driver = _Driver("INSERT INTO t VALUES ($1); INSERT INTO t VALUES ($1)", [1]) + statement = SimpleNamespace(statement_config=default_statement_config) + + with pytest.raises(SQLSpecError, match="multi-statement"): + await driver.dispatch_execute_script(_Cursor(), statement) + + +@pytest.mark.anyio +async def test_psqlpy_execute_script_uses_empty_params_for_each_sub_statement() -> None: + driver = _Driver("SELECT 1; SELECT 2") + statement = SimpleNamespace(statement_config=default_statement_config) + cursor = _Cursor() + + await driver.dispatch_execute_script(cursor, statement) + + assert cursor.execute_calls == [("SELECT 1", []), ("SELECT 2", [])] + + +@pytest.mark.anyio +async def test_psqlpy_execute_script_allows_single_statement_parameters_with_empty_driver_params() -> None: + driver = _Driver("INSERT INTO t VALUES ($1)", [1]) + statement = SimpleNamespace(statement_config=default_statement_config) + cursor = _Cursor() + + await driver.dispatch_execute_script(cursor, statement) + + assert cursor.execute_calls == [("INSERT INTO t VALUES ($1)", [])] diff --git a/tests/unit/adapters/test_psycopg/test_driver.py b/tests/unit/adapters/test_psycopg/test_driver.py new file mode 100644 index 000000000..54baecbb7 --- /dev/null +++ b/tests/unit/adapters/test_psycopg/test_driver.py @@ -0,0 +1,152 @@ +"""Psycopg driver regression tests.""" + +from collections.abc import Iterator +from types import SimpleNamespace +from typing import Any + +import pytest +from typing_extensions import Self + +from sqlspec.adapters.psycopg.core import default_statement_config +from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver +from sqlspec.exceptions import SQLSpecError + + +class _SyncCopyContext: + def __init__(self, rows: list[bytes] | None = None) -> None: + self.rows = rows or [] + self.writes: list[bytes] = [] + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_args: object) -> None: + return None + + def __iter__(self) -> Iterator[bytes]: + return iter(self.rows) + + def write(self, data: bytes) -> None: + self.writes.append(data) + + +class _SyncCursor: + def __init__(self) -> None: + self.rowcount = 0 + self.description = None + self.copy_calls: list[str] = [] + self.execute_calls: list[tuple[Any, ...]] = [] + self.copy_context = _SyncCopyContext([b"exported"]) + + def copy(self, sql: str) -> _SyncCopyContext: + self.copy_calls.append(sql) + return self.copy_context + + def execute(self, *args: Any) -> None: + self.execute_calls.append(args) + + +class _AsyncCopyContext: + def __init__(self, rows: list[bytes] | None = None) -> None: + self.rows = rows or [] + self.writes: list[bytes] = [] + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *_args: object) -> None: + return None + + def __aiter__(self) -> "_AsyncCopyContext": + self._iter = iter(self.rows) + return self + + async def __anext__(self) -> bytes: + try: + return next(self._iter) + except StopIteration as exc: + raise StopAsyncIteration from exc + + async def write(self, data: bytes) -> None: + self.writes.append(data) + + +class _AsyncCursor: + def __init__(self) -> None: + self.rowcount = 0 + self.description = None + self.copy_calls: list[str] = [] + self.execute_calls: list[tuple[Any, ...]] = [] + self.copy_context = _AsyncCopyContext([b"exported"]) + + def copy(self, sql: str) -> _AsyncCopyContext: + self.copy_calls.append(sql) + return self.copy_context + + async def execute(self, *args: Any) -> None: + self.execute_calls.append(args) + + +class _SyncDriver(PsycopgSyncDriver): + def __init__(self, compiled_sql: str, parameters: object = None) -> None: + super().__init__(connection=object()) + self.compiled_sql = compiled_sql + self.compiled_parameters = parameters + + def _get_compiled_sql(self, *_args: object, **_kwargs: object) -> tuple[str, object]: + return self.compiled_sql, self.compiled_parameters + + +class _AsyncDriver(PsycopgAsyncDriver): + def __init__(self, compiled_sql: str, parameters: object = None) -> None: + super().__init__(connection=object()) + self.compiled_sql = compiled_sql + self.compiled_parameters = parameters + + def _get_compiled_sql(self, *_args: object, **_kwargs: object) -> tuple[str, object]: + return self.compiled_sql, self.compiled_parameters + + +def test_psycopg_sync_execute_script_rejects_multi_statement_parameters() -> None: + driver = _SyncDriver("INSERT INTO t VALUES (%s); INSERT INTO t VALUES (%s)", [1]) + statement = SimpleNamespace(statement_config=default_statement_config) + + with pytest.raises(SQLSpecError, match="multi-statement"): + driver.dispatch_execute_script(_SyncCursor(), statement) + + +@pytest.mark.anyio +async def test_psycopg_async_execute_script_rejects_multi_statement_parameters() -> None: + driver = _AsyncDriver("INSERT INTO t VALUES (%s); INSERT INTO t VALUES (%s)", [1]) + statement = SimpleNamespace(statement_config=default_statement_config) + + with pytest.raises(SQLSpecError, match="multi-statement"): + await driver.dispatch_execute_script(_AsyncCursor(), statement) + + +@pytest.mark.anyio +async def test_psycopg_async_copy_from_uses_copy_for_program_variant() -> None: + driver = _AsyncDriver("COPY users FROM PROGRAM 'cat data.csv'") + statement = SimpleNamespace( + operation_type="COPY_FROM", parameters="payload", statement_config=default_statement_config + ) + cursor = _AsyncCursor() + + await driver.dispatch_special_handling(cursor, statement) + + assert cursor.copy_calls == ["COPY users FROM PROGRAM 'cat data.csv'"] + assert cursor.execute_calls == [] + + +@pytest.mark.anyio +async def test_psycopg_async_copy_to_uses_copy_for_file_variant() -> None: + driver = _AsyncDriver("COPY users TO '/tmp/users.csv'") + statement = SimpleNamespace(operation_type="COPY_TO", parameters=None, statement_config=default_statement_config) + cursor = _AsyncCursor() + + result = await driver.dispatch_special_handling(cursor, statement) + + assert cursor.copy_calls == ["COPY users TO '/tmp/users.csv'"] + assert cursor.execute_calls == [] + assert result is not None + assert result.data == [{"copy_output": "exported"}] diff --git a/tests/unit/adapters/test_spanner/test_config.py b/tests/unit/adapters/test_spanner/test_config.py index 761253266..2226a1f26 100644 --- a/tests/unit/adapters/test_spanner/test_config.py +++ b/tests/unit/adapters/test_spanner/test_config.py @@ -208,3 +208,32 @@ def snapshot(self, multi_use: bool = False): with config.provide_write_session() as driver: assert isinstance(driver.connection, _Txn) + + +def test_create_connection_delegates_to_get_database() -> None: + """create_connection should use get_database rather than rebuilding Database.""" + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) + sentinel = object() + get_database_call_count = 0 + + class _DB: + def snapshot(self): + return sentinel + + def _get_database() -> _DB: + nonlocal get_database_call_count + get_database_call_count += 1 + return _DB() + + def _unexpected_get_client() -> object: + raise AssertionError("create_connection should delegate to get_database") + + config.get_database = _get_database # type: ignore[assignment] + config._get_client = _unexpected_get_client # type: ignore[method-assign] + config.connection_instance = object() # type: ignore[assignment] + + assert config.create_connection() is sentinel + assert get_database_call_count == 1 + + assert config.create_connection() is sentinel + assert get_database_call_count == 2 diff --git a/tests/unit/adapters/test_spanner/test_driver.py b/tests/unit/adapters/test_spanner/test_driver.py index cbc92cf75..d2cb033a9 100644 --- a/tests/unit/adapters/test_spanner/test_driver.py +++ b/tests/unit/adapters/test_spanner/test_driver.py @@ -1,3 +1,4 @@ +import inspect from unittest.mock import MagicMock, Mock import pyarrow as pa @@ -141,3 +142,74 @@ def _mock_infer_param_types(params: dict[str, object] | list[object] | tuple[obj assert result.telemetry["rows_processed"] == 2 assert infer_call_count == 1 mock_transaction.batch_update.assert_called_once() + + +def test_dispatch_execute_script_cte_select_detected_as_select(mock_transaction: MagicMock) -> None: + driver = SpannerSyncDriver(mock_transaction) + mock_result = MagicMock() + mock_result.__iter__ = MagicMock(return_value=iter([])) + mock_transaction.execute_sql = MagicMock(return_value=mock_result) + mock_transaction.execute_sql.return_value = mock_result + mock_transaction.execute_update.side_effect = AssertionError("execute_update called for SELECT CTE") + + statement = driver.prepare_statement( + "WITH cte AS (SELECT id FROM users) SELECT * FROM cte", statement_config=driver.statement_config + ) + driver.dispatch_execute_script(mock_transaction, statement) + + mock_transaction.execute_sql.assert_called_once() + mock_transaction.execute_update.assert_not_called() + + +def test_dispatch_execute_script_plain_select_detected_as_select(mock_transaction: MagicMock) -> None: + driver = SpannerSyncDriver(mock_transaction) + mock_result = MagicMock() + mock_result.__iter__ = MagicMock(return_value=iter([])) + mock_transaction.execute_sql = MagicMock(return_value=mock_result) + mock_transaction.execute_sql.return_value = mock_result + mock_transaction.execute_update.side_effect = AssertionError("execute_update should not be called for SELECT") + + statement = driver.prepare_statement("SELECT 1", statement_config=driver.statement_config) + driver.dispatch_execute_script(mock_transaction, statement) + + mock_transaction.execute_sql.assert_called_once() + mock_transaction.execute_update.assert_not_called() + + +def test_dispatch_execute_script_insert_detected_as_non_select(mock_transaction: MagicMock) -> None: + driver = SpannerSyncDriver(mock_transaction) + mock_transaction.execute_sql = MagicMock() + mock_transaction.execute_update.return_value = 1 + + statement = driver.prepare_statement( + "INSERT INTO users (id, name) VALUES (1, 'Alice')", statement_config=driver.statement_config + ) + driver.dispatch_execute_script(mock_transaction, statement) + + mock_transaction.execute_update.assert_called_once() + mock_transaction.execute_sql.assert_not_called() + + +def test_dispatch_execute_many_no_shadowed_local_aliases() -> None: + source = inspect.getsource(SpannerSyncDriver.dispatch_execute_many) + + assert "coerce_params = self._coerce_params" not in source + assert "infer_param_types = self._infer_param_types" not in source + assert "_coerce = self._coerce_params" in source + assert "_infer = self._infer_param_types" in source + assert "_coerce(cast" in source + assert "_infer(coerced_params)" in source + + +def test_module_level_coerce_params_not_shadowed_by_local_alias() -> None: + import sqlspec.adapters.spanner.driver as driver_module + from sqlspec.adapters.spanner.core import coerce_params as core_coerce_params + + assert driver_module.coerce_params is core_coerce_params + + +def test_module_level_infer_param_types_not_shadowed_by_local_alias() -> None: + import sqlspec.adapters.spanner.driver as driver_module + from sqlspec.adapters.spanner.core import infer_param_types as core_infer_param_types + + assert driver_module.infer_param_types is core_infer_param_types diff --git a/tests/unit/adapters/test_sqlite/test_pool_no_duplicate_typedef.py b/tests/unit/adapters/test_sqlite/test_pool_no_duplicate_typedef.py new file mode 100644 index 000000000..ccd673af6 --- /dev/null +++ b/tests/unit/adapters/test_sqlite/test_pool_no_duplicate_typedef.py @@ -0,0 +1,44 @@ +"""Regression tests for the sqlite pool module public surface.""" + + +def test_sqlite_connection_params_not_exported_from_pool() -> None: + import sqlspec.adapters.sqlite.pool as pool_mod + + assert not hasattr(pool_mod, "SqliteConnectionParams") + + +def test_sqlite_connection_pool_still_importable() -> None: + from sqlspec.adapters.sqlite.pool import SqliteConnectionPool + + assert SqliteConnectionPool is not None + + +def test_pool_module_all_unchanged() -> None: + import sqlspec.adapters.sqlite.pool as pool_mod + + assert pool_mod.__all__ == ("SqliteConnectionPool",) + + +def test_canonical_typedef_still_importable_from_config() -> None: + from sqlspec.adapters.sqlite.config import SqliteConnectionParams + + assert hasattr(SqliteConnectionParams, "__annotations__") or hasattr(SqliteConnectionParams, "__required_keys__") + + +def test_canonical_typedef_importable_from_package() -> None: + from sqlspec.adapters.sqlite import SqliteConnectionParams + + assert SqliteConnectionParams is not None + + +def test_pool_creates_connection_after_cleanup() -> None: + from sqlspec.adapters.sqlite.pool import SqliteConnectionPool + + pool = SqliteConnectionPool(connection_parameters={"database": ":memory:"}) + conn = pool.acquire() + cursor = conn.execute("SELECT 1 AS n") + row = cursor.fetchone() + pool.close() + + assert row is not None + assert row[0] == 1 diff --git a/tests/unit/adapters/test_sync_adapters.py b/tests/unit/adapters/test_sync_adapters.py index 6ec74f5f0..56cdb2bf6 100644 --- a/tests/unit/adapters/test_sync_adapters.py +++ b/tests/unit/adapters/test_sync_adapters.py @@ -508,6 +508,33 @@ def test_sync_driver_special_handling_integration(sqlite_sync_driver: SqliteDriv mock_special.assert_called_once() +def test_sync_driver_special_handling_takes_priority_over_script_in_fast_path(sqlite_sync_driver: SqliteDriver) -> None: + assert sqlite_sync_driver._observability is None or sqlite_sync_driver._observability.is_idle + + statement = SQL( + "INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob');", + statement_config=sqlite_sync_driver.statement_config, + is_script=True, + ) + sentinel = SQLResult( + statement=SQL("SELECT 1", statement_config=sqlite_sync_driver.statement_config), rows_affected=999 + ) + + def _special_handling_returning_sentinel(_cursor: Any, stmt: SQL) -> SQLResult | None: + if stmt.is_script: + return sentinel + return None + + with ( + patch.object(sqlite_sync_driver, "dispatch_special_handling", side_effect=_special_handling_returning_sentinel), + patch.object(sqlite_sync_driver, "dispatch_execute_script") as mock_execute_script, + ): + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) + + assert result is sentinel + mock_execute_script.assert_not_called() + + def test_sync_driver_error_handling_in_dispatch(sqlite_sync_driver: SqliteDriver) -> None: """Test error handling during statement dispatch.""" statement = SQL("SELECT * FROM users", statement_config=sqlite_sync_driver.statement_config) diff --git a/tests/unit/base/test_sqlspec_class.py b/tests/unit/base/test_sqlspec_class.py index 2f27c6871..1f3dfaceb 100644 --- a/tests/unit/base/test_sqlspec_class.py +++ b/tests/unit/base/test_sqlspec_class.py @@ -152,8 +152,8 @@ def test_get_cache_stats_returns_statistics() -> None: def test_reset_cache_stats_clears_statistics() -> None: - """Test that reset_cache_stats clears all cache statistics.""" - SQLSpec.reset_cache_stats() + """Test that reset_stats_only clears all cache statistics.""" + SQLSpec.reset_stats_only() stats = SQLSpec.get_cache_stats() multi_stats = stats["namespaced"] @@ -292,7 +292,7 @@ def stats_worker() -> None: try: for _ in range(50): stats = SQLSpec.get_cache_stats() - SQLSpec.reset_cache_stats() + SQLSpec.reset_stats_only() multi_stats = stats["namespaced"] total_ops = multi_stats.hits + multi_stats.misses results.append(total_ops) @@ -565,7 +565,7 @@ def test_statistics_collection_during_configuration_changes() -> None: multi_stats = stats["namespaced"] assert hasattr(multi_stats, "hit_rate") - SQLSpec.reset_cache_stats() + SQLSpec.reset_stats_only() finally: SQLSpec.update_cache_config(original_config) diff --git a/tests/unit/builder/test_alter_table.py b/tests/unit/builder/test_alter_table.py new file mode 100644 index 000000000..16b9b5f8e --- /dev/null +++ b/tests/unit/builder/test_alter_table.py @@ -0,0 +1,50 @@ +"""Tests for AlterTable public API completeness.""" + +import pytest + +from sqlspec.builder._ddl import AlterTable +from sqlspec.exceptions import SQLBuilderError + + +def test_alter_table_in_schema_qualifies_table_name() -> None: + sql = AlterTable("users").in_schema("myschema").add_column("x", "INT").build().sql + + assert "myschema" in sql + assert "users" in sql + + +def test_alter_table_in_schema_is_chainable() -> None: + builder = AlterTable("t") + + assert builder.in_schema("s") is builder + + +def test_alter_table_set_column_default_renders_set_default() -> None: + sql = AlterTable("t").set_column_default("x", 0).build().sql + + assert "SET DEFAULT" in sql + assert "x" in sql + + +def test_alter_table_drop_column_default_renders_drop_default() -> None: + sql = AlterTable("t").drop_column_default("x").build().sql + + assert "DROP DEFAULT" in sql + assert "x" in sql + + +def test_alter_table_default_methods_are_chainable() -> None: + builder = AlterTable("t") + + assert builder.set_column_default("x", 0) is builder + assert builder.drop_column_default("x") is builder + + +def test_alter_table_set_column_default_empty_column_raises() -> None: + with pytest.raises(SQLBuilderError): + AlterTable("t").set_column_default("", 0) + + +def test_alter_table_drop_column_default_empty_column_raises() -> None: + with pytest.raises(SQLBuilderError): + AlterTable("t").drop_column_default("") diff --git a/tests/unit/builder/test_ddl_builder.py b/tests/unit/builder/test_ddl_builder.py new file mode 100644 index 000000000..af6d7391e --- /dev/null +++ b/tests/unit/builder/test_ddl_builder.py @@ -0,0 +1,87 @@ +"""Regression tests for DDL builder Wave 1 fixes.""" + +import inspect + +import pytest + +from sqlspec.builder._ddl import ( + CONSTRAINT_TYPE_FOREIGN_KEY, + VALID_FOREIGN_KEY_ACTIONS, + ColumnDefinition, + ConstraintDefinition, + CreateIndex, + CreateTable, + build_column_expression, + build_constraint_expression, +) +from sqlspec.exceptions import SQLBuilderError + + +def test_auto_increment_column_renders_auto_increment() -> None: + expr = build_column_expression(ColumnDefinition("id", "INT", auto_increment=True)) + + assert "AUTO_INCREMENT" in expr.sql() + + +def test_create_index_using_renders_using_clause() -> None: + result = CreateIndex("idx").on_table("t").columns("a").using("BTREE").build() + + assert "USING" in result.sql + assert "BTREE" in result.sql + + +def test_create_index_using_without_columns_renders_using_clause() -> None: + result = CreateIndex("idx").on_table("t").using("HASH").build() + + assert "USING" in result.sql + assert "HASH" in result.sql + + +def test_foreign_key_deferrable_initially_deferred_renders_clause() -> None: + result = ( + CreateTable("t") + .column("user_id", "INT") + .foreign_key_constraint("user_id", "users", "id", deferrable=True, initially_deferred=True) + .build() + ) + + assert "DEFERRABLE" in result.sql + assert "INITIALLY DEFERRED" in result.sql + + +def test_foreign_key_deferrable_initially_immediate_renders_clause() -> None: + constraint = ConstraintDefinition( + constraint_type=CONSTRAINT_TYPE_FOREIGN_KEY, + columns=["user_id"], + references_table="users", + references_columns=["id"], + deferrable=True, + initially_deferred=False, + ) + expr = build_constraint_expression(constraint) + + assert expr is not None + sql = expr.sql() + assert "DEFERRABLE" in sql + assert "INITIALLY IMMEDIATE" in sql + + +def test_valid_foreign_key_actions_excludes_none() -> None: + assert None not in VALID_FOREIGN_KEY_ACTIONS + + +def test_foreign_key_action_none_still_short_circuits_validation() -> None: + table = CreateTable("orders") + + table.foreign_key_constraint("user_id", "users", "id", on_delete=None, on_update=None) + + +def test_foreign_key_action_invalid_value_still_raises() -> None: + with pytest.raises(SQLBuilderError): + CreateTable("orders").foreign_key_constraint("user_id", "users", "id", on_delete="EXPLODE") + + +def test_ddl_module_uses_pep604_annotations() -> None: + import sqlspec.builder._ddl as ddl_module + + assert "Union" not in inspect.getsource(ddl_module) diff --git a/tests/unit/builder/test_sql_factory.py b/tests/unit/builder/test_sql_factory.py index ce3cf2010..8b22e98b8 100644 --- a/tests/unit/builder/test_sql_factory.py +++ b/tests/unit/builder/test_sql_factory.py @@ -770,6 +770,29 @@ def test_any_subquery() -> None: assert isinstance(any_expr, Any) +def test_not_any_factory_wraps_any_expression_in_not() -> None: + """Test SQLFactory.not_any_ returns NOT ANY instead of bare ANY.""" + result = SQLFactory.not_any_([1, 2, 3]) + + assert isinstance(result.expression, exp.Not) + assert isinstance(result.expression.this, exp.Any) + + +def test_any_factory_still_returns_any_expression() -> None: + """Test SQLFactory.any remains unchanged.""" + result = SQLFactory.any([1, 2]) + + assert isinstance(result.expression, exp.Any) + + +def test_not_any_factory_where_clause_contains_not() -> None: + """Test not_any_ emits NOT when composed into a WHERE clause.""" + query = sql.select("*").from_("users").where(SQLFactory.not_any_("SELECT id FROM blocked_users").expression) + stmt = query.build() + + assert "NOT" in stmt.sql.upper() + + def test_all_subquery() -> None: """Test ALL subquery functionality.""" subquery = sql.select("salary").from_("employees").where_eq("department", "Sales") diff --git a/tests/unit/config/test_assert_guards.py b/tests/unit/config/test_assert_guards.py new file mode 100644 index 000000000..f08d82ae2 --- /dev/null +++ b/tests/unit/config/test_assert_guards.py @@ -0,0 +1,48 @@ +"""Regression tests for assert-to-exception runtime guards.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from sqlspec.adapters.sqlite.config import SqliteConfig +from sqlspec.migrations.commands import SyncMigrationCommands + + +def test_get_observability_runtime_raises_runtime_error_not_assertion_error() -> None: + """A broken attach_observability implementation should hit RuntimeError guard.""" + config = SqliteConfig(connection_config={"database": ":memory:"}) + config._observability_runtime = None + + with ( + patch.object(config, "attach_observability", return_value=None), + pytest.raises(RuntimeError, match="ObservabilityRuntime was not set"), + ): + config.get_observability_runtime() + + +def test_no_pool_sync_config_init_migrations_uses_default_directory(tmp_path: Path) -> None: + """init_migrations derives directory from migration_config when argument is None.""" + migration_dir = tmp_path / "migrations" + config = SqliteConfig( + connection_config={"database": str(tmp_path / "test.db")}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "init", return_value=None) as init: + config.init_migrations() + + init.assert_called_once_with(str(migration_dir), True) + + +def test_default_serializer_raises_runtime_error_if_fallback_does_not_set_serializer(monkeypatch) -> None: + """get_default_serializer should use RuntimeError instead of assert.""" + import sqlspec.utils.serializers._json as json_module + + monkeypatch.setattr(json_module, "_default_serializer", None) + monkeypatch.setattr(json_module, "MSGSPEC_INSTALLED", False) + monkeypatch.setattr(json_module, "ORJSON_INSTALLED", False) + monkeypatch.setattr(json_module, "StandardLibSerializer", lambda: None) + + with pytest.raises(RuntimeError, match="No JSON serializer available"): + json_module.get_default_serializer() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d36efcf8f..ecb7c0e1e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -25,6 +25,32 @@ ) from sqlspec.driver import ExecutionResult +__all__ = ( + "aiosqlite_async_driver", + "cache_config_disabled", + "cache_config_enabled", + "cache_statistics_tracker", + "complex_sql_with_joins", + "mock_lru_cache", + "parameter_style_config_advanced", + "parameter_style_config_basic", + "reset_cache_state", + "reset_global_state", + "sample_delete_sql", + "sample_insert_sql", + "sample_parameters_mixed", + "sample_parameters_named", + "sample_parameters_positional", + "sample_select_sql", + "sample_update_sql", + "sql_with_typed_parameters", + "sqlite_sync_driver", + "statement_config_mysql", + "statement_config_postgres", + "statement_config_sqlite", + "test_isolation", +) + class TestSqliteDriver(SqliteDriver): """Test-friendly SQLite driver that allows patching.""" @@ -64,33 +90,6 @@ async def aiosqlite_async_driver() -> AsyncGenerator[TestAiosqliteDriver, None]: await conn.close() -__all__ = ( - "aiosqlite_async_driver", - "cache_config_disabled", - "cache_config_enabled", - "cache_statistics_tracker", - "complex_sql_with_joins", - "mock_lru_cache", - "parameter_style_config_advanced", - "parameter_style_config_basic", - "reset_cache_state", - "reset_global_state", - "sample_delete_sql", - "sample_insert_sql", - "sample_parameters_mixed", - "sample_parameters_named", - "sample_parameters_positional", - "sample_select_sql", - "sample_update_sql", - "sql_with_typed_parameters", - "sqlite_sync_driver", - "statement_config_mysql", - "statement_config_postgres", - "statement_config_sqlite", - "test_isolation", -) - - @pytest.fixture def parameter_style_config_basic() -> ParameterStyleConfig: """Basic parameter style configuration for simple test cases.""" diff --git a/tests/unit/core/test_cache.py b/tests/unit/core/test_cache.py index ac906df9a..ff0cf10ac 100644 --- a/tests/unit/core/test_cache.py +++ b/tests/unit/core/test_cache.py @@ -36,6 +36,7 @@ get_default_cache, log_cache_stats, reset_cache_stats, + reset_stats_only, update_cache_config, ) @@ -560,7 +561,7 @@ def test_namespaced_cache_namespace_isolation() -> None: def test_get_cache_statistics_aggregation() -> None: """Test cache statistics aggregation.""" - reset_cache_stats() + reset_stats_only() stats = get_cache_statistics() assert isinstance(stats, dict) @@ -568,16 +569,18 @@ def test_get_cache_statistics_aggregation() -> None: assert "namespaced" in stats -def test_reset_cache_stats_function() -> None: - """Test resetting all cache statistics.""" +def test_reset_stats_only_function_preserves_entries() -> None: + """Test resetting cache statistics without evicting entries.""" default_cache = get_default_cache() multi_cache = get_cache() test_key = CacheKey(("test",)) + default_cache.put(test_key, "cached") + multi_cache.put_statement("key", "cached") default_cache.get(test_key) multi_cache.get_statement("key") - reset_cache_stats() + reset_stats_only() default_stats = default_cache.get_stats() multi_stats = multi_cache.get_stats() @@ -586,11 +589,29 @@ def test_reset_cache_stats_function() -> None: assert default_stats.misses == 0 assert multi_stats.hits == 0 assert multi_stats.misses == 0 + assert test_key in default_cache + assert multi_cache.get_statement("key") == "cached" + + +def test_reset_cache_stats_function_warns_and_clears_entries() -> None: + """Test deprecated cache reset alias still clears cache entries.""" + default_cache = get_default_cache() + multi_cache = get_cache() + + test_key = CacheKey(("deprecated-reset",)) + default_cache.put(test_key, "cached") + multi_cache.put_statement("deprecated-key", "cached") + + with pytest.warns(DeprecationWarning, match="clear_all_caches\\(\\).*reset_stats_only\\(\\)"): + reset_cache_stats() + + assert test_key not in default_cache + assert multi_cache.get_statement("deprecated-key") is None def test_log_cache_stats_function(caplog: pytest.LogCaptureFixture) -> None: """Test logging cache statistics.""" - reset_cache_stats() + reset_stats_only() caplog.set_level(logging.DEBUG, logger="sqlspec.cache") log_cache_stats() diff --git a/tests/unit/core/test_compiler.py b/tests/unit/core/test_compiler.py index bf8201c35..24f18ea32 100644 --- a/tests/unit/core/test_compiler.py +++ b/tests/unit/core/test_compiler.py @@ -14,6 +14,7 @@ 8. Performance characteristics - Compilation speed and efficiency testing """ +import inspect import logging import threading import time @@ -29,6 +30,7 @@ from sqlspec.core import ( CompiledSQL, + OperationProfile, OperationType, ParameterProcessor, ParameterProfile, @@ -613,6 +615,96 @@ def to_delete(expression: exp.Expr, parameters: Any) -> "tuple[exp.Expr, Any]": assert "DELETE" in result.compiled_sql +def test_parameter_casts_detected_with_statement_transformer() -> None: + """Parameter casts are detected from the final transformed AST.""" + parameter_config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + default_execution_parameter_style=ParameterStyle.NUMERIC, + ) + + def rename_table(expression: exp.Expr, parameters: Any) -> "tuple[exp.Expr, Any]": + updated = expression.transform(lambda node: exp.to_table("accounts") if isinstance(node, exp.Table) else node) + return updated, parameters + + config = StatementConfig( + dialect="postgres", + parameter_config=parameter_config, + enable_caching=True, + enable_parsing=True, + enable_validation=False, + statement_transformers=(rename_table,), + ) + processor = SQLProcessor(config) + + result = processor.compile("SELECT $1::text FROM users", [1]) + + assert result.parameter_casts == {1: "TEXT"} + assert "accounts" in result.compiled_sql + + +def test_parse_cache_entry_does_not_store_parameter_casts() -> None: + """Parse cache entries should store expression, operation, and profile only.""" + parameter_config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + default_execution_parameter_style=ParameterStyle.NUMERIC, + ) + config = StatementConfig( + dialect="postgres", + parameter_config=parameter_config, + enable_caching=True, + enable_parsing=True, + enable_validation=False, + ) + processor = SQLProcessor(config) + + processor.compile("SELECT $1::text FROM users", [1]) + cache_entry = next(iter(processor._parse_cache.values())) + + assert len(cache_entry) == 3 + + +@requires_interpreted +def test_detect_parameter_casts_called_exactly_once_with_transformer() -> None: + """Transformers should not trigger duplicate parameter-cast AST walks.""" + parameter_config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + default_execution_parameter_style=ParameterStyle.NUMERIC, + ) + + def pass_through(expression: exp.Expr, parameters: Any) -> "tuple[exp.Expr, Any]": + return expression, parameters + + config = StatementConfig( + dialect="postgres", + parameter_config=parameter_config, + enable_caching=False, + enable_parsing=True, + enable_validation=False, + statement_transformers=(pass_through,), + ) + + original_detect_parameter_casts = SQLProcessor._detect_parameter_casts + detect_parameter_casts_calls = 0 + + def count_detect_parameter_casts(expression: "exp.Expr | None") -> "dict[int, str]": + nonlocal detect_parameter_casts_calls + detect_parameter_casts_calls += 1 + return original_detect_parameter_casts(expression) + + processor = SQLProcessor(config) + + with patch.object(SQLProcessor, "_detect_parameter_casts", staticmethod(count_detect_parameter_casts)): + processor.compile("SELECT $1::text FROM users", [1]) + + assert detect_parameter_casts_calls == 1 + + def test_parsing_enabled_optimization( basic_statement_config: "StatementConfig", sample_sql_queries: "dict[str, str]" ) -> None: @@ -1090,3 +1182,52 @@ def test_compile_with_pipeline_passes_expression() -> None: _, kwargs = mock_compile.call_args assert kwargs["expression"] is expression + + +def test_c11_mypyc_native_class_pass(basic_statement_config: "StatementConfig") -> None: + """Verify c11 native-class optimization invariants and compiler behavior.""" + for cls in (OperationProfile, CompiledSQL, SQLProcessor): + source = inspect.getsource(cls) + assert source.startswith("@final\n@mypyc_attr(allow_interpreted_subclasses=False)\n") + + for name in ( + "_normalize_expression_override", + "_unpack_parse_cache_entry", + "_detect_operation_type", + "_detect_parameter_casts", + "_build_operation_profile", + "_make_parse_cache_key", + ): + assert isinstance(SQLProcessor.__dict__.get(name), staticmethod) + + source = inspect.getsource(SQLProcessor._compile_uncached) + assert "dialect_str = str(self._config.dialect)" not in source + + processor = SQLProcessor(basic_statement_config) + select_result = processor.compile("SELECT * FROM users WHERE id = ?", (42,)) + assert select_result.operation_type == "SELECT" + assert select_result.operation_profile.returns_rows is True + + insert_result = processor.compile("INSERT INTO users (name) VALUES (?)", ("alice",)) + assert insert_result.operation_type == "INSERT" + assert insert_result.operation_profile.modifies_rows is True + + ddl_result = processor.compile("CREATE TABLE test (id INTEGER PRIMARY KEY)", None) + assert ddl_result.operation_type == "DDL" + + result1 = processor.compile("SELECT id FROM users WHERE id = ?", (1,)) + result2 = processor.compile("SELECT id FROM users WHERE id = ?", (2,)) + assert result1.compiled_sql == result2.compiled_sql + assert processor._cache_hits >= 1 + + assert SQLProcessor._detect_operation_type(sqlglot.parse_one("SELECT 1")) == "SELECT" + assert SQLProcessor._make_parse_cache_key("SELECT 1", None) == ("default", "SELECT 1") + assert SQLProcessor._make_parse_cache_key("SELECT 1", "postgres") == ("postgres", "SELECT 1") + assert SQLProcessor._normalize_expression_override(None, "SELECT 1", "SELECT 1") is None + + profile = SQLProcessor._build_operation_profile(sqlglot.parse_one("SELECT 1"), "SELECT") + assert profile.returns_rows is True + assert profile.modifies_rows is False + + casts = SQLProcessor._detect_parameter_casts(sqlglot.parse_one("SELECT ?")) + assert isinstance(casts, dict) diff --git a/tests/unit/core/test_filters.py b/tests/unit/core/test_filters.py index 109eff8af..d16e4ba51 100644 --- a/tests/unit/core/test_filters.py +++ b/tests/unit/core/test_filters.py @@ -552,6 +552,100 @@ def test_not_null_filter_cache_key_uniqueness() -> None: assert key1 == key3 +def _where_condition(statement: SQL) -> exp.Expression: + expr = statement.expression + assert isinstance(expr, exp.Select) + where_expr = expr.args.get("where") + assert isinstance(where_expr, exp.Where) + condition = where_expr.this + assert isinstance(condition, exp.Expression) + return condition + + +def test_before_after_filter_col_nodes_are_independent() -> None: + filter_obj = BeforeAfterFilter("created_at", before=datetime(2023, 12, 31), after=datetime(2023, 1, 1)) + + condition = _where_condition(apply_filter(SQL("SELECT * FROM items"), filter_obj)) + + assert isinstance(condition, exp.And) + lt_cond = condition.this + gt_cond = condition.expression + assert isinstance(lt_cond, exp.LT) + assert isinstance(gt_cond, exp.GT) + lt_col = lt_cond.this + gt_col = gt_cond.this + + assert lt_col is not gt_col + lt_col.replace(exp.column("x")) + assert lt_cond.this.sql() == "x" + assert gt_cond.this.sql() == "created_at" + + +def test_on_before_after_filter_col_nodes_are_independent() -> None: + from sqlspec.core.filters import OnBeforeAfterFilter + + filter_obj = OnBeforeAfterFilter( + "created_at", on_or_before=datetime(2023, 12, 31), on_or_after=datetime(2023, 1, 1) + ) + + condition = _where_condition(apply_filter(SQL("SELECT * FROM items"), filter_obj)) + + assert isinstance(condition, exp.And) + lte_cond = condition.this + gte_cond = condition.expression + assert isinstance(lte_cond, exp.LTE) + assert isinstance(gte_cond, exp.GTE) + lte_col = lte_cond.this + gte_col = gte_cond.this + + assert lte_col is not gte_col + lte_col.replace(exp.column("x")) + assert lte_cond.this.sql() == "x" + assert gte_cond.this.sql() == "created_at" + + +def test_search_filter_set_case_placeholder_nodes_are_independent() -> None: + filter_obj = SearchFilter(field_name={"name", "description"}, value="oracle") + + condition = _where_condition(apply_filter(SQL("SELECT * FROM items"), filter_obj)) + + assert isinstance(condition, exp.Or) + left_like = condition.this + right_like = condition.expression + assert isinstance(left_like, exp.Like) + assert isinstance(right_like, exp.Like) + left_placeholder = left_like.expression + right_placeholder = right_like.expression + + assert left_placeholder is not right_placeholder + left_placeholder.replace(exp.Placeholder(this="replacement")) + assert left_like.expression.sql() == ":replacement" + assert right_like.expression.sql() == ":search_value" + + +def test_not_in_search_filter_set_case_placeholder_nodes_are_independent() -> None: + filter_obj = NotInSearchFilter(field_name={"name", "description"}, value="oracle") + + condition = _where_condition(apply_filter(SQL("SELECT * FROM items"), filter_obj)) + + assert isinstance(condition, exp.And) + left_not = condition.this + right_not = condition.expression + assert isinstance(left_not, exp.Not) + assert isinstance(right_not, exp.Not) + left_like = left_not.this + right_like = right_not.this + assert isinstance(left_like, exp.Like) + assert isinstance(right_like, exp.Like) + left_placeholder = left_like.expression + right_placeholder = right_like.expression + + assert left_placeholder is not right_placeholder + left_placeholder.replace(exp.Placeholder(this="replacement")) + assert left_like.expression.sql() == ":replacement" + assert right_like.expression.sql() == ":not_search_value" + + # --- Bug #379 regression tests: multi-field placeholder independence --- diff --git a/tests/unit/core/test_fingerprint_call_count.py b/tests/unit/core/test_fingerprint_call_count.py new file mode 100644 index 000000000..051f55691 --- /dev/null +++ b/tests/unit/core/test_fingerprint_call_count.py @@ -0,0 +1,62 @@ +# pyright: reportPrivateUsage = false +"""Regression tests for structural fingerprint reuse in the shared pipeline.""" + +from typing import Any +from unittest.mock import patch + +import pytest + +from sqlspec.core.parameters._processor import structural_fingerprint, value_fingerprint +from sqlspec.core.parameters._types import ParameterStyle, ParameterStyleConfig +from sqlspec.core.pipeline import reset_statement_pipeline_cache +from sqlspec.core.statement import SQL, StatementConfig + + +@pytest.fixture(autouse=True) +def reset_pipeline() -> None: + reset_statement_pipeline_cache() + + +def _make_config(*, needs_static_script_compilation: bool = False) -> StatementConfig: + return StatementConfig( + parameter_config=ParameterStyleConfig( + default_parameter_style=ParameterStyle.NAMED_COLON, + supported_parameter_styles={ParameterStyle.NAMED_COLON}, + default_execution_parameter_style=ParameterStyle.NAMED_COLON, + supported_execution_parameter_styles={ParameterStyle.NAMED_COLON}, + needs_static_script_compilation=needs_static_script_compilation, + ) + ) + + +def test_structural_fingerprint_called_once_per_fresh_compile() -> None: + config = _make_config() + statement = SQL("SELECT * FROM users WHERE id = :id", statement_config=config, id=1) + call_count = 0 + + def counting_fingerprint(parameters: Any, is_many: bool = False) -> Any: + nonlocal call_count + call_count += 1 + return structural_fingerprint(parameters, is_many=is_many) + + with patch("sqlspec.core.statement.structural_fingerprint", side_effect=counting_fingerprint): + with patch("sqlspec.core.compiler.structural_fingerprint", side_effect=counting_fingerprint): + statement.compile() + + assert call_count == 1 + + +def test_static_script_path_does_not_forward_structural_fingerprint() -> None: + config = _make_config(needs_static_script_compilation=True) + statement = SQL("SELECT :id", statement_config=config, id=1) + value_fingerprint_calls = 0 + + def counting_value_fingerprint(parameters: Any) -> Any: + nonlocal value_fingerprint_calls + value_fingerprint_calls += 1 + return value_fingerprint(parameters) + + with patch("sqlspec.core.compiler.value_fingerprint", side_effect=counting_value_fingerprint): + statement.compile() + + assert value_fingerprint_calls >= 1 diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index fa67f8a23..9caa7df04 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -10,8 +10,9 @@ import json import math +import warnings from collections.abc import Callable, Sequence -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal from importlib import import_module from typing import Any @@ -26,6 +27,7 @@ ParameterConverter, ParameterInfo, ParameterProcessor, + ParameterProfile, ParameterStyle, ParameterStyleConfig, ParameterValidator, @@ -38,6 +40,7 @@ replace_placeholders_with_literals, wrap_with_type, ) +from sqlspec.core.parameters import _converter as _converter_module # Detect whether the core parameters module is mypyc-compiled. # When compiled, `patch.object` on C-extension classes is a no-op, @@ -48,6 +51,7 @@ from sqlspec.utils.serializers import from_json, to_json _VALIDATOR_COMPILED = (_validator_module.__file__ or "").endswith((".so", ".pyd")) +_CONVERTER_COMPILED = (_converter_module.__file__ or "").endswith((".so", ".pyd")) _ADAPTER_DRIVER_MODULES: "tuple[str, ...]" = ( "sqlspec.adapters.adbc.driver", @@ -380,9 +384,10 @@ def test_replace_null_parameters_with_literals_numeric_dialect() -> None: """Null parameters should render as literals and shrink parameter list.""" expression = sqlglot.parse_one("INSERT INTO test VALUES ($1, $2)", dialect="postgres") - modified_expression, cleaned_params = replace_null_parameters_with_literals( - expression, (42, None), dialect="postgres" - ) + with pytest.warns(DeprecationWarning, match="round-trip AST serialization|parameter_profile"): + modified_expression, cleaned_params = replace_null_parameters_with_literals( + expression, (42, None), dialect="postgres" + ) assert modified_expression.sql(dialect="postgres") == "INSERT INTO test VALUES ($1, NULL)" assert cleaned_params == (42,) @@ -392,9 +397,10 @@ def test_replace_null_parameters_with_literals_named_bigquery() -> None: """Named parameters should inline NULL literals and drop None values.""" expression = sqlglot.parse_one("INSERT INTO test VALUES (@id, @desc)", dialect="bigquery") - modified_expression, cleaned_params = replace_null_parameters_with_literals( - expression, {"id": 1, "desc": None}, dialect="bigquery" - ) + with pytest.warns(DeprecationWarning, match="round-trip AST serialization|parameter_profile"): + modified_expression, cleaned_params = replace_null_parameters_with_literals( + expression, {"id": 1, "desc": None}, dialect="bigquery" + ) assert modified_expression.sql(dialect="bigquery") == "INSERT INTO test VALUES (@id, NULL)" assert cleaned_params == {"id": 1} @@ -783,15 +789,69 @@ def output_transformer(sql: str, params: Any) -> tuple[str, Any]: assert config.strict_named_parameters is False -def test_parameter_style_config_hash() -> None: - """Test ParameterStyleConfig hash method for caching.""" +def test_parameter_style_config_hash_value() -> None: + """Test ParameterStyleConfig built-in hash for caching.""" config1 = ParameterStyleConfig(ParameterStyle.QMARK) config2 = ParameterStyleConfig(ParameterStyle.QMARK) config3 = ParameterStyleConfig(ParameterStyle.NAMED_COLON) - assert config1.hash() == config2.hash() + assert hash(config1) == hash(config2) + + assert hash(config1) != hash(config3) + + +def test_parameter_style_config_hash_differs_for_different_output_transformers() -> None: + """Test different output transformer callables produce different hashes.""" + + def transformer_a(sql: str, params: Any) -> tuple[str, Any]: + return sql, params + + def transformer_b(sql: str, params: Any) -> tuple[str, Any]: + return sql.upper(), params + + config_a = ParameterStyleConfig(ParameterStyle.QMARK, output_transformer=transformer_a) + config_b = ParameterStyleConfig(ParameterStyle.QMARK, output_transformer=transformer_b) + + assert hash(config_a) != hash(config_b) + + +def test_parameter_style_config_hash_same_for_same_output_transformer_reference() -> None: + """Test shared output transformer references produce the same hash.""" + + def transformer(sql: str, params: Any) -> tuple[str, Any]: + return sql, params + + config_a = ParameterStyleConfig(ParameterStyle.QMARK, output_transformer=transformer) + config_b = ParameterStyleConfig(ParameterStyle.QMARK, output_transformer=transformer) - assert config1.hash() != config3.hash() + assert hash(config_a) == hash(config_b) + + +def test_parameter_style_config_hash_differs_for_different_ast_transformers() -> None: + """Test different AST transformer callables produce different hashes.""" + + def transformer_a(expr: Any, params: Any, profile: Any) -> tuple[Any, Any]: + return expr, params + + def transformer_b(expr: Any, params: Any, profile: Any) -> tuple[Any, Any]: + return expr, params + + config_a = ParameterStyleConfig(ParameterStyle.QMARK, ast_transformer=transformer_a) + config_b = ParameterStyleConfig(ParameterStyle.QMARK, ast_transformer=transformer_b) + + assert hash(config_a) != hash(config_b) + + +def test_parameter_style_config_hash_same_for_same_ast_transformer_reference() -> None: + """Test shared AST transformer references produce the same hash.""" + + def transformer(expr: Any, params: Any, profile: Any) -> tuple[Any, Any]: + return expr, params + + config_a = ParameterStyleConfig(ParameterStyle.QMARK, ast_transformer=transformer) + config_b = ParameterStyleConfig(ParameterStyle.QMARK, ast_transformer=transformer) + + assert hash(config_a) == hash(config_b) @pytest.fixture @@ -1204,6 +1264,71 @@ class MyInt(int): assert result.parameters == [5] +def test_coerce_parameter_types_returns_decimal_scalar(processor: "ParameterProcessor") -> None: + """Test scalar type coercion results are preserved.""" + result = processor._coerce_parameter_types(Decimal("3.14"), {Decimal: float}) # pyright: ignore[reportPrivateUsage] + + assert result == 3.14 + assert isinstance(result, float) + + +def test_coerce_parameter_types_returns_custom_scalar(processor: "ParameterProcessor") -> None: + """Test custom scalar coercion results are preserved.""" + + class CustomValue: + pass + + result = processor._coerce_parameter_types( # pyright: ignore[reportPrivateUsage] + CustomValue(), {CustomValue: lambda value: "coerced"} + ) + + assert result == "coerced" + + +def test_process_type_coercion_preserves_scalar_parameter(processor: "ParameterProcessor") -> None: + """Test process() does not discard a coerced scalar parameter.""" + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.QMARK, + supported_execution_parameter_styles={ParameterStyle.QMARK}, + default_execution_parameter_style=ParameterStyle.QMARK, + type_coercion_map={Decimal: float}, + ) + + result = processor.process("SELECT * FROM prices WHERE amount = ?", Decimal("19.99"), config, wrap_types=False) + + assert result.parameters == [19.99] + + +def test_replace_null_parameters_without_profile_emits_deprecation_warning() -> None: + """No-profile NULL pruning warns about the fallback AST round trip.""" + expression = sqlglot.parse_one("SELECT * FROM t WHERE id = $1", dialect="postgres") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + replace_null_parameters_with_literals(expression, (42,), dialect="postgres") + + deprecation_warnings = [warning for warning in caught if issubclass(warning.category, DeprecationWarning)] + assert deprecation_warnings + assert any( + "round-trip" in str(warning.message).lower() or "parameter_profile" in str(warning.message) + for warning in deprecation_warnings + ) + + +def test_replace_null_parameters_with_profile_emits_no_warning() -> None: + """Supplying parameter_profile avoids the deprecated fallback path.""" + validator = ParameterValidator() + expression = sqlglot.parse_one("SELECT * FROM t WHERE id = $1", dialect="postgres") + profile = ParameterProfile(validator.extract_parameters(expression.sql(dialect="postgres"))) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + replace_null_parameters_with_literals(expression, (42,), dialect="postgres", parameter_profile=profile) + + deprecation_warnings = [warning for warning in caught if issubclass(warning.category, DeprecationWarning)] + assert not deprecation_warnings + + def test_resolve_type_coercion_supports_virtual_abc_fallback() -> None: """ABC-registered coercions should still resolve for builtin sequence payloads.""" type_map: dict[type, Callable[[Any], Any]] = {Sequence: lambda value: tuple(value)} # type: ignore[dict-item] @@ -1383,8 +1508,8 @@ def test_processor_cache_key_uses_shared_tuple_helper() -> None: ) -def test_singledispatch_type_wrapping_performance() -> None: - """Test that singledispatch provides efficient type-based dispatch.""" +def test_isinstance_type_wrapping() -> None: + """Test that isinstance dispatch correctly wraps typed parameters.""" bool_result = wrap_with_type(True) decimal_result = wrap_with_type(Decimal("123.45")) @@ -1397,6 +1522,67 @@ def test_singledispatch_type_wrapping_performance() -> None: assert not isinstance(string_result, TypedParameter) +def test_parameter_style_constants_are_module_frozensets() -> None: + """Style membership constants are hoisted for compiled hot paths.""" + from sqlspec.core.parameters import _types + + named_styles = getattr(_types, "_NAMED_STYLES", None) + positional_styles = getattr(_types, "_POSITIONAL_STYLES", None) + named_style_values = getattr(_types, "_NAMED_STYLE_VALUES", None) + positional_style_values = getattr(_types, "_POSITIONAL_STYLE_VALUES", None) + + assert isinstance(named_styles, frozenset) + assert isinstance(positional_styles, frozenset) + assert isinstance(named_style_values, frozenset) + assert isinstance(positional_style_values, frozenset) + assert named_styles == frozenset({ + ParameterStyle.NAMED_COLON, + ParameterStyle.NAMED_AT, + ParameterStyle.NAMED_DOLLAR, + ParameterStyle.NAMED_PYFORMAT, + }) + assert positional_styles == frozenset({ + ParameterStyle.QMARK, + ParameterStyle.NUMERIC, + ParameterStyle.POSITIONAL_COLON, + ParameterStyle.POSITIONAL_PYFORMAT, + }) + assert named_style_values == frozenset(style.value for style in named_styles) + assert positional_style_values == frozenset(style.value for style in positional_styles) + + +def test_wrap_with_type_isinstance_chain_semantics() -> None: + """The plain dispatch chain preserves prior type-wrapping semantics.""" + bool_result = wrap_with_type(False) + decimal_result = wrap_with_type(Decimal("3.14")) + datetime_result = wrap_with_type(datetime(2024, 1, 1, 12, 30)) + date_result = wrap_with_type(date(2024, 1, 1)) + time_result = wrap_with_type(time(12, 30)) + bytes_result = wrap_with_type(b"hello") + + assert isinstance(bool_result, TypedParameter) + assert bool_result.original_type is bool + assert isinstance(decimal_result, TypedParameter) + assert decimal_result.original_type is Decimal + assert isinstance(datetime_result, TypedParameter) + assert datetime_result.original_type is datetime + assert isinstance(date_result, TypedParameter) + assert date_result.original_type is date + assert isinstance(time_result, TypedParameter) + assert time_result.original_type is time + assert isinstance(bytes_result, TypedParameter) + assert bytes_result.original_type is bytes + assert wrap_with_type(42) == 42 + + +def test_parameter_style_config_hash_wrapper_removed() -> None: + """ParameterStyleConfig only exposes __hash__, not a wrapper method.""" + config = ParameterStyleConfig(default_parameter_style=ParameterStyle.NAMED_COLON) + + assert not hasattr(config, "hash") + assert isinstance(hash(config), int) + + def test_hash_map_placeholder_generation() -> None: """Test that placeholder generation uses O(1) hash map lookups.""" converter = ParameterConverter() @@ -1950,7 +2136,8 @@ def test_converted_parameters_transformers_null_pruning(processor: ParameterProc expression = sqlglot.parse_one(sql, dialect="postgres") # Test null pruning - _transformed_expr, transformed_params = replace_null_parameters_with_literals(expression, parameters) + with pytest.warns(DeprecationWarning, match="round-trip AST serialization|parameter_profile"): + _transformed_expr, transformed_params = replace_null_parameters_with_literals(expression, parameters) # Should return concrete type - list or tuple assert transformed_params is not None @@ -2548,3 +2735,84 @@ def test_duplicate_named_params_cache_hit_embedding_query( result2 = processor.process(sql, params, config) assert len(result2.parameters) == 3, f"Cache hit produced {len(result2.parameters)} params, expected 3" + + +class _CountingParameterConverter(ParameterConverter): + """ParameterConverter that records conversion-plan builds for regression tests.""" + + def __init__(self) -> None: + super().__init__() + self.build_conversion_plan_calls = 0 + + def _build_conversion_plan( + self, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + ) -> "tuple[list[ParameterInfo], dict[str, int]]": + self.build_conversion_plan_calls += 1 + return super()._build_conversion_plan(param_info, target_style) # pyright: ignore[reportPrivateUsage] + + +def _counting_processor() -> "tuple[ParameterProcessor, _CountingParameterConverter]": + converter = _CountingParameterConverter() + return ParameterProcessor(converter=converter, cache_max_size=0, validator_cache_max_size=0), converter + + +@pytest.mark.skipif( + _CONVERTER_COMPILED, reason="interpreted subclass cannot override mypyc-compiled ParameterConverter methods" +) +def test_build_conversion_plan_called_once_in_execution_conversion() -> None: + """Execution-style conversion should build one shared conversion plan.""" + processor, converter = _counting_processor() + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NAMED_COLON, + default_execution_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + ) + + result = processor.process("SELECT * FROM t WHERE id = :id AND name = :name", {"id": 1, "name": "alice"}, config) + + assert converter.build_conversion_plan_calls == 1 + assert result.sql == "SELECT * FROM t WHERE id = $1 AND name = $2" + + +@pytest.mark.skipif( + _CONVERTER_COMPILED, reason="interpreted subclass cannot override mypyc-compiled ParameterConverter methods" +) +def test_build_conversion_plan_called_once_in_requires_mapping_branch() -> None: + """Mapping-normalization pre-processing should build one shared conversion plan.""" + processor, converter = _counting_processor() + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NUMERIC, + default_execution_parameter_style=ParameterStyle.NAMED_PYFORMAT, + supported_parameter_styles={ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NAMED_PYFORMAT}, + ) + + result = processor.process_for_execution( + "SELECT * FROM t WHERE id = $1 AND name = $2", {"id": 42, "name": "bob"}, config + ) + + assert converter.build_conversion_plan_calls == 1 + assert result.parameters == {"1": 42, "2": "bob"} + + +@pytest.mark.skipif( + _CONVERTER_COMPILED, reason="interpreted subclass cannot override mypyc-compiled ParameterConverter methods" +) +def test_build_conversion_plan_called_once_execute_many_preserve_original_params() -> None: + """The execute_many preserve-original path should build one shared conversion plan.""" + processor, converter = _counting_processor() + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NAMED_COLON, + default_execution_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + preserve_original_params_for_many=True, + ) + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + + result = processor.process("INSERT INTO t (a, b) VALUES (:a, :b)", rows, config, is_many=True) + + assert converter.build_conversion_plan_calls == 1 + assert result.sql == "INSERT INTO t (a, b) VALUES ($1, $2)" + assert result.parameters is rows diff --git a/tests/unit/core/test_pipeline.py b/tests/unit/core/test_pipeline.py new file mode 100644 index 000000000..ce0b15aeb --- /dev/null +++ b/tests/unit/core/test_pipeline.py @@ -0,0 +1,59 @@ +# pyright: reportPrivateUsage = false +"""Unit tests for the shared statement pipeline registry.""" + +from unittest.mock import patch + +import sqlspec.core.pipeline as pipeline_module +from sqlspec.core import SQL, get_pipeline_metrics, reset_pipeline_registry +from sqlspec.core.pipeline import StatementPipelineRegistry +from sqlspec.core.statement import StatementConfig + + +def test_record_pipeline_metrics_constant_is_bool() -> None: + assert isinstance(getattr(pipeline_module, "_RECORD_PIPELINE_METRICS", None), bool) + + +def test_os_getenv_not_called_during_compile() -> None: + registry = StatementPipelineRegistry() + config = StatementConfig() + + with patch("os.getenv") as mock_getenv: + registry.compile(config, "SELECT 1", {}) + + mock_getenv.assert_not_called() + + +def test_os_getenv_not_called_during_metrics() -> None: + registry = StatementPipelineRegistry() + + with patch("os.getenv") as mock_getenv: + registry.metrics() + + mock_getenv.assert_not_called() + + +def test_record_pipeline_metrics_patch_controls_metrics_output() -> None: + with patch.object(pipeline_module, "_RECORD_PIPELINE_METRICS", True): + reset_pipeline_registry() + SQL("SELECT 42").compile() + SQL("SELECT 42").compile() + metrics = get_pipeline_metrics() + + reset_pipeline_registry() + assert metrics + assert sum(entry.get("hits", 0) for entry in metrics) >= 1 + + +def test_fingerprint_cache_uses_cached_slot_value() -> None: + registry = StatementPipelineRegistry() + config = StatementConfig() + + first_fingerprint = registry._fingerprint_config(config) + assert first_fingerprint.startswith("pipeline::") + assert config._fingerprint_cache == first_fingerprint + + with patch("hashlib.blake2b") as mock_blake2b: + second_fingerprint = registry._fingerprint_config(config) + + mock_blake2b.assert_not_called() + assert second_fingerprint == first_fingerprint diff --git a/tests/unit/core/test_result.py b/tests/unit/core/test_result.py index 5fff5c564..d8d20c8ce 100644 --- a/tests/unit/core/test_result.py +++ b/tests/unit/core/test_result.py @@ -392,6 +392,54 @@ class User: assert mocked_to_schema.call_count == 1 +def test_get_schema_row_uses_requested_row_after_get_data_cache_warm() -> None: + @dataclass + class User: + id: int + name: str + + rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + result = SQLResult(statement=SQL("SELECT id, name FROM users"), data=rows, rows_affected=2) + + cached_rows = result.get_data(schema_type=User) + row = result._get_schema_row(User, rows[1]) + + assert row == User(id=2, name="Bob") + assert row is not cached_rows[0] + + +def test_get_schema_row_uses_requested_row_after_all_cache_warm() -> None: + @dataclass + class User: + id: int + name: str + + rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + result = SQLResult(statement=SQL("SELECT id, name FROM users"), data=rows, rows_affected=2) + + cached_rows = result.all(schema_type=User) + row = result._get_schema_row(User, rows[1]) + + assert row == User(id=2, name="Bob") + assert row is not cached_rows[0] + + +def test_get_schema_row_populates_single_row_cache_when_rows_cache_is_warm() -> None: + @dataclass + class User: + id: int + name: str + + rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + result = SQLResult(statement=SQL("SELECT id, name FROM users"), data=rows, rows_affected=2) + + result.all(schema_type=User) + row = result._get_schema_row(User, rows[1]) + + assert result._schema_row_cache is not None + assert result._schema_row_cache[User] is row + + @pytest.mark.skipif(_RESULT_BASE_COMPILED, reason="patch.object cannot intercept mypyc-compiled modules") def test_sql_result_reuses_cached_single_row_schema_conversion() -> None: """Repeated single-row schema access should not re-run to_schema().""" diff --git a/tests/unit/core/test_statement.py b/tests/unit/core/test_statement.py index 58078d5ca..2d3c069e9 100644 --- a/tests/unit/core/test_statement.py +++ b/tests/unit/core/test_statement.py @@ -15,7 +15,9 @@ 8. Edge cases - Complex queries, comments, string literals """ -import os +import copy +import logging +import pickle from typing import Any from unittest.mock import MagicMock, patch @@ -102,6 +104,15 @@ def test_statement_config_replace_invalid_attribute() -> None: config.replace(invalid_attr="value") +@pytest.mark.parametrize("private_field", ["_fingerprint_cache", "_hash_cache", "_is_frozen"]) +def test_statement_config_replace_rejects_private_fields(private_field: str) -> None: + """Test StatementConfig.replace() rejects private cache fields.""" + config = StatementConfig(parameter_config=DEFAULT_PARAMETER_CONFIG) + + with pytest.raises(TypeError, match=f"{private_field!r} is not a field"): + config.replace(**{private_field: "value"}) + + def test_statement_config_hash_equality() -> None: """Test StatementConfig hash and equality methods.""" config1 = StatementConfig(parameter_config=DEFAULT_PARAMETER_CONFIG, dialect="sqlite") @@ -138,7 +149,7 @@ def test_statement_config_driver_required_attributes() -> None: assert hasattr(config.parameter_config, "supported_parameter_styles") assert hasattr(config.parameter_config, "type_coercion_map") assert hasattr(config.parameter_config, "output_transformer") - assert callable(config.parameter_config.hash) + assert isinstance(hash(config.parameter_config), int) def test_processed_state_initialization() -> None: @@ -322,8 +333,80 @@ def test_sql_initialization_with_expression() -> None: expr = exp.select("*").from_("users") stmt = SQL(expr) - assert "SELECT" in stmt._raw_sql - assert "users" in stmt._raw_sql + assert stmt._raw_sql == "" + assert stmt._raw_expression is expr + + +def test_lazy_raw_sql_property_materializes_once() -> None: + """Accessing raw_sql materializes a deferred expression once.""" + expr = exp.select("*").from_("users") + stmt = SQL(expr) + + raw_sql = stmt.raw_sql + + assert "SELECT" in raw_sql + assert "users" in raw_sql + assert stmt._raw_sql == raw_sql + assert stmt.raw_sql == raw_sql + + +@requires_interpreted +def test_lazy_raw_sql_chain_defers_expression_serialization() -> None: + """Chained expression modifiers defer SQL serialization until raw_sql access.""" + expr = exp.select("*").from_("users") + sql_call_count = 0 + original_sql = exp.Expression.sql + + def tracking_sql(self: "exp.Expression", *args: Any, **kwargs: Any) -> str: + nonlocal sql_call_count + sql_call_count += 1 + return original_sql(self, *args, **kwargs) + + with patch.object(exp.Expression, "sql", tracking_sql): + stmt = SQL(expr).where_eq("id", 1).limit(10).offset(5) + assert sql_call_count == 0 + + materialized = stmt.raw_sql + assert sql_call_count == 1 + assert "LIMIT" in materialized + assert "OFFSET" in materialized + + assert stmt.raw_sql == materialized + assert sql_call_count == 1 + + +def test_lazy_raw_sql_compile_materializes_deferred_expression() -> None: + """compile() materializes deferred raw SQL before entering the pipeline.""" + expr = exp.select("id", "name").from_("products").where("active = 1") + stmt = SQL(expr) + + compiled_sql, params = stmt.compile() + + assert "SELECT" in compiled_sql + assert "products" in compiled_sql + assert params == [] + + +def test_lazy_raw_sql_pickle_and_deepcopy_roundtrip() -> None: + """Deferred SQL objects materialize correctly for pickle and deepcopy.""" + stmt = SQL(exp.select("*").from_("items").where("price > 0")) + + restored = pickle.loads(pickle.dumps(stmt)) + copied = copy.deepcopy(stmt) + + assert restored.raw_sql == stmt.raw_sql + assert copied.raw_sql == stmt.raw_sql + assert "items" in restored.raw_sql + + +def test_lazy_raw_sql_repr_hash_and_equality_materialize_consistently() -> None: + """repr, hash, and equality use materialized raw SQL for deferred expressions.""" + stmt1 = SQL(exp.select("*").from_("users").where("active = 1")) + stmt2 = SQL(exp.select("*").from_("users").where("active = 1")) + + assert stmt1 == stmt2 + assert hash(stmt1) == hash(stmt2) + assert "users" in repr(stmt1) def test_sql_initialization_with_custom_config() -> None: @@ -388,7 +471,14 @@ def test_sql_single_pass_processing_triggered_by_sql_property() -> None: compiled_sql, params = stmt.compile() - mock_compile.assert_called_once_with(stmt._statement_config, stmt._raw_sql, [], is_many=False, expression=None) + mock_compile.assert_called_once_with( + stmt._statement_config, + stmt._raw_sql, + [], + is_many=False, + expression=None, + param_fingerprint=structural_fingerprint([], is_many=False), + ) assert compiled_sql == "SELECT * FROM users" assert params == [] @@ -408,7 +498,12 @@ def test_sql_compilation_uses_raw_expression() -> None: compiled_sql, params = stmt.compile() mock_compile.assert_called_once_with( - stmt._statement_config, stmt._raw_sql, [], is_many=False, expression=expression + stmt._statement_config, + stmt._raw_sql, + [], + is_many=False, + expression=expression, + param_fingerprint=structural_fingerprint([], is_many=False), ) assert compiled_sql == "SELECT * FROM users" assert params == [] @@ -952,7 +1047,7 @@ def test_sql_where_method_compatibility() -> None: assert "WHERE" not in stmt._raw_sql - assert "WHERE" in where_stmt._raw_sql or "id > 10" in where_stmt._raw_sql + assert "WHERE" in where_stmt.raw_sql or "id > 10" in where_stmt.raw_sql def test_sql_where_method_with_expression() -> None: @@ -964,7 +1059,7 @@ def test_sql_where_method_with_expression() -> None: assert where_stmt is not stmt - assert where_stmt._raw_sql != stmt._raw_sql + assert where_stmt.raw_sql != stmt.raw_sql def test_sql_order_by_method_compatibility() -> None: @@ -976,7 +1071,7 @@ def test_sql_order_by_method_compatibility() -> None: assert "ORDER BY" not in stmt._raw_sql - assert "ORDER BY" in order_stmt._raw_sql or "id" in order_stmt._raw_sql + assert "ORDER BY" in order_stmt.raw_sql or "id" in order_stmt.raw_sql def test_sql_order_by_method_with_expression() -> None: @@ -987,7 +1082,7 @@ def test_sql_order_by_method_with_expression() -> None: assert order_stmt is not stmt - assert order_stmt._raw_sql != stmt._raw_sql + assert order_stmt.raw_sql != stmt.raw_sql def test_sql_order_by_method_without_parsing() -> None: @@ -997,7 +1092,7 @@ def test_sql_order_by_method_without_parsing() -> None: order_stmt = stmt.order_by("id DESC") - assert "ORDER BY" in order_stmt._raw_sql or "DESC" in order_stmt._raw_sql + assert "ORDER BY" in order_stmt.raw_sql or "DESC" in order_stmt.raw_sql def test_sql_filters_property_compatibility() -> None: @@ -1057,6 +1152,47 @@ def test_sql_has_errors_property_compatibility() -> None: assert stmt.has_errors is True +def test_sql_has_errors_does_not_copy_validation_errors() -> None: + """Test SQL.has_errors reads validation state without copying.""" + + class CopyTrackingList(list[str]): + def __init__(self, values: list[str]) -> None: + super().__init__(values) + self.copy_calls = 0 + + def copy(self) -> list[str]: + self.copy_calls += 1 + return super().copy() + + errors = CopyTrackingList(["boom"]) + stmt = SQL("SELECT * FROM invalid") + stmt._processed_state = ProcessedState(compiled_sql="", execution_parameters=None, validation_errors=errors) + + assert stmt.has_errors is True + assert errors.copy_calls == 0 + + +def test_handle_compile_failure_no_stderr(capfd: pytest.CaptureFixture[str]) -> None: + """Test compile failure fallback does not write directly to stderr.""" + stmt = SQL("SELECT 1") + + state = stmt._handle_compile_failure(RuntimeError("simulated pipeline failure")) + captured = capfd.readouterr() + + assert captured.err == "" + assert "simulated pipeline failure" in state.validation_errors + + +def test_handle_compile_failure_logs_at_debug(caplog: pytest.LogCaptureFixture) -> None: + """Test compile failure fallback logs through the statement logger.""" + stmt = SQL("SELECT 1") + + with caplog.at_level(logging.DEBUG, logger="sqlspec.core.statement"): + stmt._handle_compile_failure(RuntimeError("check debug log")) + + assert any("check debug log" in record.message for record in caplog.records) + + @requires_interpreted def test_sql_single_parse_guarantee() -> None: """Test SQL guarantees single parse operation.""" @@ -1356,10 +1492,9 @@ def test_processed_state_parameter_profile_exposed() -> None: def test_shared_pipeline_metrics_respects_debug_flag() -> None: """Shared pipeline metrics emit data only when debug flag is enabled.""" + import sqlspec.core.pipeline as pipeline_module - previous = os.environ.get("SQLSPEC_DEBUG_PIPELINE_CACHE") - os.environ["SQLSPEC_DEBUG_PIPELINE_CACHE"] = "1" - try: + with patch.object(pipeline_module, "_RECORD_PIPELINE_METRICS", True): reset_pipeline_registry() SQL("SELECT 1").compile() @@ -1368,12 +1503,8 @@ def test_shared_pipeline_metrics_respects_debug_flag() -> None: metrics = get_pipeline_metrics() total_hits = sum(entry.get("hits", 0) for entry in metrics) assert total_hits >= 1 - finally: - if previous is None: - os.environ.pop("SQLSPEC_DEBUG_PIPELINE_CACHE", None) - else: - os.environ["SQLSPEC_DEBUG_PIPELINE_CACHE"] = previous - reset_pipeline_registry() + + reset_pipeline_registry() def test_sql_builder_from_named_parameters() -> None: diff --git a/tests/unit/data_dictionary/test_registry.py b/tests/unit/data_dictionary/test_registry.py new file mode 100644 index 000000000..2887380f3 --- /dev/null +++ b/tests/unit/data_dictionary/test_registry.py @@ -0,0 +1,50 @@ +"""Regression tests for data dictionary registry loading state.""" + +from types import SimpleNamespace + +import sqlspec.data_dictionary._registry as registry + + +def test_dialects_loaded_annotation_is_bool() -> None: + """_DIALECTS_LOADED should be annotated to avoid Literal[False] narrowing.""" + assert registry.__annotations__["_DIALECTS_LOADED"] is bool + + +def test_load_default_dialects_sets_loaded_flag(monkeypatch) -> None: + """_load_default_dialects sets the loaded flag after importing dialects.""" + calls: list[str] = [] + monkeypatch.setattr(registry, "_DIALECTS_LOADED", False) + monkeypatch.setattr(registry.importlib, "import_module", lambda name: calls.append(name)) + + registry._load_default_dialects() + + assert calls == ["sqlspec.data_dictionary.dialects"] + assert registry._DIALECTS_LOADED is True + + +def test_load_default_dialects_is_idempotent(monkeypatch) -> None: + """_load_default_dialects should not import again after the flag is set.""" + monkeypatch.setattr(registry, "_DIALECTS_LOADED", True) + + def fail_import(name: str) -> None: + raise AssertionError(name) + + monkeypatch.setattr(registry.importlib, "import_module", fail_import) + + registry._load_default_dialects() + + +def test_get_dialect_config_triggers_load(monkeypatch) -> None: + """get_dialect_config calls the default dialect loader before lookup.""" + calls = 0 + config = SimpleNamespace(name="example") + monkeypatch.setitem(registry._DIALECT_CONFIGS, "example", config) + + def fake_load() -> None: + nonlocal calls + calls += 1 + + monkeypatch.setattr(registry, "_load_default_dialects", fake_load) + + assert registry.get_dialect_config("example") is config + assert calls == 1 diff --git a/tests/unit/dialects/test_spanner_ordering.py b/tests/unit/dialects/test_spanner_ordering.py new file mode 100644 index 000000000..937ce29c6 --- /dev/null +++ b/tests/unit/dialects/test_spanner_ordering.py @@ -0,0 +1,129 @@ +"""Regression tests for Spanner CREATE TABLE property ordering.""" + +from sqlglot import Dialect, exp, parse_one + +import sqlspec.dialects # noqa: F401 +from sqlspec.dialects.spanner._generators import _bq_create_transform, _is_post_schema_spanner_property + + +def _render_spanner(sql: str) -> str: + return parse_one(sql, dialect="spanner").sql(dialect="spanner") + + +def _assert_token_after_column_list(rendered: str, token: str) -> None: + assert f") {token}" in rendered + + +def test_interleave_appears_after_column_list() -> None: + rendered = _render_spanner( + """ + CREATE TABLE child ( + parent_id STRING(36), + child_id INT64, + PRIMARY KEY (parent_id, child_id) + ) INTERLEAVE IN PARENT parent_table ON DELETE CASCADE + """ + ) + + assert "INTERLEAVE IN PARENT parent_table ON DELETE CASCADE" in rendered + _assert_token_after_column_list(rendered, "INTERLEAVE IN PARENT") + + +def test_interleave_roundtrip_is_stable() -> None: + sql = """ + CREATE TABLE child ( + parent_id STRING(36), + child_id INT64, + PRIMARY KEY (parent_id, child_id) + ) INTERLEAVE IN PARENT parent_table ON DELETE CASCADE + """ + + first = _render_spanner(sql) + assert _render_spanner(first) == first + + +def test_row_deletion_policy_appears_after_column_list() -> None: + rendered = _render_spanner( + """ + CREATE TABLE logs ( + id STRING(36), + created_at TIMESTAMP, + PRIMARY KEY (id) + ) ROW DELETION POLICY (OLDER_THAN(created_at, INTERVAL 30 DAY)) + """ + ) + + assert "ROW DELETION POLICY" in rendered + _assert_token_after_column_list(rendered, "ROW DELETION POLICY") + + +def test_row_deletion_policy_roundtrip_is_stable() -> None: + sql = """ + CREATE TABLE logs ( + id STRING(36), + created_at TIMESTAMP, + PRIMARY KEY (id) + ) ROW DELETION POLICY (OLDER_THAN(created_at, INTERVAL 30 DAY)) + """ + + first = _render_spanner(sql) + assert _render_spanner(first) == first + + +def test_spanner_property_moves_after_non_spanner_property() -> None: + parsed = parse_one( + """ + CREATE TABLE child ( + parent_id STRING(36), + child_id INT64, + PRIMARY KEY (parent_id, child_id) + ) INTERLEAVE IN PARENT parent_table ON DELETE CASCADE + """, + dialect="spanner", + ) + assert isinstance(parsed, exp.Create) + properties = parsed.args["properties"] + spanner_property = properties.expressions[0] + non_spanner_property = exp.Property( + this=exp.Literal.string("enable_change_stream_capture"), value=exp.Boolean(this=True) + ) + assert _is_post_schema_spanner_property(spanner_property) + assert not _is_post_schema_spanner_property(non_spanner_property) + + properties.set("expressions", [spanner_property, non_spanner_property]) + generator = Dialect.get_or_raise("spanner").generator() + _bq_create_transform(generator, parsed) + + assert properties.expressions == [non_spanner_property, spanner_property] + + +def test_two_spanner_properties_order_preserved() -> None: + rendered = _render_spanner( + """ + CREATE TABLE child ( + parent_id STRING(36), + child_id INT64, + expires_at TIMESTAMP, + PRIMARY KEY (parent_id, child_id) + ) INTERLEAVE IN PARENT parent_table ON DELETE NO ACTION + ROW DELETION POLICY (OLDER_THAN(expires_at, INTERVAL 7 DAY)) + """ + ) + + assert rendered.find("INTERLEAVE IN PARENT") < rendered.find("ROW DELETION POLICY") + + +def test_two_spanner_properties_order_preserved_reverse_input() -> None: + rendered = _render_spanner( + """ + CREATE TABLE child ( + parent_id STRING(36), + child_id INT64, + expires_at TIMESTAMP, + PRIMARY KEY (parent_id, child_id) + ) ROW DELETION POLICY (OLDER_THAN(expires_at, INTERVAL 7 DAY)) + INTERLEAVE IN PARENT parent_table ON DELETE NO ACTION + """ + ) + + assert rendered.find("ROW DELETION POLICY") < rendered.find("INTERLEAVE IN PARENT") diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index 93e03f69f..7021a3c0d 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -6,7 +6,15 @@ import pytest -from sqlspec.core import SQL, OperationProfile, OperationType, ParameterProfile, ProcessedState +from sqlspec.core import ( + SQL, + OperationProfile, + OperationType, + ParameterInfo, + ParameterProfile, + ParameterStyle, + ProcessedState, +) from sqlspec.driver._query_cache import CachedQuery from sqlspec.exceptions import SQLSpecError @@ -17,6 +25,7 @@ def _make_cached( operation_type: OperationType = "SELECT", column_names: list[str] | None = None, operation_profile: OperationProfile | None = None, + parameter_profile: ParameterProfile | None = None, processed_state: ProcessedState | None = None, ) -> CachedQuery: if operation_profile is None: @@ -27,7 +36,7 @@ def _make_cached( ) return CachedQuery( compiled_sql=compiled_sql, - parameter_profile=ParameterProfile(), + parameter_profile=parameter_profile or ParameterProfile(), input_named_parameters=(), applied_wrap_types=False, parameter_casts={}, @@ -152,6 +161,35 @@ def test_sync_stmt_cache_execute_direct_re_raises_mapped_exception(sqlite_sync_d sqlite_sync_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) +def test_stmt_cache_prepare_direct_named_style_sets_needs_rebind(sqlite_sync_driver: Any, monkeypatch: Any) -> None: + profile = ParameterProfile([ParameterInfo("id", ParameterStyle.NAMED_COLON, 0, 0, ":id")]) + cached = _make_cached( + compiled_sql="SELECT :id", + param_count=1, + parameter_profile=profile, + processed_state=ProcessedState( + compiled_sql="SELECT :id", execution_parameters=[1], operation_type="SELECT", parameter_profile=profile + ), + ) + sqlite_sync_driver._stmt_cache.set("SELECT :id", cached) + sqlite_sync_driver._stmt_cache_enabled = True + called = False + + def _fake_rebind(params: tuple[Any, ...] | list[Any], cached_query: CachedQuery) -> tuple[Any, ...] | list[Any]: + nonlocal called + called = True + assert params == (1,) + assert cached_query is cached + return params + + monkeypatch.setattr(sqlite_sync_driver, "stmt_cache_rebind", _fake_rebind) + + prepared = sqlite_sync_driver._stmt_cache_prepare_direct("SELECT :id", (1,)) + + assert prepared is not None + assert called is True + + @pytest.mark.anyio async def test_async_execute_uses_fast_path_when_eligible(aiosqlite_async_driver: Any, monkeypatch: Any) -> None: sentinel = object() diff --git a/tests/unit/exceptions/test_wrap_exceptions.py b/tests/unit/exceptions/test_wrap_exceptions.py new file mode 100644 index 000000000..3d3aec55d --- /dev/null +++ b/tests/unit/exceptions/test_wrap_exceptions.py @@ -0,0 +1,46 @@ +"""Regression tests for wrap_exceptions.""" + +import inspect + +import pytest + +from sqlspec.exceptions import RepositoryError, SQLSpecError, wrap_exceptions + + +def test_wrap_exceptions_suppresses_single_type() -> None: + """suppress= silently swallows matching exceptions.""" + with wrap_exceptions(suppress=ValueError): + raise ValueError("suppressed") + + +def test_wrap_exceptions_suppresses_tuple_of_types() -> None: + """suppress=(, ...) silently swallows matching exceptions.""" + with wrap_exceptions(suppress=(ValueError, TypeError)): + raise TypeError("suppressed") + + +def test_wrap_exceptions_wraps_unmatched_suppressed_type() -> None: + """Non-matching exceptions are still wrapped.""" + with pytest.raises(RepositoryError): + with wrap_exceptions(suppress=ValueError): + raise RuntimeError("not suppressed") + + +def test_wrap_exceptions_sqlspec_error_passes_through() -> None: + """SQLSpecError is reraised when it is not explicitly suppressed.""" + original = SQLSpecError("already mapped") + + with pytest.raises(SQLSpecError) as exc_info: + with wrap_exceptions(): + raise original + + assert exc_info.value is original + + +def test_wrap_exceptions_does_not_guard_suppress_classinfo_shape() -> None: + """wrap_exceptions should delegate classinfo handling directly to isinstance.""" + source = inspect.getsource(wrap_exceptions) + + assert "isinstance(suppress, type)" not in source + assert "isinstance(suppress, tuple)" not in source + assert "isinstance(exc, suppress)" in source diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index cf866bef1..76025da6a 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -9,7 +9,7 @@ """ import importlib.util -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import pytest @@ -22,6 +22,7 @@ from google.genai import types from sqlspec.extensions.adk.converters import ( + compute_update_marker, event_to_record, filter_temp_state, merge_scoped_state, @@ -215,6 +216,29 @@ def test_merge_scoped_state_does_not_mutate_session_state() -> None: assert session == original +def test_merge_scoped_state_accepts_explicit_empty_app_state() -> None: + app_state: dict[str, object] = {} + merged = merge_scoped_state(session_state={"session": "value"}, app_state=app_state) + assert merged == {"session": "value"} + + +def test_merge_scoped_state_accepts_explicit_empty_user_state() -> None: + user_state: dict[str, object] = {} + merged = merge_scoped_state(session_state={"session": "value"}, user_state=user_state) + assert merged == {"session": "value"} + + +def test_compute_update_marker_normalizes_naive_datetime_to_utc() -> None: + marker = compute_update_marker(datetime(2024, 1, 15, 12, 0, 0)) + assert marker == "2024-01-15T12:00:00.000000+00:00" + + +def test_compute_update_marker_normalizes_aware_datetime_to_utc() -> None: + eastern = timezone(timedelta(hours=-5)) + marker = compute_update_marker(datetime(2024, 1, 15, 7, 0, 0, tzinfo=eastern)) + assert marker == "2024-01-15T12:00:00.000000+00:00" + + # --------------------------------------------------------------------------- # event_to_record — signature and structure # --------------------------------------------------------------------------- diff --git a/tests/unit/extensions/test_events/test_channel_nack_guard.py b/tests/unit/extensions/test_events/test_channel_nack_guard.py new file mode 100644 index 000000000..f2d8669b0 --- /dev/null +++ b/tests/unit/extensions/test_events/test_channel_nack_guard.py @@ -0,0 +1,67 @@ +# pyright: reportPrivateUsage=false, reportAttributeAccessIssue=false +"""Regression tests for event channel nack capability guards.""" + +import pytest + +from sqlspec.adapters.aiosqlite import AiosqliteConfig +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel + + +class _SyncBackend: + backend_name = "sync-test" + + def __init__(self, *, supports_sync: bool) -> None: + self.supports_sync = supports_sync + self.nacked: list[str] = [] + + def nack(self, event_id: str) -> None: + self.nacked.append(event_id) + + +class _AsyncBackend: + backend_name = "async-test" + + def __init__(self, *, supports_async: bool) -> None: + self.supports_async = supports_async + self.nacked: list[str] = [] + + async def nack(self, event_id: str) -> None: + self.nacked.append(event_id) + + +def test_sync_event_channel_nack_rejects_backend_without_sync_support(tmp_path) -> None: + channel = SyncEventChannel(SqliteConfig(connection_config={"database": str(tmp_path / "events.db")})) + channel._backend = _SyncBackend(supports_sync=False) # type: ignore[assignment] + + with pytest.raises(ImproperConfigurationError, match="does not support sync nack"): + channel.nack("event-1") + + +def test_sync_event_channel_nack_delegates_when_supported(tmp_path) -> None: + backend = _SyncBackend(supports_sync=True) + channel = SyncEventChannel(SqliteConfig(connection_config={"database": str(tmp_path / "events.db")})) + channel._backend = backend # type: ignore[assignment] + + channel.nack("event-1") + + assert backend.nacked == ["event-1"] + + +async def test_async_event_channel_nack_rejects_backend_without_async_support(tmp_path) -> None: + channel = AsyncEventChannel(AiosqliteConfig(connection_config={"database": str(tmp_path / "events.db")})) + channel._backend = _AsyncBackend(supports_async=False) # type: ignore[assignment] + + with pytest.raises(ImproperConfigurationError, match="does not support async nack"): + await channel.nack("event-1") + + +async def test_async_event_channel_nack_delegates_when_supported(tmp_path) -> None: + backend = _AsyncBackend(supports_async=True) + channel = AsyncEventChannel(AiosqliteConfig(connection_config={"database": str(tmp_path / "events.db")})) + channel._backend = backend # type: ignore[assignment] + + await channel.nack("event-1") + + assert backend.nacked == ["event-1"] diff --git a/tests/unit/extensions/test_flask/test_utils.py b/tests/unit/extensions/test_flask/test_utils.py new file mode 100644 index 000000000..2c8ec3f76 --- /dev/null +++ b/tests/unit/extensions/test_flask/test_utils.py @@ -0,0 +1,64 @@ +"""Tests for Flask extension utility helpers.""" + +from typing import Any + +import pytest + +pytest.importorskip("flask") + +from flask import Flask, g + +from sqlspec.extensions.flask import FlaskConfigState +from sqlspec.extensions.flask._utils import get_or_create_session + + +class _Driver: + def __init__(self, *, connection: Any, statement_config: Any, driver_features: dict[str, Any]) -> None: + self.connection = connection + self.statement_config = statement_config + self.driver_features = driver_features + + +class _Config: + driver_type = _Driver + driver_features = {"returning_support": True} + statement_config = object() + + +def _make_state() -> FlaskConfigState: + return FlaskConfigState( + config=_Config(), # type: ignore[arg-type] + connection_key="sqlspec_connection", + session_key="db_session", + commit_mode="manual", + extra_commit_statuses=None, + extra_rollback_statuses=None, + is_async=False, + disable_di=False, + ) + + +def test_get_or_create_session_passes_driver_features() -> None: + app = Flask(__name__) + state = _make_state() + connection = object() + + with app.app_context(): + setattr(g, state.connection_key, connection) + session = get_or_create_session(state, portal=None) + + assert session.connection is connection + assert session.statement_config is _Config.statement_config + assert session.driver_features == _Config.driver_features + + +def test_get_or_create_session_returns_cached_session() -> None: + app = Flask(__name__) + state = _make_state() + + with app.app_context(): + setattr(g, state.connection_key, object()) + first = get_or_create_session(state, portal=None) + second = get_or_create_session(state, portal=None) + + assert second is first diff --git a/tests/unit/extensions/test_litestar/test_correlation_middleware.py b/tests/unit/extensions/test_litestar/test_correlation_middleware.py index 9b610a50e..399861fa1 100644 --- a/tests/unit/extensions/test_litestar/test_correlation_middleware.py +++ b/tests/unit/extensions/test_litestar/test_correlation_middleware.py @@ -1,10 +1,33 @@ """Tests for Litestar correlation middleware behavior.""" +from collections.abc import Awaitable, Callable from typing import Any, cast -from sqlspec.extensions.litestar.plugin import CorrelationMiddleware +import pytest + +from sqlspec.extensions.litestar.plugin import CORRELATION_STATE_KEY, CorrelationMiddleware, get_sqlspec_scope_state from sqlspec.utils.correlation import CorrelationContext +pytestmark = pytest.mark.anyio + + +async def _noop_receive() -> dict[str, str]: + return {"type": "http.request"} + + +async def _noop_send(_message: Any) -> None: + return None + + +async def _call_middleware( + scope: dict[str, Any], + app: Callable[[Any, Any, Any], Awaitable[None]], + *, + headers: tuple[str, ...] = ("x-request-id",), +) -> None: + middleware = CorrelationMiddleware(app, headers=headers) + await middleware(cast("Any", scope), cast("Any", _noop_receive), cast("Any", _noop_send)) + async def test_litestar_correlation_middleware_restores_previous_correlation_id() -> None: CorrelationContext.set("outer") @@ -13,18 +36,74 @@ async def test_litestar_correlation_middleware_restores_previous_correlation_id( async def app(_scope: Any, _receive: Any, _send: Any) -> None: seen["cid"] = CorrelationContext.get() - middleware = CorrelationMiddleware(app, headers=("x-request-id",)) scope = {"type": "http", "headers": [(b"x-request-id", b"inner")]} - async def receive() -> Any: - return {"type": "http.request"} + try: + await _call_middleware(scope, app) + assert seen["cid"] == "inner" + assert CorrelationContext.get() == "outer" + finally: + CorrelationContext.clear() + + +def test_correlation_middleware_slots_include_cached_extractor() -> None: + assert CorrelationMiddleware.__slots__ == ("_app", "_extractor", "_headers") + + +async def test_litestar_correlation_middleware_sanitizes_header_value() -> None: + seen: dict[str, Any] = {} + raw_value = f" {'x' * 200} " + + async def app(scope: Any, _receive: Any, _send: Any) -> None: + seen["cid"] = CorrelationContext.get() + seen["scope_cid"] = get_sqlspec_scope_state(scope, CORRELATION_STATE_KEY) + + scope = {"type": "http", "headers": [(b"x-request-id", raw_value.encode())]} + + await _call_middleware(scope, app) + + assert seen["cid"] == "x" * 128 + assert seen["scope_cid"] == "x" * 128 + assert get_sqlspec_scope_state(scope, CORRELATION_STATE_KEY) is None + - async def send(_message: Any) -> None: - return None +async def test_litestar_correlation_middleware_uses_additional_header() -> None: + seen: dict[str, Any] = {} + + async def app(_scope: Any, _receive: Any, _send: Any) -> None: + seen["cid"] = CorrelationContext.get() + + scope = {"type": "http", "headers": [(b"x-correlation-id", b"secondary")]} + + await _call_middleware(scope, app, headers=("x-request-id", "x-correlation-id")) + + assert seen["cid"] == "secondary" + + +async def test_litestar_correlation_middleware_generates_id_when_headers_missing() -> None: + seen: dict[str, Any] = {} + + async def app(_scope: Any, _receive: Any, _send: Any) -> None: + seen["cid"] = CorrelationContext.get() + + scope = {"type": "http", "headers": []} + + await _call_middleware(scope, app) + + assert isinstance(seen["cid"], str) + assert seen["cid"] + + +async def test_litestar_correlation_middleware_skips_non_http_scope() -> None: + CorrelationContext.set("outer") + seen: dict[str, Any] = {} + + async def app(_scope: Any, _receive: Any, _send: Any) -> None: + seen["cid"] = CorrelationContext.get() try: - await middleware(cast("Any", scope), cast("Any", receive), cast("Any", send)) - assert seen["cid"] == "inner" + await _call_middleware({"type": "websocket", "headers": [(b"x-request-id", b"inner")]}, app) + assert seen["cid"] == "outer" assert CorrelationContext.get() == "outer" finally: CorrelationContext.clear() diff --git a/tests/unit/extensions/test_litestar/test_handlers.py b/tests/unit/extensions/test_litestar/test_handlers.py index a9b6bff9b..a4a71eb51 100644 --- a/tests/unit/extensions/test_litestar/test_handlers.py +++ b/tests/unit/extensions/test_litestar/test_handlers.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: from litestar.types import Message, Scope +pytestmark = pytest.mark.anyio + async def test_async_manual_handler_closes_connection() -> None: """Test async manual handler closes connection on terminus event.""" diff --git a/tests/unit/extensions/test_litestar/test_on_app_init_middleware.py b/tests/unit/extensions/test_litestar/test_on_app_init_middleware.py new file mode 100644 index 000000000..3e684b8ac --- /dev/null +++ b/tests/unit/extensions/test_litestar/test_on_app_init_middleware.py @@ -0,0 +1,76 @@ +"""Tests for SQLSpecPlugin Litestar middleware assembly.""" + +from litestar.config.app import AppConfig +from litestar.middleware import DefineMiddleware + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.core import StatementConfig +from sqlspec.extensions.litestar.plugin import CorrelationMiddleware, SQLCommenterMiddleware, SQLSpecPlugin + + +class _ExistingMiddleware: + pass + + +def _build_plugin(*, correlation: bool = False, sqlcommenter: bool = False) -> SQLSpecPlugin: + sqlspec = SQLSpec() + sqlspec.add_config( + AiosqliteConfig( + connection_config={"database": ":memory:"}, + statement_config=StatementConfig(enable_sqlcommenter=sqlcommenter), + extension_config={ + "litestar": { + "enable_correlation_middleware": correlation, + "enable_sqlcommenter_middleware": sqlcommenter, + } + }, + ) + ) + return SQLSpecPlugin(sqlspec=sqlspec) + + +def _middleware_types(app_config: AppConfig) -> list[type[object]]: + return [middleware.middleware for middleware in app_config.middleware or []] + + +def test_on_app_init_appends_both_middlewares_when_enabled() -> None: + app_config = AppConfig() + + _build_plugin(correlation=True, sqlcommenter=True).on_app_init(app_config) + + assert _middleware_types(app_config)[-2:] == [CorrelationMiddleware, SQLCommenterMiddleware] + + +def test_on_app_init_appends_only_correlation_middleware() -> None: + app_config = AppConfig() + + _build_plugin(correlation=True, sqlcommenter=False).on_app_init(app_config) + + assert _middleware_types(app_config) == [CorrelationMiddleware] + + +def test_on_app_init_appends_only_sqlcommenter_middleware() -> None: + app_config = AppConfig() + + _build_plugin(correlation=False, sqlcommenter=True).on_app_init(app_config) + + assert _middleware_types(app_config) == [SQLCommenterMiddleware] + + +def test_on_app_init_appends_no_observability_middlewares_when_disabled() -> None: + app_config = AppConfig() + + _build_plugin(correlation=False, sqlcommenter=False).on_app_init(app_config) + + assert app_config.middleware == [] + + +def test_on_app_init_preserves_existing_middlewares() -> None: + existing = DefineMiddleware(_ExistingMiddleware) + app_config = AppConfig(middleware=[existing]) + + _build_plugin(correlation=True, sqlcommenter=True).on_app_init(app_config) + + assert app_config.middleware[0] is existing + assert _middleware_types(app_config)[1:] == [CorrelationMiddleware, SQLCommenterMiddleware] diff --git a/tests/unit/extensions/test_litestar/test_raise_missing_connection.py b/tests/unit/extensions/test_litestar/test_raise_missing_connection.py new file mode 100644 index 000000000..25d5ab792 --- /dev/null +++ b/tests/unit/extensions/test_litestar/test_raise_missing_connection.py @@ -0,0 +1,73 @@ +"""Regression tests for Litestar missing-connection control flow.""" + +from typing import NoReturn + +from litestar.config.app import AppConfig + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.litestar.plugin import SQLSpecPlugin + + +def _build_plugin() -> SQLSpecPlugin: + sqlspec = SQLSpec() + sqlspec.add_config(AiosqliteConfig(connection_config={"database": ":memory:"})) + return SQLSpecPlugin(sqlspec=sqlspec) + + +def test_raise_missing_connection_always_raises() -> None: + plugin = _build_plugin() + + try: + plugin._raise_missing_connection("db_connection") # pyright: ignore[reportPrivateUsage] + except ImproperConfigurationError as exc: + assert "db_connection" in str(exc) + else: # pragma: no cover - defensive assertion + msg = "_raise_missing_connection should not return" + raise AssertionError(msg) + + +def test_get_plugin_state_raises_for_unknown_key() -> None: + plugin = _build_plugin() + plugin.on_app_init(AppConfig()) + + try: + plugin._get_plugin_state("unknown") # pyright: ignore[reportPrivateUsage] + except KeyError as exc: + assert "unknown" in str(exc) + else: # pragma: no cover - defensive assertion + msg = "_get_plugin_state should raise for unknown keys" + raise AssertionError(msg) + + +def test_provide_request_connection_raises_when_connection_missing() -> None: + plugin = _build_plugin() + state = AppConfig().state + scope = {"type": "http"} + + try: + plugin.provide_request_connection("db_connection", state, scope) # type: ignore[arg-type] + except ImproperConfigurationError as exc: + assert "db_connection" in str(exc) + else: # pragma: no cover - defensive assertion + msg = "provide_request_connection should raise when no connection is scoped" + raise AssertionError(msg) + + +def test_provide_request_session_raises_when_connection_missing() -> None: + plugin = _build_plugin() + state = AppConfig().state + scope = {"type": "http"} + + try: + plugin.provide_request_session("db_connection", state, scope) # type: ignore[arg-type] + except ImproperConfigurationError as exc: + assert "db_connection" in str(exc) + else: # pragma: no cover - defensive assertion + msg = "provide_request_session should raise when no connection is scoped" + raise AssertionError(msg) + + +def test_raise_missing_connection_annotation_is_noreturn() -> None: + assert SQLSpecPlugin._raise_missing_connection.__annotations__["return"] is NoReturn diff --git a/tests/unit/extensions/test_sanic/test_extractor_init.py b/tests/unit/extensions/test_sanic/test_extractor_init.py new file mode 100644 index 000000000..1c222ee58 --- /dev/null +++ b/tests/unit/extensions/test_sanic/test_extractor_init.py @@ -0,0 +1,63 @@ +"""Tests for Sanic correlation extractor lifecycle.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from sqlspec.core import CorrelationExtractor +from sqlspec.extensions.sanic import SQLSpecPlugin +from sqlspec.utils.correlation import CorrelationContext + +pytestmark = pytest.mark.anyio + + +def _make_config(*, enable_correlation: bool = False) -> MagicMock: + config = MagicMock() + config.supports_connection_pooling = True + config.extension_config = { + "sanic": { + "disable_di": True, + "enable_correlation_middleware": enable_correlation, + "enable_sqlcommenter_middleware": False, + } + } + return config + + +def _make_request(correlation_id: str) -> SimpleNamespace: + return SimpleNamespace( + app=SimpleNamespace(ctx=SimpleNamespace()), + ctx=SimpleNamespace(), + endpoint=None, + headers={"x-request-id": correlation_id}, + path="/test", + uri_template=None, + ) + + +def test_sanic_plugin_extractor_is_none_when_correlation_disabled() -> None: + plugin = SQLSpecPlugin(MagicMock(configs={"default": _make_config(enable_correlation=False)})) + + assert plugin._extractor is None # pyright: ignore[reportPrivateUsage] + + +def test_sanic_plugin_builds_extractor_when_correlation_enabled() -> None: + plugin = SQLSpecPlugin(MagicMock(configs={"default": _make_config(enable_correlation=True)})) + + assert isinstance(plugin._extractor, CorrelationExtractor) # pyright: ignore[reportPrivateUsage] + + +async def test_sanic_plugin_reuses_same_extractor_across_requests() -> None: + plugin = SQLSpecPlugin(MagicMock(configs={"default": _make_config(enable_correlation=True)})) + extractor_id = id(plugin._extractor) # pyright: ignore[reportPrivateUsage] + + try: + for index in range(5): + request = _make_request(f"request-{index}") + plugin._set_correlation_context(request) # pyright: ignore[reportPrivateUsage] + assert CorrelationContext.get() == f"request-{index}" + assert id(plugin._extractor) == extractor_id # pyright: ignore[reportPrivateUsage] + plugin._restore_correlation_context(request, SimpleNamespace(headers={})) # pyright: ignore[reportPrivateUsage] + finally: + CorrelationContext.clear() diff --git a/tests/unit/extensions/test_sanic/test_lifecycle.py b/tests/unit/extensions/test_sanic/test_lifecycle.py index 70ab820e1..c34dedfdb 100644 --- a/tests/unit/extensions/test_sanic/test_lifecycle.py +++ b/tests/unit/extensions/test_sanic/test_lifecycle.py @@ -13,6 +13,8 @@ from sqlspec.extensions.sanic import SQLSpecPlugin from sqlspec.extensions.sanic._utils import has_context_value +pytestmark = pytest.mark.anyio + async def test_startup_creates_pool_on_app_context() -> None: """before_server_start should create configured pools on app.ctx.""" diff --git a/tests/unit/extensions/test_sanic/test_middleware.py b/tests/unit/extensions/test_sanic/test_middleware.py index e3eeba544..785f182f2 100644 --- a/tests/unit/extensions/test_sanic/test_middleware.py +++ b/tests/unit/extensions/test_sanic/test_middleware.py @@ -14,6 +14,8 @@ from sqlspec.extensions.sanic import SQLSpecPlugin from sqlspec.extensions.sanic._utils import has_context_value +pytestmark = pytest.mark.anyio + async def test_request_middleware_acquires_pooled_connection() -> None: """Request middleware should expose pooled connections on request.ctx.""" diff --git a/tests/unit/extensions/test_sanic/test_observability.py b/tests/unit/extensions/test_sanic/test_observability.py index e547f1967..ee694c344 100644 --- a/tests/unit/extensions/test_sanic/test_observability.py +++ b/tests/unit/extensions/test_sanic/test_observability.py @@ -14,6 +14,8 @@ from sqlspec.extensions.sanic import SQLSpecPlugin from sqlspec.utils.correlation import CorrelationContext +pytestmark = pytest.mark.anyio + async def test_correlation_context_is_set_and_restored() -> None: """Correlation middleware should use request headers and restore prior context.""" diff --git a/tests/unit/loader/test_cache_integration.py b/tests/unit/loader/test_cache_integration.py index 76bc79555..20618389a 100644 --- a/tests/unit/loader/test_cache_integration.py +++ b/tests/unit/loader/test_cache_integration.py @@ -175,7 +175,7 @@ def test_cache_hit_scenario(mock_cache_setup: tuple[Mock, Mock, SQLFileLoader]) mock_cache.get_file.return_value = cached_file - with patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True): + with patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=True): loader._load_single_file(tf.name, None) mock_cache.get_file.assert_called_once() @@ -231,7 +231,7 @@ def test_cache_invalidation_on_file_change(mock_cache_setup: tuple[Mock, Mock, S mock_cache.get_file.return_value = cached_file - with patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=False): + with patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=False): loader._load_single_file(tf.name, None) mock_cache.get_file.assert_called_once() @@ -376,7 +376,7 @@ def test_cache_restoration_with_namespace(tmp_path: Path) -> None: with ( patch("sqlspec.loader.get_cache_config") as mock_config, patch("sqlspec.loader.get_cache") as mock_cache_factory, - patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True), + patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=True), ): mock_cache_config = Mock() mock_cache_config.compiled_cache_enabled = True @@ -484,7 +484,7 @@ def test_cache_sharing_between_loaders() -> None: loader1 = SQLFileLoader() shared_cache.get_file.return_value = None - with patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True): + with patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=True): loader1._load_single_file(tf.name, None) shared_cache.put_file.assert_called_once() @@ -498,7 +498,7 @@ def test_cache_sharing_between_loaders() -> None: shared_cache.reset_mock() shared_cache.get_file.return_value = cached_file - with patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True): + with patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=True): loader2._load_single_file(tf.name, None) shared_cache.get_file.assert_called_once() @@ -617,7 +617,7 @@ def test_cache_hit_performance_benefit() -> None: mock_cache.get_file.return_value = cached_file mock_cache.reset_mock() - with patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True): + with patch("sqlspec.loader.SQLFileLoader._is_file_content_unchanged", return_value=True): loader2._load_single_file(tf.name, None) assert len(loader2._queries) == 100 diff --git a/tests/unit/loader/test_get_sql_caching.py b/tests/unit/loader/test_get_sql_caching.py new file mode 100644 index 000000000..bd7808d2f --- /dev/null +++ b/tests/unit/loader/test_get_sql_caching.py @@ -0,0 +1,65 @@ +"""Regression tests for SQLFileLoader.get_sql compiled statement caching.""" + +from collections.abc import Callable + +from sqlspec.core import SQL +from sqlspec.loader import SQLFileLoader + + +def test_get_sql_compiles_once_per_statement(monkeypatch) -> None: + """Repeated get_sql calls should reuse the compiled SQL object.""" + compile_calls = 0 + original_compile: Callable[[SQL], tuple[str, object]] = SQL.compile + + def counting_compile(sql: SQL) -> tuple[str, object]: + nonlocal compile_calls + compile_calls += 1 + return original_compile(sql) + + monkeypatch.setattr(SQL, "compile", counting_compile) + + loader = SQLFileLoader() + loader.add_named_sql("find-user", "SELECT * FROM users WHERE id = :id") + + first = loader.get_sql("find-user") + second = loader.get_sql("find_user") + + assert first is second + assert compile_calls == 1 + + +def test_get_sql_compiles_each_unique_statement_once(monkeypatch) -> None: + """The compiled SQL cache is keyed by normalized statement name.""" + compile_calls = 0 + original_compile: Callable[[SQL], tuple[str, object]] = SQL.compile + + def counting_compile(sql: SQL) -> tuple[str, object]: + nonlocal compile_calls + compile_calls += 1 + return original_compile(sql) + + monkeypatch.setattr(SQL, "compile", counting_compile) + + loader = SQLFileLoader() + loader.add_named_sql("find-user", "SELECT * FROM users WHERE id = :id") + loader.add_named_sql("list-users", "SELECT * FROM users") + + loader.get_sql("find-user") + loader.get_sql("find-user") + loader.get_sql("list-users") + loader.get_sql("list-users") + + assert compile_calls == 2 + + +def test_clear_cache_clears_compiled_statements() -> None: + """clear_cache clears the compiled SQL object cache.""" + loader = SQLFileLoader() + loader.add_named_sql("find-user", "SELECT * FROM users WHERE id = :id") + + loader.get_sql("find-user") + assert loader._compiled_statements + + loader.clear_cache() + + assert loader._compiled_statements == {} diff --git a/tests/unit/loader/test_sql_file_loader.py b/tests/unit/loader/test_sql_file_loader.py index ad0319d90..3aaba1d84 100644 --- a/tests/unit/loader/test_sql_file_loader.py +++ b/tests/unit/loader/test_sql_file_loader.py @@ -50,6 +50,78 @@ def test_named_statement_no_dialect() -> None: assert stmt.start_line == 0 +def test_compute_checksum_matches_sqlfile_checksum() -> None: + """_compute_checksum must produce the same hash that SQLFile stores.""" + content = "SELECT 1;" + + assert SQLFileLoader._compute_checksum(content) == SQLFile(content=content, path="test.sql").checksum + + +def test_is_file_content_unchanged_matching() -> None: + """_is_file_content_unchanged returns True when content matches cached checksum.""" + content = "SELECT * FROM users;" + cached = SQLFileCacheEntry(sql_file=SQLFile(content=content, path="test.sql"), parsed_statements={}) + + assert SQLFileLoader()._is_file_content_unchanged(content, cached) is True + + +def test_load_file_without_cache_accepts_pre_read_content(monkeypatch, tmp_path: Path) -> None: + """_load_file_without_cache uses pre-read content without reading the file again.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("-- name: from_disk\nSELECT 0;") + loader = SQLFileLoader() + + def fail_read(self: SQLFileLoader, path: object) -> str: + msg = f"unexpected read: {path}" + raise AssertionError(msg) + + monkeypatch.setattr(SQLFileLoader, "_read_file_content", fail_read) + + loader._load_file_without_cache(sql_file, None, content="-- name: from_memory\nSELECT 1;") + + assert loader.has_query("from_memory") + + +def test_load_single_file_reads_once_on_stale_cache(monkeypatch, tmp_path: Path) -> None: + """A stale cache entry should not cause a checksum read and a parse read.""" + sql_file = tmp_path / "queries.sql" + original = "-- name: get_user\nSELECT id FROM users WHERE id = :id;" + updated = "-- name: get_user\nSELECT id, name FROM users WHERE id = :id;" + sql_file.write_text(updated) + cached = SQLFileCacheEntry( + sql_file=SQLFile(content=original, path=str(sql_file)), + parsed_statements={"get_user": NamedStatement("get_user", "SELECT id FROM users WHERE id = :id;")}, + ) + + class Cache: + def get_file(self, key: str) -> SQLFileCacheEntry: + return cached + + def put_file(self, key: str, value: object) -> None: + return None + + monkeypatch.setattr("sqlspec.loader.get_cache", Cache) + monkeypatch.setattr( + "sqlspec.loader.get_cache_config", lambda: type("Config", (), {"compiled_cache_enabled": True})() + ) + + loader = SQLFileLoader() + read_count = 0 + original_read = SQLFileLoader._read_file_content + + def counting_read(self: SQLFileLoader, path: str | Path) -> str: + nonlocal read_count + read_count += 1 + return original_read(self, path) + + monkeypatch.setattr(SQLFileLoader, "_read_file_content", counting_read) + + loader._load_single_file(sql_file, None) + + assert read_count == 1 + assert loader.get_query_text("get_user") == "SELECT id, name FROM users WHERE id = :id;" + + @requires_interpreted def test_named_statement_slots() -> None: """Test that NamedStatement uses __slots__.""" diff --git a/tests/unit/migrations/test_commands_contract.py b/tests/unit/migrations/test_commands_contract.py new file mode 100644 index 000000000..d4ff42ecc --- /dev/null +++ b/tests/unit/migrations/test_commands_contract.py @@ -0,0 +1,24 @@ +"""Contract tests for migration command signatures.""" + +import inspect + +from sqlspec.migrations.base import BaseMigrationCommands +from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands + + +def test_base_revision_file_type_default_matches_concrete_commands() -> None: + base_param = inspect.signature(BaseMigrationCommands.revision).parameters["file_type"] + sync_param = inspect.signature(SyncMigrationCommands.revision).parameters["file_type"] + async_param = inspect.signature(AsyncMigrationCommands.revision).parameters["file_type"] + + assert base_param.default is None + assert sync_param.default is None + assert async_param.default is None + + +def test_upgrade_documents_config_auto_sync_as_unoverridable() -> None: + sync_source = inspect.getsource(SyncMigrationCommands.upgrade) + async_source = inspect.getsource(AsyncMigrationCommands.upgrade) + + assert "config auto_sync=False cannot be overridden by the call-site flag" in sync_source + assert "config auto_sync=False cannot be overridden by the call-site flag" in async_source diff --git a/tests/unit/observability/test_observability_core.py b/tests/unit/observability/test_observability_core.py index b3d9675f9..80b35d000 100644 --- a/tests/unit/observability/test_observability_core.py +++ b/tests/unit/observability/test_observability_core.py @@ -603,10 +603,20 @@ def test_resolve_db_system_unknown() -> None: def test_compute_sql_hash() -> None: sql = "SELECT 1" - expected = hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16] + expected = hashlib.sha256(sql.encode("utf-8"), usedforsecurity=False).hexdigest()[:16] assert compute_sql_hash(sql) == expected +def test_compute_sql_hash_returns_16_char_hex_string() -> None: + first_hash = compute_sql_hash("SELECT 1") + second_hash = compute_sql_hash("SELECT 2") + + assert len(first_hash) == 16 + assert first_hash != second_hash + assert first_hash.islower() + assert all(character in "0123456789abcdef" for character in first_hash) + + def test_get_trace_context_without_otel(monkeypatch) -> None: real_import = builtins.__import__ diff --git a/tests/unit/observability/test_observability_sampling.py b/tests/unit/observability/test_observability_sampling.py index f587d6a9e..1958de398 100644 --- a/tests/unit/observability/test_observability_sampling.py +++ b/tests/unit/observability/test_observability_sampling.py @@ -3,6 +3,8 @@ import pytest from sqlspec.observability import AWSLogFormatter, GCPLogFormatter, ObservabilityConfig, SamplingConfig, StatementEvent +from sqlspec.observability._config import LifecycleHook as ConfigLifecycleHook +from sqlspec.observability._dispatcher import LifecycleHook as DispatcherLifecycleHook # ============================================================================= # ObservabilityConfig Sampling Tests @@ -101,6 +103,24 @@ def test_sampling_merge_partial_override() -> None: assert merged.sampling.force_sample_slow_queries_ms == 100.0 +def test_lifecycle_hook_alias_matches_dispatcher() -> None: + assert ConfigLifecycleHook is DispatcherLifecycleHook + + +def test_sampling_merge_allows_false_boolean_overrides() -> None: + base = ObservabilityConfig(sampling=SamplingConfig(sample_rate=0.5, deterministic=True, force_sample_on_error=True)) + override = ObservabilityConfig( + sampling=SamplingConfig(sample_rate=1.0, deterministic=False, force_sample_on_error=False) + ) + + merged = ObservabilityConfig.merge(base, override) + + assert merged.sampling is not None + assert merged.sampling.sample_rate == 1.0 + assert merged.sampling.deterministic is False + assert merged.sampling.force_sample_on_error is False + + # ============================================================================= # ObservabilityConfig Sampling Equality Tests # ============================================================================= diff --git a/tests/unit/observability/test_runtime_sampling.py b/tests/unit/observability/test_runtime_sampling.py new file mode 100644 index 000000000..f1c3594ca --- /dev/null +++ b/tests/unit/observability/test_runtime_sampling.py @@ -0,0 +1,83 @@ +"""Regression tests for runtime statement-event sampling.""" + +from typing import Any, cast + +from sqlspec.observability import ObservabilityConfig, ObservabilityRuntime, SamplingConfig, StatementEvent +from sqlspec.observability._config import StatementObserver +from sqlspec.utils.correlation import CorrelationContext + + +def _emit_statement(runtime: ObservabilityRuntime, *, duration_s: float = 0.001) -> None: + runtime.emit_statement_event( + sql="SELECT 1", + parameters=None, + driver="DummyDriver", + operation="SELECT", + execution_mode="single", + is_many=False, + is_script=False, + rows_affected=1, + duration_s=duration_s, + storage_backend=None, + ) + + +def test_emit_statement_event_skips_observers_when_sampling_rejects() -> None: + observed: list[StatementEvent] = [] + runtime = ObservabilityRuntime( + ObservabilityConfig( + sampling=SamplingConfig(sample_rate=0.0, force_sample_on_error=False, force_sample_slow_queries_ms=None), + statement_observers=(cast(StatementObserver, observed.append),), + ) + ) + + _emit_statement(runtime) + + assert observed == [] + + +def test_emit_statement_event_marks_sampled_when_sampling_accepts() -> None: + observed: list[StatementEvent] = [] + runtime = ObservabilityRuntime( + ObservabilityConfig( + sampling=SamplingConfig(sample_rate=1.0), statement_observers=(cast(StatementObserver, observed.append),) + ) + ) + + _emit_statement(runtime) + + assert len(observed) == 1 + assert observed[0].sampled is True + + +def test_emit_statement_event_force_samples_slow_query() -> None: + observed: list[StatementEvent] = [] + runtime = ObservabilityRuntime( + ObservabilityConfig( + sampling=SamplingConfig(sample_rate=0.0, force_sample_slow_queries_ms=10.0), + statement_observers=(cast(StatementObserver, observed.append),), + ) + ) + + _emit_statement(runtime, duration_s=0.1) + + assert len(observed) == 1 + assert observed[0].sampled is True + + +def test_emit_statement_event_without_sampling_preserves_previous_behavior() -> None: + observed: list[dict[str, Any]] = [] + + def observer(event: StatementEvent) -> None: + observed.append(event.as_dict()) + + runtime = ObservabilityRuntime( + ObservabilityConfig(sampling=None, statement_observers=(cast(StatementObserver, observer),)) + ) + + with CorrelationContext.context("sample-correlation"): + _emit_statement(runtime) + + assert len(observed) == 1 + assert observed[0]["correlation_id"] == "sample-correlation" + assert observed[0]["sampled"] is True diff --git a/tests/unit/storage/test_obstore_backend.py b/tests/unit/storage/test_obstore_backend.py index 30b29c55a..f47224d7a 100644 --- a/tests/unit/storage/test_obstore_backend.py +++ b/tests/unit/storage/test_obstore_backend.py @@ -13,6 +13,107 @@ from sqlspec.storage.backends.obstore import ObStoreBackend +class _FakeBytes: + def to_bytes(self) -> bytes: + return b"parquet" + + +class _FakeGetResult: + def bytes(self) -> _FakeBytes: + return _FakeBytes() + + async def bytes_async(self) -> _FakeBytes: + return _FakeBytes() + + +class _FakeStore: + def __init__(self) -> None: + self.get_paths: list[str] = [] + self.put_paths: list[str] = [] + + def get(self, path: str) -> _FakeGetResult: + self.get_paths.append(path) + return _FakeGetResult() + + def put(self, path: str, data: bytes) -> None: + self.put_paths.append(path) + + async def get_async(self, path: str) -> _FakeGetResult: + self.get_paths.append(path) + return _FakeGetResult() + + async def put_async(self, path: str, data: bytes) -> None: + self.put_paths.append(path) + + +class _FakeListingStore: + def __init__(self) -> None: + self.list_paths: list[str] = [] + self.list_with_delimiter_paths: list[str] = [] + + def list(self, path: str) -> Any: + self.list_paths.append(path) + return iter([ + [{"path": "file1.txt", "size": 5, "last_modified": None, "e_tag": None, "version": None}], + [ + {"path": "subdir/file2.txt", "size": 8, "last_modified": None, "e_tag": None, "version": None}, + {"path": "subdir/file3.txt", "size": 12, "last_modified": None, "e_tag": None, "version": None}, + ], + ]) + + def list_with_delimiter(self, path: str) -> dict[str, Any]: + self.list_with_delimiter_paths.append(path) + return { + "objects": [ + {"path": "dir/file_a.txt", "size": 10, "last_modified": None, "e_tag": None, "version": None}, + {"path": "dir/file_b.txt", "size": 20, "last_modified": None, "e_tag": None, "version": None}, + ], + "common_prefixes": ["dir/subdir/"], + } + + +class _FakeHeadStore: + def __init__(self, metadata: dict[str, object] | None = None, *, missing: bool = False) -> None: + self.head_paths: list[str] = [] + self.head_async_paths: list[str] = [] + self.metadata = metadata or { + "path": "object.txt", + "size": 12, + "last_modified": None, + "e_tag": "etag", + "version": "1", + "metadata": {"owner": "sqlspec"}, + } + self.missing = missing + + def head(self, path: str) -> dict[str, object]: + self.head_paths.append(path) + if self.missing: + raise FileNotFoundError(path) + return self.metadata + + async def head_async(self, path: str) -> dict[str, object]: + self.head_async_paths.append(path) + if self.missing: + raise FileNotFoundError(path) + return self.metadata + + +class _FakeParquet: + def __init__(self, result: Any) -> None: + self.result = result + + def read_table(self, source: Any, **kwargs: Any) -> Any: + return self.result + + def write_table(self, table: Any, sink: Any, **kwargs: Any) -> None: + sink.write(b"parquet") + + +class _FakeArrowTable: + schema: tuple[Any, ...] = () + + @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") def test_init_with_file_uri(tmp_path: Path) -> None: """Test initialization with file:// URI.""" @@ -138,6 +239,38 @@ def test_list_objects(tmp_path: Path) -> None: assert any("file3.txt" in obj for obj in all_objects) +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_list_objects_sync_non_recursive_uses_list_with_delimiter(tmp_path: Path) -> None: + """Non-recursive list results are ListResult dicts, not ListStream batches.""" + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + fake_store = _FakeListingStore() + store.store = fake_store + + result = store.list_objects_sync(recursive=False) + + assert fake_store.list_with_delimiter_paths == [""] + assert fake_store.list_paths == [] + assert result == ["dir/file_a.txt", "dir/file_b.txt"] + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_list_objects_sync_recursive_uses_list_stream(tmp_path: Path) -> None: + """Recursive lists still consume the ListStream batch iterator.""" + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + fake_store = _FakeListingStore() + store.store = fake_store + + result = store.list_objects_sync(recursive=True) + + assert fake_store.list_paths == [""] + assert fake_store.list_with_delimiter_paths == [] + assert result == ["file1.txt", "subdir/file2.txt", "subdir/file3.txt"] + + @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") def test_glob(tmp_path: Path) -> None: """Test glob pattern matching.""" @@ -172,6 +305,53 @@ def test_get_metadata(tmp_path: Path) -> None: assert metadata["exists"] is True +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_get_metadata_sync_uses_local_store_path_resolution(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}", base_path="data") + fake_store = _FakeHeadStore() + store.store = fake_store + + metadata = store.get_metadata_sync("nested/object.txt") + + assert fake_store.head_paths == ["nested/object.txt"] + assert metadata["path"] == "nested/object.txt" + assert metadata["exists"] is True + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_get_metadata_sync_returns_typed_dict_keys(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + store.store = _FakeHeadStore() + + metadata = store.get_metadata_sync("object.txt") + + assert metadata == { + "path": "object.txt", + "exists": True, + "size": 12, + "last_modified": None, + "e_tag": "etag", + "version": "1", + "custom_metadata": {"owner": "sqlspec"}, + } + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_get_metadata_sync_returns_absent_result_for_missing_object(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + store.store = _FakeHeadStore(missing=True) + + metadata = store.get_metadata_sync("missing.txt") + + assert metadata == {"path": "missing.txt", "exists": False} + + @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") def test_is_object_and_is_path(tmp_path: Path) -> None: """Test is_object and is_path operations.""" @@ -208,6 +388,39 @@ def test_write_and_read_arrow(tmp_path: Path) -> None: assert result.equals(table) +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_read_arrow_sync_resolves_cloud_base_path_once(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.storage.backends import obstore as obstore_module + from sqlspec.storage.backends.obstore import ObStoreBackend + + expected_table = object() + store = ObStoreBackend("memory://", base_path="mybase") + fake_store = _FakeStore() + store.store = fake_store + monkeypatch.setattr(obstore_module, "import_pyarrow_parquet", lambda: _FakeParquet(expected_table)) + + result = store.read_arrow_sync("key") + + assert result is expected_table + assert fake_store.get_paths == ["mybase/key"] + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +def test_write_arrow_sync_resolves_cloud_base_path_once(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.storage.backends import obstore as obstore_module + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend("memory://", base_path="mybase") + fake_store = _FakeStore() + store.store = fake_store + monkeypatch.setattr(obstore_module, "import_pyarrow", lambda: object()) + monkeypatch.setattr(obstore_module, "import_pyarrow_parquet", lambda: _FakeParquet(object())) + + store.write_arrow_sync("key", _FakeArrowTable()) + + assert fake_store.put_paths == ["mybase/key"] + + @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow missing") def test_stream_arrow(tmp_path: Path) -> None: """Test stream Arrow record batches.""" @@ -380,6 +593,67 @@ async def test_async_get_metadata(tmp_path: Path) -> None: assert metadata["exists"] is True +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_get_metadata_async_uses_local_store_path_resolution(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}", base_path="data") + fake_store = _FakeHeadStore() + store.store = fake_store + + metadata = await store.get_metadata_async("nested/object.txt") + + assert fake_store.head_async_paths == ["nested/object.txt"] + assert metadata["path"] == "nested/object.txt" + assert metadata["exists"] is True + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_get_metadata_async_returns_typed_dict_keys(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + store.store = _FakeHeadStore() + + metadata = await store.get_metadata_async("object.txt") + + assert metadata == { + "path": "object.txt", + "exists": True, + "size": 12, + "last_modified": None, + "e_tag": "etag", + "version": "1", + "custom_metadata": {"owner": "sqlspec"}, + } + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_get_metadata_async_returns_absent_result_for_missing_object(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend(f"file://{tmp_path}") + store.store = _FakeHeadStore(missing=True) + + metadata = await store.get_metadata_async("missing.txt") + + assert metadata == {"path": "missing.txt", "exists": False} + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_get_metadata_sync_and_async_result_keys_match(tmp_path: Path) -> None: + from sqlspec.storage.backends.obstore import ObStoreBackend + + metadata = {"path": "object.txt", "size": 42, "last_modified": None, "e_tag": "same", "version": "7"} + store = ObStoreBackend(f"file://{tmp_path}") + store.store = _FakeHeadStore(metadata) + + sync_metadata = store.get_metadata_sync("object.txt") + async_metadata = await store.get_metadata_async("object.txt") + + assert sync_metadata == async_metadata + + @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow missing") async def test_async_write_and_read_arrow(tmp_path: Path) -> None: """Test async write and read Arrow table operations.""" @@ -403,6 +677,38 @@ async def test_async_write_and_read_arrow(tmp_path: Path) -> None: assert result.equals(table) +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_read_arrow_async_resolves_cloud_base_path_once(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.storage.backends import obstore as obstore_module + from sqlspec.storage.backends.obstore import ObStoreBackend + + expected_table = object() + store = ObStoreBackend("memory://", base_path="mybase") + fake_store = _FakeStore() + store.store = fake_store + monkeypatch.setattr(obstore_module, "import_pyarrow_parquet", lambda: _FakeParquet(expected_table)) + + result = await store.read_arrow_async("key") + + assert result is expected_table + assert fake_store.get_paths == ["mybase/key"] + + +@pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore missing") +async def test_write_arrow_async_resolves_cloud_base_path_once(monkeypatch: pytest.MonkeyPatch) -> None: + from sqlspec.storage.backends import obstore as obstore_module + from sqlspec.storage.backends.obstore import ObStoreBackend + + store = ObStoreBackend("memory://", base_path="mybase") + fake_store = _FakeStore() + store.store = fake_store + monkeypatch.setattr(obstore_module, "import_pyarrow_parquet", lambda: _FakeParquet(object())) + + await store.write_arrow_async("key", _FakeArrowTable()) + + assert fake_store.put_paths == ["mybase/key"] + + @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow missing") async def test_async_stream_arrow(tmp_path: Path) -> None: """Test async stream Arrow record batches.""" diff --git a/tests/unit/storage/test_protocol_conformance.py b/tests/unit/storage/test_protocol_conformance.py new file mode 100644 index 000000000..2fa8b4c27 --- /dev/null +++ b/tests/unit/storage/test_protocol_conformance.py @@ -0,0 +1,28 @@ +"""Regression tests for ObjectStoreProtocol stream_arrow_async conformance.""" + +import inspect +from pathlib import Path + +from sqlspec.protocols import ObjectStoreProtocol +from sqlspec.storage.backends.local import LocalStore + + +def test_protocol_stream_arrow_async_is_not_coroutine_function() -> None: + """ObjectStoreProtocol.stream_arrow_async is a plain def returning AsyncIterator.""" + assert not inspect.iscoroutinefunction(ObjectStoreProtocol.stream_arrow_async) + + +def test_local_store_satisfies_protocol(tmp_path: Path) -> None: + """LocalStore structurally satisfies ObjectStoreProtocol.""" + assert isinstance(LocalStore(base_path=tmp_path), ObjectStoreProtocol) + + +def test_local_store_stream_arrow_async_is_not_coroutine_function() -> None: + """Concrete backends should not expose stream_arrow_async as async def.""" + assert not inspect.iscoroutinefunction(LocalStore.stream_arrow_async) + + +def test_stream_arrow_async_notes_are_present() -> None: + """Source comments document why stream_arrow_async is not async def.""" + assert "Returns AsyncIterator directly" in Path("sqlspec/protocols.py").read_text() + assert "Returns AsyncIterator directly" in Path("sqlspec/storage/backends/base.py").read_text() diff --git a/tests/unit/storage/test_storage_registry.py b/tests/unit/storage/test_storage_registry.py index 011926108..6c47a53ff 100644 --- a/tests/unit/storage/test_storage_registry.py +++ b/tests/unit/storage/test_storage_registry.py @@ -164,6 +164,38 @@ def test_register_alias_with_backend_override(tmp_path: Path) -> None: assert backend.backend_type == "local" +def test_register_alias_evicts_stale_cache(tmp_path: Path) -> None: + """Re-registering an alias should evict the previously cached backend.""" + registry = StorageRegistry() + + first_root = tmp_path / "first" + second_root = tmp_path / "second" + registry.register_alias("test_store", f"file://{first_root}", backend="local") + backend_a = registry.get("test_store") + + registry.register_alias("test_store", f"file://{second_root}", backend="local") + backend_b = registry.get("test_store") + + assert backend_b is not backend_a + assert str(backend_b.base_path) == str(second_root.resolve()) + + +def test_register_alias_evicts_stale_cache_with_kwargs(tmp_path: Path) -> None: + """Alias cache eviction must also clear keyed instances created with kwargs.""" + registry = StorageRegistry() + + first_root = tmp_path / "first" + second_root = tmp_path / "second" + registry.register_alias("test_store", f"file://{first_root}", backend="local") + backend_a = registry.get("test_store", cache_label="stable") + + registry.register_alias("test_store", f"file://{second_root}", backend="local") + backend_b = registry.get("test_store", cache_label="stable") + + assert backend_b is not backend_a + assert str(backend_b.base_path) == str(second_root.resolve()) + + def test_cache_functionality(tmp_path: Path) -> None: """Test registry caching.""" registry = StorageRegistry() diff --git a/tests/unit/utils/test_fixtures.py b/tests/unit/utils/test_fixtures.py index 2310ca7ea..d25ff6195 100644 --- a/tests/unit/utils/test_fixtures.py +++ b/tests/unit/utils/test_fixtures.py @@ -6,7 +6,9 @@ """ import gzip +import importlib import json +import sys import zipfile from pathlib import Path from typing import Any @@ -438,6 +440,16 @@ async def test_write_fixture_async_custom_backend(mock_registry: Mock) -> None: mock_storage.write_text_async.assert_called_once() +async def test_write_fixture_async_imports_without_anyio(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + monkeypatch.setitem(sys.modules, "anyio", None) + sys.modules.pop("sqlspec.utils.fixtures", None) + + fixtures_module = importlib.import_module("sqlspec.utils.fixtures") + await fixtures_module.write_fixture_async(str(tmp_path), "without_anyio", {"ok": True}) + + assert (tmp_path / "without_anyio.json").exists() + + def test_write_read_roundtrip(tmp_path: Path) -> None: """Test complete write and read roundtrip.""" original_data: Any = { diff --git a/tests/unit/utils/test_to_value_type.py b/tests/unit/utils/test_to_value_type.py index 27c4c38a8..4a489ab12 100644 --- a/tests/unit/utils/test_to_value_type.py +++ b/tests/unit/utils/test_to_value_type.py @@ -21,6 +21,67 @@ # ============================================================================= +def test_foreign_key_metadata_recognized_via_issubclass() -> None: + """Real ForeignKeyMetadata resolves to a schema converter.""" + from sqlspec.typing import ForeignKeyMetadata + from sqlspec.utils.schema import _get_schema_converter + + assert _get_schema_converter(ForeignKeyMetadata) is not None + + +def test_foreign_key_metadata_conversion_roundtrip() -> None: + """Dict with FK column data converts to ForeignKeyMetadata.""" + from sqlspec.typing import ForeignKeyMetadata + + result = to_value_type( + { + "table_name": "orders", + "column_name": "user_id", + "referenced_table": "users", + "referenced_column": "id", + "constraint_name": "fk_orders_user", + }, + ForeignKeyMetadata, + ) + + assert isinstance(result, ForeignKeyMetadata) + assert result.table_name == "orders" + assert result.column_name == "user_id" + assert result.referenced_table == "users" + assert result.referenced_column == "id" + assert result.constraint_name == "fk_orders_user" + + +def test_unrelated_class_named_foreign_key_metadata_not_matched() -> None: + """A same-named third-party class must not use the FK metadata converter.""" + from sqlspec.utils.schema import _get_schema_converter + + class ForeignKeyMetadata: + __slots__ = ("column_name", "referenced_column", "referenced_table", "table_name") + + assert _get_schema_converter(ForeignKeyMetadata) is None + + +def test_foreign_key_metadata_list_conversion() -> None: + """A list of FK dictionaries converts to FK metadata instances.""" + from sqlspec.typing import ForeignKeyMetadata + from sqlspec.utils.schema import _convert_foreign_key_metadata + + result = _convert_foreign_key_metadata( + [ + {"table_name": "orders", "column_name": "user_id", "referenced_table": "users", "referenced_column": "id"}, + {"table_name": "items", "column_name": "order_id", "referenced_table": "orders", "referenced_column": "id"}, + ], + ForeignKeyMetadata, + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, ForeignKeyMetadata) for item in result) + assert result[0].table_name == "orders" + assert result[1].table_name == "items" + + class TestIdentityConversions: """Tests for identity conversions where value is already the correct type.""" diff --git a/tests/unit/utils/test_type_guards.py b/tests/unit/utils/test_type_guards.py index 1385c7d56..ff6832318 100644 --- a/tests/unit/utils/test_type_guards.py +++ b/tests/unit/utils/test_type_guards.py @@ -35,6 +35,7 @@ has_expressions_attribute, has_parent_attribute, has_this_attribute, + is_async_readable, is_attrs_instance, is_attrs_instance_with_field, is_attrs_instance_without_field, @@ -58,6 +59,7 @@ is_pydantic_model, is_pydantic_model_with_field, is_pydantic_model_without_field, + is_readable, is_schema, is_schema_or_dict, is_schema_or_dict_with_field, @@ -156,6 +158,28 @@ def __init__(self, value: Any) -> None: self.value = value +class SyncReadable: + def read(self) -> str: + return "sync" + + +class AsyncReadable: + async def read(self) -> str: + return "async" + + +def test_is_readable_accepts_sync_read_method() -> None: + assert is_readable(SyncReadable()) is True + + +def test_is_async_readable_rejects_sync_read_method() -> None: + assert is_async_readable(SyncReadable()) is False + + +def test_is_async_readable_accepts_async_read_method() -> None: + assert is_async_readable(AsyncReadable()) is True + + def test_is_dataclass_instance_with_valid_dataclass() -> None: """Test is_dataclass_instance returns True for dataclass instances.""" instance = SampleDataclass(name="test", age=25) diff --git a/tests/unit/utils/test_typing_module.py b/tests/unit/utils/test_typing_module.py new file mode 100644 index 000000000..e85312e7c --- /dev/null +++ b/tests/unit/utils/test_typing_module.py @@ -0,0 +1,26 @@ +"""Regression tests for sqlspec.typing portability.""" + +import typing +from collections.abc import Mapping +from pathlib import Path + + +def test_typing_module_source_does_not_reference_private_typeddict() -> None: + """sqlspec.typing must not import or reference typing._TypedDict.""" + assert "_TypedDict" not in Path("sqlspec/typing.py").read_text() + + +def test_supported_schema_model_includes_mapping() -> None: + """SupportedSchemaModel should include Mapping[str, Any].""" + from sqlspec.typing import SupportedSchemaModel + + args = typing.get_args(SupportedSchemaModel) + + assert any(getattr(arg, "__origin__", None) is Mapping for arg in args) + + +def test_typing_module_has_no_private_typeddict_name() -> None: + """sqlspec.typing should not expose the private _TypedDict name.""" + import sqlspec.typing as typing_module + + assert not hasattr(typing_module, "_TypedDict") diff --git a/tools/scripts/bench.py b/tools/scripts/bench.py index ea3479ade..617100c4f 100644 --- a/tools/scripts/bench.py +++ b/tools/scripts/bench.py @@ -29,22 +29,6 @@ from sqlspec.utils.schema import transform_dict_keys from sqlspec.utils.text import camelize -# Pool leak detection helper -_leaked_pools: list[str] = [] - - -def _is_compiled() -> bool: - """Detect if sqlspec driver modules are mypyc-compiled.""" - try: - from sqlspec.driver import _sync - - return hasattr(_sync, "__file__") and (_sync.__file__ or "").endswith(".so") - except ImportError: - return False - - -SQLSPEC_LABEL = "sqlspec (mypyc)" if _is_compiled() else "sqlspec" - __all__ = ( "main", "print_benchmark_table", @@ -99,6 +83,22 @@ def _is_compiled() -> bool: "sqlspec_sqlite_write_heavy", ) +# Pool leak detection helper +_leaked_pools: list[str] = [] + + +def _is_compiled() -> bool: + """Detect if sqlspec driver modules are mypyc-compiled.""" + try: + from sqlspec.driver import _sync + + return hasattr(_sync, "__file__") and (_sync.__file__ or "").endswith(".so") + except ImportError: + return False + + +SQLSPEC_LABEL = "sqlspec (mypyc)" if _is_compiled() else "sqlspec" + ROWS_TO_INSERT = 10_000 POOL_SIZE = 5 # Default pool size for async adapters diff --git a/tools/scripts/mypyc_boundary_map.py b/tools/scripts/mypyc_boundary_map.py index 8ea2c47ab..0175b4f7f 100644 --- a/tools/scripts/mypyc_boundary_map.py +++ b/tools/scripts/mypyc_boundary_map.py @@ -5,11 +5,6 @@ from pathlib import Path from typing import Any -try: - import tomllib # type: ignore[import-not-found] -except ModuleNotFoundError: # pragma: no cover - import tomli as tomllib - __all__ = ( "ANY_AUDIT_SEAMS", "CONFIG_RUNTIME_BOUNDARIES", @@ -24,6 +19,11 @@ "load_mypyc_patterns", ) +try: + import tomllib # type: ignore[import-not-found] +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + CONFIG_RUNTIME_BOUNDARIES: tuple[dict[str, Any], ...] = ( { diff --git a/tools/scripts/mypyc_inventory.py b/tools/scripts/mypyc_inventory.py index af0b39476..b8616a8dd 100644 --- a/tools/scripts/mypyc_inventory.py +++ b/tools/scripts/mypyc_inventory.py @@ -8,11 +8,6 @@ from pathlib import Path from typing import Any -try: - import tomllib # type: ignore[import-not-found] -except ModuleNotFoundError: # pragma: no cover - import tomli as tomllib - __all__ = ( "HOT_SURFACE_CLASSIFICATIONS", "build_inventory", @@ -24,6 +19,12 @@ "main", ) +try: + import tomllib # type: ignore[import-not-found] +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + + CANDIDATE_CLASSIFICATIONS = {"compile_now", "helper_split_first", "prove_separately"} SURFACE_ORDER = ("compiled", "candidate", "keep_interpreted", "interpreted")