Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,5 @@ tools/scripts/profiles/*.prof
.beads-credential-key
.codex
.mypy_worker*
data/
.antigravitycli/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ ignore = [
"PLW0603", # Using the global statement to update is discouraged
"PLW0108", # Replace lambda expression with a def
"RUF067", # ruff - init should only have import statements
"PLW0717", # too many statements in try block
]
select = ["ALL"]

Expand Down
4 changes: 2 additions & 2 deletions sqlspec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@
from sqlspec.typing import ConnectionT, PoolT, SchemaT, StatementParameters, SupportedSchemaModel
from sqlspec.utils.logging import suppress_erroneous_sqlglot_log_messages

suppress_erroneous_sqlglot_log_messages()

__all__ = (
"SQL",
"ArrowResult",
Expand Down Expand Up @@ -165,3 +163,5 @@
"typing",
"utils",
)

suppress_erroneous_sqlglot_log_messages()
165 changes: 83 additions & 82 deletions sqlspec/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,89 @@

from sqlspec.utils.module_loader import dependency_flag, module_available

__all__ = (
"ALLOYDB_CONNECTOR_INSTALLED",
"ATTRS_INSTALLED",
"CATTRS_INSTALLED",
"CLOUD_SQL_CONNECTOR_INSTALLED",
"FSSPEC_INSTALLED",
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"NANOID_INSTALLED",
"NUMPY_INSTALLED",
"OBSTORE_INSTALLED",
"OPENTELEMETRY_INSTALLED",
"ORJSON_INSTALLED",
"PANDAS_INSTALLED",
"PGVECTOR_INSTALLED",
"POLARS_INSTALLED",
"PROMETHEUS_INSTALLED",
"PYARROW_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"UNSET_STUB",
"UUID_UTILS_INSTALLED",
"ArrowRecordBatch",
"ArrowRecordBatchReader",
"ArrowRecordBatchReaderProtocol",
"ArrowRecordBatchResult",
"ArrowSchema",
"ArrowSchemaProtocol",
"ArrowTable",
"ArrowTableResult",
"AttrsInstance",
"AttrsInstanceStub",
"BaseModel",
"BaseModelStub",
"Counter",
"DTOData",
"DTODataStub",
"DataclassProtocol",
"Empty",
"EmptyEnum",
"EmptyType",
"FailFast",
"FailFastStub",
"Gauge",
"Histogram",
"NumpyArray",
"NumpyArrayStub",
"PandasDataFrame",
"PandasDataFrameProtocol",
"PolarsDataFrame",
"PolarsDataFrameProtocol",
"Span",
"Status",
"StatusCode",
"Struct",
"StructStub",
"T",
"T_co",
"Tracer",
"TypeAdapter",
"TypeAdapterStub",
"UnsetType",
"UnsetTypeStub",
"attrs_asdict",
"attrs_asdict_stub",
"attrs_define",
"attrs_define_stub",
"attrs_field",
"attrs_field_stub",
"attrs_fields",
"attrs_fields_stub",
"attrs_has",
"attrs_has_stub",
"cattrs_structure",
"cattrs_unstructure",
"convert",
"convert_stub",
"module_available",
"msgspec_fields",
"msgspec_fields_stub",
"trace",
)


@runtime_checkable
class DataclassProtocol(Protocol):
Expand Down Expand Up @@ -630,85 +713,3 @@ def labels(self, *labelvalues: str, **labelkwargs: str) -> _MetricInstance:
ALLOYDB_CONNECTOR_INSTALLED = dependency_flag("google.cloud.alloydb.connector")
NANOID_INSTALLED = dependency_flag("fastnanoid")
UUID_UTILS_INSTALLED = dependency_flag("uuid_utils")
__all__ = (
"ALLOYDB_CONNECTOR_INSTALLED",
"ATTRS_INSTALLED",
"CATTRS_INSTALLED",
"CLOUD_SQL_CONNECTOR_INSTALLED",
"FSSPEC_INSTALLED",
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"NANOID_INSTALLED",
"NUMPY_INSTALLED",
"OBSTORE_INSTALLED",
"OPENTELEMETRY_INSTALLED",
"ORJSON_INSTALLED",
"PANDAS_INSTALLED",
"PGVECTOR_INSTALLED",
"POLARS_INSTALLED",
"PROMETHEUS_INSTALLED",
"PYARROW_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"UNSET_STUB",
"UUID_UTILS_INSTALLED",
"ArrowRecordBatch",
"ArrowRecordBatchReader",
"ArrowRecordBatchReaderProtocol",
"ArrowRecordBatchResult",
"ArrowSchema",
"ArrowSchemaProtocol",
"ArrowTable",
"ArrowTableResult",
"AttrsInstance",
"AttrsInstanceStub",
"BaseModel",
"BaseModelStub",
"Counter",
"DTOData",
"DTODataStub",
"DataclassProtocol",
"Empty",
"EmptyEnum",
"EmptyType",
"FailFast",
"FailFastStub",
"Gauge",
"Histogram",
"NumpyArray",
"NumpyArrayStub",
"PandasDataFrame",
"PandasDataFrameProtocol",
"PolarsDataFrame",
"PolarsDataFrameProtocol",
"Span",
"Status",
"StatusCode",
"Struct",
"StructStub",
"T",
"T_co",
"Tracer",
"TypeAdapter",
"TypeAdapterStub",
"UnsetType",
"UnsetTypeStub",
"attrs_asdict",
"attrs_asdict_stub",
"attrs_define",
"attrs_define_stub",
"attrs_field",
"attrs_field_stub",
"attrs_fields",
"attrs_fields_stub",
"attrs_has",
"attrs_has_stub",
"cattrs_structure",
"cattrs_unstructure",
"convert",
"convert_stub",
"module_available",
"msgspec_fields",
"msgspec_fields_stub",
"trace",
)
5 changes: 2 additions & 3 deletions sqlspec/adapters/adbc/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
AdbcConnection = _AdbcConnection
AdbcRawCursor = _AdbcRawCursor

__all__ = ("AdbcConnection", "AdbcCursor", "AdbcRawCursor", "AdbcSessionContext")


class AdbcCursor:
"""Context manager for cursor management."""
Expand Down Expand Up @@ -106,6 +108,3 @@ def __exit__(
self._release_connection(self._connection)
self._connection = None
return None


__all__ = ("AdbcConnection", "AdbcCursor", "AdbcRawCursor", "AdbcSessionContext")
3 changes: 2 additions & 1 deletion sqlspec/adapters/adbc/adk/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from sqlspec.adapters.adbc.config import AdbcConfig
from sqlspec.extensions.adk import MemoryRecord

__all__ = ("AdbcADKMemoryStore", "AdbcADKStore")

logger = get_logger("sqlspec.adapters.adbc.adk.store")

__all__ = ("AdbcADKMemoryStore", "AdbcADKStore")

DIALECT_POSTGRESQL: Final = "postgresql"
DIALECT_SQLITE: Final = "sqlite"
Expand Down
4 changes: 2 additions & 2 deletions sqlspec/adapters/adbc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def handle_postgres_rollback(dialect: str, cursor: Any, logger: Any | None = Non
cursor: Database cursor to execute rollback.
logger: Optional logger for diagnostics.
"""
if dialect != "postgres":
if not is_postgres_dialect(dialect):
return
try:
cursor.execute("ROLLBACK")
Expand All @@ -457,7 +457,7 @@ def normalize_postgres_empty_parameters(dialect: str, parameters: Any) -> Any:
Returns:
Normalized parameter payload.
"""
if dialect == "postgres" and isinstance(parameters, dict) and not parameters:
if is_postgres_dialect(dialect) and isinstance(parameters, dict) and not parameters:
return None
return parameters

Expand Down
4 changes: 2 additions & 2 deletions sqlspec/adapters/adbc/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from sqlspec.exceptions import SQLFileNotFoundError
from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo

__all__ = ("AdbcDataDictionary",)

if TYPE_CHECKING:
from sqlspec.adapters.adbc.driver import AdbcDriver
from sqlspec.core import SQL

__all__ = ("AdbcDataDictionary",)


@mypyc_attr(allow_interpreted_subclasses=True, native_class=False)
class AdbcDataDictionary(SyncDataDictionaryBase):
Expand Down
19 changes: 9 additions & 10 deletions sqlspec/adapters/aiomysql/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def close(self) -> Any: ...
AiomysqlDictCursor = _AiomysqlDictCursor
AiomysqlPool = _AiomysqlPool

__all__ = (
"AiomysqlConnection",
"AiomysqlCursor",
"AiomysqlDictCursor",
"AiomysqlPool",
"AiomysqlRawCursor",
"AiomysqlSessionContext",
)


class AiomysqlCursor:
"""Context manager for aiomysql cursor operations.
Expand Down Expand Up @@ -125,13 +134,3 @@ async def __aexit__(
await self._release_connection(self._connection)
self._connection = None
return None


__all__ = (
"AiomysqlConnection",
"AiomysqlCursor",
"AiomysqlDictCursor",
"AiomysqlPool",
"AiomysqlRawCursor",
"AiomysqlSessionContext",
)
4 changes: 4 additions & 0 deletions sqlspec/adapters/aiomysql/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)

if prepared_parameters is not None and len(statements) > 1:
msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements"
raise SQLSpecError(msg)

successful_count = 0
last_cursor = cursor

Expand Down
3 changes: 2 additions & 1 deletion sqlspec/adapters/aiomysql/litestar/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
if TYPE_CHECKING:
from sqlspec.adapters.aiomysql.config import AiomysqlConfig

__all__ = ("AiomysqlStore",)

logger = get_logger("sqlspec.adapters.aiomysql.litestar.store")

__all__ = ("AiomysqlStore",)

MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146

Expand Down
7 changes: 2 additions & 5 deletions sqlspec/adapters/aiosqlite/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
AiosqliteConnection = _AiosqliteConnection
AiosqliteRawCursor = aiosqlite.Cursor

__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteRawCursor", "AiosqliteSessionContext")


class AiosqliteCursor:
"""Async context manager for AIOSQLite cursors."""
Expand All @@ -44,8 +46,6 @@ async def __aenter__(self) -> "AiosqliteRawCursor":
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> None:
if exc_type is not None:
return
if self.cursor is not None:
with contextlib.suppress(Exception):
await self.cursor.close()
Expand Down Expand Up @@ -104,6 +104,3 @@ async def __aexit__(
await self._release_connection(self._connection)
self._connection = None
return None


__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteRawCursor", "AiosqliteSessionContext")
4 changes: 2 additions & 2 deletions sqlspec/adapters/aiosqlite/adk/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
from sqlspec.extensions.adk import MemoryRecord

__all__ = ("AiosqliteADKMemoryStore", "AiosqliteADKStore")


SECONDS_PER_DAY = 86400.0
JULIAN_EPOCH = 2440587.5
SQLITE_TABLE_NOT_FOUND_ERROR: Final = "no such table"

__all__ = ("AiosqliteADKMemoryStore", "AiosqliteADKStore")


def _datetime_to_julian(dt: datetime) -> float:
"""Convert datetime to Julian Day number for SQLite storage.
Expand Down
2 changes: 1 addition & 1 deletion sqlspec/adapters/aiosqlite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AiosqlitePoolConnection,
AiosqlitePoolConnectionContext,
)
from sqlspec.adapters.sqlite.type_converter import register_type_handlers
from sqlspec.adapters.aiosqlite.type_converter import register_type_handlers
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory
from sqlspec.utils.config_tools import normalize_connection_config
Expand Down
Loading
Loading