Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*/__pycache__/*
presto_python_client.egg-info/*
prestodb/sqlalchemy/__pycache__/*
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,33 @@ The transaction is created when the first SQL statement is executed.
exits the *with* context and the queries succeed, otherwise
`prestodb.dbapi.Connection.rollback()' will be called.


# SQLAlchemy Support

The client also provides a SQLAlchemy dialect.

## Installation

```
$ pip install presto-python-client[sqlalchemy]
```

## Usage

To connect to Presto using SQLAlchemy:

```python
from sqlalchemy import create_engine

engine = create_engine('presto://user:password@host:port/catalog/schema')
connection = engine.connect()

rows = connection.execute("SELECT * FROM system.runtime.nodes").fetchall()
```

# Running Tests


There is a helper scripts, `run`, that provides commands to run tests.
Type `./run tests` to run both unit and integration tests.

Expand Down
56 changes: 56 additions & 0 deletions integration_tests/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import absolute_import, division, print_function

import pytest
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.schema import Table, MetaData, Column
from sqlalchemy.types import Integer, String
from integration_tests.fixtures import run_presto

@pytest.fixture
def sqlalchemy_engine(run_presto):
_, host, port = run_presto
# Construct the SQLAlchemy URL.
# Note: 'test' user and 'test' catalog/schema match the dbapi fixtures.
url = "presto://test@{}:{}/test/test".format(host, port)
engine = create_engine(url)
return engine

def test_sqlalchemy_engine_connect(sqlalchemy_engine):
with sqlalchemy_engine.connect() as conn:
result = conn.execute(text("SELECT 1"))
assert result.scalar() == 1

def test_sqlalchemy_query_execution(sqlalchemy_engine):
with sqlalchemy_engine.connect() as conn:
# Using a system table that is guaranteed to exist
result = conn.execute(text("SELECT * FROM system.runtime.nodes LIMIT 1"))
row = result.fetchone()
assert row is not None

def test_sqlalchemy_reflection(sqlalchemy_engine):
# This requires tables to exist.
# tpch is usually available in the test environment (referenced in test_dbapi.py)
insp = inspect(sqlalchemy_engine)

# Check schemas
schemas = insp.get_schema_names()
assert "sys" in schemas or "system" in schemas

# Check tables in a specific schema (e.g. system.runtime)
tables = insp.get_table_names(schema="system")
assert "nodes" in tables or "runtime.nodes" in tables # Representation might vary

def test_sqlalchemy_orm_basic(sqlalchemy_engine):
# Basic table definition
metadata = MetaData()
# we use a known table from tpch to avoid needing CREATE TABLE rights or persistence
# tpch.sf1.customer
# but that might be read-only.

# For integration test without write access, we typically verify SELECTs
# If we need to write, we arguably should rely on the test_dbapi.py establishing environment

Comment on lines +43 to +52
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): test_sqlalchemy_orm_basic doesn’t exercise ORM constructs and depends on tpch.sf1.customer existing, which may be brittle.

This test only runs a raw text() query on tpch.sf1.customer despite the ORM-style name/imports. Consider either renaming it to reflect that it’s a basic query test, or actually defining a mapped table/metadata and issuing a select() using SQLAlchemy objects to exercise the ORM/compiler more thoroughly.

Also, the hard dependency on tpch.sf1.customer can make the test flaky if that schema isn’t present. It would be more robust to use a table guaranteed by the integration fixtures or a temporary table created in test setup so the test is self-contained.

Suggested implementation:

from sqlalchemy import MetaData, Table, select
from sqlalchemy.orm import registry, sessionmaker

def test_sqlalchemy_orm_basic(sqlalchemy_engine):
    # Use ORM constructs against a table that we know exists in the "system" schema
    metadata = MetaData()
    insp = inspect(sqlalchemy_engine)

    tables = insp.get_table_names(schema="system")
    assert tables, "Expected at least one table in the 'system' schema for ORM test"

    # Prefer the "nodes" table if available, otherwise fall back to the first table
    table_name = "nodes" if "nodes" in tables else tables[0]

    system_table = Table(
        table_name,
        metadata,
        schema="system",
        autoload_with=sqlalchemy_engine,
    )

    mapper_registry = registry()

    class SystemRow:
        pass

    mapper_registry.map_imperatively(SystemRow, system_table)

    Session = sessionmaker(bind=sqlalchemy_engine)

    # Exercise ORM-style select using a Session and mapped class
    with Session() as session:
        result = session.execute(select(SystemRow).limit(1)).first()
        # We don't assert on specific columns or values, just that we can read at least one row
        assert result is not None

If MetaData, select, Table, registry, or sessionmaker are already imported elsewhere in this file, you should remove the duplicated import lines I added and keep a single, consolidated import block following your existing conventions.

with sqlalchemy_engine.connect() as conn:
result = conn.execute(text("SELECT count(*) FROM tpch.sf1.customer"))
count = result.scalar()
assert count > 0
1 change: 1 addition & 0 deletions prestodb/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

apilevel = "2.0"
threadsafety = 2
paramstyle = "pyformat"

logger = logging.getLogger(__name__)

Expand Down
11 changes: 11 additions & 0 deletions prestodb/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
209 changes: 209 additions & 0 deletions prestodb/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Attribution:
# This code is adapted from the trino-python-client project (Apache 2.0 License).
# https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/dialect.py

import re
from sqlalchemy import types, util, text
from sqlalchemy.engine import default
from sqlalchemy.sql import sqltypes

from prestodb import auth, dbapi
from prestodb.sqlalchemy import compiler, datatype

_type_map = {
# Standard types
"boolean": datatype.BOOLEAN,
"tinyint": datatype.TINYINT,
"smallint": datatype.SMALLINT,
"integer": datatype.INTEGER,
"bigint": datatype.BIGINT,
"real": datatype.REAL,
"double": datatype.DOUBLE,
"decimal": datatype.DECIMAL,
"varchar": datatype.VARCHAR,
"char": datatype.CHAR,
"varbinary": datatype.VARBINARY,
"json": datatype.JSON,
"date": datatype.DATE,
"time": datatype.TIME,
"time with time zone": datatype.TIME, # TODO: time with time zone
"timestamp": datatype.TIMESTAMP,
"timestamp with time zone": datatype.TIMESTAMP, # TODO: timestamp with time zone
"interval year to month": datatype.INTERVAL,
"interval day to second": datatype.INTERVAL,
# Specific types
"array": datatype.ARRAY,
"map": datatype.MAP,
"row": datatype.ROW,
"hyperloglog": datatype.HYPERLOGLOG,
"p4hyperloglog": datatype.P4HYPERLOGLOG,
"qdigest": datatype.QDIGEST,
}


class PrestoDialect(default.DefaultDialect):
name = "presto"
driver = "presto"
author = "Presto Team"
supports_alter = False
supports_pk_on_update = False
supports_full_outer_join = True
supports_simple_order_by_label = False
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_native_boolean = True

statement_compiler = compiler.PrestoSQLCompiler
type_compiler = compiler.PrestoTypeCompiler
preparer = compiler.PrestoIdentifierPreparer

def create_connect_args(self, url):
args = {"host": url.host}
if url.port:
args["port"] = url.port
if url.username:
args["user"] = url.username
if url.password:
args["http_scheme"] = "https"
args["auth"] = auth.BasicAuthentication(url.username, url.password)

db_parts = (url.database or "system").split("/")
if len(db_parts) == 1:
args["catalog"] = db_parts[0]
elif len(db_parts) == 2:
args["catalog"] = db_parts[0]
args["schema"] = db_parts[1]
else:
raise ValueError("Unexpected database format: {}".format(url.database))

return ([args], {})

@classmethod
def import_dbapi(cls):
return dbapi

def has_table(self, connection, table_name, schema=None):
return self._has_object(connection, "TABLE", table_name, schema)

def has_sequence(self, connection, sequence_name, schema=None):
return False

def _has_object(self, connection, object_type, object_name, schema=None):
if schema is None:
schema = connection.engine.dialect.default_schema_name

query = text(
"SELECT count(*) FROM information_schema.tables "
"WHERE table_schema = :schema AND table_name = :table"
)
return (
connection.execute(
query, {"schema": schema, "table": object_name}
).scalar()
> 0
)

def get_schema_names(self, connection, **kw):
result = connection.execute(
text("SELECT schema_name FROM information_schema.schemata")
)
return [row[0] for row in result]

def get_table_names(self, connection, schema=None, **kw):
schema = schema or self.default_schema_name
if schema is None:
raise ValueError("schema argument is required")

query = text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = :schema"
)
result = connection.execute(query, {"schema": schema})
return [row[0] for row in result]

def get_columns(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
if schema is None:
raise ValueError("schema argument is required")

query = text(
"SELECT column_name, data_type, is_nullable, column_default "
"FROM information_schema.columns "
"WHERE table_schema = :schema AND table_name = :table "
"ORDER BY ordinal_position"
)
result = connection.execute(query, {"schema": schema, "table": table_name})

columns = []
for row in result:
col_name, col_type, is_nullable, default_val = row
columns.append(
{
"name": col_name,
"type": self._parse_type(col_type),
"nullable": is_nullable.lower() == "yes",
"default": default_val,
}
)
return columns

def _parse_type(self, type_str):
type_str = type_str.lower()
match = re.match(r"^([a-zA-Z0-9_ ]+)(\((.+)\))?$", type_str)
if not match:
return sqltypes.NullType()

type_name = match.group(1).strip()
type_args = match.group(3)

if type_name in _type_map:
type_class = _type_map[type_name]
if type_args:
return type_class(*self._parse_type_args(type_args))
return type_class()
return sqltypes.NullType()

def _parse_type_args(self, type_args):
# TODO: improve parsing for nested types
return [int(a.strip()) if a.strip().isdigit() else a.strip() for a in type_args.split(",")]

def do_rollback(self, dbapi_connection):
# Presto transactions usually auto-commit or are read-only
pass

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Presto doesn't enforce foreign keys
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Presto doesn't enforce primary keys
return {"constrained_columns": [], "name": None}

def get_indexes(self, connection, table_name, schema=None, **kw):
# TODO: Implement index reflection
return []

def do_ping(self, dbapi_connection):
cursor = None
try:
cursor = dbapi_connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
except Exception:
if cursor:
cursor.close()
return False
else:
cursor.close()
return True
Loading