From 0f1f8263984a9ba51f5e409206e3fbaa1dd1da6c Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:22:15 -0500 Subject: [PATCH 1/9] fix: Add paramstyle to DBAPI for SQLAlchemy compliance --- prestodb/dbapi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index cc60bf7..2ce9a69 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -38,6 +38,7 @@ apilevel = "2.0" threadsafety = 2 +paramstyle = "pyformat" logger = logging.getLogger(__name__) From 1f436eecc525da48db908f53c852c35eb1a020ae Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:22:35 -0500 Subject: [PATCH 2/9] feat: Implement native SQLAlchemy dialect --- prestodb/sqlalchemy/__init__.py | 11 ++ prestodb/sqlalchemy/base.py | 198 ++++++++++++++++++++++++++++++ prestodb/sqlalchemy/compiler.py | 209 ++++++++++++++++++++++++++++++++ prestodb/sqlalchemy/datatype.py | 120 ++++++++++++++++++ 4 files changed, 538 insertions(+) create mode 100644 prestodb/sqlalchemy/__init__.py create mode 100644 prestodb/sqlalchemy/base.py create mode 100644 prestodb/sqlalchemy/compiler.py create mode 100644 prestodb/sqlalchemy/datatype.py diff --git a/prestodb/sqlalchemy/__init__.py b/prestodb/sqlalchemy/__init__.py new file mode 100644 index 0000000..4d9a924 --- /dev/null +++ b/prestodb/sqlalchemy/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/prestodb/sqlalchemy/base.py b/prestodb/sqlalchemy/base.py new file mode 100644 index 0000000..16a07da --- /dev/null +++ b/prestodb/sqlalchemy/base.py @@ -0,0 +1,198 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Attribution: +# This code is adapted from the trino-python-client project (Apache 2.0 License). +# https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/dialect.py + +from sqlalchemy import types, util +from sqlalchemy.engine import default +from sqlalchemy.sql import sqltypes + +from prestodb import auth, dbapi +from prestodb.sqlalchemy import compiler, datatype + +_type_map = { + # Standard types + "boolean": datatype.BOOLEAN, + "tinyint": datatype.TINYINT, + "smallint": datatype.SMALLINT, + "integer": datatype.INTEGER, + "bigint": datatype.BIGINT, + "real": datatype.REAL, + "double": datatype.DOUBLE, + "decimal": datatype.DECIMAL, + "varchar": datatype.VARCHAR, + "char": datatype.CHAR, + "varbinary": datatype.VARBINARY, + "json": datatype.JSON, + "date": datatype.DATE, + "time": datatype.TIME, + "time with time zone": datatype.TIME, # TODO: time with time zone + "timestamp": datatype.TIMESTAMP, + "timestamp with time zone": datatype.TIMESTAMP, # TODO: timestamp with time zone + "interval year to month": datatype.INTERVAL, + "interval day to second": datatype.INTERVAL, + # Specific types + "array": datatype.ARRAY, + "map": datatype.MAP, + "row": datatype.ROW, + "hyperloglog": datatype.HYPERLOGLOG, + "p4hyperloglog": datatype.P4HYPERLOGLOG, + "qdigest": datatype.QDIGEST, +} + + +class PrestoDialect(default.DefaultDialect): + name = "presto" + driver = "presto" + author = "Presto Team" + supports_alter = False + supports_pk_on_update = False + supports_full_outer_join = True + supports_simple_order_by_label = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + supports_native_boolean = True + + statement_compiler = compiler.PrestoSQLCompiler + type_compiler = compiler.PrestoTypeCompiler + preparer = compiler.PrestoIdentifierPreparer + + def create_connect_args(self, url): + args = {"host": url.host} + if url.port: + args["port"] = url.port + if url.username: + args["user"] = url.username + if url.password: + args["http_scheme"] = "https" + args["auth"] = auth.BasicAuthentication(url.username, url.password) + + db_parts = (url.database or "system").split("/") + if len(db_parts) == 1: + args["catalog"] = db_parts[0] + elif len(db_parts) == 2: + args["catalog"] = db_parts[0] + args["schema"] = db_parts[1] + else: + raise ValueError("Unexpected database format: {}".format(url.database)) + + return ([args], {}) + + @classmethod + def import_dbapi(cls): + return dbapi + + def has_table(self, connection, table_name, schema=None): + return self._has_object(connection, "TABLE", table_name, schema) + + def has_sequence(self, connection, sequence_name, schema=None): + return False + + def _has_object(self, connection, object_type, object_name, schema=None): + if schema is None: + schema = connection.engine.dialect.default_schema_name + + return ( + connection.execute( + "SELECT count(*) FROM information_schema.tables " + "WHERE table_schema = '{}' AND table_name = '{}'".format( + schema, object_name + ) + ).scalar() + > 0 + ) + + def get_schema_names(self, connection, **kw): + result = connection.execute("SHOW SCHEMAS") + return [row[0] for row in result] + + def get_table_names(self, connection, schema=None, **kw): + schema = schema or self.default_schema_name + if schema is None: + raise ValueError("schema argument is required") + result = connection.execute("SHOW TABLES FROM {}".format(schema)) + return [row[0] for row in result] + + def get_columns(self, connection, table_name, schema=None, **kw): + schema = schema or self.default_schema_name + if schema is None: + raise ValueError("schema argument is required") + query = "SHOW COLUMNS FROM {}.{}".format(schema, table_name) + result = connection.execute(query) + columns = [] + for row in result: + # Column(Column, Type, Extra, Comment) + col_name = row[0] + col_type = row[1] + # extra = row[2] + # comment = row[3] + columns.append( + { + "name": col_name, + "type": self._parse_type(col_type), + "nullable": True, # TODO: check nullability + "default": None, + } + ) + return columns + + def _parse_type(self, type_str): + type_str = type_str.lower() + match = util.re.match(r"^([a-zA-Z0-9_]+)(\((.+)\))?$", type_str) + if not match: + return sqltypes.NullType() + + type_name = match.group(1) + type_args = match.group(3) + + if type_name in _type_map: + type_class = _type_map[type_name] + if type_args: + return type_class(*self._parse_type_args(type_args)) + return type_class() + return sqltypes.NullType() + + def _parse_type_args(self, type_args): + # TODO: improve parsing for nested types + return [int(a) if a.isdigit() else a for a in type_args.split(",")] + + def do_rollback(self, dbapi_connection): + # Presto transactions usually auto-commit or are read-only + pass + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Presto doesn't enforce foreign keys + return [] + + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + # Presto doesn't enforce primary keys + return {"constrained_columns": [], "name": None} + + def get_indexes(self, connection, table_name, schema=None, **kw): + # TODO: Implement index reflection + return [] + + def do_ping(self, dbapi_connection): + cursor = None + try: + cursor = dbapi_connection.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + except Exception: + if cursor: + cursor.close() + return False + else: + cursor.close() + return True diff --git a/prestodb/sqlalchemy/compiler.py b/prestodb/sqlalchemy/compiler.py new file mode 100644 index 0000000..8023d41 --- /dev/null +++ b/prestodb/sqlalchemy/compiler.py @@ -0,0 +1,209 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Attribution: +# This code is adapted from the trino-python-client project (Apache 2.0 License). +# https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/compiler.py + +from sqlalchemy.sql import compiler + + +class PrestoSQLCompiler(compiler.SQLCompiler): + def visit_char_length_func(self, fn, **kw): + return "length{}".format(self.function_argspec(fn, **kw)) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += " LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def visit_lambda_element(self, element, **kw): + # Lambda expression not fully supported in standard SQLCompiler yet + return super(PrestoSQLCompiler, self).visit_lambda_element(element, **kw) + + +class PrestoTypeCompiler(compiler.GenericTypeCompiler): + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_REAL(self, type_, **kw): + return "REAL" + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_SMALLINT(self, type_, **kw): + return "SMALLINT" + + def visit_INTEGER(self, type_, **kw): + return "INTEGER" + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_VARCHAR(self, type_, **kw): + if type_.length is None: + return "VARCHAR" + return "VARCHAR(%d)" % type_.length + + def visit_CHAR(self, type_, **kw): + if type_.length is None: + return "CHAR" + return "CHAR(%d)" % type_.length + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY" + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_FLOAT(self, type_, **kw): + return "DOUBLE" + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return "DECIMAL" + if type_.scale is None: + return "DECIMAL(%d)" % (type_.precision) + return "DECIMAL(%d, %d)" % (type_.precision, type_.scale) + + def visit_DECIMAL(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_DATETIME(self, type_, **kw): + return "TIMESTAMP" + + def visit_CLOB(self, type_, **kw): + return "VARCHAR" + + def visit_NCLOB(self, type_, **kw): + return "VARCHAR" + + def visit_TEXT(self, type_, **kw): + return "VARCHAR" + + def visit_BLOB(self, type_, **kw): + return "VARBINARY" + + def visit_BOOLEAN(self, type_, **kw): + return "BOOLEAN" + + def visit_ARRAY(self, type_, **kw): + return "ARRAY(%s)" % self.process(type_.item_type, **kw) + + def visit_MAP(self, type_, **kw): + return "MAP(%s, %s)" % ( + self.process(type_.key_type, **kw), + self.process(type_.value_type, **kw), + ) + + def visit_ROW(self, type_, **kw): + items = [ + "%s %s" % (name, self.process(attr_type, **kw)) + for name, attr_type in type_.attr_types + ] + return "ROW(%s)" % ", ".join(items) + + def visit_HYPERLOGLOG(self, type_, **kw): + return "HyperLogLog" + + def visit_QDIGEST(self, type_, **kw): + return "QDigest" + + def visit_P4HYPERLOGLOG(self, type_, **kw): + return "P4HyperLogLog" + + +class PrestoIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = { + "alter", + "and", + "as", + "between", + "by", + "case", + "cast", + "constraint", + "create", + "cross", + "cube", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "deallocate", + "delete", + "describe", + "distinct", + "drop", + "else", + "end", + "escape", + "except", + "execute", + "exists", + "extract", + "false", + "for", + "from", + "full", + "group", + "grouping", + "having", + "in", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "left", + "like", + "localtime", + "localtimestamp", + "natural", + "normalize", + "not", + "null", + "on", + "or", + "order", + "outer", + "prepare", + "recursive", + "right", + "rollup", + "select", + "table", + "then", + "true", + "uescape", + "union", + "unnest", + "using", + "values", + "when", + "where", + "with", + } diff --git a/prestodb/sqlalchemy/datatype.py b/prestodb/sqlalchemy/datatype.py new file mode 100644 index 0000000..524b9b8 --- /dev/null +++ b/prestodb/sqlalchemy/datatype.py @@ -0,0 +1,120 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Attribution: +# This code is adapted from the trino-python-client project (Apache 2.0 License). +# https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/datatype.py + +from sqlalchemy import types + + +class DOUBLE(types.Float): + __visit_name__ = "DOUBLE" + + +class REAL(types.Float): + __visit_name__ = "REAL" + + +class BOOLEAN(types.Boolean): + __visit_name__ = "BOOLEAN" + + +class TINYINT(types.Integer): + __visit_name__ = "TINYINT" + + +class SMALLINT(types.Integer): + __visit_name__ = "SMALLINT" + + +class INTEGER(types.Integer): + __visit_name__ = "INTEGER" + + +class BIGINT(types.BigInteger): + __visit_name__ = "BIGINT" + + +class DECIMAL(types.DECIMAL): + __visit_name__ = "DECIMAL" + + +class VARCHAR(types.String): + __visit_name__ = "VARCHAR" + + +class CHAR(types.String): + __visit_name__ = "CHAR" + + +class VARBINARY(types.LargeBinary): + __visit_name__ = "VARBINARY" + + +class JSON(types.JSON): + __visit_name__ = "JSON" + + +class DATE(types.Date): + __visit_name__ = "DATE" + + +class TIME(types.Time): + __visit_name__ = "TIME" + + +class TIMESTAMP(types.TIMESTAMP): + __visit_name__ = "TIMESTAMP" + + +class INTERVAL(types.TypeEngine): + __visit_name__ = "INTERVAL" + + def __init__(self, start, end=None, precision=None): + self.start = start + self.end = end + self.precision = precision + + +class ARRAY(types.TypeEngine): + __visit_name__ = "ARRAY" + + def __init__(self, item_type): + self.item_type = item_type + + +class MAP(types.TypeEngine): + __visit_name__ = "MAP" + + def __init__(self, key_type, value_type): + self.key_type = key_type + self.value_type = value_type + + +class ROW(types.TypeEngine): + __visit_name__ = "ROW" + + def __init__(self, attr_types): + self.attr_types = attr_types + + +class HYPERLOGLOG(types.TypeEngine): + __visit_name__ = "HYPERLOGLOG" + + +class QDIGEST(types.TypeEngine): + __visit_name__ = "QDIGEST" + + +class P4HYPERLOGLOG(types.TypeEngine): + __visit_name__ = "P4HYPERLOGLOG" From cd525409b277bc1da8b62bab7fa7fa7250f34386 Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:27:39 -0500 Subject: [PATCH 3/9] build: Register sqlalchemy dialect entry point and dependencies --- .gitignore | 3 +++ setup.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7b7f654 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*/__pycache__/* +presto_python_client.egg-info/* +prestodb/sqlalchemy/__pycache__/* diff --git a/setup.py b/setup.py index 8bda465..1d1bb69 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,8 @@ google_auth_require = ["google_auth"] +sqlalchemy_require = ["sqlalchemy"] + all_require = [kerberos_require, google_auth_require] tests_require = all_require + ["httpretty", "pytest", "pytest-runner"] @@ -41,7 +43,7 @@ author_email="presto-users@googlegroups.com", version=version, url="https://github.com/prestodb/presto-python-client", - packages=["prestodb"], + packages=["prestodb", "prestodb.sqlalchemy"], package_data={"": ["LICENSE", "README.md"]}, description="Client for the Presto distributed SQL Engine", long_description=textwrap.dedent( @@ -72,11 +74,17 @@ "Topic :: Database :: Front-Ends", ], install_requires=["click", "requests", "six"], + entry_points={ + "sqlalchemy.dialects": [ + "presto = prestodb.sqlalchemy.base:PrestoDialect", + ], + }, extras_require={ "all": all_require, "kerberos": kerberos_require, "google_auth": google_auth_require, "tests": tests_require, + "sqlalchemy": sqlalchemy_require, ':python_version=="2.7"': py27_require, }, ) From 05006a82e3b0a762509b5fc4d1f6ea714827a1aa Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:28:01 -0500 Subject: [PATCH 4/9] docs: Add usage instructions for SQLAlchemy dialect --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index 5e4cd18..27d7745 100644 --- a/README.md +++ b/README.md @@ -111,8 +111,33 @@ The transaction is created when the first SQL statement is executed. exits the *with* context and the queries succeed, otherwise `prestodb.dbapi.Connection.rollback()' will be called. + +# SQLAlchemy Support + +The client also provides a SQLAlchemy dialect. + +## Installation + +``` +$ pip install presto-python-client[sqlalchemy] +``` + +## Usage + +To connect to Presto using SQLAlchemy: + +```python +from sqlalchemy import create_engine + +engine = create_engine('presto://user:password@host:port/catalog/schema') +connection = engine.connect() + +rows = connection.execute("SELECT * FROM system.runtime.nodes").fetchall() +``` + # Running Tests + There is a helper scripts, `run`, that provides commands to run tests. Type `./run tests` to run both unit and integration tests. From 92b856b78c2ce52b76c874aa7dec275693e3f47b Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:58:33 -0500 Subject: [PATCH 5/9] fix: Correct all_require flattening in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1d1bb69..917935b 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ sqlalchemy_require = ["sqlalchemy"] -all_require = [kerberos_require, google_auth_require] +all_require = [kerberos_require, google_auth_require, sqlalchemy_require] tests_require = all_require + ["httpretty", "pytest", "pytest-runner"] From 81e2e783fb6d951a41e8707d16c50e005708ce8e Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:58:45 -0500 Subject: [PATCH 6/9] test: Add unit and integration tests for SQLAlchemy dialect --- .../test_sqlalchemy_integration.py | 56 +++++++++ tests/test_sqlalchemy.py | 112 ++++++++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 integration_tests/test_sqlalchemy_integration.py create mode 100644 tests/test_sqlalchemy.py diff --git a/integration_tests/test_sqlalchemy_integration.py b/integration_tests/test_sqlalchemy_integration.py new file mode 100644 index 0000000..84c21e9 --- /dev/null +++ b/integration_tests/test_sqlalchemy_integration.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import, division, print_function + +import pytest +from sqlalchemy import create_engine, inspect, text +from sqlalchemy.schema import Table, MetaData, Column +from sqlalchemy.types import Integer, String +from integration_tests.fixtures import run_presto + +@pytest.fixture +def sqlalchemy_engine(run_presto): + _, host, port = run_presto + # Construct the SQLAlchemy URL. + # Note: 'test' user and 'test' catalog/schema match the dbapi fixtures. + url = "presto://test@{}:{}/test/test".format(host, port) + engine = create_engine(url) + return engine + +def test_sqlalchemy_engine_connect(sqlalchemy_engine): + with sqlalchemy_engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + +def test_sqlalchemy_query_execution(sqlalchemy_engine): + with sqlalchemy_engine.connect() as conn: + # Using a system table that is guaranteed to exist + result = conn.execute(text("SELECT * FROM system.runtime.nodes LIMIT 1")) + row = result.fetchone() + assert row is not None + +def test_sqlalchemy_reflection(sqlalchemy_engine): + # This requires tables to exist. + # tpch is usually available in the test environment (referenced in test_dbapi.py) + insp = inspect(sqlalchemy_engine) + + # Check schemas + schemas = insp.get_schema_names() + assert "sys" in schemas or "system" in schemas + + # Check tables in a specific schema (e.g. system.runtime) + tables = insp.get_table_names(schema="system") + assert "nodes" in tables or "runtime.nodes" in tables # Representation might vary + +def test_sqlalchemy_orm_basic(sqlalchemy_engine): + # Basic table definition + metadata = MetaData() + # we use a known table from tpch to avoid needing CREATE TABLE rights or persistence + # tpch.sf1.customer + # but that might be read-only. + + # For integration test without write access, we typically verify SELECTs + # If we need to write, we arguably should rely on the test_dbapi.py establishing environment + + with sqlalchemy_engine.connect() as conn: + result = conn.execute(text("SELECT count(*) FROM tpch.sf1.customer")) + count = result.scalar() + assert count > 0 diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py new file mode 100644 index 0000000..9ed0bdd --- /dev/null +++ b/tests/test_sqlalchemy.py @@ -0,0 +1,112 @@ +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.engine.url import make_url +from prestodb.sqlalchemy.base import PrestoDialect +from prestodb import auth +import prestodb.dbapi + +# Mocking the interaction with the DBAPI for unit tests +class MockCursor: + def __init__(self, display_list=None): + self.description = [] + self._display_list = display_list or [] + + def execute(self, operation, parameters=None): + pass + + def fetchall(self): + return self._display_list + + def close(self): + pass + +class MockConnection: + def __init__(self, host, **kwargs): + self.host = host + self.kwargs = kwargs + self._cursor = MockCursor() + + def cursor(self): + return self._cursor + + def close(self): + pass + + def commit(self): + pass + +@pytest.fixture +def mock_dbapi(monkeypatch): + def connect(*args, **kwargs): + return MockConnection(*args, **kwargs) + + monkeypatch.setattr(prestodb.dbapi, "connect", connect) + return prestodb.dbapi + +def test_engine_creation(): + url = "presto://user:password@localhost:8080/catalog/schema" + # Registry might not be loaded in test environment without setup.py install, + # so we might need to manually register if it fails, but ideally via entrypoints. + # For unit test, we can pass the dialect class directly to create_engine is not robust + # but for "presto://", it relies on entry points. + # Alternatively we rely on the implementation in base which imports dbapi. + + # We will test the dialect logic directly first + dialect = PrestoDialect() + u = make_url(url) + connect_args = dialect.create_connect_args(u) + + args = connect_args[0][0] + assert args["host"] == "localhost" + assert args["port"] == 8080 + assert args["user"] == "user" + assert args["catalog"] == "catalog" + assert args["schema"] == "schema" + assert args["http_scheme"] == "https" + assert isinstance(args["auth"], auth.BasicAuthentication) + +def test_type_parsing(): + dialect = PrestoDialect() + + # Simple types + assert str(dialect._parse_type("integer")) == "INTEGER" + assert str(dialect._parse_type("varchar")) == "VARCHAR" + + # Types with args + assert str(dialect._parse_type("varchar(255)")) == "VARCHAR(255)" + assert str(dialect._parse_type("decimal(10, 2)")) == "DECIMAL(10, 2)" + + # Multi-word types (Fixed by recent patch) + assert str(dialect._parse_type("timestamp with time zone")) == "TIMESTAMP" + assert str(dialect._parse_type("time with time zone")) == "TIME" + +def test_type_parsing_case_insensitive(): + dialect = PrestoDialect() + assert str(dialect._parse_type("INTEGER")) == "INTEGER" + assert str(dialect._parse_type("Varchar(10)")) == "VARCHAR(10)" + +def test_reflection_queries_generated(): + # Verify that reflection methods generate the expected SQL (using info schema) + # We can mock the connection.execute and check the query + dialect = PrestoDialect() + + class InspectableConnection: + def __init__(self): + self.queries = [] + self.engine = type('Engine', (), {'dialect': dialect})() + + def execute(self, sql, params=None): + self.queries.append((sql, params)) + return type('Result', (), {'scalar': lambda: 0, '__iter__': lambda x: iter([])})() + + conn = InspectableConnection() + + dialect.get_table_names(conn, schema="test_schema") + last_query, last_params = conn.queries[-1] + assert "information_schema.tables" in str(last_query) + assert last_params["schema"] == "test_schema" + + dialect.get_columns(conn, "test_table", schema="test_schema") + last_query, last_params = conn.queries[-1] + assert "information_schema.columns" in str(last_query) + assert last_params["table"] == "test_table" From 611ac9ef10ae6e8f01cd869ccddaec2684e6293b Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 15:58:58 -0500 Subject: [PATCH 7/9] fix: Resolve regex and security issues in SQLAlchemy dialect --- prestodb/sqlalchemy/base.py | 47 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/prestodb/sqlalchemy/base.py b/prestodb/sqlalchemy/base.py index 16a07da..9b03300 100644 --- a/prestodb/sqlalchemy/base.py +++ b/prestodb/sqlalchemy/base.py @@ -14,7 +14,8 @@ # This code is adapted from the trino-python-client project (Apache 2.0 License). # https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/dialect.py -from sqlalchemy import types, util +import re +from sqlalchemy import types, util, text from sqlalchemy.engine import default from sqlalchemy.sql import sqltypes @@ -103,57 +104,65 @@ def _has_object(self, connection, object_type, object_name, schema=None): if schema is None: schema = connection.engine.dialect.default_schema_name + query = text( + "SELECT count(*) FROM information_schema.tables " + "WHERE table_schema = :schema AND table_name = :table" + ) return ( connection.execute( - "SELECT count(*) FROM information_schema.tables " - "WHERE table_schema = '{}' AND table_name = '{}'".format( - schema, object_name - ) + query, {"schema": schema, "table": object_name} ).scalar() > 0 ) def get_schema_names(self, connection, **kw): - result = connection.execute("SHOW SCHEMAS") + result = connection.execute("SELECT schema_name FROM information_schema.schemata") return [row[0] for row in result] def get_table_names(self, connection, schema=None, **kw): schema = schema or self.default_schema_name if schema is None: raise ValueError("schema argument is required") - result = connection.execute("SHOW TABLES FROM {}".format(schema)) + + query = text( + "SELECT table_name FROM information_schema.tables WHERE table_schema = :schema" + ) + result = connection.execute(query, {"schema": schema}) return [row[0] for row in result] def get_columns(self, connection, table_name, schema=None, **kw): schema = schema or self.default_schema_name if schema is None: raise ValueError("schema argument is required") - query = "SHOW COLUMNS FROM {}.{}".format(schema, table_name) - result = connection.execute(query) + + query = text( + "SELECT column_name, data_type, is_nullable, column_default " + "FROM information_schema.columns " + "WHERE table_schema = :schema AND table_name = :table " + "ORDER BY ordinal_position" + ) + result = connection.execute(query, {"schema": schema, "table": table_name}) + columns = [] for row in result: - # Column(Column, Type, Extra, Comment) - col_name = row[0] - col_type = row[1] - # extra = row[2] - # comment = row[3] + col_name, col_type, is_nullable, default_val = row columns.append( { "name": col_name, "type": self._parse_type(col_type), - "nullable": True, # TODO: check nullability - "default": None, + "nullable": is_nullable.lower() == "yes", + "default": default_val, } ) return columns def _parse_type(self, type_str): type_str = type_str.lower() - match = util.re.match(r"^([a-zA-Z0-9_]+)(\((.+)\))?$", type_str) + match = re.match(r"^([a-zA-Z0-9_ ]+)(\((.+)\))?$", type_str) if not match: return sqltypes.NullType() - type_name = match.group(1) + type_name = match.group(1).strip() type_args = match.group(3) if type_name in _type_map: @@ -165,7 +174,7 @@ def _parse_type(self, type_str): def _parse_type_args(self, type_args): # TODO: improve parsing for nested types - return [int(a) if a.isdigit() else a for a in type_args.split(",")] + return [int(a.strip()) if a.strip().isdigit() else a.strip() for a in type_args.split(",")] def do_rollback(self, dbapi_connection): # Presto transactions usually auto-commit or are read-only From d3b93f8efb37e6ae908dd0925f1956820839217d Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 16:06:17 -0500 Subject: [PATCH 8/9] Update setup.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 917935b..92d015c 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ sqlalchemy_require = ["sqlalchemy"] -all_require = [kerberos_require, google_auth_require, sqlalchemy_require] +all_require = kerberos_require + google_auth_require + sqlalchemy_require tests_require = all_require + ["httpretty", "pytest", "pytest-runner"] From f633c14754670ec0807b2c3a781396d0aabb8310 Mon Sep 17 00:00:00 2001 From: Saquib Saifee Date: Sat, 10 Jan 2026 16:06:38 -0500 Subject: [PATCH 9/9] Update prestodb/sqlalchemy/base.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- prestodb/sqlalchemy/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/prestodb/sqlalchemy/base.py b/prestodb/sqlalchemy/base.py index 9b03300..ba68001 100644 --- a/prestodb/sqlalchemy/base.py +++ b/prestodb/sqlalchemy/base.py @@ -116,7 +116,9 @@ def _has_object(self, connection, object_type, object_name, schema=None): ) def get_schema_names(self, connection, **kw): - result = connection.execute("SELECT schema_name FROM information_schema.schemata") + result = connection.execute( + text("SELECT schema_name FROM information_schema.schemata") + ) return [row[0] for row in result] def get_table_names(self, connection, schema=None, **kw):