diff --git a/pyproject.toml b/pyproject.toml index c26d208..d9b9d31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,7 @@ docdb = [ "sshtunnel" ] rds = [ - "psycopg2-binary==2.9.5", - "pandas>=2.0.0,<2.2.0", - "SQLAlchemy==1.4.49" + "redshift-connector>=2.0.0", ] helpers = [ "aind-data-schema>=1.1.0,<2.0", diff --git a/src/aind_data_access_api/rds_tables.py b/src/aind_data_access_api/rds_tables.py index 5a2b5a1..daa4c20 100644 --- a/src/aind_data_access_api/rds_tables.py +++ b/src/aind_data_access_api/rds_tables.py @@ -3,11 +3,9 @@ from typing import Optional, Union import pandas as pd -import sqlalchemy.engine +import redshift_connector from pydantic import AliasChoices, Field, SecretStr, model_validator from pydantic_settings import SettingsConfigDict -from sqlalchemy import create_engine, engine, text -from sqlalchemy.engine.cursor import CursorResult from typing_extensions import Self from aind_data_access_api.credentials import CoreCredentials @@ -48,7 +46,7 @@ class Client: def __init__( self, credentials: RDSCredentials, - drivername: Optional[str] = "postgresql", + drivername: Optional[str] = None, ): """ Construct a client to interface with relational database. @@ -57,31 +55,35 @@ def __init__( credentials : CoreCredentials drivername: Optional[str] - Combination of dialect[+driver] where the dialect is - the database name such as ``mysql``, ``oracle``, ``postgresql``, - etc. and the optional driver name is a DBAPI such as - ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. + Deprecated parameter, kept for backward compatibility. + Ignored when using redshift-connector. """ self.credentials = credentials - self.drivername = drivername + if drivername is not None: + import warnings + warnings.warn( + "drivername parameter is deprecated and ignored when using " + "redshift-connector", + DeprecationWarning, + stacklevel=2 + ) - @property - def _engine(self) -> sqlalchemy.engine.Engine: - """Create a sqlalchemy engine: - https://docs.sqlalchemy.org/en/20/core/engines.html + def _get_connection(self) -> redshift_connector.Connection: + """Create a redshift-connector connection. - Returns: sqlalchemy.engine.Engine + Returns + ------- + redshift_connector.Connection + A connection to the Redshift database. """ - connection_url = engine.URL.create( - drivername=self.drivername, - username=self.credentials.username, - password=self.credentials.password.get_secret_value(), + return redshift_connector.connect( host=self.credentials.host, - database=self.credentials.database, port=self.credentials.port, + database=self.credentials.database, + user=self.credentials.username, + password=self.credentials.password.get_secret_value(), ) - return create_engine(connection_url) def append_df_to_table( self, @@ -96,23 +98,32 @@ def append_df_to_table( df : pd.Dataframe table_name : str dtype: Optional[Union[dict, str]] + Note: dtype parameter is kept for backward compatibility + but is not directly supported by redshift-connector. Returns ------- None """ - # to_sql method has types str | None, but also allows for callable - # Suppressing type check warning. - # noinspection PyTypeChecker - df.to_sql( - name=table_name, - con=self._engine, - dtype=dtype, - method="multi", - if_exists="append", - index=False, # Redshift doesn't support index=True - ) + conn = self._get_connection() + try: + cursor = conn.cursor() + # Create INSERT statements from dataframe + columns = ", ".join([f'"{col}"' for col in df.columns]) + placeholders = ", ".join(["%s"] * len(df.columns)) + insert_query = ( + f'INSERT INTO "{table_name}" ({columns}) ' + f'VALUES ({placeholders})' + ) + + # Execute batch insert + for _, row in df.iterrows(): + cursor.execute(insert_query, tuple(row)) + + conn.commit() + finally: + conn.close() return None def overwrite_table_with_df( @@ -128,24 +139,75 @@ def overwrite_table_with_df( df : pd.Dataframe table_name : str dtype: Optional[Union[dict, str]] + Note: dtype parameter is kept for backward compatibility + but is not directly supported by redshift-connector. Returns ------- None """ - # to_sql method has types str | None, but also allows for callable - # Suppressing type check warning. - # noinspection PyTypeChecker - df.to_sql( - name=table_name, - con=self._engine, - dtype=dtype, - method="multi", - if_exists="replace", - index=False, # Redshift doesn't support index=True - ) + conn = self._get_connection() + try: + cursor = conn.cursor() + # Drop and recreate table + cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"') + + # Create table with columns from dataframe + column_defs = [] + for col in df.columns: + # Infer SQL type from pandas dtype + dtype_str = self._infer_sql_type(df[col].dtype) + column_defs.append(f'"{col}" {dtype_str}') + + create_query = ( + f'CREATE TABLE "{table_name}" ' + f'({", ".join(column_defs)})' + ) + cursor.execute(create_query) + + # Insert data + columns = ", ".join([f'"{col}"' for col in df.columns]) + placeholders = ", ".join(["%s"] * len(df.columns)) + insert_query = ( + f'INSERT INTO "{table_name}" ({columns}) ' + f'VALUES ({placeholders})' + ) + + for _, row in df.iterrows(): + cursor.execute(insert_query, tuple(row)) + + conn.commit() + finally: + conn.close() return None + def _infer_sql_type(self, dtype) -> str: + """ + Infer SQL type from pandas dtype. + + Parameters + ---------- + dtype : pandas dtype + + Returns + ------- + str + SQL type string + """ + dtype_str = str(dtype) + if "int" in dtype_str: + return "INTEGER" + elif "float" in dtype_str: + return "FLOAT" + elif "bool" in dtype_str: + return "BOOLEAN" + elif "datetime" in dtype_str: + return "TIMESTAMP" + elif "object" in dtype_str: + return "VARCHAR(MAX)" + else: + return "VARCHAR(MAX)" + def read_table( self, table_name: str, where_clause: Optional[str] = None ) -> pd.DataFrame: @@ -165,16 +227,29 @@ def read_table( A pandas dataframe created from the sql table. """ - with self._engine.begin() as conn: + conn = self._get_connection() + try: + cursor = conn.cursor() query = ( f'SELECT * FROM "{table_name}"' if where_clause is None else f'SELECT * FROM "{table_name}" WHERE {where_clause}' ) - df = pd.read_sql_query(sql=text(query), con=conn) + cursor.execute(query) + + # Fetch column names + columns = [desc[0] for desc in cursor.description] + + # Fetch all rows + rows = cursor.fetchall() + + # Create dataframe + df = pd.DataFrame(rows, columns=columns) + finally: + conn.close() return df - def execute_query(self, query: str) -> CursorResult: + def execute_query(self, query: str) -> redshift_connector.Cursor: """ Run a sql query against the database Parameters @@ -183,10 +258,17 @@ def execute_query(self, query: str) -> CursorResult: Returns ------- - CursorResult - The result of the query. + redshift_connector.Cursor + The cursor object after executing the query. + Note: The connection is closed after query execution, + so fetch results before returning if needed. """ - with self._engine.begin() as conn: - result = conn.execute(text(query)) - return result + conn = self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(query) + conn.commit() + finally: + conn.close() + return cursor diff --git a/tests/test_rds_tables.py b/tests/test_rds_tables.py index 9322312..cb0deee 100644 --- a/tests/test_rds_tables.py +++ b/tests/test_rds_tables.py @@ -1,10 +1,9 @@ """Test rds_tables module.""" import unittest -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch import pandas as pd -from sqlalchemy import text from aind_data_access_api.rds_tables import Client, RDSCredentials @@ -45,9 +44,8 @@ def test_validate_database_name(self): class TestClient(unittest.TestCase): """Tests methods in the Client class.""" - @patch("pandas.read_sql_query") - @patch("sqlalchemy.engine.Engine.begin") - def test_read_table(self, mock_engine: MagicMock, mock_pd_read: MagicMock): + @patch("redshift_connector.connect") + def test_read_table(self, mock_connect: MagicMock): """Tests that read_table returns a pandas df.""" rds_client = Client( credentials=RDSCredentials( @@ -58,39 +56,37 @@ def test_read_table(self, mock_engine: MagicMock, mock_pd_read: MagicMock): ), ) - mock_pd_read.return_value = pd.DataFrame() + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_connect.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [("a", 1), ("b", 2)] df1 = rds_client.read_table("some_table", where_clause=None) df2 = rds_client.read_table( "some_table", where_clause="subject_id=0", ) - query = text('SELECT * FROM "some_table"') - query2 = text('SELECT * FROM "some_table" WHERE subject_id=0') - mock_engine.assert_has_calls( - [ - call(), - call().__enter__(), - call().__exit__(None, None, None), - call(), - call().__enter__(), - call().__exit__(None, None, None), - ] - ) - self.assertEqual( - mock_pd_read.mock_calls[0].kwargs["sql"].text, query.text - ) + + # Verify connections were made + self.assertEqual(mock_connect.call_count, 2) + + # Verify queries were executed + execute_calls = mock_cursor.execute.call_args_list + self.assertEqual(execute_calls[0][0][0], 'SELECT * FROM "some_table"') self.assertEqual( - mock_pd_read.mock_calls[1].kwargs["sql"].text, query2.text + execute_calls[1][0][0], + 'SELECT * FROM "some_table" WHERE subject_id=0', ) - self.assertTrue(df1.empty) - self.assertTrue(df2.empty) - - @patch("pandas.DataFrame.to_sql") - @patch("aind_data_access_api.rds_tables.Client._engine") - def test_overwrite_table_with_df( - self, mock_engine: MagicMock, mock_to_sql: MagicMock - ): + + # Verify dataframes were created + self.assertEqual(len(df1), 2) + self.assertEqual(len(df2), 2) + + @patch("redshift_connector.connect") + def test_overwrite_table_with_df(self, mock_connect: MagicMock): """Test overwrite table method""" rds_client = Client( credentials=RDSCredentials( @@ -101,24 +97,32 @@ def test_overwrite_table_with_df( ), ) + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_connect.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + df1 = pd.DataFrame([["a", 1], ["b", 2]], columns=["foo", "bar"]) - mock_engine.return_value = MagicMock() rds_client.overwrite_table_with_df(df1, "some_table") - mock_to_sql.assert_called_once_with( - name="some_table", - con=rds_client._engine, - dtype=None, - method="multi", - if_exists="replace", - index=False, + + # Verify cursor methods were called + self.assertTrue(mock_cursor.execute.called) + execute_calls = [ + call[0][0] for call in mock_cursor.execute.call_args_list + ] + # Should have DROP TABLE and CREATE TABLE calls + self.assertTrue( + any("DROP TABLE" in call for call in execute_calls) ) + self.assertTrue( + any("CREATE TABLE" in call for call in execute_calls) + ) + self.assertTrue(mock_conn.commit.called) - @patch("pandas.DataFrame.to_sql") - @patch("aind_data_access_api.rds_tables.Client._engine") - def test_append_df_to_table( - self, mock_engine: MagicMock, mock_to_sql: MagicMock - ): + @patch("redshift_connector.connect") + def test_append_df_to_table(self, mock_connect: MagicMock): """Test append df to table method""" rds_client = Client( credentials=RDSCredentials( @@ -129,21 +133,22 @@ def test_append_df_to_table( ), ) + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_connect.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + df1 = pd.DataFrame([["a", 1], ["b", 2]], columns=["foo", "bar"]) - mock_engine.return_value = MagicMock() rds_client.append_df_to_table(df1, "some_table") - mock_to_sql.assert_called_once_with( - name="some_table", - con=rds_client._engine, - dtype=None, - method="multi", - if_exists="append", - index=False, - ) - @patch("sqlalchemy.engine.Engine.begin") - def test_execute_query(self, mock_engine: MagicMock): + # Verify cursor execute was called for inserts + self.assertEqual(mock_cursor.execute.call_count, 2) # 2 rows + self.assertTrue(mock_conn.commit.called) + + @patch("redshift_connector.connect") + def test_execute_query(self, mock_connect: MagicMock): """Tests that a sql query gets executed.""" rds_client = Client( credentials=RDSCredentials( @@ -153,12 +158,21 @@ def test_execute_query(self, mock_engine: MagicMock): database="db", ), ) - mock_exec = mock_engine.return_value.__enter__.return_value.execute - mock_exec.return_value = "some result" + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_connect.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + res = rds_client.execute_query('SELECT * FROM "some_table"') - self.assertEqual("some result", res) - input_text = mock_exec.mock_calls[0].args[0].text - self.assertEqual('SELECT * FROM "some_table"', input_text) + + # Verify query was executed + mock_cursor.execute.assert_called_once_with( + 'SELECT * FROM "some_table"' + ) + self.assertTrue(mock_conn.commit.called) + self.assertEqual(res, mock_cursor) if __name__ == "__main__":