Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ docdb = [
"sshtunnel"
]
rds = [
"psycopg2-binary==2.9.5",
"pandas>=2.0.0,<2.2.0",
"SQLAlchemy==1.4.49"
"redshift-connector>=2.0.0",
]
helpers = [
"aind-data-schema>=1.1.0,<2.0",
Expand Down
182 changes: 132 additions & 50 deletions src/aind_data_access_api/rds_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from typing import Optional, Union

import pandas as pd
import sqlalchemy.engine
import redshift_connector
from pydantic import AliasChoices, Field, SecretStr, model_validator
from pydantic_settings import SettingsConfigDict
from sqlalchemy import create_engine, engine, text
from sqlalchemy.engine.cursor import CursorResult
from typing_extensions import Self

from aind_data_access_api.credentials import CoreCredentials
Expand Down Expand Up @@ -48,7 +46,7 @@ class Client:
def __init__(
self,
credentials: RDSCredentials,
drivername: Optional[str] = "postgresql",
drivername: Optional[str] = None,
):
"""
Construct a client to interface with relational database.
Expand All @@ -57,31 +55,35 @@ def __init__(
credentials : CoreCredentials

drivername: Optional[str]
Combination of dialect[+driver] where the dialect is
the database name such as ``mysql``, ``oracle``, ``postgresql``,
etc. and the optional driver name is a DBAPI such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc.
Deprecated parameter, kept for backward compatibility.
Ignored when using redshift-connector.

"""
self.credentials = credentials
self.drivername = drivername
if drivername is not None:
import warnings
warnings.warn(
"drivername parameter is deprecated and ignored when using "
"redshift-connector",
DeprecationWarning,
stacklevel=2
)

@property
def _engine(self) -> sqlalchemy.engine.Engine:
"""Create a sqlalchemy engine:
https://docs.sqlalchemy.org/en/20/core/engines.html
def _get_connection(self) -> redshift_connector.Connection:
"""Create a redshift-connector connection.

Returns: sqlalchemy.engine.Engine
Returns
-------
redshift_connector.Connection
A connection to the Redshift database.
"""
connection_url = engine.URL.create(
drivername=self.drivername,
username=self.credentials.username,
password=self.credentials.password.get_secret_value(),
return redshift_connector.connect(
host=self.credentials.host,
database=self.credentials.database,
port=self.credentials.port,
database=self.credentials.database,
user=self.credentials.username,
password=self.credentials.password.get_secret_value(),
)
return create_engine(connection_url)

def append_df_to_table(
self,
Expand All @@ -96,23 +98,32 @@ def append_df_to_table(
df : pd.Dataframe
table_name : str
dtype: Optional[Union[dict, str]]
Note: dtype parameter is kept for backward compatibility
but is not directly supported by redshift-connector.

Returns
-------
None

"""
# to_sql method has types str | None, but also allows for callable
# Suppressing type check warning.
# noinspection PyTypeChecker
df.to_sql(
name=table_name,
con=self._engine,
dtype=dtype,
method="multi",
if_exists="append",
index=False, # Redshift doesn't support index=True
)
conn = self._get_connection()
try:
cursor = conn.cursor()
# Create INSERT statements from dataframe
columns = ", ".join([f'"{col}"' for col in df.columns])
placeholders = ", ".join(["%s"] * len(df.columns))
insert_query = (
f'INSERT INTO "{table_name}" ({columns}) '
f'VALUES ({placeholders})'
)

# Execute batch insert
for _, row in df.iterrows():
cursor.execute(insert_query, tuple(row))

conn.commit()
finally:
conn.close()
return None

def overwrite_table_with_df(
Expand All @@ -128,24 +139,75 @@ def overwrite_table_with_df(
df : pd.Dataframe
table_name : str
dtype: Optional[Union[dict, str]]
Note: dtype parameter is kept for backward compatibility
but is not directly supported by redshift-connector.

Returns
-------
None
"""
# to_sql method has types str | None, but also allows for callable
# Suppressing type check warning.
# noinspection PyTypeChecker
df.to_sql(
name=table_name,
con=self._engine,
dtype=dtype,
method="multi",
if_exists="replace",
index=False, # Redshift doesn't support index=True
)
conn = self._get_connection()
try:
cursor = conn.cursor()
# Drop and recreate table
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')

# Create table with columns from dataframe
column_defs = []
for col in df.columns:
# Infer SQL type from pandas dtype
dtype_str = self._infer_sql_type(df[col].dtype)
column_defs.append(f'"{col}" {dtype_str}')

create_query = (
f'CREATE TABLE "{table_name}" '
f'({", ".join(column_defs)})'
)
cursor.execute(create_query)

# Insert data
columns = ", ".join([f'"{col}"' for col in df.columns])
placeholders = ", ".join(["%s"] * len(df.columns))
insert_query = (
f'INSERT INTO "{table_name}" ({columns}) '
f'VALUES ({placeholders})'
)

for _, row in df.iterrows():
cursor.execute(insert_query, tuple(row))

conn.commit()
finally:
conn.close()
return None

def _infer_sql_type(self, dtype) -> str:
"""
Infer SQL type from pandas dtype.

Parameters
----------
dtype : pandas dtype

Returns
-------
str
SQL type string
"""
dtype_str = str(dtype)
if "int" in dtype_str:
return "INTEGER"
elif "float" in dtype_str:
return "FLOAT"
elif "bool" in dtype_str:
return "BOOLEAN"
elif "datetime" in dtype_str:
return "TIMESTAMP"
elif "object" in dtype_str:
return "VARCHAR(MAX)"
else:
return "VARCHAR(MAX)"

def read_table(
self, table_name: str, where_clause: Optional[str] = None
) -> pd.DataFrame:
Expand All @@ -165,16 +227,29 @@ def read_table(
A pandas dataframe created from the sql table.

"""
with self._engine.begin() as conn:
conn = self._get_connection()
try:
cursor = conn.cursor()
query = (
f'SELECT * FROM "{table_name}"'
if where_clause is None
else f'SELECT * FROM "{table_name}" WHERE {where_clause}'
)
df = pd.read_sql_query(sql=text(query), con=conn)
cursor.execute(query)

# Fetch column names
columns = [desc[0] for desc in cursor.description]

# Fetch all rows
rows = cursor.fetchall()

# Create dataframe
df = pd.DataFrame(rows, columns=columns)
finally:
conn.close()
return df

def execute_query(self, query: str) -> CursorResult:
def execute_query(self, query: str) -> redshift_connector.Cursor:
"""
Run a sql query against the database
Parameters
Expand All @@ -183,10 +258,17 @@ def execute_query(self, query: str) -> CursorResult:

Returns
-------
CursorResult
The result of the query.
redshift_connector.Cursor
The cursor object after executing the query.
Note: The connection is closed after query execution,
so fetch results before returning if needed.

"""
with self._engine.begin() as conn:
result = conn.execute(text(query))
return result
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(query)
conn.commit()
finally:
conn.close()
return cursor
Loading
Loading