From b17206b390e59eb306e551daea8a9cd6936fbe67 Mon Sep 17 00:00:00 2001 From: Sophia Chu Date: Mon, 1 Dec 2025 16:39:48 -0800 Subject: [PATCH 1/3] feat: srw --- aws_advanced_python_wrapper/plugin_service.py | 4 + .../read_write_splitting_plugin.py | 650 ++++++++++--- ...dvanced_python_wrapper_messages.properties | 8 +- .../simple_read_write_splitting_plugin.py | 343 +++++++ .../utils/properties.py | 470 +++++---- pyproject.toml | 3 +- .../container/test_read_write_splitting.py | 585 +++++++++-- .../unit/test_read_write_splitting_plugin.py | 915 ++++++++++++++---- 8 files changed, 2339 insertions(+), 639 deletions(-) create mode 100644 aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index f3dc5fbc..8fc41614 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -82,6 +82,8 @@ from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.read_write_splitting_plugin import \ ReadWriteSplittingPluginFactory +from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \ + SimpleReadWriteSplittingPluginFactory from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.decorators import \ @@ -760,6 +762,7 @@ class PluginManager(CanReleaseResources): "host_monitoring_v2": HostMonitoringV2PluginFactory, "failover": FailoverPluginFactory, "read_write_splitting": ReadWriteSplittingPluginFactory, + "srw": SimpleReadWriteSplittingPluginFactory, "fastest_response_strategy": FastestResponseStrategyPluginFactory, "stale_dns": StaleDnsPluginFactory, "custom_endpoint": CustomEndpointPluginFactory, @@ -784,6 +787,7 @@ class PluginManager(CanReleaseResources): AuroraConnectionTrackerPluginFactory: 100, StaleDnsPluginFactory: 200, ReadWriteSplittingPluginFactory: 300, + SimpleReadWriteSplittingPluginFactory: 310, FailoverPluginFactory: 400, HostMonitoringPluginFactory: 500, HostMonitoringV2PluginFactory: 510, diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 98999aa8..125bbb5f 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -15,7 +15,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Set, Tuple if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -23,7 +23,9 @@ from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.utils.properties import Properties - from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager + from aws_advanced_python_wrapper.connection_provider import ( + ConnectionProviderManager, + ) from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverError, ReadWriteSplittingError) @@ -38,80 +40,66 @@ logger = Logger(__name__) -class ReadWriteSplittingPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"init_host_provider", - "connect", - "notify_connection_changed", - "Connection.set_read_only"} +class ReadWriteSplittingConnectionManager(Plugin): + """Base class that manages connection switching logic.""" + + _SUBSCRIBED_METHODS: Set[str] = { + "init_host_provider", + "connect", + "notify_connection_changed", + "Connection.set_read_only", + } _POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider" - def __init__(self, plugin_service: PluginService, props: Properties): - self._plugin_service = plugin_service - self._properties = props - self._host_list_provider_service: HostListProviderService + def __init__( + self, + plugin_service: PluginService, + props: Properties, + connection_handler: ConnectionHandler, + ): + self._plugin_service: PluginService = plugin_service + self._properties: Properties = props + self._connection_handler: ConnectionHandler = connection_handler self._writer_connection: Optional[Connection] = None self._reader_connection: Optional[Connection] = None + self._writer_host_info: Optional[HostInfo] = None self._reader_host_info: Optional[HostInfo] = None - self._conn_provider_manager: ConnectionProviderManager = self._plugin_service.get_connection_provider_manager() + self._conn_provider_manager: ConnectionProviderManager = ( + self._plugin_service.get_connection_provider_manager() + ) self._is_reader_conn_from_internal_pool: bool = False self._is_writer_conn_from_internal_pool: bool = False self._in_read_write_split: bool = False - self._reader_selector_strategy: str = "" - strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties) - if strategy is not None: - self._reader_selector_strategy = strategy - else: - default_strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.default_value - if default_strategy is not None: - self._reader_selector_strategy = default_strategy - @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS def init_host_provider( - self, - props: Properties, - host_list_provider_service: HostListProviderService, - init_host_provider_func: Callable): - self._host_list_provider_service = host_list_provider_service + self, + props: Properties, + host_list_provider_service: HostListProviderService, + init_host_provider_func: Callable, + ): + self._connection_handler.host_list_provider_service = host_list_provider_service init_host_provider_func() def connect( - self, - target_driver_func: Callable, - driver_dialect: DriverDialect, - host_info: HostInfo, - props: Properties, - is_initial_connection: bool, - connect_func: Callable) -> Connection: - if not self._plugin_service.accepts_strategy(host_info.role, self._reader_selector_strategy): - raise AwsWrapperError( - Messages.get_formatted("ReadWriteSplittingPlugin.UnsupportedHostInfoSelectorStrategy", - self._reader_selector_strategy)) - - current_conn = connect_func() - - if not is_initial_connection or self._host_list_provider_service.is_static_host_list_provider(): - return current_conn - - current_role = self._plugin_service.get_host_role(current_conn) - if current_role is None or current_role == HostRole.UNKNOWN: - self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole") - - current_host = self._plugin_service.initial_connection_host_info - if current_host is not None: - if current_role == current_host.role: - return current_conn - - updated_host = deepcopy(current_host) - updated_host.role = current_role - self._host_list_provider_service.initial_connection_host_info = updated_host - - return current_conn - - def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable, + ) -> Connection: + return self._connection_handler.get_verified_initial_connection( + host_info, props, is_initial_connection, connect_func + ) + + def notify_connection_changed( + self, changes: Set[ConnectionEvent] + ) -> OldConnectionSuggestedAction: self._update_internal_connection_info() if self._in_read_write_split: @@ -119,26 +107,47 @@ def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnect return OldConnectionSuggestedAction.NO_OPINION - def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: + def execute( + self, + target: type, + method_name: str, + execute_func: Callable, + *args: Any, + **kwargs: Any, + ) -> Any: driver_dialect = self._plugin_service.driver_dialect conn: Optional[Connection] = driver_dialect.get_connection_from_obj(target) - current_conn: Optional[Connection] = driver_dialect.unwrap_connection(self._plugin_service.current_connection) + current_conn: Optional[Connection] = driver_dialect.unwrap_connection( + self._plugin_service.current_connection + ) if conn is not None and conn != current_conn: - msg = Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target) + msg = Messages.get_formatted( + "PluginManager.MethodInvokedAgainstOldConnection", target + ) raise AwsWrapperError(msg) - if method_name == "Connection.set_read_only" and args is not None and len(args) > 0: + if ( + method_name == "Connection.set_read_only" + and args is not None + and len(args) > 0 + ): self._switch_connection_if_required(args[0]) try: return execute_func() except Exception as ex: if isinstance(ex, FailoverError): - logger.debug("ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand", method_name) + logger.debug( + "ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand", + method_name, + ) self._close_idle_connections() else: - logger.debug("ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand", method_name) + logger.debug( + "ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand", + method_name, + ) raise ex def _update_internal_connection_info(self): @@ -147,24 +156,49 @@ def _update_internal_connection_info(self): if current_conn is None or current_host is None: return - if current_host.role == HostRole.WRITER: + if self._connection_handler.should_update_writer_with_current_conn( + current_conn, current_host, self._writer_connection + ): self._set_writer_connection(current_conn, current_host) - else: + elif self._connection_handler.should_update_reader_with_current_conn( + current_conn, current_host, self._reader_connection + ): self._set_reader_connection(current_conn, current_host) - def _set_writer_connection(self, writer_conn: Connection, writer_host_info: HostInfo): + def _set_writer_connection( + self, writer_conn: Connection, writer_host_info: HostInfo + ): self._writer_connection = writer_conn - logger.debug("ReadWriteSplittingPlugin.SetWriterConnection", writer_host_info.url) - - def _set_reader_connection(self, reader_conn: Connection, reader_host_info: HostInfo): + self._writer_host_info = writer_host_info + logger.debug( + "ReadWriteSplittingPlugin.SetWriterConnection", writer_host_info.url + ) + + def _set_reader_connection( + self, reader_conn: Connection, reader_host_info: HostInfo + ): self._reader_connection = reader_conn self._reader_host_info = reader_host_info - logger.debug("ReadWriteSplittingPlugin.SetReaderConnection", reader_host_info.url) + logger.debug( + "ReadWriteSplittingPlugin.SetReaderConnection", reader_host_info.url + ) + + def _initialize_writer_connection(self): + conn, writer_host = self._connection_handler.open_new_writer_connection() - def _get_new_writer_connection(self, writer_host: HostInfo): - conn = self._plugin_service.connect(writer_host, self._properties, self) - provider = self._conn_provider_manager.get_connection_provider(writer_host, self._properties) - self._is_writer_conn_from_internal_pool = (ReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME in str(type(provider))) + if conn is None: + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.FailedToConnectToWriter" + ) + return + + provider = self._conn_provider_manager.get_connection_provider( + writer_host, self._properties + ) + self._is_writer_conn_from_internal_pool = ( + ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) self._set_writer_connection(conn, writer_host) self._switch_current_connection_to(conn, writer_host) @@ -172,184 +206,480 @@ def _switch_connection_if_required(self, read_only: bool): current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect - if (current_conn is not None and - driver_dialect is not None and driver_dialect.is_closed(current_conn)): - self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection") + if ( + current_conn is not None + and driver_dialect is not None + and driver_dialect.is_closed(current_conn) + ): + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection" + ) - if current_conn is not None and driver_dialect.can_execute_query(current_conn): - try: - self._plugin_service.refresh_host_list() - except Exception: - pass # Swallow exception - - hosts = self._plugin_service.hosts - if hosts is None or len(hosts) == 0: - self._log_and_raise_exception("ReadWriteSplittingPlugin.EmptyHostList") + self._connection_handler.refresh_and_store_host_list( + current_conn, driver_dialect + ) current_host = self._plugin_service.current_host_info if current_host is None: - self._log_and_raise_exception("ReadWriteSplittingPlugin.UnavailableHostInfo") + self.log_and_raise_exception("ReadWriteSplittingPlugin.UnavailableHostInfo") return if read_only: - if not self._plugin_service.is_in_transaction and current_host.role != HostRole.READER: + if ( + not self._plugin_service.is_in_transaction + and not self._connection_handler.is_reader_host(current_host) + ): try: - self._switch_to_reader_connection(hosts) + self._switch_to_reader_connection() except Exception: if not self._is_connection_usable(current_conn, driver_dialect): - self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToReader") + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.ErrorSwitchingToReader" + ) return - logger.warning("ReadWriteSplittingPlugin.FallbackToWriter", current_host.url) - elif current_host.role != HostRole.WRITER: + logger.warning( + "ReadWriteSplittingPlugin.FallbackToCurrentConnection", + current_host.url, + ) + elif not self._connection_handler.is_writer_host(current_host): if self._plugin_service.is_in_transaction: - self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction") + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction" + ) try: - self._switch_to_writer_connection(hosts) + self._switch_to_writer_connection() except Exception: - self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToWriter") + self.log_and_raise_exception( + "ReadWriteSplittingPlugin.ErrorSwitchingToWriter" + ) - def _switch_current_connection_to(self, new_conn: Connection, new_conn_host: HostInfo): + def _switch_current_connection_to( + self, new_conn: Connection, new_conn_host: HostInfo + ): current_conn = self._plugin_service.current_connection if current_conn == new_conn: return self._plugin_service.set_current_connection(new_conn, new_conn_host) - logger.debug("ReadWriteSplittingPlugin.SettingCurrentConnection", new_conn_host.url) + logger.debug( + "ReadWriteSplittingPlugin.SettingCurrentConnection", new_conn_host.url + ) - def _switch_to_writer_connection(self, hosts: Tuple[HostInfo, ...]): + def _switch_to_writer_connection(self): current_host = self._plugin_service.current_host_info current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect - if (current_host is not None and current_host.role == HostRole.WRITER and - self._is_connection_usable(current_conn, driver_dialect)): - return - - writer_host = self._get_writer(hosts) - if writer_host is None: + if ( + current_host is not None + and self._connection_handler.is_writer_host(current_host) + and self._is_connection_usable(current_conn, driver_dialect) + ): + # Already connected to the intended writer. return self._in_read_write_split = True if not self._is_connection_usable(self._writer_connection, driver_dialect): - self._get_new_writer_connection(writer_host) - elif self._writer_connection is not None: - self._switch_current_connection_to(self._writer_connection, writer_host) + self._initialize_writer_connection() + elif self._writer_connection is not None and self._writer_host_info is not None: + self._switch_current_connection_to( + self._writer_connection, self._writer_host_info + ) if self._is_reader_conn_from_internal_pool: self._close_connection_if_idle(self._reader_connection) - logger.debug("ReadWriteSplittingPlugin.SwitchedFromReaderToWriter", writer_host.url) + logger.debug( + "ReadWriteSplittingPlugin.SwitchedFromReaderToWriter", + self._writer_host_info.url, + ) - def _switch_to_reader_connection(self, hosts: Tuple[HostInfo, ...]): + def _switch_to_reader_connection(self): current_host = self._plugin_service.current_host_info current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect - if (current_host is not None and current_host.role == HostRole.READER and - self._is_connection_usable(current_conn, driver_dialect)): + if ( + current_host is not None + and self._connection_handler.is_reader_host(current_host) + and self._is_connection_usable(current_conn, driver_dialect) + ): + # Already connected to the intended reader. return - hostnames = [host_info.host for host_info in hosts] - if self._reader_host_info is not None and self._reader_host_info.host not in hostnames: - # The old reader cannot be used anymore because it is no longer in the list of allowed hosts. + if ( + self._reader_connection is not None + and not self._connection_handler.old_reader_can_be_used( + self._reader_host_info + ) + ): + # The old reader cannot be used anymore, close it. self._close_connection_if_idle(self._reader_connection) self._in_read_write_split = True if not self._is_connection_usable(self._reader_connection, driver_dialect): - self._initialize_reader_connection(hosts) + self._initialize_reader_connection() elif self._reader_connection is not None and self._reader_host_info is not None: try: - self._switch_current_connection_to(self._reader_connection, self._reader_host_info) - logger.debug("ReadWriteSplittingPlugin.SwitchedFromWriterToReader", self._reader_host_info.url) + self._switch_current_connection_to( + self._reader_connection, self._reader_host_info + ) + logger.debug( + "ReadWriteSplittingPlugin.SwitchedFromWriterToReader", + self._reader_host_info.url, + ) except Exception: - logger.debug("ReadWriteSplittingPlugin.ErrorSwitchingToCachedReader", self._reader_host_info.url) - - self._reader_connection.close() + logger.debug( + "ReadWriteSplittingPlugin.ErrorSwitchingToCachedReader", + self._reader_host_info.url, + ) + + ReadWriteSplittingConnectionManager.close_connection( + self._reader_connection + ) self._reader_connection = None self._reader_host_info = None - self._initialize_reader_connection(hosts) + self._initialize_reader_connection() if self._is_writer_conn_from_internal_pool: self._close_connection_if_idle(self._writer_connection) - def _initialize_reader_connection(self, hosts: Tuple[HostInfo, ...]): - if len(hosts) == 1: - writer_host = self._get_writer(hosts) - if writer_host is not None: - if not self._is_connection_usable(self._writer_connection, self._plugin_service.driver_dialect): - self._get_new_writer_connection(writer_host) - logger.warning("ReadWriteSplittingPlugin.NoReadersFound", writer_host.url) - return - - conn: Optional[Connection] = None - reader_host: Optional[HostInfo] = None + def _initialize_reader_connection(self): + if self._connection_handler.need_connect_to_writer(): + if not self._is_connection_usable( + self._writer_connection, self._plugin_service.driver_dialect + ): + self._initialize_writer_connection() + logger.warning( + "ReadWriteSplittingPlugin.NoReadersFound", self._writer_host_info.url + ) + return - conn_attempts = len(self._plugin_service.hosts) * 2 - for _ in range(conn_attempts): - host = self._plugin_service.get_host_info_by_strategy(HostRole.READER, self._reader_selector_strategy) - if host is not None: - try: - conn = self._plugin_service.connect(host, self._properties, self) - provider = self._conn_provider_manager.get_connection_provider(host, self._properties) - self._is_reader_conn_from_internal_pool = (ReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME in str(type(provider))) - reader_host = host - break - except Exception: - logger.warning("ReadWriteSplittingPlugin.FailedToConnectToReader", host.url) + conn, reader_host = self._connection_handler.open_new_reader_connection() if conn is None or reader_host is None: - self._log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") + self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") return - logger.debug("ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url) + logger.debug( + "ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url + ) + + provider = self._conn_provider_manager.get_connection_provider( + reader_host, self._properties + ) + self._is_reader_conn_from_internal_pool = ( + ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME + in str(type(provider)) + ) self._set_reader_connection(conn, reader_host) self._switch_current_connection_to(conn, reader_host) - logger.debug("ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url) + logger.debug( + "ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url + ) def _close_connection_if_idle(self, internal_conn: Optional[Connection]): + if internal_conn is None: + return + current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect + try: - if (internal_conn is not None and internal_conn != current_conn and - self._is_connection_usable(internal_conn, driver_dialect)): + if internal_conn != current_conn and self._is_connection_usable( + internal_conn, driver_dialect + ): internal_conn.close() if internal_conn == self._writer_connection: self._writer_connection = None + self._writer_host_info = None if internal_conn == self._reader_connection: self._reader_connection = None self._reader_host_info = None - except Exception: - pass # Swallow exception + # Ignore exceptions during cleanup - connection might already be dead + pass def _close_idle_connections(self): logger.debug("ReadWriteSplittingPlugin.ClosingInternalConnections") self._close_connection_if_idle(self._reader_connection) self._close_connection_if_idle(self._writer_connection) + # Always clear cached references even if connections couldn't be closed + self._reader_connection = None + self._reader_host_info = None + self._writer_connection = None + self._writer_host_info = None + @staticmethod - def _log_and_raise_exception(log_msg: str): + def log_and_raise_exception(log_msg: str): logger.error(log_msg) raise ReadWriteSplittingError(Messages.get(log_msg)) @staticmethod - def _is_connection_usable(conn: Optional[Connection], driver_dialect: Optional[DriverDialect]): - return conn is not None and driver_dialect is not None and not driver_dialect.is_closed(conn) + def _is_connection_usable( + conn: Optional[Connection], driver_dialect: Optional[DriverDialect] + ): + if conn is None or driver_dialect is None: + return False + try: + return not driver_dialect.is_closed(conn) + except Exception: + # If we cannot determine connection state, assume unavailable. + return False @staticmethod - def _get_writer(hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: - for host in hosts: + def close_connection(connection: Optional[Connection]): + if connection is not None: + try: + connection.close() + except Exception: + # Swallow exception + return + + +class ConnectionHandler(Protocol): + """Protocol for handling writer/reader connection logic.""" + + @property + def host_list_provider_service(self) -> Optional[HostListProviderService]: + """Getter for the 'host_list_provider_service' attribute.""" + ... + + @host_list_provider_service.setter + def host_list_provider_service(self, new_value: int) -> None: + """The setter for the 'host_list_provider_service' attribute.""" + ... + + def open_new_writer_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + """Open a writer connection.""" + ... + + def open_new_reader_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + """Open a reader connection.""" + ... + + def get_verified_initial_connection( + self, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable, + ) -> Connection: + """Verify initial connection or return normal workflow.""" + ... + + def should_update_writer_with_current_conn( + self, current_conn: Connection, current_host: HostInfo, writer_conn: Connection + ) -> bool: + """Return true if the current connection fits the criteria of a writer connection.""" + ... + + def should_update_reader_with_current_conn( + self, current_conn: Connection, current_host: HostInfo, reader_conn: Connection + ) -> bool: + """Return true if the current connection fits the criteria of a reader connection.""" + ... + + def is_writer_host(self, current_host: HostInfo) -> bool: + """Return true if the current host fits the criteria of a writer host.""" + ... + + def is_reader_host(self, current_host: HostInfo) -> bool: + """Return true if the current host fits the criteria of a writer host.""" + ... + + def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool: + """Return true if the current host can be used to switch connection to.""" + ... + + def need_connect_to_writer(self) -> bool: + """Return true if switching to reader should instead connect to writer.""" + ... + + def refresh_and_store_host_list( + self, current_conn: Optional[Connection], driver_dialect: DriverDialect + ): + """Refreshes the host list and then stores it.""" + ... + + +class TopologyBasedConnectionHandler(ConnectionHandler): + """Topology based implementation of connection handling logic.""" + + def __init__(self, plugin_service: PluginService, props: Properties): + self._plugin_service: PluginService = plugin_service + self._properties: Properties = props + self._host_list_provider_service: Optional[HostListProviderService] = None + strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties) + if strategy is not None: + self._reader_selector_strategy = strategy + else: + default_strategy = ( + WrapperProperties.READER_HOST_SELECTOR_STRATEGY.default_value + ) + if default_strategy is not None: + self._reader_selector_strategy = default_strategy + self._hosts: Tuple[HostInfo, ...] = () + + @property + def host_list_provider_service(self) -> Optional[HostListProviderService]: + return self._host_list_provider_service + + @host_list_provider_service.setter + def host_list_provider_service(self, new_value: HostListProviderService) -> None: + self._host_list_provider_service = new_value + + def open_new_writer_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + writer_host = self._get_writer() + if writer_host is None: + return None, None + + conn = self._plugin_service.connect(writer_host, self._properties, None) + + return conn, writer_host + + def open_new_reader_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + conn: Optional[Connection] = None + reader_host: Optional[HostInfo] = None + + conn_attempts = len(self._plugin_service.hosts) * 2 + for _ in range(conn_attempts): + host = self._plugin_service.get_host_info_by_strategy( + HostRole.READER, self._reader_selector_strategy + ) + if host is not None: + try: + conn = self._plugin_service.connect(host, self._properties, None) + reader_host = host + break + except Exception: + logger.warning( + "ReadWriteSplittingPlugin.FailedToConnectToReader", host.url + ) + + return conn, reader_host + + def get_verified_initial_connection( + self, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable, + ) -> Connection: + if not self._plugin_service.accepts_strategy( + host_info.role, self._reader_selector_strategy + ): + raise AwsWrapperError( + Messages.get_formatted( + "ReadWriteSplittingPlugin.UnsupportedHostInfoSelectorStrategy", + self._reader_selector_strategy, + ) + ) + + current_conn = connect_func() + + if not is_initial_connection or ( + self._host_list_provider_service is not None + and self._host_list_provider_service.is_static_host_list_provider() + ): + return current_conn + + current_role = self._plugin_service.get_host_role(current_conn) + if current_role is None or current_role == HostRole.UNKNOWN: + ReadWriteSplittingConnectionManager.log_and_raise_exception( + "ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole" + ) + + current_host = self._plugin_service.initial_connection_host_info + if current_host is not None: + if current_role == current_host.role: + return current_conn + + updated_host = deepcopy(current_host) + updated_host.role = current_role + if self._host_list_provider_service is not None: + self._host_list_provider_service.initial_connection_host_info = ( + updated_host + ) + + return current_conn + + def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool: + hostnames = [host_info.host for host_info in self._hosts] + return reader_host_info is not None and reader_host_info.host in hostnames + + def need_connect_to_writer(self) -> bool: + if len(self._hosts) == 1: + return self._get_writer() is not None + return False + + def refresh_and_store_host_list( + self, current_conn: Optional[Connection], driver_dialect: DriverDialect + ): + if current_conn is not None and driver_dialect.can_execute_query(current_conn): + try: + self._plugin_service.refresh_host_list() + except Exception: + pass # Swallow exception + + hosts = self._plugin_service.hosts + if hosts is None or len(hosts) == 0: + ReadWriteSplittingConnectionManager.log_and_raise_exception( + "ReadWriteSplittingPlugin.EmptyHostList" + ) + + self._hosts = hosts + + def should_update_writer_with_current_conn( + self, current_conn, current_host: HostInfo, writer_conn: Connection + ) -> bool: + return self.is_writer_host(current_host) + + def should_update_reader_with_current_conn( + self, current_conn, current_host, reader_conn: Connection + ) -> bool: + return True + + def is_writer_host(self, current_host: HostInfo) -> bool: + return current_host.role == HostRole.WRITER + + def is_reader_host(self, current_host) -> bool: + return current_host.role == HostRole.READER + + def _get_writer(self) -> Optional[HostInfo]: + for host in self._hosts: if host.role == HostRole.WRITER: return host - ReadWriteSplittingPlugin._log_and_raise_exception("ReadWriteSplittingPlugin.NoWriterFound") - + ReadWriteSplittingConnectionManager.log_and_raise_exception( + "ReadWriteSplittingPlugin.NoWriterFound" + ) return None +class ReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): + def __init__(self, plugin_service, props: Properties): + # The read/write splitting plugin handles connections based on topology. + connection_handler = TopologyBasedConnectionHandler( + plugin_service, + props, + ) + + super().__init__(plugin_service, props, connection_handler) + + class ReadWriteSplittingPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: return ReadWriteSplittingPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 65f16806..9a58abd8 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -286,7 +286,7 @@ OpenTelemetryFactory.WrongParameterType="[OpenTelemetryFactory] Wrong parameter Plugin.UnsupportedMethod=[Plugin] '{}' is not supported by this plugin. -PluginManager.ConfigurationProfileNotFound=PluginManager] Configuration profile '{}' not found. +PluginManager.ConfigurationProfileNotFound=[PluginManager] Configuration profile '{}' not found. PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'. PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly. PluginManager.PipelineNone=[PluginManager] A pipeline was requested but the created pipeline evaluated to None. @@ -357,8 +357,9 @@ ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole=[ReadWriteSplittingPl ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand=[ReadWriteSplittingPlugin] Detected an exception while executing a command: '{}' ReadWriteSplittingPlugin.ExecutingAgainstOldConnection=[ReadWriteSplittingPlugin] Executing method against old connection: '{}' ReadWriteSplittingPlugin.FailedToConnectToReader=[ReadWriteSplittingPlugin] Failed to connect to reader host: '{}' +ReadWriteSplittingPlugin.FailedToConnectToWriter=[ReadWriteSplittingPlugin] Failed to connect to writer host: '{}' ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand=[ReadWriteSplittingPlugin] Detected a failover exception while executing a command: '{}' -ReadWriteSplittingPlugin.FallbackToWriter=[ReadWriteSplittingPlugin] Failed to switch to a reader; the current writer will be used as a fallback: '{}' +ReadWriteSplittingPlugin.FallbackToCurrentConnection=[ReadWriteSplittingPlugin] Failed to switch to a reader; the current connection will be used as a fallback: '{}' ReadWriteSplittingPlugin.NoReadersAvailable=[ReadWriteSplittingPlugin] The plugin was unable to establish a reader connection to any reader instance. ReadWriteSplittingPlugin.NoReadersFound=[ReadWriteSplittingPlugin] A reader instance was requested via set_read_only, but there are no readers in the host list. The current writer will be used as a fallback: '{}' ReadWriteSplittingPlugin.NoWriterFound=[ReadWriteSplittingPlugin] No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts. @@ -382,6 +383,9 @@ RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector WeightedRandomHostSelector.WeightedRandomInvalidHostWeightPairs= [WeightedRandomHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. +SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter=[SimpleReadWriteSplittingPlugin] Configuration parameter {} is required. +SimpleReadWriteSplittingPlugin.IncorrectConfiguration=[SimpleReadWriteSplittingPlugin] Unable to verify connections with this current configuration. Ensure a correct value is provided to the configuration parameter {}. + SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None. SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties. diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py new file mode 100644 index 00000000..bbf503c1 --- /dev/null +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -0,0 +1,343 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from time import perf_counter_ns, sleep +from typing import TYPE_CHECKING, Callable, Optional + +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.read_write_splitting_plugin import ( + ConnectionHandler, ReadWriteSplittingConnectionManager) +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.host_list_provider import HostListProviderService + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import Properties + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin import PluginFactory +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import WrapperProperties + + +class EndpointBasedConnectionHandler(ConnectionHandler): + """Endpoint based implementation of connection handling logic.""" + + def __init__(self, plugin_service: PluginService, props: Properties): + srw_read_endpoint = WrapperProperties.SRW_READ_ENDPOINT.get(props) + if srw_read_endpoint is None: + raise AwsWrapperError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter", + WrapperProperties.SRW_READ_ENDPOINT.name, + ) + ) + self._read_endpoint: str = srw_read_endpoint + + srw_write_endpoint = WrapperProperties.SRW_WRITE_ENDPOINT.get(props) + if srw_write_endpoint is None: + raise AwsWrapperError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter", + WrapperProperties.SRW_WRITE_ENDPOINT.name, + ) + ) + self._write_endpoint: str = srw_write_endpoint + + self._verify_new_connections: bool = ( + WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.get_bool(props) + ) + if self._verify_new_connections is True: + srw_connect_retry_timeout_ms: int = ( + WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.get_int(props) + ) + if srw_connect_retry_timeout_ms <= 0: + raise ValueError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", + WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name, + ) + ) + self._connect_retry_timeout_ms: int = srw_connect_retry_timeout_ms + + srw_connect_retry_interval_ms: int = ( + WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.get_int(props) + ) + if srw_connect_retry_interval_ms <= 0: + raise ValueError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", + WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name, + ) + ) + self._connect_retry_interval_ms: int = srw_connect_retry_interval_ms + + self._verify_opened_connection_type: Optional[HostRole] = ( + EndpointBasedConnectionHandler._parse_connection_type( + WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.get(props) + ) + ) + + self._plugin_service: PluginService = plugin_service + self._properties: Properties = props + self._rds_utils: RdsUtils = RdsUtils() + self._host_list_provider_service: Optional[HostListProviderService] = None + self._write_endpoint_host_info: HostInfo = self._create_host_info( + self._write_endpoint, HostRole.WRITER + ) + self._read_endpoint_host_info: HostInfo = self._create_host_info( + self._read_endpoint, HostRole.READER + ) + + @property + def host_list_provider_service(self) -> Optional[HostListProviderService]: + return self._host_list_provider_service + + @host_list_provider_service.setter + def host_list_provider_service(self, new_value: HostListProviderService) -> None: + self._host_list_provider_service = new_value + + def open_new_writer_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + conn: Optional[Connection] = None + if self._verify_new_connections: + conn = self._get_verified_connection( + self._properties, self._write_endpoint_host_info, HostRole.WRITER + ) + else: + conn = self._plugin_service.connect( + self._write_endpoint_host_info, self._properties, None + ) + + return conn, self._write_endpoint_host_info + + def open_new_reader_connection( + self, + ) -> tuple[Optional[Connection], Optional[HostInfo]]: + conn: Optional[Connection] = None + if self._verify_new_connections: + conn = self._get_verified_connection( + self._properties, self._read_endpoint_host_info, HostRole.READER + ) + else: + conn = self._plugin_service.connect( + self._read_endpoint_host_info, self._properties, None + ) + + return conn, self._read_endpoint_host_info + + def get_verified_initial_connection( + self, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable, + ) -> Connection: + if not is_initial_connection or not self._verify_new_connections: + return connect_func() + + url_type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host) + + conn: Optional[Connection] = None + + if ( + url_type == RdsUrlType.RDS_WRITER_CLUSTER + or self._verify_opened_connection_type == HostRole.WRITER + ): + conn = self._get_verified_connection( + props, host_info, HostRole.WRITER, connect_func + ) + elif ( + url_type == RdsUrlType.RDS_READER_CLUSTER + or self._verify_opened_connection_type == HostRole.READER + ): + conn = self._get_verified_connection( + props, host_info, HostRole.READER, connect_func + ) + + if conn is None: + conn = connect_func() + + self._set_initial_connection_host_info(conn, host_info) + return conn + + def _set_initial_connection_host_info(self, conn: Connection, host_info: HostInfo): + if self._host_list_provider_service is None: + return + + self._host_list_provider_service.initial_connection_host_info = host_info + + def _get_verified_connection( + self, + props: Properties, + host_info: HostInfo, + role: HostRole, + connect_func: Optional[Callable] = None, + ) -> Optional[Connection]: + end_time_nano = perf_counter_ns() + (self._connect_retry_timeout_ms * 1000000) + + candidate_conn: Optional[Connection] + + while perf_counter_ns() < end_time_nano: + candidate_conn = None + + try: + if connect_func is not None: + candidate_conn = connect_func() + elif host_info is not None: + candidate_conn = self._plugin_service.connect( + host_info, props, None + ) + else: + return None + + if candidate_conn is None: + self._delay() + continue + + actual_role = self._plugin_service.get_host_role(candidate_conn) + + if actual_role != role: + ReadWriteSplittingConnectionManager.close_connection(candidate_conn) + self._delay() + continue + + return candidate_conn + + except Exception: + ReadWriteSplittingConnectionManager.close_connection(candidate_conn) + self._delay() + + return None + + def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool: + # Assume that the old reader can always be used, no topology-based information to check. + return True + + def need_connect_to_writer(self) -> bool: + # SetReadOnly(true) will always connect to the read_endpoint, and not the writer. + return False + + def refresh_and_store_host_list( + self, current_conn: Optional[Connection], driver_dialect: DriverDialect + ): + # Endpoint based connections do not require a host list. + return + + def should_update_writer_with_current_conn( + self, current_conn: Connection, current_host: HostInfo, writer_conn: Connection + ) -> bool: + return ( + self.is_writer_host(current_host) + and current_conn != writer_conn + and ( + not self._verify_new_connections + or self._plugin_service.get_host_role(current_conn) == HostRole.WRITER + ) + ) + + def should_update_reader_with_current_conn( + self, current_conn: Connection, current_host: HostInfo, reader_conn: Connection + ) -> bool: + return ( + self.is_reader_host(current_host) + and current_conn != reader_conn + and ( + not self._verify_new_connections + or self._plugin_service.get_host_role(current_conn) == HostRole.READER + ) + ) + + def is_writer_host(self, current_host: HostInfo) -> bool: + return ( + current_host.host.casefold() == self._write_endpoint.casefold() + or current_host.url.casefold() == self._write_endpoint.casefold() + ) + + def is_reader_host(self, current_host: HostInfo) -> bool: + return ( + current_host.host.casefold() == self._read_endpoint.casefold() + or current_host.url.casefold() == self._read_endpoint.casefold() + ) + + def _create_host_info(self, endpoint, role: HostRole) -> HostInfo: + endpoint = endpoint.strip() + host = endpoint + port = self._plugin_service.database_dialect.default_port + colon_index = endpoint.rfind(":") + + if colon_index != -1: + port_str = endpoint[colon_index + 1:] + if port_str.isdigit(): + host = endpoint[:colon_index] + port = int(port_str) + else: + if ( + self._host_list_provider_service is not None + and self._host_list_provider_service.initial_connection_host_info + is not None + and self._host_list_provider_service.initial_connection_host_info.port + != HostInfo.NO_PORT + ): + port = ( + self._host_list_provider_service.initial_connection_host_info.port + ) + + return HostInfo( + host=host, port=port, role=role, availability=HostAvailability.AVAILABLE + ) + + def _delay(self): + sleep(self._connect_retry_interval_ms / 1000) + + @staticmethod + def _parse_connection_type(phase_str: Optional[str]) -> HostRole: + if not phase_str: + return HostRole.UNKNOWN + + phase_upper = phase_str.lower() + if phase_upper == "reader": + return HostRole.READER + elif phase_upper == "writer": + return HostRole.WRITER + else: + raise ValueError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", + WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.name, + ) + ) + + +class SimpleReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): + def __init__(self, plugin_service, props: Properties): + # The simple read/write splitting plugin handles connections based on configuration parameter endpoints. + connection_handler = EndpointBasedConnectionHandler( + plugin_service, + props, + ) + + super().__init__(plugin_service, props, connection_handler) + + +class SimpleReadWriteSplittingPluginFactory(PluginFactory): + def get_instance(self, plugin_service, props: Properties): + return SimpleReadWriteSplittingPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 1d813404..4bbde03e 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -26,7 +26,9 @@ def put_if_absent(self, key: str, value: Any): class WrapperProperty: - def __init__(self, name: str, description: str, default_value: Optional[Any] = None): + def __init__( + self, name: str, description: str, default_value: Optional[Any] = None + ): self.name = name self.default_value = default_value self.description = description @@ -76,390 +78,500 @@ class WrapperProperties: DEFAULT_PLUGINS = "aurora_connection_tracker,failover,host_monitoring_v2" _DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - PROFILE_NAME = WrapperProperty("profile_name", "Driver configuration profile name", None) + PROFILE_NAME = WrapperProperty( + "profile_name", "Driver configuration profile name", None + ) PLUGINS = WrapperProperty( - "plugins", - "Comma separated list of connection plugin codes", - DEFAULT_PLUGINS) + "plugins", "Comma separated list of connection plugin codes", DEFAULT_PLUGINS + ) USER = WrapperProperty("user", "Driver user name") PASSWORD = WrapperProperty("password", "Driver password") DATABASE = WrapperProperty("database", "Driver database name") CONNECT_TIMEOUT_SEC = WrapperProperty( "connect_timeout", - "Max number of seconds to wait for a connection to be established before timing out.") + "Max number of seconds to wait for a connection to be established before timing out.", + ) SOCKET_TIMEOUT_SEC = WrapperProperty( "socket_timeout", - "Max number of seconds to wait for a SQL query to complete before timing out.") + "Max number of seconds to wait for a SQL query to complete before timing out.", + ) TCP_KEEPALIVE = WrapperProperty( - "tcp_keepalive", "Enable TCP keepalive functionality.") + "tcp_keepalive", "Enable TCP keepalive functionality." + ) TCP_KEEPALIVE_TIME_SEC = WrapperProperty( - "tcp_keepalive_time", "Number of seconds to wait before sending an initial keepalive probe.") + "tcp_keepalive_time", + "Number of seconds to wait before sending an initial keepalive probe.", + ) TCP_KEEPALIVE_INTERVAL_SEC = WrapperProperty( "tcp_keepalive_interval", - "Number of seconds to wait before sending additional keepalive probes after the initial probe has been sent.") + "Number of seconds to wait before sending additional keepalive probes after the initial probe has been sent.", + ) TCP_KEEPALIVE_PROBES = WrapperProperty( - "tcp_keepalive_probes", "Number of keepalive probes to send before concluding that the connection is invalid.") + "tcp_keepalive_probes", + "Number of keepalive probes to send before concluding that the connection is invalid.", + ) TRANSFER_SESSION_STATE_ON_SWITCH = WrapperProperty( - "transfer_session_state_on_switch", "Enables session state transfer to a new connection", True) + "transfer_session_state_on_switch", + "Enables session state transfer to a new connection", + True, + ) RESET_SESSION_STATE_ON_CLOSE = WrapperProperty( "reset_session_state_on_close", "Enables to reset connection session state before closing it.", - True) + True, + ) ROLLBACK_ON_SWITCH = WrapperProperty( "rollback_on_switch", "Enables to rollback a current transaction being in progress when switching to a new connection.", - True) + True, + ) # RdsHostListProvider TOPOLOGY_REFRESH_MS = WrapperProperty( "topology_refresh_ms", """Cluster topology refresh rate in milliseconds. The cached topology for the cluster will be invalidated after the specified time, after which it will be updated during the next interaction with the connection.""", - 30_000) + 30_000, + ) CLUSTER_ID = WrapperProperty( "cluster_id", """A unique identifier for the cluster. Connections with the same cluster id share a cluster topology cache. If - unspecified, a cluster id is automatically created for AWS RDS clusters.""") + unspecified, a cluster id is automatically created for AWS RDS clusters.""", + ) CLUSTER_INSTANCE_HOST_PATTERN = WrapperProperty( "cluster_instance_host_pattern", """The cluster instance DNS pattern that will be used to build a complete instance endpoint. A "?" character in this pattern should be used as a placeholder for cluster instance names. This pattern is required to be specified for IP address or custom domain connections to AWS RDS clusters. Otherwise, if unspecified, the - pattern will be automatically created for AWS RDS clusters.""") + pattern will be automatically created for AWS RDS clusters.""", + ) - IAM_HOST = WrapperProperty("iam_host", "Overrides the host that is used to generate the IAM token.") + IAM_HOST = WrapperProperty( + "iam_host", "Overrides the host that is used to generate the IAM token." + ) IAM_DEFAULT_PORT = WrapperProperty( "iam_default_port", - "Overrides default port that is used to generate the IAM token.") - IAM_REGION = WrapperProperty("iam_region", "Overrides AWS region that is used to generate the IAM token.") + "Overrides default port that is used to generate the IAM token.", + ) + IAM_REGION = WrapperProperty( + "iam_region", "Overrides AWS region that is used to generate the IAM token." + ) IAM_EXPIRATION = WrapperProperty( "iam_expiration", "IAM token cache expiration in seconds", - _DEFAULT_TOKEN_EXPIRATION_SEC) + _DEFAULT_TOKEN_EXPIRATION_SEC, + ) SECRETS_MANAGER_SECRET_ID = WrapperProperty( - "secrets_manager_secret_id", - "The name or the ARN of the secret to retrieve.") + "secrets_manager_secret_id", "The name or the ARN of the secret to retrieve." + ) SECRETS_MANAGER_SECRET_USERNAME_KEY = WrapperProperty( "secrets_manager_secret_username_key", "The key of the secret to retrieve, which contains the username.", - "username") + "username", + ) SECRETS_MANAGER_SECRET_PASSWORD_KEY = WrapperProperty( "secrets_manager_secret_password_key", "The key of the secret to retrieve, which contains the password.", - "password" + "password", ) SECRETS_MANAGER_REGION = WrapperProperty( - "secrets_manager_region", - "The region of the secret to retrieve.", - "us-east-1") + "secrets_manager_region", "The region of the secret to retrieve.", "us-east-1" + ) SECRETS_MANAGER_ENDPOINT = WrapperProperty( - "secrets_manager_endpoint", - "The endpoint of the secret to retrieve.") + "secrets_manager_endpoint", "The endpoint of the secret to retrieve." + ) SECRETS_MANAGER_EXPIRATION = WrapperProperty( "secrets_manager_expiration", "Secret cache expiration in seconds", - 60 * 60 * 24 * 365) + 60 * 60 * 24 * 365, + ) - DIALECT = WrapperProperty("wrapper_dialect", "A unique identifier for the supported database dialect.") + DIALECT = WrapperProperty( + "wrapper_dialect", "A unique identifier for the supported database dialect." + ) AUXILIARY_QUERY_TIMEOUT_SEC = WrapperProperty( "auxiliary_query_timeout_sec", """Network timeout, in seconds, used for auxiliary queries to the database. This timeout applies to queries executed by the wrapper driver to gain info about the connected database. It does not apply to queries requested by the driver client.""", - 5) + 5, + ) # HostMonitoringPlugin FAILURE_DETECTION_ENABLED = WrapperProperty( "failure_detection_enabled", "Enable failure detection logic in the HostMonitoringPlugin.", - True) + True, + ) FAILURE_DETECTION_TIME_MS = WrapperProperty( "failure_detection_time_ms", "Interval in milliseconds between sending SQL to the server and the first connection check.", - 30_000) + 30_000, + ) FAILURE_DETECTION_INTERVAL_MS = WrapperProperty( "failure_detection_interval_ms", "Interval in milliseconds between consecutive connection checks.", - 5_000) + 5_000, + ) FAILURE_DETECTION_COUNT = WrapperProperty( "failure_detection_count", "Number of failed connection checks before considering the database host unavailable.", - 3) + 3, + ) MONITOR_DISPOSAL_TIME_MS = WrapperProperty( "monitor_disposal_time_ms", "Interval in milliseconds after which a monitor should be considered inactive and marked for disposal.", - 600_000) # 10 minutes + 600_000, + ) # 10 minutes # Failover ENABLE_FAILOVER = WrapperProperty( - "enable_failover", - "Enable/disable cluster-aware failover logic", - True) + "enable_failover", "Enable/disable cluster-aware failover logic", True + ) FAILOVER_MODE = WrapperProperty( "failover_mode", - "Decide which host role (writer, reader, or either) to connect to during failover.") + "Decide which host role (writer, reader, or either) to connect to during failover.", + ) FAILOVER_TIMEOUT_SEC = WrapperProperty( "failover_timeout_sec", "Maximum allowed time in seconds for the failover process.", - 300) # 5 minutes + 300, + ) # 5 minutes FAILOVER_CLUSTER_TOPOLOGY_REFRESH_RATE_SEC = WrapperProperty( "failover_cluster_topology_refresh_rate_sec", """Cluster topology refresh rate in seconds during a writer failover process. During the writer failover process, cluster topology may be refreshed at a faster pace than normal to speed up discovery of the newly promoted writer.""", - 2) + 2, + ) FAILOVER_WRITER_RECONNECT_INTERVAL_SEC = WrapperProperty( "failover_writer_reconnect_interval_sec", "Interval of time in seconds to wait between attempts to reconnect to a failed writer during a writer failover process.", - 2) + 2, + ) FAILOVER_READER_CONNECT_TIMEOUT_SEC = WrapperProperty( "failover_reader_connect_timeout_sec", "Reader connection attempt timeout in seconds during a reader failover process.", - 30) + 30, + ) # CustomEndpointPlugin CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS = WrapperProperty( "custom_endpoint_info_refresh_rate_ms", "Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds.", - 30_000) + 30_000, + ) CUSTOM_ENDPOINT_IDLE_MONITOR_EXPIRATION_MS = WrapperProperty( "custom_endpoint_idle_monitor_expiration_ms", "Controls how long a monitor should run without use before expiring and being removed, in milliseconds.", - 900_000) # 15 minutes + 900_000, + ) # 15 minutes WAIT_FOR_CUSTOM_ENDPOINT_INFO = WrapperProperty( "wait_for_custom_endpoint_info", """Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint.""", - True) + True, + ) WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = WrapperProperty( "wait_for_custom_endpoint_info_timeout_ms", """Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds.""", - 5_000) + 5_000, + ) # Host Availability Strategy DEFAULT_HOST_AVAILABILITY_STRATEGY = WrapperProperty( "default_host_availability_strategy", "An override for specifying the default host availability change strategy.", - "" + "", ) HOST_AVAILABILITY_STRATEGY_MAX_RETRIES = WrapperProperty( "host_availability_strategy_max_retries", "Max number of retries for checking a host's availability.", - "5" + "5", ) HOST_AVAILABILITY_STRATEGY_INITIAL_BACKOFF_TIME = WrapperProperty( "host_availability_strategy_initial_backoff_time", "The initial backoff time in seconds.", - "30" + "30", ) # Driver Dialect DRIVER_DIALECT = WrapperProperty( - "wrapper_driver_dialect", - "A unique identifier for the target driver dialect.") + "wrapper_driver_dialect", "A unique identifier for the target driver dialect." + ) # Read/Write Splitting READER_HOST_SELECTOR_STRATEGY = WrapperProperty( "reader_host_selector_strategy", "The strategy that should be used to select a new reader host.", - "random") + "random", + ) # Plugin Sorting AUTO_SORT_PLUGIN_ORDER = WrapperProperty( "auto_sort_wrapper_plugin_order", "This flag is enabled by default, meaning that the plugins order will be automatically adjusted. " "Disable it at your own risk or if you really need plugins to be executed in a particular order.", - True) + True, + ) # Host Selector - ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty("round_robin_default_weight", - "The default weight for any hosts that have not been " + - "configured with the `round_robin_host_weight_pairs` parameter.", - 1) + ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty( + "round_robin_default_weight", + "The default weight for any hosts that have not been " + + "configured with the `round_robin_host_weight_pairs` parameter.", + 1, + ) - ROUND_ROBIN_HOST_WEIGHT_PAIRS = WrapperProperty("round_robin_host_weight_pairs", - "Comma separated list of database host-weight pairs in the format of `:`.", - "") + ROUND_ROBIN_HOST_WEIGHT_PAIRS = WrapperProperty( + "round_robin_host_weight_pairs", + "Comma separated list of database host-weight pairs in the format of `:`.", + "", + ) - WEIGHTED_RANDOM_DEFAULT_WEIGHT = WrapperProperty("weighted_random_default_weight", "The default weight for any hosts that have not been " + - "configured with the `weighted_random_host_weight_pairs` parameter.", - 1) + WEIGHTED_RANDOM_DEFAULT_WEIGHT = WrapperProperty( + "weighted_random_default_weight", + "The default weight for any hosts that have not been " + + "configured with the `weighted_random_host_weight_pairs` parameter.", + 1, + ) - WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS = WrapperProperty("weighted_random_host_weight_pairs", - "Comma separated list of database host-weight pairs in the format of `:`.", - "") + WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS = WrapperProperty( + "weighted_random_host_weight_pairs", + "Comma separated list of database host-weight pairs in the format of `:`.", + "", + ) # Federated Auth Plugin - IDP_ENDPOINT = WrapperProperty("idp_endpoint", - "The hosting URL of the Identity Provider", - None) + IDP_ENDPOINT = WrapperProperty( + "idp_endpoint", "The hosting URL of the Identity Provider", None + ) - IDP_PORT = WrapperProperty("idp_port", - "The hosting port of the Identity Provider", - 443) + IDP_PORT = WrapperProperty( + "idp_port", "The hosting port of the Identity Provider", 443 + ) - RELAYING_PARTY_ID = WrapperProperty("rp_identifier", - "The relaying party identifier", - "urn:amazon:webservices") + RELAYING_PARTY_ID = WrapperProperty( + "rp_identifier", "The relaying party identifier", "urn:amazon:webservices" + ) - IAM_ROLE_ARN = WrapperProperty("iam_role_arn", - "The ARN of the IAM Role that is to be assumed.", - None) + IAM_ROLE_ARN = WrapperProperty( + "iam_role_arn", "The ARN of the IAM Role that is to be assumed.", None + ) - IAM_IDP_ARN = WrapperProperty("iam_idp_arn", - "The ARN of the Identity Provider", - None) + IAM_IDP_ARN = WrapperProperty( + "iam_idp_arn", "The ARN of the Identity Provider", None + ) - IAM_TOKEN_EXPIRATION = WrapperProperty("iam_token_expiration", - "IAM token cache expiration in seconds", - 15 * 60 - 30) + IAM_TOKEN_EXPIRATION = WrapperProperty( + "iam_token_expiration", "IAM token cache expiration in seconds", 15 * 60 - 30 + ) - IDP_USERNAME = WrapperProperty("idp_username", - "The federated user name", - None) + IDP_USERNAME = WrapperProperty("idp_username", "The federated user name", None) - IDP_PASSWORD = WrapperProperty("idp_password", - "The federated user password", - None) + IDP_PASSWORD = WrapperProperty("idp_password", "The federated user password", None) - HTTP_REQUEST_TIMEOUT = WrapperProperty("http_request_connect_timeout", - "The timeout value in seconds to send the HTTP request data used by the FederatedAuthPlugin", - 60) + HTTP_REQUEST_TIMEOUT = WrapperProperty( + "http_request_connect_timeout", + "The timeout value in seconds to send the HTTP request data used by the FederatedAuthPlugin", + 60, + ) - SSL_SECURE = WrapperProperty("ssl_secure", - "Whether the SSL session is to be secure and the server's certificates will be verified." - " We do not recommend disabling this for production use.", - True) + SSL_SECURE = WrapperProperty( + "ssl_secure", + "Whether the SSL session is to be secure and the server's certificates will be verified." + " We do not recommend disabling this for production use.", + True, + ) - IDP_NAME = WrapperProperty("idp_name", - "The name of the Identity Provider implementation used", - "adfs") + IDP_NAME = WrapperProperty( + "idp_name", "The name of the Identity Provider implementation used", "adfs" + ) - DB_USER = WrapperProperty("db_user", - "The database user used to access the database", - None) + DB_USER = WrapperProperty( + "db_user", "The database user used to access the database", None + ) # Okta - APP_ID = WrapperProperty("app_id", "The ID of the AWS application configured on Okta", None) + APP_ID = WrapperProperty( + "app_id", "The ID of the AWS application configured on Okta", None + ) # Fastest Response Strategy - RESPONSE_MEASUREMENT_INTERVAL_MS = WrapperProperty("response_measurement_interval_ms", - "Interval in milliseconds between measuring response time to a database host", - 30_000) + RESPONSE_MEASUREMENT_INTERVAL_MS = WrapperProperty( + "response_measurement_interval_ms", + "Interval in milliseconds between measuring response time to a database host", + 30_000, + ) # Limitless - LIMITLESS_MONITOR_DISPOSAL_TIME_MS = WrapperProperty("limitless_transaction_router_monitor_disposal_time_ms", - "Interval in milliseconds for an Limitless router monitor to be " - "considered inactive and to be disposed.", - 600_000) - - LIMITLESS_INTERVAL_MILLIS = WrapperProperty("limitless_transaction_router_monitor_interval_ms", - "Interval in millis between polling for Limitless Transaction Routers to the database.", - 7_500) - - WAIT_FOR_ROUTER_INFO = WrapperProperty("limitless_wait_for_transaction_router_info", - "If the cache of transaction router info is empty " - "and a new connection is made, this property toggles whether " - "the plugin will wait and synchronously fetch transaction router info before selecting a transaction " - "router to connect to, or to fall back to using the provided DB Shard Group endpoint URL.", - True) - - GET_ROUTER_RETRY_INTERVAL_MS = WrapperProperty("limitless_get_transaction_router_retry_interval_ms", - "Interval in milliseconds between retries fetching Limitless Transaction Router information.", - 300) - - GET_ROUTER_MAX_RETRIES = WrapperProperty("limitless_get_transaction_router_max_retries", - "Max number of connection retries the Limitless Connection Plugin will attempt.", - 5) - - MAX_RETRIES_MS = WrapperProperty("limitless_max_retries_ms", - "Interval in milliseconds between polling for Limitless Transaction Routers to the database.", - 7_500) + LIMITLESS_MONITOR_DISPOSAL_TIME_MS = WrapperProperty( + "limitless_transaction_router_monitor_disposal_time_ms", + "Interval in milliseconds for an Limitless router monitor to be " + "considered inactive and to be disposed.", + 600_000, + ) + + LIMITLESS_INTERVAL_MILLIS = WrapperProperty( + "limitless_transaction_router_monitor_interval_ms", + "Interval in millis between polling for Limitless Transaction Routers to the database.", + 7_500, + ) + + WAIT_FOR_ROUTER_INFO = WrapperProperty( + "limitless_wait_for_transaction_router_info", + "If the cache of transaction router info is empty " + "and a new connection is made, this property toggles whether " + "the plugin will wait and synchronously fetch transaction router info before selecting a transaction " + "router to connect to, or to fall back to using the provided DB Shard Group endpoint URL.", + True, + ) + + GET_ROUTER_RETRY_INTERVAL_MS = WrapperProperty( + "limitless_get_transaction_router_retry_interval_ms", + "Interval in milliseconds between retries fetching Limitless Transaction Router information.", + 300, + ) + + GET_ROUTER_MAX_RETRIES = WrapperProperty( + "limitless_get_transaction_router_max_retries", + "Max number of connection retries the Limitless Connection Plugin will attempt.", + 5, + ) + + MAX_RETRIES_MS = WrapperProperty( + "limitless_max_retries_ms", + "Interval in milliseconds between polling for Limitless Transaction Routers to the database.", + 7_500, + ) # Blue/Green BG_CONNECT_TIMEOUT_MS = WrapperProperty( "bg_connect_timeout_ms", "Connect timeout (in msec) during Blue/Green Deployment switchover.", - 30_000) + 30_000, + ) BG_ID = WrapperProperty( "bg_id", "Blue/Green Deployment identifier that helps the driver to distinguish different deployments.", - "1") + "1", + ) BG_INTERVAL_BASELINE_MS = WrapperProperty( "bg_interval_baseline_ms", "Baseline Blue/Green Deployment status checking interval (in msec).", - 60_000) + 60_000, + ) BG_INTERVAL_INCREASED_MS = WrapperProperty( "bg_interval_increased_ms", "Increased Blue/Green Deployment status checking interval (in msec).", - 1_000) + 1_000, + ) BG_INTERVAL_HIGH_MS = WrapperProperty( "bg_interval_high_ms", "High Blue/Green Deployment status checking interval (in msec).", - 100) + 100, + ) BG_SWITCHOVER_TIMEOUT_MS = WrapperProperty( "bg_switchover_timeout_ms", "Blue/Green Deployment switchover timeout (in msec).", - 180_000) # 3 minutes + 180_000, + ) # 3 minutes BG_SUSPEND_NEW_BLUE_CONNECTIONS = WrapperProperty( "bg_suspend_new_blue_connections", "Enables Blue/Green Deployment switchover to suspend new blue connection requests while the " "switchover process is in progress.", - False) + False, + ) # Telemetry ENABLE_TELEMETRY = WrapperProperty( - "enable_telemetry", - "Enables telemetry and observability of the driver.", - False + "enable_telemetry", "Enables telemetry and observability of the driver.", False ) TELEMETRY_SUBMIT_TOPLEVEL = WrapperProperty( "telemetry_submit_toplevel", "Force submitting traces related to Python calls as top level traces.", - False + False, ) TELEMETRY_TRACES_BACKEND = WrapperProperty( "telemetry_traces_backend", "Method to export telemetry traces of the driver.", - None + None, ) TELEMETRY_METRICS_BACKEND = WrapperProperty( "telemetry_metrics_backend", "Method to export telemetry metrics of the driver.", - None + None, ) TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE = WrapperProperty( "telemetry_failover_additional_top_trace", "Post an additional top-level trace for failover process.", - False + False, ) READER_INITIAL_HOST_SELECTOR_STRATEGY = WrapperProperty( "reader_initial_connection_host_selector_strategy", "The strategy that should be used to select a new reader host while opening a new connection.", - "random") + "random", + ) OPEN_CONNECTION_RETRY_TIMEOUT_MS = WrapperProperty( "open_connection_retry_timeout_ms", "Maximum allowed time for the retries opening a connection.", - 30000 + 30000, ) OPEN_CONNECTION_RETRY_INTERVAL_MS = WrapperProperty( "open_connection_retry_interval_ms", "Time between each retry of opening a connection.", - 1000 + 1000, + ) + + # Simple Read/Write Splitting + SRW_READ_ENDPOINT = WrapperProperty( + "srw_read_endpoint", + "The read-only endpoint that should be used to connect to a reader.", + None, + ) + + SRW_WRITE_ENDPOINT = WrapperProperty( + "srw_write_endpoint", + "The read-write/cluster endpoint that should be used to connect to the writer.", + None, + ) + + SRW_VERIFY_NEW_CONNECTIONS = WrapperProperty( + "srw_verify_new_connections", + "Enables role-verification for new connections made by the Simple Read/Write Splitting Plugin..", + True, + ) + + SRW_VERIFY_INITIAL_CONNECTION_TYPE = WrapperProperty( + "srw_verify_initial_connection_type", + "Force to verify an initial connection to be either a writer or a reader.", + None, + ) + + SRW_CONNECT_RETRY_TIMEOUT_MS = WrapperProperty( + "srw_connect_retry_timeout_ms", + "Maximum allowed time in milliseconds for the plugin to retry opening a connection.", + 60000, + ) + + SRW_CONNECT_RETRY_INTERVAL_MS = WrapperProperty( + "srw_connect_retry_interval_ms", + "Time in milliseconds between each retry of opening a connection.", + 1000, ) @@ -485,7 +597,9 @@ def parse_pg_scheme_url(conn_info: str) -> Properties: elif conn_info.startswith("postgres://"): to_parse = conn_info[len("postgres://"):] else: - raise AwsWrapperError(Messages.get_formatted("PropertiesUtils.InvalidPgSchemeUrl", conn_info)) + raise AwsWrapperError( + Messages.get_formatted("PropertiesUtils.InvalidPgSchemeUrl", conn_info) + ) # Example URL: postgresql://user:password@host:port/dbname?some_prop=some_value # More examples here: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING @@ -495,7 +609,9 @@ def parse_pg_scheme_url(conn_info: str) -> Properties: password_separator = user_spec.find(":") if password_separator >= 0: props[WrapperProperties.USER.name] = user_spec[:password_separator] - props[WrapperProperties.PASSWORD.name] = user_spec[password_separator + 1:host_separator] + props[WrapperProperties.PASSWORD.name] = user_spec[ + password_separator + 1: host_separator + ] else: props[WrapperProperties.USER.name] = user_spec to_parse = to_parse[host_separator + 1:] @@ -513,7 +629,11 @@ def parse_pg_scheme_url(conn_info: str) -> Properties: host_spec = to_parse if host_spec.find(",") >= 0: - raise AwsWrapperError(Messages.get_formatted("PropertiesUtils.MultipleHostsNotSupported", conn_info)) + raise AwsWrapperError( + Messages.get_formatted( + "PropertiesUtils.MultipleHostsNotSupported", conn_info + ) + ) # host_spec may be a percent-encoded unix domain socket, eg '%2Fvar%2Flib%2Fpostgresql'. # When stored as a kwarg instead of a connection string property, it should be decoded. @@ -521,7 +641,7 @@ def parse_pg_scheme_url(conn_info: str) -> Properties: if host_spec.startswith("["): # IPv6 addresses should be enclosed in square brackets, eg 'postgresql://[2001:db8::1234]/dbname' host_end = host_spec.find("]") - props["host"] = host_spec[:host_end + 1] + props["host"] = host_spec[: host_end + 1] host_spec = host_spec[host_end + 1:] if len(host_spec) > 0: props["port"] = host_spec[1:] @@ -544,11 +664,17 @@ def parse_pg_scheme_url(conn_info: str) -> Properties: if props_separator >= 0: # Connection string properties must be percent-decoded when stored as kwargs - props.update(PropertiesUtils.parse_key_values(to_parse, separator="&", percent_decode=True)) + props.update( + PropertiesUtils.parse_key_values( + to_parse, separator="&", percent_decode=True + ) + ) return props @staticmethod - def parse_key_values(conn_info: str, separator: str = " ", percent_decode: bool = False) -> Properties: + def parse_key_values( + conn_info: str, separator: str = " ", percent_decode: bool = False + ) -> Properties: props = Properties() to_parse = conn_info @@ -558,7 +684,11 @@ def parse_key_values(conn_info: str, separator: str = " ", percent_decode: bool equals_i = to_parse.find("=") key_end = sep_i if -1 < sep_i < equals_i else equals_i if key_end == -1: - raise AwsWrapperError(Messages.get_formatted("PropertiesUtils.ErrorParsingConnectionString", conn_info)) + raise AwsWrapperError( + Messages.get_formatted( + "PropertiesUtils.ErrorParsingConnectionString", conn_info + ) + ) key = to_parse[0:key_end] to_parse = to_parse[equals_i + 1:].lstrip() @@ -576,7 +706,10 @@ def parse_key_values(conn_info: str, separator: str = " ", percent_decode: bool @staticmethod def remove_wrapper_props(props: Properties): - persisting_properties = [WrapperProperties.USER.name, WrapperProperties.PASSWORD.name] + persisting_properties = [ + WrapperProperties.USER.name, + WrapperProperties.PASSWORD.name, + ] for attr_name, attr_val in WrapperProperties.__dict__.items(): if isinstance(attr_val, WrapperProperty): @@ -584,7 +717,11 @@ def remove_wrapper_props(props: Properties): if attr_val.name not in persisting_properties: props.pop(attr_val.name, None) - monitor_prop_keys = [key for key in props if key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX)] + monitor_prop_keys = [ + key + for key in props + if key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX) + ] for key in monitor_prop_keys: props.pop(key, None) @@ -618,6 +755,7 @@ def create_monitoring_properties(props: Properties) -> Properties: monitoring_properties = copy.deepcopy(props) for property_key in list(monitoring_properties.keys()): if property_key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX): - monitoring_properties[property_key[len(PropertiesUtils._MONITORING_PROPERTY_PREFIX):]] = \ - monitoring_properties.pop(property_key) + monitoring_properties[ + property_key[len(PropertiesUtils._MONITORING_PROPERTY_PREFIX):] + ] = monitoring_properties.pop(property_key) return monitoring_properties diff --git a/pyproject.toml b/pyproject.toml index ae8796e2..30e2b005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] filterwarnings = [ 'ignore:cache could not write path', - 'ignore:could not create cache path' + 'ignore:could not create cache path', + 'ignore:Exception during reset or similar:pytest.PytestUnhandledThreadExceptionWarning' ] diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 6cf94137..89b2f87f 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -66,55 +66,141 @@ def setup_method(self, request): MonitoringThreadContainer.clean_up() gc.collect() - @pytest.fixture(scope='class') + # Plugin configurations + @pytest.fixture( + params=[("read_write_splitting", "read_write_splitting"), ("srw", "srw")] + ) + def plugin_config(self, request): + return request.param + + @pytest.fixture(scope="class") def rds_utils(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) @pytest.fixture(autouse=True) def clear_caches(self): + # Clear wrapper caches RdsHostListProvider._topology_cache.clear() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() - @pytest.fixture(scope='class') - def props(self): - p: Properties = Properties({"plugins": "read_write_splitting", "socket_timeout": 10, "connect_timeout": 10, "autocommit": True}) + # Force DNS refresh by clearing any internal connection pools + ConnectionProviderManager.release_resources() + ConnectionProviderManager.reset_provider() + + @pytest.fixture + def props(self, plugin_config, conn_utils): + plugin_name, plugin_value = plugin_config + p: Properties = Properties( + { + "plugins": plugin_value, + "socket_timeout": 10, + "connect_timeout": 10, + "autocommit": True, + } + ) - if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features() \ - or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + # Add simple plugin specific configuration + if plugin_name == "srw": + WrapperProperties.SRW_WRITE_ENDPOINT.set(p, conn_utils.writer_cluster_host) + WrapperProperties.SRW_READ_ENDPOINT.set(p, conn_utils.reader_cluster_host) + WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.set(p, "30000") + WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.set(p, "1000") + + if ( + TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED + in TestEnvironment.get_current().get_features() + or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED + in TestEnvironment.get_current().get_features() + ): WrapperProperties.ENABLE_TELEMETRY.set(p, "True") WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.set(p, "True") - if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): + if ( + TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED + in TestEnvironment.get_current().get_features() + ): WrapperProperties.TELEMETRY_TRACES_BACKEND.set(p, "XRAY") - if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + if ( + TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED + in TestEnvironment.get_current().get_features() + ): WrapperProperties.TELEMETRY_METRICS_BACKEND.set(p, "OTLP") return p - @pytest.fixture(scope='class') - def failover_props(self): - return { - "plugins": "read_write_splitting,failover", + @pytest.fixture + def failover_props(self, plugin_config, conn_utils): + plugin_name, plugin_value = plugin_config + props = { + "plugins": f"{plugin_value},failover", "socket_timeout": 10, "connect_timeout": 10, - "autocommit": True + "autocommit": True, } - - @pytest.fixture(scope='class') - def proxied_props(self, props, conn_utils): + # Add simple plugin specific configuration + if plugin_name == "srw": + WrapperProperties.SRW_WRITE_ENDPOINT.set( + props, conn_utils.writer_cluster_host + ) + WrapperProperties.SRW_READ_ENDPOINT.set( + props, conn_utils.reader_cluster_host + ) + + return props + + @pytest.fixture + def proxied_props(self, props, plugin_config, conn_utils): + plugin_name, _ = plugin_config props_copy = props.copy() - endpoint_suffix = TestEnvironment.get_current().get_proxy_database_info().get_instance_endpoint_suffix() - WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.set(props_copy, f"?.{endpoint_suffix}:{conn_utils.proxy_port}") + + # Add simple plugin specific configuration + if plugin_name == "srw": + WrapperProperties.SRW_WRITE_ENDPOINT.set( + props_copy, + f"{conn_utils.proxy_writer_cluster_host}:{conn_utils.proxy_port}", + ) + WrapperProperties.SRW_READ_ENDPOINT.set( + props_copy, + f"{conn_utils.proxy_reader_cluster_host}:{conn_utils.proxy_port}", + ) + + endpoint_suffix = ( + TestEnvironment.get_current() + .get_proxy_database_info() + .get_instance_endpoint_suffix() + ) + WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.set( + props_copy, f"?.{endpoint_suffix}:{conn_utils.proxy_port}" + ) return props_copy - @pytest.fixture(scope='class') - def proxied_failover_props(self, failover_props, conn_utils): + @pytest.fixture + def proxied_failover_props(self, failover_props, plugin_config, conn_utils): + plugin_name, _ = plugin_config props_copy = failover_props.copy() - endpoint_suffix = TestEnvironment.get_current().get_proxy_database_info().get_instance_endpoint_suffix() - WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.set(props_copy, f"?.{endpoint_suffix}:{conn_utils.proxy_port}") + + # Add simple plugin specific configuration + if plugin_name == "srw": + WrapperProperties.SRW_WRITE_ENDPOINT.set( + props_copy, + f"{conn_utils.proxy_writer_cluster_host}:{conn_utils.proxy_port}", + ) + WrapperProperties.SRW_READ_ENDPOINT.set( + props_copy, + f"{conn_utils.proxy_reader_cluster_host}:{conn_utils.proxy_port}", + ) + + endpoint_suffix = ( + TestEnvironment.get_current() + .get_proxy_database_info() + .get_instance_endpoint_suffix() + ) + WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.set( + props_copy, f"?.{endpoint_suffix}:{conn_utils.proxy_port}" + ) return props_copy @pytest.fixture(autouse=True) @@ -126,9 +212,12 @@ def cleanup_connection_provider(self): ProxyHelper.enable_all_connectivity() def test_connect_to_writer__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) as conn: writer_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -152,11 +241,26 @@ def test_connect_to_writer__switch_read_only( assert reader_id == current_id def test_connect_to_reader__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + props, + conn_utils, + rds_utils, + plugin_config, + ): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + pytest.skip( + "Test only applies to read_write_splitting plugin: srw does not connect to instances" + ) target_driver_connect = DriverHelper.get_connect_func(test_driver) reader_instance = test_environment.get_instances()[1] with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_connect_params(reader_instance.get_host()), **props) as conn: + target_driver_connect, + **conn_utils.get_connect_params(reader_instance.get_host()), + **props, + ) as conn: reader_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -168,10 +272,14 @@ def test_connect_to_reader__switch_read_only( assert reader_id != writer_id def test_connect_to_reader_cluster__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_connect_params(conn_utils.reader_cluster_host), **props) as conn: + target_driver_connect, + **conn_utils.get_connect_params(conn_utils.reader_cluster_host), + **props, + ) as conn: reader_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -183,9 +291,12 @@ def test_connect_to_reader_cluster__switch_read_only( assert reader_id != writer_id def test_set_read_only_false__read_only_transaction( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) as conn: writer_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -210,9 +321,12 @@ def test_set_read_only_false__read_only_transaction( assert writer_id == current_id def test_set_read_only_false_in_transaction( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) as conn: writer_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -235,9 +349,12 @@ def test_set_read_only_false_in_transaction( assert writer_id == current_id def test_set_read_only_true_in_transaction( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) as conn: writer_id = rds_utils.query_instance_id(conn) cursor = conn.cursor() @@ -258,16 +375,31 @@ def test_set_read_only_true_in_transaction( @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED]) @enable_on_num_instances(min_instances=3) def test_set_read_only_true__all_readers_down( - self, test_environment: TestEnvironment, test_driver: TestDriver, proxied_props, conn_utils, rds_utils): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + proxied_props, + conn_utils, + rds_utils, + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) connect_params = conn_utils.get_proxy_connect_params() - with AwsWrapperConnection.connect(target_driver_connect, **connect_params, **proxied_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **connect_params, **proxied_props + ) as conn: writer_id = rds_utils.query_instance_id(conn) - instance_ids = [instance.get_instance_id() for instance in test_environment.get_instances()] + # Disable all reader instance ids and reader cluster endpoint. + instance_ids = [ + instance.get_instance_id() + for instance in test_environment.get_instances() + ] for i in range(1, len(instance_ids)): ProxyHelper.disable_connectivity(instance_ids[i]) + ProxyHelper.disable_connectivity( + test_environment.get_proxy_database_info().get_cluster_read_only_endpoint() + ) conn.read_only = True current_id = rds_utils.query_instance_id(conn) @@ -283,30 +415,53 @@ def test_set_read_only_true__all_readers_down( assert writer_id != current_id def test_set_read_only_true__closed_connection( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, test_driver: TestDriver, props, conn_utils, rds_utils + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - conn = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) + conn = AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) conn.close() with pytest.raises(AwsWrapperError): conn.read_only = True - # @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED]) + @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED]) @pytest.mark.skip def test_set_read_only_false__all_instances_down( - self, test_environment: TestEnvironment, test_driver: TestDriver, proxied_props, conn_utils, rds_utils): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + proxied_props, + conn_utils, + rds_utils, + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) reader = test_environment.get_proxy_instances()[1] with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_proxy_connect_params(reader.get_host()), **proxied_props) as conn: + target_driver_connect, + **conn_utils.get_proxy_connect_params(reader.get_host()), + **proxied_props, + ) as conn: ProxyHelper.disable_all_connectivity() with pytest.raises(AwsWrapperError): conn.read_only = False def test_execute__old_connection( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, rds_utils): + self, + test_driver: TestDriver, + props: Properties, + conn_utils, + rds_utils, + plugin_config, + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: + WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.set(props, "False") + with AwsWrapperConnection.connect( + target_driver_connect, + **conn_utils.get_connect_params(conn_utils.writer_cluster_host), + **props, + ) as conn: writer_id = rds_utils.query_instance_id(conn) old_cursor = conn.cursor() @@ -325,20 +480,35 @@ def test_execute__old_connection( current_id = rds_utils.query_instance_id(conn) assert reader_id == current_id - @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, + TestEnvironmentFeatures.FAILOVER_SUPPORTED]) @enable_on_num_instances(min_instances=3) def test_failover_to_new_writer__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, - proxied_failover_props, conn_utils, rds_utils): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + proxied_failover_props, + conn_utils, + rds_utils, + ): target_driver_connect = DriverHelper.get_connect_func(test_driver) connect_params = conn_utils.get_proxy_connect_params() - with AwsWrapperConnection.connect(target_driver_connect, **connect_params, **proxied_failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **connect_params, **proxied_failover_props + ) as conn: original_writer_id = rds_utils.query_instance_id(conn) - instance_ids = [instance.get_instance_id() for instance in test_environment.get_instances()] + # Disable all reader instance ids and reader cluster endpoint. + instance_ids = [ + instance.get_instance_id() + for instance in test_environment.get_instances() + ] for i in range(1, len(instance_ids)): ProxyHelper.disable_connectivity(instance_ids[i]) + ProxyHelper.disable_connectivity( + test_environment.get_proxy_database_info().get_cluster_read_only_endpoint() + ) # Force internal reader connection to the writer instance conn.read_only = True @@ -368,15 +538,32 @@ def test_failover_to_new_writer__switch_read_only( @enable_on_num_instances(min_instances=3) @disable_on_engines([DatabaseEngine.MYSQL]) def test_failover_to_new_reader__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, - proxied_failover_props, conn_utils, rds_utils, plugins): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + proxied_failover_props, + conn_utils, + rds_utils, + plugin_config, + plugins + ): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + # Disabling the reader connection in srw, the srwReadEndpoint, results in defaulting to the writer not connecting to another reader. + pytest.skip( + "Test only applies to read_write_splitting plugin: reader connection failover" + ) + WrapperProperties.FAILOVER_MODE.set(proxied_failover_props, "reader-or-writer") WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) target_driver_connect = DriverHelper.get_connect_func(test_driver) with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_proxy_connect_params(), **proxied_failover_props) as conn: + target_driver_connect, + **conn_utils.get_proxy_connect_params(), + **proxied_failover_props, + ) as conn: writer_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -384,8 +571,14 @@ def test_failover_to_new_reader__switch_read_only( assert writer_id != reader_id instances = test_environment.get_instances() - other_reader_id = next(( - instance.get_instance_id() for instance in instances[1:] if instance.get_instance_id() != reader_id), None) + other_reader_id = next( + ( + instance.get_instance_id() + for instance in instances[1:] + if instance.get_instance_id() != reader_id + ), + None, + ) if other_reader_id is None: pytest.fail("Could not acquire alternate reader ID") @@ -412,18 +605,29 @@ def test_failover_to_new_reader__switch_read_only( current_id = rds_utils.query_instance_id(conn) assert other_reader_id == current_id - @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring", "read_write_splitting,failover,host_monitoring_v2"]) + @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @enable_on_num_instances(min_instances=3) @disable_on_engines([DatabaseEngine.MYSQL]) def test_failover_reader_to_writer__switch_read_only( - self, test_environment: TestEnvironment, test_driver: TestDriver, - proxied_failover_props, conn_utils, rds_utils, plugins): - WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + proxied_failover_props, + conn_utils, + rds_utils, + plugin_config, + plugins + ): + plugin_name, _ = plugin_config + WrapperProperties.PLUGINS.set(proxied_failover_props, plugin_name + "," + plugins) target_driver_connect = DriverHelper.get_connect_func(test_driver) with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_proxy_connect_params(), **proxied_failover_props) as conn: + target_driver_connect, + **conn_utils.get_proxy_connect_params(), + **proxied_failover_props, + ) as conn: writer_id = rds_utils.query_instance_id(conn) conn.read_only = True @@ -435,6 +639,9 @@ def test_failover_reader_to_writer__switch_read_only( instance_id = instance.get_instance_id() if instance_id != writer_id: ProxyHelper.disable_connectivity(instance_id) + ProxyHelper.disable_connectivity( + test_environment.get_proxy_database_info().get_cluster_read_only_endpoint() + ) rds_utils.assert_first_query_throws(conn, FailoverSuccessError) assert not conn.is_closed @@ -450,18 +657,105 @@ def test_failover_reader_to_writer__switch_read_only( current_id = rds_utils.query_instance_id(conn) assert writer_id == current_id + def test_incorrect_reader_endpoint( + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + conn_utils, + rds_utils, + plugin_config, + ): + plugin_name, plugin_value = plugin_config + if plugin_name != "srw": + pytest.skip( + "Test only applies to simple_read_write_splitting plugin: uses srwReadEndpoint property" + ) + + props = Properties( + {"plugins": plugin_value, "connect_timeout": 30, "autocommit": True} + ) + port = ( + test_environment.get_info().get_database_info().get_cluster_endpoint_port() + ) + writer_endpoint = conn_utils.writer_cluster_host + + # Set both endpoints to writer (incorrect reader endpoint) + WrapperProperties.SRW_WRITE_ENDPOINT.set(props, f"{writer_endpoint}:{port}") + WrapperProperties.SRW_READ_ENDPOINT.set(props, f"{writer_endpoint}:{port}") + + target_driver_connect = DriverHelper.get_connect_func(test_driver) + with AwsWrapperConnection.connect( + target_driver_connect, + **conn_utils.get_connect_params(conn_utils.writer_cluster_host), + **props, + ) as conn: + writer_connection_id = rds_utils.query_instance_id(conn) + + # Switch to reader successfully + conn.read_only = True + reader_connection_id = rds_utils.query_instance_id(conn) + # Should stay on writer as fallback since reader endpoint points to a writer + assert writer_connection_id == reader_connection_id + + # Going to the write endpoint will be the same connection again + conn.read_only = False + final_connection_id = rds_utils.query_instance_id(conn) + assert writer_connection_id == final_connection_id + + def test_autocommit_state_preserved_across_connection_switches( + self, test_driver: TestDriver, props, conn_utils, rds_utils, plugin_config + ): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip( + "Test only applies to simple_read_write_splitting plugin: autocommit impacts srw verification" + ) + + target_driver_connect = DriverHelper.get_connect_func(test_driver) + WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.set(props, "False") + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) as conn: + # Set autocommit to False on writer + conn.autocommit = False + assert conn.autocommit is False + writer_connection_id = rds_utils.query_instance_id(conn) + conn.commit() + + # Switch to reader - autocommit should remain False + conn.read_only = True + assert conn.autocommit is False + reader_connection_id = rds_utils.query_instance_id(conn) + assert writer_connection_id != reader_connection_id + conn.commit() + + # Change autocommit on reader + conn.autocommit = True + assert conn.autocommit is True + + # Switch back to writer - autocommit should be True + conn.read_only = False + assert conn.autocommit is True + final_writer_connection_id = rds_utils.query_instance_id(conn) + assert writer_connection_id == final_writer_connection_id + def test_pooled_connection__reuses_cached_connection( - self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + self, test_driver: TestDriver, conn_utils, props + ): provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}) ConnectionProviderManager.set_connection_provider(provider) target_driver_connect = DriverHelper.get_connect_func(test_driver) - conn1 = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) + conn1 = AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) assert isinstance(conn1.target_connection, PoolProxiedConnection) driver_conn1 = conn1.target_connection.driver_connection conn1.close() - conn2 = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) + conn2 = AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) assert isinstance(conn2.target_connection, PoolProxiedConnection) driver_conn2 = conn2.target_connection.driver_connection conn2.close() @@ -471,12 +765,15 @@ def test_pooled_connection__reuses_cached_connection( @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_pooled_connection__failover( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, failover_props): + self, test_driver: TestDriver, rds_utils, conn_utils, failover_props + ): provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}) ConnectionProviderManager.set_connection_provider(provider) target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **failover_props + ) as conn: assert isinstance(conn.target_connection, PoolProxiedConnection) initial_driver_conn = conn.target_connection.driver_connection initial_writer_id = rds_utils.query_instance_id(conn) @@ -493,7 +790,9 @@ def test_pooled_connection__failover( assert initial_driver_conn is not new_driver_conn # New connection to the original writer (now a reader) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **failover_props + ) as conn: current_id = rds_utils.query_instance_id(conn) assert initial_writer_id == current_id @@ -505,14 +804,17 @@ def test_pooled_connection__failover( @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_pooled_connection__cluster_url_failover( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, failover_props): + self, test_driver: TestDriver, rds_utils, conn_utils, failover_props + ): provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}) ConnectionProviderManager.set_connection_provider(provider) target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, - **conn_utils.get_connect_params(conn_utils.writer_cluster_host), - **failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, + **conn_utils.get_connect_params(conn_utils.writer_cluster_host), + **failover_props, + ) as conn: # The internal connection pool should not be used if the connection is established via a cluster URL. assert 0 == len(SqlAlchemyPooledConnectionProvider._database_pools) @@ -532,25 +834,41 @@ def test_pooled_connection__cluster_url_failover( new_driver_conn = conn.target_connection assert initial_driver_conn is not new_driver_conn - @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring", "read_write_splitting,failover,host_monitoring_v2"]) - @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED, TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, + @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2"]) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED, + TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @disable_on_engines([DatabaseEngine.MYSQL]) def test_pooled_connection__failover_failed( - self, test_environment: TestEnvironment, test_driver: TestDriver, - rds_utils, conn_utils, proxied_failover_props, plugins): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + rds_utils, + conn_utils, + proxied_failover_props, + plugin_config, + plugins + ): + plugin_name, _ = plugin_config writer_host = test_environment.get_writer().get_host() - provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}, None, lambda host_info, props: writer_host in host_info.host) + provider = SqlAlchemyPooledConnectionProvider( + lambda _, __: {"pool_size": 1}, + None, + lambda host_info, props: writer_host in host_info.host, + ) ConnectionProviderManager.set_connection_provider(provider) WrapperProperties.FAILOVER_TIMEOUT_SEC.set(proxied_failover_props, "1") WrapperProperties.FAILURE_DETECTION_TIME_MS.set(proxied_failover_props, "1000") WrapperProperties.FAILURE_DETECTION_COUNT.set(proxied_failover_props, "1") - WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) + WrapperProperties.PLUGINS.set(proxied_failover_props, plugin_name + "," + plugins) target_driver_connect = DriverHelper.get_connect_func(test_driver) with AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_proxy_connect_params(), **proxied_failover_props) as conn: + target_driver_connect, + **conn_utils.get_proxy_connect_params(), + **proxied_failover_props, + ) as conn: assert isinstance(conn.target_connection, PoolProxiedConnection) initial_driver_conn = conn.target_connection.driver_connection writer_id = rds_utils.query_instance_id(conn) @@ -561,7 +879,10 @@ def test_pooled_connection__failover_failed( ProxyHelper.enable_all_connectivity() conn = AwsWrapperConnection.connect( - target_driver_connect, **conn_utils.get_proxy_connect_params(), **proxied_failover_props) + target_driver_connect, + **conn_utils.get_proxy_connect_params(), + **proxied_failover_props, + ) current_writer_id = rds_utils.query_instance_id(conn) assert writer_id == current_writer_id @@ -574,12 +895,15 @@ def test_pooled_connection__failover_failed( @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_pooled_connection__failover_in_transaction( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, failover_props): + self, test_driver: TestDriver, rds_utils, conn_utils, failover_props + ): provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}) ConnectionProviderManager.set_connection_provider(provider) target_driver_connect = DriverHelper.get_connect_func(test_driver) - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **failover_props + ) as conn: assert isinstance(conn.target_connection, PoolProxiedConnection) initial_driver_conn = conn.target_connection.driver_connection initial_writer_id = rds_utils.query_instance_id(conn) @@ -599,7 +923,9 @@ def test_pooled_connection__failover_in_transaction( new_driver_conn = conn.target_connection assert initial_driver_conn is not new_driver_conn - with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **failover_props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **failover_props + ) as conn: current_id = rds_utils.query_instance_id(conn) assert initial_writer_id == current_id @@ -610,7 +936,13 @@ def test_pooled_connection__failover_in_transaction( assert initial_driver_conn is not current_driver_conn def test_pooled_connection__different_users( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, props): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + rds_utils, + conn_utils, + props, + ): privileged_user_props = conn_utils.get_connect_params().copy() limited_user_props = conn_utils.get_connect_params().copy() limited_user_name = "limited_user" @@ -627,47 +959,79 @@ def test_pooled_connection__different_users( target_driver_connect = DriverHelper.get_connect_func(test_driver) try: - with AwsWrapperConnection.connect(target_driver_connect, **privileged_user_props, **props) as conn: + with AwsWrapperConnection.connect( + target_driver_connect, **privileged_user_props, **props + ) as conn: assert isinstance(conn.target_connection, PoolProxiedConnection) privileged_driver_conn = conn.target_connection.driver_connection with conn.cursor() as cursor: cursor.execute(f"DROP USER IF EXISTS {limited_user_name}") - rds_utils.create_user(conn, limited_user_name, limited_user_password) + rds_utils.create_user( + conn, limited_user_name, limited_user_password + ) engine = test_environment.get_engine() if engine == DatabaseEngine.MYSQL: db = test_environment.get_database_info().get_default_db_name() # MySQL needs this extra command to allow the limited user to connect to the default database - cursor.execute(f"GRANT ALL PRIVILEGES ON {db}.* TO {limited_user_name}") + cursor.execute( + f"GRANT ALL PRIVILEGES ON {db}.* TO {limited_user_name}" + ) # Validate that the privileged connection established above is not reused and that the new connection is # correctly established under the limited user - with AwsWrapperConnection.connect(target_driver_connect, **limited_user_props, **props) as conn2: - assert isinstance(conn2.target_connection, PoolProxiedConnection) + with AwsWrapperConnection.connect( + target_driver_connect, **limited_user_props, **props + ) as conn2: + assert isinstance( + conn2.target_connection, PoolProxiedConnection + ) limited_driver_conn = conn2.target_connection.driver_connection assert privileged_driver_conn is not limited_driver_conn with conn2.cursor() as cursor2: with pytest.raises(Exception): # The limited user does not have create permissions on the default database, so this should fail - cursor2.execute(f"CREATE DATABASE {limited_user_new_db}") + cursor2.execute( + f"CREATE DATABASE {limited_user_new_db}" + ) with pytest.raises(Exception): - AwsWrapperConnection.connect(target_driver_connect, **wrong_user_right_password_props, - **props) + AwsWrapperConnection.connect( + target_driver_connect, + **wrong_user_right_password_props, + **props, + ) finally: - conn = AwsWrapperConnection.connect(target_driver_connect, **privileged_user_props, **props) + conn = AwsWrapperConnection.connect( + target_driver_connect, **privileged_user_props, **props + ) cursor = conn.cursor() cursor.execute(f"DROP DATABASE IF EXISTS {limited_user_new_db}") cursor.execute(f"DROP USER IF EXISTS {limited_user_name}") @enable_on_num_instances(min_instances=5) def test_pooled_connection__least_connections( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, props): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + rds_utils, + conn_utils, + props, + plugin_config, + ): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + pytest.skip( + "Test only applies to read_write_splitting plugin: reader host selector strategy" + ) + WrapperProperties.READER_HOST_SELECTOR_STRATEGY.set(props, "least_connections") instances = test_environment.get_instances() - provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": len(instances)}) + provider = SqlAlchemyPooledConnectionProvider( + lambda _, __: {"pool_size": len(instances)} + ) ConnectionProviderManager.set_connection_provider(provider) connections = [] @@ -676,7 +1040,9 @@ def test_pooled_connection__least_connections( try: # Assume one writer and [size - 1] readers. Create an internal connection pool for each reader. for _ in range(len(instances) - 1): - conn = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) + conn = AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) connections.append(conn) conn.read_only = True @@ -695,7 +1061,20 @@ def test_pooled_connection__least_connections( @enable_on_num_instances(min_instances=5) def test_pooled_connection__least_connections__pool_mapping( - self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, props): + self, + test_environment: TestEnvironment, + test_driver: TestDriver, + rds_utils, + conn_utils, + props, + plugin_config, + ): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + pytest.skip( + "Test only applies to read_write_splitting plugin: reader host selector strategy" + ) + WrapperProperties.READER_HOST_SELECTOR_STRATEGY.set(props, "least_connections") # We will be testing all instances excluding the writer and overloaded reader. Each instance @@ -707,7 +1086,7 @@ def test_pooled_connection__least_connections__pool_mapping( provider = SqlAlchemyPooledConnectionProvider( lambda _, __: {"pool_size": num_test_connections}, # Create a new pool for each instance-arbitrary_prop combination - lambda host_info, conn_props: f"{host_info.url}-{len(SqlAlchemyPooledConnectionProvider._database_pools)}" + lambda host_info, conn_props: f"{host_info.url}-{len(SqlAlchemyPooledConnectionProvider._database_pools)}", ) ConnectionProviderManager.set_connection_provider(provider) @@ -720,14 +1099,20 @@ def test_pooled_connection__least_connections__pool_mapping( # with each pool consisting of just one connection. The total connection count for the # instance should be overloaded_reader_connection_count despite being spread across multiple # pools. - conn = AwsWrapperConnection.connect(target_driver_connect, - **conn_utils.get_connect_params(reader_to_overload.get_host()), - **props) + conn = AwsWrapperConnection.connect( + target_driver_connect, + **conn_utils.get_connect_params(reader_to_overload.get_host()), + **props, + ) connections.append(conn) - assert overloaded_reader_connection_count == len(SqlAlchemyPooledConnectionProvider._database_pools) + assert overloaded_reader_connection_count == len( + SqlAlchemyPooledConnectionProvider._database_pools + ) for _ in range(num_test_connections): - conn = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) + conn = AwsWrapperConnection.connect( + target_driver_connect, **conn_utils.get_connect_params(), **props + ) connections.append(conn) conn.read_only = True diff --git a/tests/unit/test_read_write_splitting_plugin.py b/tests/unit/test_read_write_splitting_plugin.py index e77f67e2..f87a2817 100644 --- a/tests/unit/test_read_write_splitting_plugin.py +++ b/tests/unit/test_read_write_splitting_plugin.py @@ -17,18 +17,28 @@ import psycopg import pytest -from aws_advanced_python_wrapper.errors import FailoverSuccessError +from aws_advanced_python_wrapper.errors import (AwsWrapperError, + FailoverSuccessError, + ReadWriteSplittingError) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249 import Error from aws_advanced_python_wrapper.read_write_splitting_plugin import \ ReadWriteSplittingPlugin +from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \ + SimpleReadWriteSplittingPlugin from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider from aws_advanced_python_wrapper.utils.notifications import \ OldConnectionSuggestedAction -from aws_advanced_python_wrapper.utils.properties import Properties +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from tests.unit.utils.unit_test_utils import AnyInstanceOf + +# Common test data +WRITE_ENDPOINT = "writer.cluster-xyz.us-east-1.rds.amazonaws.com" +READ_ENDPOINT = "reader.cluster-xyz.us-east-1.rds.amazonaws.com" +TEST_PORT = 5432 -default_props = Properties() writer_host = HostInfo(host="instance0", role=HostRole.WRITER) reader_host1 = HostInfo(host="instance1", role=HostRole.READER) reader_host2 = HostInfo(host="instance2", role=HostRole.READER) @@ -37,15 +47,42 @@ default_hosts: List[HostInfo] = [writer_host, reader_host1, reader_host2, reader_host3] single_reader_topology: List[HostInfo] = [writer_host, reader_host1] +# Simple plugin specific hosts +simple_writer_host = HostInfo(host=WRITE_ENDPOINT, port=TEST_PORT, role=HostRole.WRITER) +simple_reader_host = HostInfo(host=READ_ENDPOINT, port=TEST_PORT, role=HostRole.READER) +any_host = AnyInstanceOf(HostInfo) + + +# Plugin configurations +@pytest.fixture( + params=[ + ("read_write_splitting", ReadWriteSplittingPlugin), + ("srw", SimpleReadWriteSplittingPlugin), + ] +) +def plugin_config(request): + return request.param + @pytest.fixture -def host_list_provider_service_mock(mocker): - return mocker.MagicMock() +def plugin_props(plugin_config): + plugin_name, _ = plugin_config + props = Properties() + if plugin_name == "srw": + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name] = "600" + props[WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name] = "10" + return props @pytest.fixture -def changes_mock(mocker): - return mocker.MagicMock() +def host_list_provider_service_mock(mocker, plugin_config): + plugin_name, _ = plugin_config + mock = mocker.MagicMock() + if plugin_name == "srw": + mock.initial_connection_host_info = simple_writer_host + return mock @pytest.fixture @@ -74,7 +111,7 @@ def connect_func_mock(mocker): @pytest.fixture -def driver_dialect_mock(mocker, writer_conn_mock): +def driver_dialect_mock(mocker, writer_conn_mock, closed_writer_conn_mock): def is_closed_side_effect(conn): return conn == closed_writer_conn_mock @@ -82,226 +119,400 @@ def is_closed_side_effect(conn): driver_dialect_mock.is_closed.side_effect = is_closed_side_effect driver_dialect_mock.get_connection_from_obj.return_value = writer_conn_mock driver_dialect_mock.unwrap_connection.return_value = writer_conn_mock - + driver_dialect_mock.can_execute_query.return_value = True return driver_dialect_mock @pytest.fixture -def plugin_service_mock(mocker, driver_dialect_mock): +def plugin_service_mock(mocker, driver_dialect_mock, writer_conn_mock): plugin_service_mock = mocker.MagicMock() plugin_service_mock.driver_dialect = driver_dialect_mock plugin_service_mock.hosts = default_hosts + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.current_host_info = writer_host + plugin_service_mock.is_in_transaction = False + plugin_service_mock.get_host_role.return_value = HostRole.WRITER return plugin_service_mock -def test_set_read_only_true(plugin_service_mock): +@pytest.fixture +def plugin_instance( + plugin_config, plugin_service_mock, plugin_props, host_list_provider_service_mock +): + _, plugin_class = plugin_config + plugin = plugin_class(plugin_service_mock, plugin_props) + plugin._connection_handler._host_list_provider_service = ( + host_list_provider_service_mock + ) + + return plugin + + +# Common tests for both plugins +def test_set_read_only_true( + plugin_instance, plugin_service_mock, plugin_config, reader_conn_mock +): + plugin_name, _ = plugin_config plugin_service_mock.current_connection = writer_conn_mock - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 - plugin_service_mock.hosts = single_reader_topology plugin_service_mock.connect.return_value = reader_conn_mock - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._reader_connection = None + if plugin_name == "read_write_splitting": + plugin_service_mock.current_host_info = writer_host + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + plugin_service_mock.hosts = single_reader_topology + plugin_instance._reader_connection = None + + plugin_instance._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_called_once_with( + reader_conn_mock, reader_host1 + ) + assert plugin_instance._reader_connection == reader_conn_mock + else: + plugin_service_mock.current_host_info = simple_writer_host + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + plugin_instance._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_called_with( + reader_conn_mock, any_host + ) + + +def test_set_read_only_false( + plugin_instance, + plugin_service_mock, + plugin_config, + reader_conn_mock, + writer_conn_mock, +): + plugin_name, _ = plugin_config + plugin_service_mock.current_connection = reader_conn_mock + plugin_service_mock.connect.return_value = writer_conn_mock + + if plugin_name == "read_write_splitting": + plugin_service_mock.current_host_info = reader_host1 + plugin_service_mock.hosts = single_reader_topology + plugin_instance._writer_host_info = writer_host + + plugin_instance._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_called_once_with( + writer_conn_mock, writer_host + ) + assert plugin_instance._writer_connection == writer_conn_mock + else: + plugin_service_mock.current_host_info = simple_reader_host + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + plugin_instance._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_called_with( + writer_conn_mock, any_host + ) + + +def test_set_read_only_true_already_on_reader( + plugin_instance, plugin_service_mock, plugin_config, reader_conn_mock +): + plugin_name, _ = plugin_config + plugin_service_mock.current_connection = reader_conn_mock + plugin_instance._reader_connection = reader_conn_mock - plugin._switch_connection_if_required(True) + if plugin_name == "read_write_splitting": + plugin_service_mock.current_host_info = reader_host1 + else: + plugin_service_mock.current_host_info = simple_reader_host - plugin_service_mock.set_current_connection.assert_called_once() - plugin_service_mock.set_current_connection.assert_called_once_with(reader_conn_mock, reader_host1) + plugin_instance._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_not_called() - assert plugin._reader_connection == reader_conn_mock +def test_set_read_only_false_already_on_writer( + plugin_instance, plugin_service_mock, plugin_config, writer_conn_mock +): + plugin_name, _ = plugin_config + plugin_service_mock.current_connection = writer_conn_mock + plugin_instance._writer_connection = writer_conn_mock -def test_set_read_only_false(plugin_service_mock): - plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.current_host_info = reader_host1 - plugin_service_mock.hosts = single_reader_topology + if plugin_name == "read_write_splitting": + plugin_service_mock.current_host_info = writer_host + plugin_instance._writer_host_info = writer_host + else: + plugin_service_mock.current_host_info = simple_writer_host - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock + plugin_instance._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_not_called() - plugin._switch_connection_if_required(False) - plugin_service_mock.set_current_connection.assert_called_once() - plugin_service_mock.set_current_connection.assert_called_once_with(writer_conn_mock, writer_host) +def test_set_read_only_false_in_transaction( + plugin_instance, plugin_service_mock, reader_conn_mock +): + plugin_service_mock.current_connection = reader_conn_mock + plugin_service_mock.is_in_transaction = True + plugin_service_mock.current_host_info = simple_reader_host - assert plugin._writer_connection == writer_conn_mock + with pytest.raises(ReadWriteSplittingError): + plugin_instance._switch_connection_if_required(False) -def test_set_read_only_true_already_on_reader(plugin_service_mock): +def test_set_read_only_true_in_transaction_already_on_reader( + plugin_instance, plugin_service_mock, reader_conn_mock +): plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.current_host_info = reader_host1 + plugin_service_mock.is_in_transaction = True + plugin_service_mock.current_host_info = simple_reader_host + plugin_instance._writer_connection = None + plugin_instance._reader_connection = reader_conn_mock - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = None - plugin._reader_connection = reader_conn_mock + plugin_instance._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_not_called() + assert plugin_instance._reader_connection == reader_conn_mock + assert plugin_instance._writer_connection is None - plugin._switch_connection_if_required(True) - plugin_service_mock.set_current_connection.assert_not_called() +def test_set_read_only_on_closed_connection( + plugin_instance, plugin_service_mock, closed_writer_conn_mock +): + plugin_service_mock.current_connection = closed_writer_conn_mock + plugin_instance._writer_connection = closed_writer_conn_mock + plugin_instance._reader_connection = None + + with pytest.raises(ReadWriteSplittingError): + plugin_instance._switch_connection_if_required(True) - assert plugin._reader_connection == reader_conn_mock - assert plugin._writer_connection is None + plugin_service_mock.set_current_connection.assert_not_called() + assert plugin_instance._reader_connection is None -def test_set_read_only_false_already_on_writer(plugin_service_mock): +def test_notify_connection_change( + plugin_instance, plugin_service_mock, writer_conn_mock +): + plugin_instance._in_read_write_split = False plugin_service_mock.current_connection = writer_conn_mock - plugin_service_mock.current_host_info = writer_host + plugin_service_mock.current_host_info = simple_writer_host - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock - plugin._reader_connection = None + suggestion = plugin_instance.notify_connection_changed(set()) + assert suggestion == OldConnectionSuggestedAction.NO_OPINION + assert plugin_instance._writer_connection == writer_conn_mock - plugin._switch_connection_if_required(False) + plugin_instance._writer_connection = None + plugin_instance._in_read_write_split = True + suggestion = plugin_instance.notify_connection_changed(set()) + assert suggestion == OldConnectionSuggestedAction.PRESERVE + assert plugin_instance._writer_connection == writer_conn_mock - plugin_service_mock.set_current_connection.assert_not_called() - assert plugin._writer_connection == writer_conn_mock - assert plugin._reader_connection is None +def test_connect_non_initial_connection( + plugin_instance, plugin_config, connect_func_mock, writer_conn_mock, mocker +): + plugin_name, _ = plugin_config + connect_func_mock.return_value = writer_conn_mock + if plugin_name == "read_write_splitting": + plugin_instance._writer_connection = writer_conn_mock + plugin_instance._writer_host_info = writer_host + plugin_instance._reader_connection = None + + conn = plugin_instance.connect( + mocker.MagicMock(), + mocker.MagicMock(), + writer_host, + Properties(), + False, + connect_func_mock, + ) + + assert conn == writer_conn_mock + connect_func_mock.assert_called() + else: + result = plugin_instance.connect( + None, None, simple_writer_host, Properties(), False, connect_func_mock + ) + + assert result == writer_conn_mock + connect_func_mock.assert_called_once() + + +def test_close_pooled_reader_connection_after_set_read_only( + plugin_props, + plugin_service_mock, + plugin_config, + mocker, + reader_conn_mock, + writer_conn_mock, +): + plugin_name, plugin_class = plugin_config + + def connect_side_effect(host: HostInfo, props, plugin): + if ( + host in [reader_host1, reader_host2, reader_host3] + or host.host == READ_ENDPOINT + ): + return reader_conn_mock + elif host == writer_host or host.host == WRITE_ENDPOINT: + return writer_conn_mock + return None -def test_set_read_only_false_in_transaction(plugin_service_mock): - plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.current_host_info = reader_host1 - plugin_service_mock.is_in_transaction = True + plugin_service_mock.connect.side_effect = connect_side_effect + plugin_service_mock.current_host_info = mocker.MagicMock( + side_effect=[writer_host, writer_host, reader_host1] + ) - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = None - plugin._reader_connection = reader_conn_mock + if plugin_name == "read_write_splitting": + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + else: + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + provider = SqlAlchemyPooledConnectionProvider( + lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes + ) # 10 minutes + + conn_provider_manager_mock = mocker.MagicMock() + conn_provider_manager_mock.get_connection_provider.return_value = provider + plugin_service_mock.get_connection_provider_manager.return_value = ( + conn_provider_manager_mock + ) + + plugin = plugin_class(plugin_service_mock, plugin_props) + + spy = mocker.spy(plugin, "_close_connection_if_idle") plugin._switch_connection_if_required(True) + plugin._switch_connection_if_required(False) - plugin_service_mock.set_current_connection.assert_not_called() + spy.assert_called_once_with(reader_conn_mock) + assert spy.call_count == 1 - assert plugin._reader_connection == reader_conn_mock - assert plugin._writer_connection is None +def test_close_pooled_writer_connection_after_set_read_only( + plugin_service_mock, + plugin_config, + plugin_props, + mocker, + reader_conn_mock, + writer_conn_mock, +): + plugin_name, plugin_class = plugin_config + + def connect_side_effect(host: HostInfo, props, plugin): + if ( + host in [reader_host1, reader_host2, reader_host3] + or host.host == READ_ENDPOINT + ): + return reader_conn_mock + elif host == writer_host or host.host == WRITE_ENDPOINT: + return writer_conn_mock + return None -def test_set_read_only_true_one_host(plugin_service_mock): - plugin_service_mock.hosts = [writer_host] + plugin_service_mock.connect.side_effect = connect_side_effect + plugin_service_mock.current_host_info = mocker.MagicMock( + side_effect=[writer_host, writer_host, reader_host1, reader_host1, writer_host] + ) - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock + if plugin_name == "read_write_splitting": + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + else: + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) - plugin._switch_connection_if_required(True) + provider = SqlAlchemyPooledConnectionProvider( + lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes + ) # 10 minutes - plugin_service_mock.set_current_connection.assert_not_called() + conn_provider_manager_mock = mocker.MagicMock() + conn_provider_manager_mock.get_connection_provider.return_value = provider + plugin_service_mock.get_connection_provider_manager.return_value = ( + conn_provider_manager_mock + ) - assert plugin._writer_connection == writer_conn_mock - assert plugin._reader_connection is None + plugin = plugin_class(plugin_service_mock, plugin_props) + spy = mocker.spy(plugin, "_close_connection_if_idle") -def test_set_read_only_false_writer_connection_fails(plugin_service_mock): - def connect_side_effect(host_info: HostInfo, props: Properties): - if host_info == writer_host and props == default_props: + plugin._switch_connection_if_required(True) + plugin._switch_connection_if_required(False) + plugin._switch_connection_if_required(True) + + spy.assert_called_with(writer_conn_mock) + assert spy.call_count == 2 + + +def test_set_read_only_false_writer_connection_fails( + plugin_instance, plugin_service_mock, reader_conn_mock +): + def connect_side_effect(host_info: HostInfo, props: Properties, plugin): + if ( + host_info == writer_host or host_info.host == WRITE_ENDPOINT + ) and props == Properties(): raise Error("Connection Error") plugin_service_mock.connect.side_effect = connect_side_effect plugin_service_mock.current_connection = reader_conn_mock plugin_service_mock.current_host_info = reader_host1 plugin_service_mock.hosts = single_reader_topology - - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = None - plugin._reader_connection = reader_conn_mock + plugin_instance._writer_connection = None + plugin_instance._reader_connection = reader_conn_mock with pytest.raises(Error): - plugin._switch_connection_if_required(False) + plugin_instance._switch_connection_if_required(False) plugin_service_mock.set_current_connection.assert_not_called() -def test_set_read_only_true_reader_connection_failed(plugin_service_mock): - def connect_side_effect(host_info: HostInfo, props: Properties): - if ((host_info == reader_host1 or host_info == reader_host2 or host_info == reader_host3) - and props == default_props): +def test_set_read_only_true_reader_connection_failed( + plugin_instance, plugin_service_mock, writer_conn_mock +): + def connect_side_effect(host_info: HostInfo, props: Properties, plugin): + if ( + host_info == reader_host1 + or host_info == reader_host2 + or host_info == reader_host3 + ) and props == Properties(): raise Error("Connection Error") plugin_service_mock.connect.side_effect = connect_side_effect + plugin_instance._writer_connection = writer_conn_mock + plugin_instance._writer_host_info = writer_host + plugin_instance._reader_connection = None - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock - plugin._reader_connection = None - - plugin._switch_connection_if_required(True) - - plugin_service_mock.set_current_connection.assert_not_called() - - assert plugin._reader_connection is None - - -def test_set_read_only_on_closed_connection(plugin_service_mock): - plugin_service_mock.current_connection = closed_writer_conn_mock - - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = closed_writer_conn_mock - plugin._reader_connection = None - - with pytest.raises(Error): - plugin._switch_connection_if_required(True) + plugin_instance._switch_connection_if_required(True) plugin_service_mock.set_current_connection.assert_not_called() - - assert plugin._reader_connection is None + assert plugin_instance._reader_connection is None -def test_execute_failover_to_new_writer(plugin_service_mock, writer_conn_mock): +def test_execute_failover_to_new_writer( + plugin_instance, plugin_service_mock, writer_conn_mock, new_writer_conn_mock +): def execute_func(): raise FailoverSuccessError plugin_service_mock.current_connection = new_writer_conn_mock - - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock - plugin._reader_connection = None + plugin_instance._writer_connection = writer_conn_mock + plugin_instance._writer_host_info = writer_host + plugin_instance._reader_connection = None with pytest.raises(Error): - plugin.execute(None, "Statement.execute_query", execute_func) + plugin_instance.execute(None, "Statement.execute_query", execute_func) writer_conn_mock.close.assert_called_once() -def test_notify_connection_change(plugin_service_mock): - plugin_service_mock.current_connection = writer_conn_mock - plugin_service_mock.current_host_info = writer_host - - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - - suggestion = plugin.notify_connection_changed(changes_mock) - - assert suggestion == OldConnectionSuggestedAction.NO_OPINION - assert plugin._writer_connection == writer_conn_mock - - -def test_connect_non_initial_connection( - mocker, plugin_service_mock, connect_func_mock, host_list_provider_service_mock): - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - plugin._writer_connection = writer_conn_mock - plugin._reader_connection = None - - connect_func_mock.return_value = writer_conn_mock - conn = plugin.connect( - mocker.MagicMock(), mocker.MagicMock(), writer_host, default_props, False, connect_func_mock) - - assert conn == writer_conn_mock - - connect_func_mock.assert_called() - host_list_provider_service_mock.initial_connection_host_info.assert_not_called() - - -def test_connect_incorrect_host_role(mocker, plugin_service_mock, connect_func_mock, host_list_provider_service_mock): +def test_connect_incorrect_host_role( + plugin_instance, + plugin_service_mock, + plugin_config, + mocker, + connect_func_mock, + reader_conn_mock, + host_list_provider_service_mock, +): + plugin_name, _ = plugin_config reader_host_incorrect_role = HostInfo(host="instance-4", role=HostRole.WRITER) def get_host_role_side_effect(conn): @@ -313,23 +524,60 @@ def get_host_role_side_effect(conn): plugin_service_mock.initial_connection_host_info = reader_host_incorrect_role host_list_provider_service_mock.is_static_host_list_provider.return_value = False - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - connect_func_mock.return_value = reader_conn_mock - conn = plugin.connect( - mocker.MagicMock(), mocker.MagicMock(), writer_host, default_props, True, connect_func_mock) + conn = plugin_instance.connect( + mocker.MagicMock(), + mocker.MagicMock(), + writer_host, + Properties(), + True, + connect_func_mock, + ) assert conn == reader_conn_mock connect_func_mock.assert_called() updated_host = host_list_provider_service_mock.initial_connection_host_info - assert updated_host.host == reader_host_incorrect_role.host - assert updated_host.role != reader_host_incorrect_role.role - assert updated_host.role == HostRole.READER + if plugin_name == "read_write_splitting": + assert updated_host.host == reader_host_incorrect_role.host + assert updated_host.role != reader_host_incorrect_role.role + assert updated_host.role == HostRole.READER + else: + assert updated_host == writer_host + + +# Tests for the Read/Write Splitting Plugin +def test_set_read_only_true_one_host( + plugin_service_mock, plugin_config, plugin_instance +): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + pytest.skip("Test only applies to Read/Write Splitting Plugin") + plugin_service_mock.hosts = [writer_host] + + plugin_instance._writer_connection = writer_conn_mock + plugin_instance._writer_host_info = writer_host + + plugin_instance._switch_connection_if_required(True) + + plugin_service_mock.set_current_connection.assert_not_called() + assert plugin_instance._writer_connection == writer_conn_mock + assert plugin_instance._reader_connection is None + + +def test_connect_error_updating_host( + plugin_service_mock, + plugin_config, + plugin_instance, + host_list_provider_service_mock, + connect_func_mock, + mocker, +): + plugin_name, _ = plugin_config + if plugin_name != "read_write_splitting": + pytest.skip("Test only applies to Read/Write Splitting Plugin") -def test_connect_error_updating_host(mocker, plugin_service_mock, connect_func_mock, host_list_provider_service_mock): def get_host_role_side_effect(conn): if conn == reader_conn_mock: return None @@ -338,81 +586,328 @@ def get_host_role_side_effect(conn): plugin_service_mock.get_host_role.side_effect = get_host_role_side_effect host_list_provider_service_mock.is_static_host_list_provider.return_value = False - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) - plugin._host_list_provider_service = host_list_provider_service_mock - connect_func_mock.return_value = reader_conn_mock with pytest.raises(Error): - plugin.connect( - mocker.MagicMock(), mocker.MagicMock(), writer_host, default_props, True, connect_func_mock) + plugin_instance.connect( + mocker.MagicMock(), + mocker.MagicMock(), + writer_host, + Properties(), + True, + connect_func_mock, + ) host_list_provider_service_mock.initial_connection_host_info.assert_not_called() -def test_close_pooled_reader_connection_after_set_read_only(mocker, plugin_service_mock): - def connect_side_effect(host, props, plugin): - if host in [reader_host1, reader_host2, reader_host3]: - return reader_conn_mock - elif host == writer_host: - return writer_conn_mock - return None +# Tests for the Simple Read/Write Splitting Plugin +def test_constructor_missing_write_endpoint(plugin_service_mock, plugin_config): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - plugin_service_mock.connect.side_effect = connect_side_effect - plugin_service_mock.current_host_info = mocker.MagicMock(side_effect=[writer_host, writer_host, reader_host1]) - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + props = Properties() + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + # Missing write endpoint - provider = SqlAlchemyPooledConnectionProvider( - lambda _, __: {"pool_size": 3}, - None, - 180000000000, # 3 minutes - 600000000000) # 10 minutes + with pytest.raises(AwsWrapperError): + SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - conn_provider_manager_mock = mocker.MagicMock() - conn_provider_manager_mock.get_connection_provider.return_value = provider - plugin_service_mock.get_connection_provider_manager.return_value = conn_provider_manager_mock - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) +def test_constructor_missing_read_endpoint(plugin_service_mock, plugin_config): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - spy = mocker.spy(plugin, "_close_connection_if_idle") + props = Properties() + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + # Missing read endpoint - plugin._switch_connection_if_required(True) - plugin._switch_connection_if_required(False) + with pytest.raises(AwsWrapperError): + SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - spy.assert_called_once_with(reader_conn_mock) - assert spy.call_count == 1 +def test_constructor_invalid_initial_connection_type( + plugin_service_mock, plugin_config +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") -def test_close_pooled_writer_connection_after_set_read_only(mocker, plugin_service_mock): - def connect_side_effect(host, props, plugin): - if host in [reader_host1, reader_host2, reader_host3]: - return reader_conn_mock - elif host == writer_host: - return writer_conn_mock - return None + props = Properties() + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.name] = ( + "other" # "writer", "reader" are the only valid options + ) - plugin_service_mock.connect.side_effect = connect_side_effect - plugin_service_mock.current_host_info = ( - mocker.MagicMock(side_effect=[writer_host, writer_host, reader_host1, reader_host1, writer_host])) - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + with pytest.raises(ValueError): + SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - provider = SqlAlchemyPooledConnectionProvider( - lambda _, __: {"pool_size": 3}, - None, - 180000000000, # 3 minutes - 600000000000) # 10 minutes - conn_provider_manager_mock = mocker.MagicMock() - conn_provider_manager_mock.get_connection_provider.return_value = provider - plugin_service_mock.get_connection_provider_manager.return_value = conn_provider_manager_mock +def test_connect_verification_disabled( + plugin_service_mock, plugin_config, connect_func_mock, writer_conn_mock +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - plugin = ReadWriteSplittingPlugin(plugin_service_mock, default_props) + props = Properties() + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.name] = False - spy = mocker.spy(plugin, "_close_connection_if_idle") + connect_func_mock.return_value = writer_conn_mock - plugin._switch_connection_if_required(True) - plugin._switch_connection_if_required(False) - plugin._switch_connection_if_required(True) + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, props) - spy.assert_called_with(writer_conn_mock) - assert spy.call_count == 2 + result = plugin.connect( + None, None, simple_writer_host, props, True, connect_func_mock + ) + + assert result == writer_conn_mock + connect_func_mock.assert_called_once() + + +def test_connect_writer_cluster_endpoint( + plugin_config, + plugin_instance, + plugin_service_mock, + plugin_props, + connect_func_mock, + writer_conn_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + writer_cluster_host = HostInfo( + host="test-cluster.cluster-xyz.us-east-1.rds.amazonaws.com", + port=TEST_PORT, + role=HostRole.WRITER, + ) + + connect_func_mock.return_value = writer_conn_mock + plugin_service_mock.get_host_role.return_value = HostRole.WRITER + + result = plugin_instance.connect( + None, None, writer_cluster_host, plugin_props, True, connect_func_mock + ) + + assert result == writer_conn_mock + connect_func_mock.assert_called_once() + assert plugin_service_mock.get_host_role.call_count == 1 + + +def test_connect_reader_cluster_endpoint( + plugin_config, + plugin_instance, + plugin_props, + plugin_service_mock, + connect_func_mock, + reader_conn_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + reader_cluster_host = HostInfo( + host="test-cluster.cluster-ro-xyz.us-east-1.rds.amazonaws.com", + port=TEST_PORT, + role=HostRole.READER, + ) + + connect_func_mock.return_value = reader_conn_mock + plugin_service_mock.get_host_role.return_value = HostRole.READER + + result = plugin_instance.connect( + None, None, reader_cluster_host, plugin_props, True, connect_func_mock + ) + + assert result == reader_conn_mock + connect_func_mock.assert_called_once() + assert plugin_service_mock.get_host_role.call_count == 1 + + +def test_connect_verification_fails_fallback( + plugin_config, + plugin_service_mock, + connect_func_mock, + writer_conn_mock, + host_list_provider_service_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + writer_cluster_host = HostInfo( + host="test-cluster.cluster-xyz.us-east-1.rds.amazonaws.com", + port=TEST_PORT, + role=HostRole.WRITER, + ) + + props = Properties() + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name] = "5" # Short timeout + props[WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name] = ( + "6" # Interval > timeout ensures only one call before fallsback + ) + + connect_func_mock.return_value = writer_conn_mock + plugin_service_mock.get_host_role.return_value = HostRole.READER # Wrong role + + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, props) + plugin._connection_handler.host_list_provider_service = ( + host_list_provider_service_mock + ) + + result = plugin.connect( + None, None, writer_cluster_host, props, True, connect_func_mock + ) + + assert result == writer_conn_mock + assert connect_func_mock.call_count == 2 + assert plugin_service_mock.get_host_role.call_count == 1 + + +def test_connect_non_rds_cluster_endpoint( + plugin_config, + plugin_instance, + plugin_service_mock, + plugin_props, + connect_func_mock, + writer_conn_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + custom_host = HostInfo( + host="custom-db.example.com", port=TEST_PORT, role=HostRole.WRITER + ) + + connect_func_mock.return_value = writer_conn_mock + + result = plugin_instance.connect( + None, None, custom_host, plugin_props, True, connect_func_mock + ) + + assert result == writer_conn_mock + connect_func_mock.assert_called_once() + assert plugin_service_mock.get_host_role.call_count == 0 + + +def test_connect_non_rds_cluster_endpoint_with_verification( + plugin_config, + plugin_service_mock, + plugin_props, + connect_func_mock, + writer_conn_mock, + mocker, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + custom_host = HostInfo( + host="custom-db.example.com", port=TEST_PORT, role=HostRole.WRITER + ) + connect_func_mock.return_value = writer_conn_mock + + props = Properties() + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.name] = ( + "writer" # Forces verification to a writer + ) + + connect_func_mock.return_value = writer_conn_mock + plugin_service_mock.get_host_role = mocker.MagicMock( + side_effect=[HostRole.READER, HostRole.WRITER] + ) + + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, props) + plugin._connection_handler.host_list_provider_service = ( + host_list_provider_service_mock + ) + + result = plugin.connect( + None, None, custom_host, plugin_props, True, connect_func_mock + ) + + assert result == writer_conn_mock + assert connect_func_mock.call_count == 2 + assert plugin_service_mock.get_host_role.call_count == 2 + + +def test_wrong_role_connection_writer_endpoint_to_reader( + plugin_service_mock, reader_conn_mock, plugin_config, plugin_instance +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + plugin_service_mock.current_connection = reader_conn_mock + plugin_service_mock.current_host_info = simple_reader_host + plugin_service_mock.connect.return_value = reader_conn_mock + plugin_service_mock.get_host_role.return_value = ( + HostRole.READER + ) # Wrong role for writer + + with pytest.raises(ReadWriteSplittingError): + plugin_instance._switch_connection_if_required(False) + + +def test_get_verified_connection_wrong_role_retry_reader( + plugin_config, + plugin_instance, + plugin_service_mock, + reader_conn_mock, + writer_conn_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.current_host_info = simple_writer_host + + # First call returns wrong role, second call returns correct role + plugin_service_mock.connect.side_effect = [writer_conn_mock, reader_conn_mock] + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + plugin_instance._switch_connection_if_required(True) + + assert plugin_service_mock.connect.call_count == 2 + writer_conn_mock.close.assert_called_once() + + +def test_get_verified_connection_sql_exception_retry( + plugin_config, + plugin_instance, + plugin_service_mock, + reader_conn_mock, + writer_conn_mock, +): + plugin_name, _ = plugin_config + if plugin_name != "srw": + pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") + + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.current_host_info = simple_writer_host + + # First call raises exception, second call succeeds + plugin_service_mock.connect.side_effect = [ + Error("Connection failed"), + reader_conn_mock, + ] + plugin_service_mock.get_host_role.return_value = HostRole.READER + + plugin_instance._switch_connection_if_required(True) + + assert plugin_service_mock.connect.call_count == 2 + assert plugin_instance._reader_connection == reader_conn_mock From c7253b55c2fee1e2b65b15a366acbed4dfc22d10 Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:39:33 -0800 Subject: [PATCH 2/3] refactor: config properties retrieval and validation (#1053) --- .../simple_read_write_splitting_plugin.py | 84 +++++++++---------- .../utils/properties.py | 46 +++++----- tests/unit/test_properties_utils.py | 76 ++++++++++++++++- 3 files changed, 140 insertions(+), 66 deletions(-) diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py index bbf503c1..677a540b 100644 --- a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -15,7 +15,7 @@ from __future__ import annotations from time import perf_counter_ns, sleep -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Type from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.read_write_splitting_plugin import ( @@ -28,7 +28,7 @@ from aws_advanced_python_wrapper.host_list_provider import HostListProviderService from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService - from aws_advanced_python_wrapper.utils.properties import Properties + from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperty from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole @@ -41,53 +41,24 @@ class EndpointBasedConnectionHandler(ConnectionHandler): """Endpoint based implementation of connection handling logic.""" def __init__(self, plugin_service: PluginService, props: Properties): - srw_read_endpoint = WrapperProperties.SRW_READ_ENDPOINT.get(props) - if srw_read_endpoint is None: - raise AwsWrapperError( - Messages.get_formatted( - "SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter", - WrapperProperties.SRW_READ_ENDPOINT.name, - ) - ) - self._read_endpoint: str = srw_read_endpoint - - srw_write_endpoint = WrapperProperties.SRW_WRITE_ENDPOINT.get(props) - if srw_write_endpoint is None: - raise AwsWrapperError( - Messages.get_formatted( - "SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter", - WrapperProperties.SRW_WRITE_ENDPOINT.name, - ) - ) - self._write_endpoint: str = srw_write_endpoint + self._read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter( + WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True + ) + self._write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter( + WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True + ) - self._verify_new_connections: bool = ( - WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.get_bool(props) + self._verify_new_connections: bool = EndpointBasedConnectionHandler._verify_parameter( + WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS, props, bool ) - if self._verify_new_connections is True: - srw_connect_retry_timeout_ms: int = ( - WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.get_int(props) - ) - if srw_connect_retry_timeout_ms <= 0: - raise ValueError( - Messages.get_formatted( - "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", - WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name, - ) - ) - self._connect_retry_timeout_ms: int = srw_connect_retry_timeout_ms - srw_connect_retry_interval_ms: int = ( - WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.get_int(props) + if self._verify_new_connections: + self._connect_retry_timeout_ms: int = EndpointBasedConnectionHandler._verify_parameter( + WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS, props, int, lambda x: x > 0 + ) + self._connect_retry_interval_ms: int = EndpointBasedConnectionHandler._verify_parameter( + WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0 ) - if srw_connect_retry_interval_ms <= 0: - raise ValueError( - Messages.get_formatted( - "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", - WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name, - ) - ) - self._connect_retry_interval_ms: int = srw_connect_retry_interval_ms self._verify_opened_connection_type: Optional[HostRole] = ( EndpointBasedConnectionHandler._parse_connection_type( @@ -305,6 +276,29 @@ def _create_host_info(self, endpoint, role: HostRole) -> HostInfo: host=host, port=port, role=role, availability=HostAvailability.AVAILABLE ) + T = TypeVar('T') + + @staticmethod + def _verify_parameter(prop: WrapperProperty, props: Properties, expected_type: Type[T], validator=None, required=False): + value = prop.get_type(props, expected_type) + if required: + if value is None: + raise AwsWrapperError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter", + prop.name, + ) + ) + + if validator and not validator(value): + raise ValueError( + Messages.get_formatted( + "SimpleReadWriteSplittingPlugin.IncorrectConfiguration", + prop.name, + ) + ) + return value + def _delay(self): sleep(self._connect_retry_interval_ms / 1000) diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 4bbde03e..d1e9c0d1 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import copy -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TypeVar, Type from urllib.parse import unquote from aws_advanced_python_wrapper.errors import AwsWrapperError @@ -25,6 +26,9 @@ def put_if_absent(self, key: str, value: Any): self[key] = value +T = TypeVar('T') + + class WrapperProperty: def __init__( self, name: str, description: str, default_value: Optional[Any] = None @@ -34,41 +38,43 @@ def __init__( self.description = description def __str__(self): - return f"WrapperProperty(name={self.name}, default_value={self.default_value}" + return f"WrapperProperty(name={self.name}, default_value={self.default_value})" def get(self, props: Properties) -> Optional[str]: if self.default_value: return props.get(self.name, self.default_value) return props.get(self.name) + def get_type(self, props: Properties, type_class: Type[T]) -> T: + value = props.get(self.name, self.default_value) if self.default_value else props.get(self.name) + if value is None: + if type_class == int: + return -1 # type: ignore + elif type_class == float: + return -1.0 # type: ignore + elif type_class == bool: + return False # type: ignore + else: + return None # type: ignore + if type_class == bool: + if isinstance(value, bool): + return value # type: ignore + return value.lower() == "true" if isinstance(value, str) else bool(value) # type: ignore + return type_class(value) # type: ignore + def get_or_default(self, props: Properties) -> str: if not self.default_value: raise ValueError(f"No default value found for property {self}") return props.get(self.name, self.default_value) def get_int(self, props: Properties) -> int: - if self.default_value: - return int(props.get(self.name, self.default_value)) - - val = props.get(self.name) - return int(val) if val else -1 + return self.get_type(props, int) def get_float(self, props: Properties) -> float: - if self.default_value: - return float(props.get(self.name, self.default_value)) - - val = props.get(self.name) - return float(val) if val else -1 + return self.get_type(props, float) def get_bool(self, props: Properties) -> bool: - if not self.default_value: - value = props.get(self.name) - else: - value = props.get(self.name, self.default_value) - if isinstance(value, bool): - return value - else: - return value is not None and value.lower() == "true" + return self.get_type(props, bool) def set(self, props: Properties, value: Any): props[self.name] = value diff --git a/tests/unit/test_properties_utils.py b/tests/unit/test_properties_utils.py index eae7ab26..3834352e 100644 --- a/tests/unit/test_properties_utils.py +++ b/tests/unit/test_properties_utils.py @@ -16,7 +16,8 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.properties import (Properties, - PropertiesUtils) + PropertiesUtils, + WrapperProperty) @pytest.mark.parametrize( @@ -93,3 +94,76 @@ def test_create_monitoring_properties(expected, test_props): props_copy = test_props.copy() props_copy = PropertiesUtils.create_monitoring_properties(props_copy) assert expected == props_copy + + +@pytest.mark.parametrize("expected, props, type_class", [ + # Int type tests + (123, Properties({"test_prop": "123"}), int), + (-1, Properties(), int), + (456, Properties({"test_prop": 456}), int), + + # Float type tests + (12.5, Properties({"test_prop": "12.5"}), float), + (-1.0, Properties(), float), + (3.14, Properties({"test_prop": 3.14}), float), + + # Bool type tests + (True, Properties({"test_prop": "true"}), bool), + (True, Properties({"test_prop": "TRUE"}), bool), + (False, Properties({"test_prop": "false"}), bool), + (True, Properties({"test_prop": True}), bool), + (False, Properties({"test_prop": False}), bool), + (False, Properties(), bool), + + # String type tests + ("test_value", Properties({"test_prop": "test_value"}), str), + (None, Properties(), str), + ("", Properties({"test_prop": ""}), str), +]) +def test_get_type(expected, props, type_class): + wrapper_prop = WrapperProperty("test_prop", "Test property") + result = wrapper_prop.get_type(props, type_class) + assert result == expected + + +def test_get_type_with_default(): + wrapper_prop = WrapperProperty("test_prop", "Test property", "default_value") + props = Properties() + result = wrapper_prop.get_type(props, str) + assert result == "default_value" + + +@pytest.mark.parametrize("expected, props", [ + (123, Properties({"test_prop": "123"})), + (-1, Properties()), + (456, Properties({"test_prop": 456})), +]) +def test_get_int(expected, props): + wrapper_prop = WrapperProperty("test_prop", "Test property") + result = wrapper_prop.get_int(props) + assert result == expected + + +@pytest.mark.parametrize("expected, props", [ + (12.5, Properties({"test_prop": "12.5"})), + (-1.0, Properties()), + (3.14, Properties({"test_prop": 3.14})), +]) +def test_get_float(expected, props): + wrapper_prop = WrapperProperty("test_prop", "Test property") + result = wrapper_prop.get_float(props) + assert result == expected + + +@pytest.mark.parametrize("expected, props", [ + (True, Properties({"test_prop": "true"})), + (True, Properties({"test_prop": "TRUE"})), + (False, Properties({"test_prop": "false"})), + (True, Properties({"test_prop": True})), + (False, Properties({"test_prop": False})), + (False, Properties()), +]) +def test_get_bool(expected, props): + wrapper_prop = WrapperProperty("test_prop", "Test property") + result = wrapper_prop.get_bool(props) + assert result == expected From ed6384eedd59c337984cb1ce649a47a2f50dbc09 Mon Sep 17 00:00:00 2001 From: Sophia Chu Date: Fri, 5 Dec 2025 15:51:46 -0800 Subject: [PATCH 3/3] refactor: parametrize unit tests --- .../read_write_splitting_plugin.py | 54 +- .../simple_read_write_splitting_plugin.py | 86 +- .../utils/properties.py | 7 +- .../unit/test_read_write_splitting_plugin.py | 750 ++++++++---------- 4 files changed, 372 insertions(+), 525 deletions(-) diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 125bbb5f..1baa1c49 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -50,6 +50,7 @@ class ReadWriteSplittingConnectionManager(Plugin): "Connection.set_read_only", } _POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider" + _CLOSE_METHOD = "Connection.close" def __init__( self, @@ -94,8 +95,7 @@ def connect( connect_func: Callable, ) -> Connection: return self._connection_handler.get_verified_initial_connection( - host_info, props, is_initial_connection, connect_func - ) + host_info, is_initial_connection, lambda x: self._plugin_service.connect(x, props, self), connect_func) def notify_connection_changed( self, changes: Set[ConnectionEvent] @@ -140,13 +140,13 @@ def execute( if isinstance(ex, FailoverError): logger.debug( "ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand", - method_name, + method_name ) self._close_idle_connections() else: logger.debug( "ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand", - method_name, + method_name ) raise ex @@ -184,13 +184,13 @@ def _set_reader_connection( ) def _initialize_writer_connection(self): - conn, writer_host = self._connection_handler.open_new_writer_connection() + conn, writer_host = self._connection_handler.open_new_writer_connection(lambda x: self._plugin_service.connect(x, self._properties, self)) if conn is None: self.log_and_raise_exception( "ReadWriteSplittingPlugin.FailedToConnectToWriter" ) - return + return None provider = self._conn_provider_manager.get_connection_provider( writer_host, self._properties @@ -335,9 +335,7 @@ def _switch_to_reader_connection(self): self._reader_host_info.url, ) - ReadWriteSplittingConnectionManager.close_connection( - self._reader_connection - ) + ReadWriteSplittingConnectionManager.close_connection(self._reader_connection, driver_dialect) self._reader_connection = None self._reader_host_info = None self._initialize_reader_connection() @@ -356,7 +354,7 @@ def _initialize_reader_connection(self): ) return - conn, reader_host = self._connection_handler.open_new_reader_connection() + conn, reader_host = self._connection_handler.open_new_reader_connection(lambda x: self._plugin_service.connect(x, self._properties, self)) if conn is None or reader_host is None: self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") @@ -392,7 +390,7 @@ def _close_connection_if_idle(self, internal_conn: Optional[Connection]): if internal_conn != current_conn and self._is_connection_usable( internal_conn, driver_dialect ): - internal_conn.close() + driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: internal_conn.close()) if internal_conn == self._writer_connection: self._writer_connection = None self._writer_host_info = None @@ -420,10 +418,8 @@ def log_and_raise_exception(log_msg: str): raise ReadWriteSplittingError(Messages.get(log_msg)) @staticmethod - def _is_connection_usable( - conn: Optional[Connection], driver_dialect: Optional[DriverDialect] - ): - if conn is None or driver_dialect is None: + def _is_connection_usable(conn: Optional[Connection], driver_dialect: DriverDialect): + if conn is None: return False try: return not driver_dialect.is_closed(conn) @@ -432,10 +428,10 @@ def _is_connection_usable( return False @staticmethod - def close_connection(connection: Optional[Connection]): - if connection is not None: + def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect): + if conn is not None: try: - connection.close() + driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: conn.close()) except Exception: # Swallow exception return @@ -456,12 +452,14 @@ def host_list_provider_service(self, new_value: int) -> None: def open_new_writer_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: """Open a writer connection.""" ... def open_new_reader_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: """Open a reader connection.""" ... @@ -469,8 +467,8 @@ def open_new_reader_connection( def get_verified_initial_connection( self, host_info: HostInfo, - props: Properties, is_initial_connection: bool, + plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Callable, ) -> Connection: """Verify initial connection or return normal workflow.""" @@ -516,9 +514,8 @@ class TopologyBasedConnectionHandler(ConnectionHandler): def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service: PluginService = plugin_service - self._properties: Properties = props self._host_list_provider_service: Optional[HostListProviderService] = None - strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties) + strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props) if strategy is not None: self._reader_selector_strategy = strategy else: @@ -539,17 +536,19 @@ def host_list_provider_service(self, new_value: HostListProviderService) -> None def open_new_writer_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: writer_host = self._get_writer() if writer_host is None: return None, None - conn = self._plugin_service.connect(writer_host, self._properties, None) + conn = plugin_service_connect_func(writer_host) return conn, writer_host def open_new_reader_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: conn: Optional[Connection] = None reader_host: Optional[HostInfo] = None @@ -561,7 +560,7 @@ def open_new_reader_connection( ) if host is not None: try: - conn = self._plugin_service.connect(host, self._properties, None) + conn = plugin_service_connect_func(host) reader_host = host break except Exception: @@ -574,8 +573,8 @@ def open_new_reader_connection( def get_verified_initial_connection( self, host_info: HostInfo, - props: Properties, is_initial_connection: bool, + plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Callable, ) -> Connection: if not self._plugin_service.accepts_strategy( @@ -670,12 +669,9 @@ def _get_writer(self) -> Optional[HostInfo]: class ReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): - def __init__(self, plugin_service, props: Properties): + def __init__(self, plugin_service: PluginService, props: Properties): # The read/write splitting plugin handles connections based on topology. - connection_handler = TopologyBasedConnectionHandler( - plugin_service, - props, - ) + connection_handler = TopologyBasedConnectionHandler(plugin_service, props) super().__init__(plugin_service, props, connection_handler) diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py index 677a540b..43c81f73 100644 --- a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -15,7 +15,7 @@ from __future__ import annotations from time import perf_counter_ns, sleep -from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Type +from typing import TYPE_CHECKING, Callable, Optional, Type, TypeVar from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.read_write_splitting_plugin import ( @@ -67,15 +67,10 @@ def __init__(self, plugin_service: PluginService, props: Properties): ) self._plugin_service: PluginService = plugin_service - self._properties: Properties = props self._rds_utils: RdsUtils = RdsUtils() self._host_list_provider_service: Optional[HostListProviderService] = None - self._write_endpoint_host_info: HostInfo = self._create_host_info( - self._write_endpoint, HostRole.WRITER - ) - self._read_endpoint_host_info: HostInfo = self._create_host_info( - self._read_endpoint, HostRole.READER - ) + self._write_endpoint_host_info: HostInfo = self._create_host_info(self._write_endpoint, HostRole.WRITER) + self._read_endpoint_host_info: HostInfo = self._create_host_info(self._read_endpoint, HostRole.READER) @property def host_list_provider_service(self) -> Optional[HostListProviderService]: @@ -87,39 +82,29 @@ def host_list_provider_service(self, new_value: HostListProviderService) -> None def open_new_writer_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: - conn: Optional[Connection] = None if self._verify_new_connections: - conn = self._get_verified_connection( - self._properties, self._write_endpoint_host_info, HostRole.WRITER - ) - else: - conn = self._plugin_service.connect( - self._write_endpoint_host_info, self._properties, None - ) + return self._get_verified_connection(self._write_endpoint_host_info, HostRole.WRITER, plugin_service_connect_func), \ + self._write_endpoint_host_info - return conn, self._write_endpoint_host_info + return plugin_service_connect_func(self._write_endpoint_host_info), self._write_endpoint_host_info def open_new_reader_connection( self, + plugin_service_connect_func: Callable[[HostInfo], Connection], ) -> tuple[Optional[Connection], Optional[HostInfo]]: - conn: Optional[Connection] = None if self._verify_new_connections: - conn = self._get_verified_connection( - self._properties, self._read_endpoint_host_info, HostRole.READER - ) - else: - conn = self._plugin_service.connect( - self._read_endpoint_host_info, self._properties, None - ) + return self._get_verified_connection(self._read_endpoint_host_info, HostRole.READER, plugin_service_connect_func), \ + self._read_endpoint_host_info - return conn, self._read_endpoint_host_info + return plugin_service_connect_func(self._read_endpoint_host_info), self._read_endpoint_host_info def get_verified_initial_connection( self, host_info: HostInfo, - props: Properties, is_initial_connection: bool, + plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Callable, ) -> Connection: if not is_initial_connection or not self._verify_new_connections: @@ -133,24 +118,20 @@ def get_verified_initial_connection( url_type == RdsUrlType.RDS_WRITER_CLUSTER or self._verify_opened_connection_type == HostRole.WRITER ): - conn = self._get_verified_connection( - props, host_info, HostRole.WRITER, connect_func - ) + conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func) elif ( url_type == RdsUrlType.RDS_READER_CLUSTER or self._verify_opened_connection_type == HostRole.READER ): - conn = self._get_verified_connection( - props, host_info, HostRole.READER, connect_func - ) + conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func) if conn is None: conn = connect_func() - self._set_initial_connection_host_info(conn, host_info) + self._set_initial_connection_host_info(host_info) return conn - def _set_initial_connection_host_info(self, conn: Connection, host_info: HostInfo): + def _set_initial_connection_host_info(self, host_info: HostInfo): if self._host_list_provider_service is None: return @@ -158,9 +139,9 @@ def _set_initial_connection_host_info(self, conn: Connection, host_info: HostInf def _get_verified_connection( self, - props: Properties, host_info: HostInfo, role: HostRole, + plugin_service_connect_func: Callable[[HostInfo], Connection], connect_func: Optional[Callable] = None, ) -> Optional[Connection]: end_time_nano = perf_counter_ns() + (self._connect_retry_timeout_ms * 1000000) @@ -174,9 +155,7 @@ def _get_verified_connection( if connect_func is not None: candidate_conn = connect_func() elif host_info is not None: - candidate_conn = self._plugin_service.connect( - host_info, props, None - ) + candidate_conn = plugin_service_connect_func(host_info) else: return None @@ -187,14 +166,14 @@ def _get_verified_connection( actual_role = self._plugin_service.get_host_role(candidate_conn) if actual_role != role: - ReadWriteSplittingConnectionManager.close_connection(candidate_conn) + ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect) self._delay() continue return candidate_conn except Exception: - ReadWriteSplittingConnectionManager.close_connection(candidate_conn) + ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect) self._delay() return None @@ -249,28 +228,18 @@ def is_reader_host(self, current_host: HostInfo) -> bool: or current_host.url.casefold() == self._read_endpoint.casefold() ) - def _create_host_info(self, endpoint, role: HostRole) -> HostInfo: + def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo: endpoint = endpoint.strip() host = endpoint - port = self._plugin_service.database_dialect.default_port + port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \ + else self._plugin_service.current_host_info.port colon_index = endpoint.rfind(":") if colon_index != -1: + host = endpoint[:colon_index] port_str = endpoint[colon_index + 1:] if port_str.isdigit(): - host = endpoint[:colon_index] port = int(port_str) - else: - if ( - self._host_list_provider_service is not None - and self._host_list_provider_service.initial_connection_host_info - is not None - and self._host_list_provider_service.initial_connection_host_info.port - != HostInfo.NO_PORT - ): - port = ( - self._host_list_provider_service.initial_connection_host_info.port - ) return HostInfo( host=host, port=port, role=role, availability=HostAvailability.AVAILABLE @@ -322,12 +291,9 @@ def _parse_connection_type(phase_str: Optional[str]) -> HostRole: class SimpleReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager): - def __init__(self, plugin_service, props: Properties): + def __init__(self, plugin_service: PluginService, props: Properties): # The simple read/write splitting plugin handles connections based on configuration parameter endpoints. - connection_handler = EndpointBasedConnectionHandler( - plugin_service, - props, - ) + connection_handler = EndpointBasedConnectionHandler(plugin_service, props) super().__init__(plugin_service, props, connection_handler) diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index d1e9c0d1..9c48172f 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -13,7 +13,7 @@ # limitations under the License. import copy -from typing import Any, Dict, Optional, TypeVar, Type +from typing import Any, Dict, Optional, Type, TypeVar from urllib.parse import unquote from aws_advanced_python_wrapper.errors import AwsWrapperError @@ -558,13 +558,14 @@ class WrapperProperties: SRW_VERIFY_NEW_CONNECTIONS = WrapperProperty( "srw_verify_new_connections", - "Enables role-verification for new connections made by the Simple Read/Write Splitting Plugin..", + "Enables role-verification for new connections made by the Simple Read/Write Splitting Plugin.", True, ) SRW_VERIFY_INITIAL_CONNECTION_TYPE = WrapperProperty( "srw_verify_initial_connection_type", - "Force to verify an initial connection to be either a writer or a reader.", + "The role of the initial connection. Valid values are 'reader' or 'writer'. If this value is set, " + + "the wrapper will verify whether the initial connection matches the specified type.", None, ) diff --git a/tests/unit/test_read_write_splitting_plugin.py b/tests/unit/test_read_write_splitting_plugin.py index f87a2817..d80f4689 100644 --- a/tests/unit/test_read_write_splitting_plugin.py +++ b/tests/unit/test_read_write_splitting_plugin.py @@ -36,7 +36,7 @@ # Common test data WRITE_ENDPOINT = "writer.cluster-xyz.us-east-1.rds.amazonaws.com" -READ_ENDPOINT = "reader.cluster-xyz.us-east-1.rds.amazonaws.com" +READ_ENDPOINT = "reader.cluster-ro-xyz.us-east-1.rds.amazonaws.com" TEST_PORT = 5432 writer_host = HostInfo(host="instance0", role=HostRole.WRITER) @@ -54,34 +54,25 @@ # Plugin configurations -@pytest.fixture( - params=[ - ("read_write_splitting", ReadWriteSplittingPlugin), - ("srw", SimpleReadWriteSplittingPlugin), - ] -) -def plugin_config(request): - return request.param +@pytest.fixture +def props(): + return Properties() @pytest.fixture -def plugin_props(plugin_config): - plugin_name, _ = plugin_config +def srw_props(): props = Properties() - if plugin_name == "srw": - props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT - props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT - props[WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name] = "600" - props[WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name] = "10" + props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT + props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT + props[WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name] = "600" + props[WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name] = "10" return props @pytest.fixture -def host_list_provider_service_mock(mocker, plugin_config): - plugin_name, _ = plugin_config +def host_list_provider_service_mock(mocker): mock = mocker.MagicMock() - if plugin_name == "srw": - mock.initial_connection_host_info = simple_writer_host + mock.initial_connection_host_info = simple_writer_host return mock @@ -120,6 +111,7 @@ def is_closed_side_effect(conn): driver_dialect_mock.get_connection_from_obj.return_value = writer_conn_mock driver_dialect_mock.unwrap_connection.return_value = writer_conn_mock driver_dialect_mock.can_execute_query.return_value = True + driver_dialect_mock.execute.side_effect = lambda method, func: func() return driver_dialect_mock @@ -137,214 +129,188 @@ def plugin_service_mock(mocker, driver_dialect_mock, writer_conn_mock): @pytest.fixture -def plugin_instance( - plugin_config, plugin_service_mock, plugin_props, host_list_provider_service_mock -): - _, plugin_class = plugin_config - plugin = plugin_class(plugin_service_mock, plugin_props) - plugin._connection_handler._host_list_provider_service = ( - host_list_provider_service_mock - ) +def read_write_splitting_plugin(plugin_service_mock, props, host_list_provider_service_mock): + plugin = ReadWriteSplittingPlugin(plugin_service_mock, props) + plugin._connection_handler._host_list_provider_service = host_list_provider_service_mock return plugin -# Common tests for both plugins -def test_set_read_only_true( - plugin_instance, plugin_service_mock, plugin_config, reader_conn_mock -): - plugin_name, _ = plugin_config - plugin_service_mock.current_connection = writer_conn_mock - plugin_service_mock.connect.return_value = reader_conn_mock +@pytest.fixture +def srw_plugin(plugin_service_mock, srw_props, host_list_provider_service_mock): + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) + plugin._connection_handler._host_list_provider_service = host_list_provider_service_mock - if plugin_name == "read_write_splitting": - plugin_service_mock.current_host_info = writer_host - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 - plugin_service_mock.hosts = single_reader_topology - plugin_instance._reader_connection = None + return plugin - plugin_instance._switch_connection_if_required(True) - plugin_service_mock.set_current_connection.assert_called_once_with( - reader_conn_mock, reader_host1 - ) - assert plugin_instance._reader_connection == reader_conn_mock - else: - plugin_service_mock.current_host_info = simple_writer_host - plugin_service_mock.get_host_role.side_effect = lambda conn: ( - HostRole.READER if conn == reader_conn_mock else HostRole.WRITER - ) - plugin_instance._switch_connection_if_required(True) - plugin_service_mock.set_current_connection.assert_called_with( - reader_conn_mock, any_host - ) +# Tests for both plugins +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_set_read_only_false_in_transaction(request, plugin_service_mock, reader_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) - -def test_set_read_only_false( - plugin_instance, - plugin_service_mock, - plugin_config, - reader_conn_mock, - writer_conn_mock, -): - plugin_name, _ = plugin_config plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.connect.return_value = writer_conn_mock - - if plugin_name == "read_write_splitting": - plugin_service_mock.current_host_info = reader_host1 - plugin_service_mock.hosts = single_reader_topology - plugin_instance._writer_host_info = writer_host + plugin_service_mock.is_in_transaction = True + plugin_service_mock.current_host_info = simple_reader_host - plugin_instance._switch_connection_if_required(False) - plugin_service_mock.set_current_connection.assert_called_once_with( - writer_conn_mock, writer_host - ) - assert plugin_instance._writer_connection == writer_conn_mock - else: - plugin_service_mock.current_host_info = simple_reader_host - plugin_service_mock.get_host_role.side_effect = lambda conn: ( - HostRole.READER if conn == reader_conn_mock else HostRole.WRITER - ) + with pytest.raises(ReadWriteSplittingError): + plugin._switch_connection_if_required(False) - plugin_instance._switch_connection_if_required(False) - plugin_service_mock.set_current_connection.assert_called_with( - writer_conn_mock, any_host - ) +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_set_read_only_true_in_transaction_already_on_reader(request, plugin_service_mock, reader_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) -def test_set_read_only_true_already_on_reader( - plugin_instance, plugin_service_mock, plugin_config, reader_conn_mock -): - plugin_name, _ = plugin_config plugin_service_mock.current_connection = reader_conn_mock - plugin_instance._reader_connection = reader_conn_mock - - if plugin_name == "read_write_splitting": - plugin_service_mock.current_host_info = reader_host1 - else: - plugin_service_mock.current_host_info = simple_reader_host + plugin_service_mock.is_in_transaction = True + plugin_service_mock.current_host_info = simple_reader_host + plugin._writer_connection = None + plugin._reader_connection = reader_conn_mock - plugin_instance._switch_connection_if_required(True) + plugin._switch_connection_if_required(True) plugin_service_mock.set_current_connection.assert_not_called() + assert plugin._reader_connection == reader_conn_mock + assert plugin._writer_connection is None -def test_set_read_only_false_already_on_writer( - plugin_instance, plugin_service_mock, plugin_config, writer_conn_mock -): - plugin_name, _ = plugin_config - plugin_service_mock.current_connection = writer_conn_mock - plugin_instance._writer_connection = writer_conn_mock +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_set_read_only_on_closed_connection(request, plugin_service_mock, closed_writer_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) - if plugin_name == "read_write_splitting": - plugin_service_mock.current_host_info = writer_host - plugin_instance._writer_host_info = writer_host - else: - plugin_service_mock.current_host_info = simple_writer_host + plugin_service_mock.current_connection = closed_writer_conn_mock + plugin._writer_connection = closed_writer_conn_mock + plugin._reader_connection = None + + with pytest.raises(ReadWriteSplittingError): + plugin._switch_connection_if_required(True) - plugin_instance._switch_connection_if_required(False) plugin_service_mock.set_current_connection.assert_not_called() + assert plugin._reader_connection is None -def test_set_read_only_false_in_transaction( - plugin_instance, plugin_service_mock, reader_conn_mock -): - plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.is_in_transaction = True - plugin_service_mock.current_host_info = simple_reader_host +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_notify_connection_change(request, plugin_service_mock, writer_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) + + plugin._in_read_write_split = False + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.current_host_info = simple_writer_host + + suggestion = plugin.notify_connection_changed(set()) + assert suggestion == OldConnectionSuggestedAction.NO_OPINION + assert plugin._writer_connection == writer_conn_mock + + plugin._writer_connection = None + plugin._in_read_write_split = True + suggestion = plugin.notify_connection_changed(set()) + assert suggestion == OldConnectionSuggestedAction.PRESERVE + assert plugin._writer_connection == writer_conn_mock - with pytest.raises(ReadWriteSplittingError): - plugin_instance._switch_connection_if_required(False) +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_set_read_only_false_writer_connection_fails(request, plugin_service_mock, reader_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) -def test_set_read_only_true_in_transaction_already_on_reader( - plugin_instance, plugin_service_mock, reader_conn_mock -): + def connect_side_effect(host_info: HostInfo, props: Properties, plugin): + if ( + host_info == writer_host or host_info.host == WRITE_ENDPOINT + ) and props == Properties(): + raise Error("Connection Error") + + plugin_service_mock.connect.side_effect = connect_side_effect plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.is_in_transaction = True - plugin_service_mock.current_host_info = simple_reader_host - plugin_instance._writer_connection = None - plugin_instance._reader_connection = reader_conn_mock + plugin_service_mock.current_host_info = reader_host1 + plugin_service_mock.hosts = single_reader_topology + plugin._writer_connection = None + plugin._reader_connection = reader_conn_mock + + with pytest.raises(Error): + plugin._switch_connection_if_required(False) - plugin_instance._switch_connection_if_required(True) plugin_service_mock.set_current_connection.assert_not_called() - assert plugin_instance._reader_connection == reader_conn_mock - assert plugin_instance._writer_connection is None -def test_set_read_only_on_closed_connection( - plugin_instance, plugin_service_mock, closed_writer_conn_mock -): - plugin_service_mock.current_connection = closed_writer_conn_mock - plugin_instance._writer_connection = closed_writer_conn_mock - plugin_instance._reader_connection = None +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_set_read_only_true_reader_connection_failed(request, plugin_service_mock, writer_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) - with pytest.raises(ReadWriteSplittingError): - plugin_instance._switch_connection_if_required(True) + def connect_side_effect(host_info: HostInfo, props: Properties, plugin): + if ( + host_info == reader_host1 + or host_info == reader_host2 + or host_info == reader_host3 + ) and props == Properties(): + raise Error("Connection Error") + + plugin_service_mock.connect.side_effect = connect_side_effect + plugin._writer_connection = writer_conn_mock + plugin._writer_host_info = writer_host + plugin._reader_connection = None + + plugin._switch_connection_if_required(True) plugin_service_mock.set_current_connection.assert_not_called() - assert plugin_instance._reader_connection is None + assert plugin._reader_connection is None -def test_notify_connection_change( - plugin_instance, plugin_service_mock, writer_conn_mock -): - plugin_instance._in_read_write_split = False - plugin_service_mock.current_connection = writer_conn_mock - plugin_service_mock.current_host_info = simple_writer_host +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_execute_failover_to_new_writer(request, plugin_service_mock, writer_conn_mock, new_writer_conn_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) - suggestion = plugin_instance.notify_connection_changed(set()) - assert suggestion == OldConnectionSuggestedAction.NO_OPINION - assert plugin_instance._writer_connection == writer_conn_mock + def execute_func(): + raise FailoverSuccessError - plugin_instance._writer_connection = None - plugin_instance._in_read_write_split = True - suggestion = plugin_instance.notify_connection_changed(set()) - assert suggestion == OldConnectionSuggestedAction.PRESERVE - assert plugin_instance._writer_connection == writer_conn_mock + plugin_service_mock.current_connection = new_writer_conn_mock + plugin._writer_connection = writer_conn_mock + plugin._writer_host_info = writer_host + plugin._reader_connection = None + with pytest.raises(Error): + plugin.execute(None, "Statement.execute_query", execute_func) -def test_connect_non_initial_connection( - plugin_instance, plugin_config, connect_func_mock, writer_conn_mock, mocker -): - plugin_name, _ = plugin_config - connect_func_mock.return_value = writer_conn_mock + writer_conn_mock.close.assert_called_once() - if plugin_name == "read_write_splitting": - plugin_instance._writer_connection = writer_conn_mock - plugin_instance._writer_host_info = writer_host - plugin_instance._reader_connection = None - conn = plugin_instance.connect( - mocker.MagicMock(), - mocker.MagicMock(), - writer_host, - Properties(), - False, - connect_func_mock, - ) +@pytest.mark.parametrize("plugin_fixture", ["read_write_splitting_plugin", "srw_plugin"]) +def test_connect_incorrect_host_role( + request, plugin_service_mock, mocker, connect_func_mock, reader_conn_mock, host_list_provider_service_mock, plugin_fixture): + plugin = request.getfixturevalue(plugin_fixture) + reader_host_incorrect_role = HostInfo(host="instance-4", role=HostRole.WRITER) - assert conn == writer_conn_mock - connect_func_mock.assert_called() - else: - result = plugin_instance.connect( - None, None, simple_writer_host, Properties(), False, connect_func_mock - ) + def get_host_role_side_effect(conn): + if conn == reader_conn_mock: + return HostRole.READER + return HostRole.WRITER - assert result == writer_conn_mock - connect_func_mock.assert_called_once() + plugin_service_mock.get_host_role.side_effect = get_host_role_side_effect + plugin_service_mock.initial_connection_host_info = reader_host_incorrect_role + host_list_provider_service_mock.is_static_host_list_provider.return_value = False + connect_func_mock.return_value = reader_conn_mock + conn = plugin.connect( + mocker.MagicMock(), + mocker.MagicMock(), + writer_host, + Properties(), + True, + connect_func_mock, + ) -def test_close_pooled_reader_connection_after_set_read_only( - plugin_props, - plugin_service_mock, - plugin_config, - mocker, - reader_conn_mock, - writer_conn_mock, -): - plugin_name, plugin_class = plugin_config + assert conn == reader_conn_mock + connect_func_mock.assert_called() + updated_host = host_list_provider_service_mock.initial_connection_host_info + if plugin_fixture == "read_write_splitting_plugin": + assert updated_host.host == reader_host_incorrect_role.host + assert updated_host.role != reader_host_incorrect_role.role + assert updated_host.role == HostRole.READER + else: + assert updated_host == writer_host + + +@pytest.mark.parametrize("plugin_type", ["read_write_splitting_plugin", "srw_plugin"]) +def test_close_pooled_reader_connection_after_set_read_only( + props, srw_props, plugin_service_mock, mocker, reader_conn_mock, writer_conn_mock, plugin_type): def connect_side_effect(host: HostInfo, props, plugin): if ( host in [reader_host1, reader_host2, reader_host3] @@ -360,13 +326,6 @@ def connect_side_effect(host: HostInfo, props, plugin): side_effect=[writer_host, writer_host, reader_host1] ) - if plugin_name == "read_write_splitting": - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 - else: - plugin_service_mock.get_host_role.side_effect = lambda conn: ( - HostRole.READER if conn == reader_conn_mock else HostRole.WRITER - ) - provider = SqlAlchemyPooledConnectionProvider( lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes ) # 10 minutes @@ -377,7 +336,14 @@ def connect_side_effect(host: HostInfo, props, plugin): conn_provider_manager_mock ) - plugin = plugin_class(plugin_service_mock, plugin_props) + if plugin_type == "read_write_splitting_plugin": + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + plugin = ReadWriteSplittingPlugin(plugin_service_mock, props) + else: + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) spy = mocker.spy(plugin, "_close_connection_if_idle") @@ -388,16 +354,9 @@ def connect_side_effect(host: HostInfo, props, plugin): assert spy.call_count == 1 +@pytest.mark.parametrize("plugin_type", ["read_write_splitting_plugin", "srw_plugin"]) def test_close_pooled_writer_connection_after_set_read_only( - plugin_service_mock, - plugin_config, - plugin_props, - mocker, - reader_conn_mock, - writer_conn_mock, -): - plugin_name, plugin_class = plugin_config - + plugin_service_mock, props, srw_props, mocker, reader_conn_mock, writer_conn_mock, plugin_type): def connect_side_effect(host: HostInfo, props, plugin): if ( host in [reader_host1, reader_host2, reader_host3] @@ -413,13 +372,6 @@ def connect_side_effect(host: HostInfo, props, plugin): side_effect=[writer_host, writer_host, reader_host1, reader_host1, writer_host] ) - if plugin_name == "read_write_splitting": - plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 - else: - plugin_service_mock.get_host_role.side_effect = lambda conn: ( - HostRole.READER if conn == reader_conn_mock else HostRole.WRITER - ) - provider = SqlAlchemyPooledConnectionProvider( lambda _, __: {"pool_size": 3}, None, 180000000000, 600000000000 # 3 minutes ) # 10 minutes @@ -430,7 +382,14 @@ def connect_side_effect(host: HostInfo, props, plugin): conn_provider_manager_mock ) - plugin = plugin_class(plugin_service_mock, plugin_props) + if plugin_type == "read_write_splitting_plugin": + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 + plugin = ReadWriteSplittingPlugin(plugin_service_mock, props) + else: + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + plugin = SimpleReadWriteSplittingPlugin(plugin_service_mock, srw_props) spy = mocker.spy(plugin, "_close_connection_if_idle") @@ -442,142 +401,96 @@ def connect_side_effect(host: HostInfo, props, plugin): assert spy.call_count == 2 -def test_set_read_only_false_writer_connection_fails( - plugin_instance, plugin_service_mock, reader_conn_mock -): - def connect_side_effect(host_info: HostInfo, props: Properties, plugin): - if ( - host_info == writer_host or host_info.host == WRITE_ENDPOINT - ) and props == Properties(): - raise Error("Connection Error") +# Tests for the Read/Write Splitting Plugin +def test_set_read_only_true_read_write_splitting(read_write_splitting_plugin, plugin_service_mock, reader_conn_mock): + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.connect.return_value = reader_conn_mock - plugin_service_mock.connect.side_effect = connect_side_effect - plugin_service_mock.current_connection = reader_conn_mock - plugin_service_mock.current_host_info = reader_host1 + plugin_service_mock.current_host_info = writer_host + plugin_service_mock.get_host_info_by_strategy.return_value = reader_host1 plugin_service_mock.hosts = single_reader_topology - plugin_instance._writer_connection = None - plugin_instance._reader_connection = reader_conn_mock - - with pytest.raises(Error): - plugin_instance._switch_connection_if_required(False) - - plugin_service_mock.set_current_connection.assert_not_called() + read_write_splitting_plugin._reader_connection = None + read_write_splitting_plugin._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_called_once_with( + reader_conn_mock, reader_host1 + ) + assert read_write_splitting_plugin._reader_connection == reader_conn_mock -def test_set_read_only_true_reader_connection_failed( - plugin_instance, plugin_service_mock, writer_conn_mock -): - def connect_side_effect(host_info: HostInfo, props: Properties, plugin): - if ( - host_info == reader_host1 - or host_info == reader_host2 - or host_info == reader_host3 - ) and props == Properties(): - raise Error("Connection Error") - plugin_service_mock.connect.side_effect = connect_side_effect - plugin_instance._writer_connection = writer_conn_mock - plugin_instance._writer_host_info = writer_host - plugin_instance._reader_connection = None +def test_set_read_only_false_read_write_splitting( + read_write_splitting_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock,): + plugin_service_mock.current_connection = reader_conn_mock + plugin_service_mock.connect.return_value = writer_conn_mock - plugin_instance._switch_connection_if_required(True) + plugin_service_mock.current_host_info = reader_host1 + plugin_service_mock.hosts = single_reader_topology + read_write_splitting_plugin._writer_host_info = writer_host - plugin_service_mock.set_current_connection.assert_not_called() - assert plugin_instance._reader_connection is None + read_write_splitting_plugin._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_called_once_with( + writer_conn_mock, writer_host + ) + assert read_write_splitting_plugin._writer_connection == writer_conn_mock -def test_execute_failover_to_new_writer( - plugin_instance, plugin_service_mock, writer_conn_mock, new_writer_conn_mock -): - def execute_func(): - raise FailoverSuccessError +def test_set_read_only_true_already_on_reader_read_write_splitting( + read_write_splitting_plugin, plugin_service_mock, reader_conn_mock): + plugin_service_mock.current_connection = reader_conn_mock + read_write_splitting_plugin._reader_connection = reader_conn_mock + plugin_service_mock.current_host_info = reader_host1 - plugin_service_mock.current_connection = new_writer_conn_mock - plugin_instance._writer_connection = writer_conn_mock - plugin_instance._writer_host_info = writer_host - plugin_instance._reader_connection = None + read_write_splitting_plugin._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_not_called() - with pytest.raises(Error): - plugin_instance.execute(None, "Statement.execute_query", execute_func) - writer_conn_mock.close.assert_called_once() +def test_set_read_only_false_already_on_writer_read_write_splitting( + read_write_splitting_plugin, plugin_service_mock, writer_conn_mock): + plugin_service_mock.current_connection = writer_conn_mock + read_write_splitting_plugin._writer_connection = writer_conn_mock + plugin_service_mock.current_host_info = writer_host + read_write_splitting_plugin._writer_host_info = writer_host + read_write_splitting_plugin._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_not_called() -def test_connect_incorrect_host_role( - plugin_instance, - plugin_service_mock, - plugin_config, - mocker, - connect_func_mock, - reader_conn_mock, - host_list_provider_service_mock, -): - plugin_name, _ = plugin_config - reader_host_incorrect_role = HostInfo(host="instance-4", role=HostRole.WRITER) - def get_host_role_side_effect(conn): - if conn == reader_conn_mock: - return HostRole.READER - return HostRole.WRITER +def test_connect_non_initial_connection_read_write_splitting( + read_write_splitting_plugin, connect_func_mock, writer_conn_mock, mocker): + connect_func_mock.return_value = writer_conn_mock - plugin_service_mock.get_host_role.side_effect = get_host_role_side_effect - plugin_service_mock.initial_connection_host_info = reader_host_incorrect_role - host_list_provider_service_mock.is_static_host_list_provider.return_value = False + read_write_splitting_plugin._writer_connection = writer_conn_mock + read_write_splitting_plugin._writer_host_info = writer_host + read_write_splitting_plugin._reader_connection = None - connect_func_mock.return_value = reader_conn_mock - conn = plugin_instance.connect( + conn = read_write_splitting_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), writer_host, Properties(), - True, + False, connect_func_mock, ) - assert conn == reader_conn_mock + assert conn == writer_conn_mock connect_func_mock.assert_called() - updated_host = host_list_provider_service_mock.initial_connection_host_info - if plugin_name == "read_write_splitting": - assert updated_host.host == reader_host_incorrect_role.host - assert updated_host.role != reader_host_incorrect_role.role - assert updated_host.role == HostRole.READER - else: - assert updated_host == writer_host - - -# Tests for the Read/Write Splitting Plugin -def test_set_read_only_true_one_host( - plugin_service_mock, plugin_config, plugin_instance -): - plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": - pytest.skip("Test only applies to Read/Write Splitting Plugin") +def test_set_read_only_true_one_host_read_write_splitting(plugin_service_mock, read_write_splitting_plugin): plugin_service_mock.hosts = [writer_host] - plugin_instance._writer_connection = writer_conn_mock - plugin_instance._writer_host_info = writer_host + read_write_splitting_plugin._writer_connection = writer_conn_mock + read_write_splitting_plugin._writer_host_info = writer_host - plugin_instance._switch_connection_if_required(True) + read_write_splitting_plugin._switch_connection_if_required(True) plugin_service_mock.set_current_connection.assert_not_called() - assert plugin_instance._writer_connection == writer_conn_mock - assert plugin_instance._reader_connection is None - - -def test_connect_error_updating_host( - plugin_service_mock, - plugin_config, - plugin_instance, - host_list_provider_service_mock, - connect_func_mock, - mocker, -): - plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": - pytest.skip("Test only applies to Read/Write Splitting Plugin") + assert read_write_splitting_plugin._writer_connection == writer_conn_mock + assert read_write_splitting_plugin._reader_connection is None + +def test_connect_error_updating_host_read_write_splitting( + plugin_service_mock, read_write_splitting_plugin, host_list_provider_service_mock, connect_func_mock, mocker): def get_host_role_side_effect(conn): if conn == reader_conn_mock: return None @@ -589,7 +502,7 @@ def get_host_role_side_effect(conn): connect_func_mock.return_value = reader_conn_mock with pytest.raises(Error): - plugin_instance.connect( + read_write_splitting_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), writer_host, @@ -598,15 +511,76 @@ def get_host_role_side_effect(conn): connect_func_mock, ) - host_list_provider_service_mock.initial_connection_host_info.assert_not_called() + # Verify initial_connection_host_info wasn't modified + assert host_list_provider_service_mock.initial_connection_host_info == simple_writer_host # Tests for the Simple Read/Write Splitting Plugin -def test_constructor_missing_write_endpoint(plugin_service_mock, plugin_config): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") +def test_set_read_only_true_srw(srw_plugin, plugin_service_mock, reader_conn_mock): + plugin_service_mock.current_connection = writer_conn_mock + plugin_service_mock.connect.return_value = reader_conn_mock + + plugin_service_mock.current_host_info = simple_writer_host + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + srw_plugin._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_called_with( + reader_conn_mock, any_host + ) + + assert srw_plugin._reader_connection == reader_conn_mock + + +def test_set_read_only_false_srw( + srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock,): + plugin_service_mock.current_connection = reader_conn_mock + plugin_service_mock.connect.return_value = writer_conn_mock + + plugin_service_mock.current_host_info = simple_reader_host + plugin_service_mock.get_host_role.side_effect = lambda conn: ( + HostRole.READER if conn == reader_conn_mock else HostRole.WRITER + ) + + srw_plugin._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_called_with( + writer_conn_mock, any_host + ) + assert srw_plugin._writer_connection == writer_conn_mock + + +def test_set_read_only_true_already_on_reader_srw( + srw_plugin, plugin_service_mock, reader_conn_mock): + plugin_service_mock.current_connection = reader_conn_mock + srw_plugin._reader_connection = reader_conn_mock + plugin_service_mock.current_host_info = simple_reader_host + srw_plugin._switch_connection_if_required(True) + plugin_service_mock.set_current_connection.assert_not_called() + + +def test_set_read_only_false_already_on_writer_srw(srw_plugin, plugin_service_mock, writer_conn_mock): + plugin_service_mock.current_connection = writer_conn_mock + srw_plugin._writer_connection = writer_conn_mock + plugin_service_mock.current_host_info = simple_writer_host + + srw_plugin._switch_connection_if_required(False) + plugin_service_mock.set_current_connection.assert_not_called() + + +def test_connect_non_initial_connection_srw(srw_plugin, connect_func_mock, writer_conn_mock, mocker): + connect_func_mock.return_value = writer_conn_mock + + result = srw_plugin.connect( + None, None, simple_writer_host, Properties(), False, connect_func_mock + ) + + assert result == writer_conn_mock + connect_func_mock.assert_called_once() + + +def test_constructor_missing_write_endpoint_srw(plugin_service_mock): props = Properties() props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT # Missing write endpoint @@ -615,11 +589,7 @@ def test_constructor_missing_write_endpoint(plugin_service_mock, plugin_config): SimpleReadWriteSplittingPlugin(plugin_service_mock, props) -def test_constructor_missing_read_endpoint(plugin_service_mock, plugin_config): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_constructor_missing_read_endpoint_srw(plugin_service_mock): props = Properties() props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT # Missing read endpoint @@ -628,13 +598,7 @@ def test_constructor_missing_read_endpoint(plugin_service_mock, plugin_config): SimpleReadWriteSplittingPlugin(plugin_service_mock, props) -def test_constructor_invalid_initial_connection_type( - plugin_service_mock, plugin_config -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_constructor_invalid_initial_connection_type_srw(plugin_service_mock): props = Properties() props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT @@ -646,13 +610,7 @@ def test_constructor_invalid_initial_connection_type( SimpleReadWriteSplittingPlugin(plugin_service_mock, props) -def test_connect_verification_disabled( - plugin_service_mock, plugin_config, connect_func_mock, writer_conn_mock -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_verification_disabled_srw(plugin_service_mock, connect_func_mock, writer_conn_mock): props = Properties() props[WrapperProperties.SRW_WRITE_ENDPOINT.name] = WRITE_ENDPOINT props[WrapperProperties.SRW_READ_ENDPOINT.name] = READ_ENDPOINT @@ -670,18 +628,8 @@ def test_connect_verification_disabled( connect_func_mock.assert_called_once() -def test_connect_writer_cluster_endpoint( - plugin_config, - plugin_instance, - plugin_service_mock, - plugin_props, - connect_func_mock, - writer_conn_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_writer_cluster_endpoint_srw( + srw_plugin, plugin_service_mock, srw_props, connect_func_mock, writer_conn_mock): writer_cluster_host = HostInfo( host="test-cluster.cluster-xyz.us-east-1.rds.amazonaws.com", port=TEST_PORT, @@ -691,8 +639,8 @@ def test_connect_writer_cluster_endpoint( connect_func_mock.return_value = writer_conn_mock plugin_service_mock.get_host_role.return_value = HostRole.WRITER - result = plugin_instance.connect( - None, None, writer_cluster_host, plugin_props, True, connect_func_mock + result = srw_plugin.connect( + None, None, writer_cluster_host, srw_props, True, connect_func_mock ) assert result == writer_conn_mock @@ -700,18 +648,8 @@ def test_connect_writer_cluster_endpoint( assert plugin_service_mock.get_host_role.call_count == 1 -def test_connect_reader_cluster_endpoint( - plugin_config, - plugin_instance, - plugin_props, - plugin_service_mock, - connect_func_mock, - reader_conn_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_reader_cluster_endpoint_srw( + srw_plugin, srw_props, plugin_service_mock, connect_func_mock, reader_conn_mock): reader_cluster_host = HostInfo( host="test-cluster.cluster-ro-xyz.us-east-1.rds.amazonaws.com", port=TEST_PORT, @@ -721,8 +659,8 @@ def test_connect_reader_cluster_endpoint( connect_func_mock.return_value = reader_conn_mock plugin_service_mock.get_host_role.return_value = HostRole.READER - result = plugin_instance.connect( - None, None, reader_cluster_host, plugin_props, True, connect_func_mock + result = srw_plugin.connect( + None, None, reader_cluster_host, srw_props, True, connect_func_mock ) assert result == reader_conn_mock @@ -730,17 +668,8 @@ def test_connect_reader_cluster_endpoint( assert plugin_service_mock.get_host_role.call_count == 1 -def test_connect_verification_fails_fallback( - plugin_config, - plugin_service_mock, - connect_func_mock, - writer_conn_mock, - host_list_provider_service_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_verification_fails_fallback_srw( + plugin_service_mock, connect_func_mock, writer_conn_mock, host_list_provider_service_mock): writer_cluster_host = HostInfo( host="test-cluster.cluster-xyz.us-east-1.rds.amazonaws.com", port=TEST_PORT, @@ -772,26 +701,16 @@ def test_connect_verification_fails_fallback( assert plugin_service_mock.get_host_role.call_count == 1 -def test_connect_non_rds_cluster_endpoint( - plugin_config, - plugin_instance, - plugin_service_mock, - plugin_props, - connect_func_mock, - writer_conn_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_non_rds_cluster_endpoint_srw( + srw_plugin, plugin_service_mock, srw_props, connect_func_mock, writer_conn_mock): custom_host = HostInfo( host="custom-db.example.com", port=TEST_PORT, role=HostRole.WRITER ) connect_func_mock.return_value = writer_conn_mock - result = plugin_instance.connect( - None, None, custom_host, plugin_props, True, connect_func_mock + result = srw_plugin.connect( + None, None, custom_host, srw_props, True, connect_func_mock ) assert result == writer_conn_mock @@ -799,18 +718,8 @@ def test_connect_non_rds_cluster_endpoint( assert plugin_service_mock.get_host_role.call_count == 0 -def test_connect_non_rds_cluster_endpoint_with_verification( - plugin_config, - plugin_service_mock, - plugin_props, - connect_func_mock, - writer_conn_mock, - mocker, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_connect_non_rds_cluster_endpoint_with_verification_srw( + plugin_service_mock, connect_func_mock, writer_conn_mock, mocker,): custom_host = HostInfo( host="custom-db.example.com", port=TEST_PORT, role=HostRole.WRITER ) @@ -834,7 +743,7 @@ def test_connect_non_rds_cluster_endpoint_with_verification( ) result = plugin.connect( - None, None, custom_host, plugin_props, True, connect_func_mock + None, None, custom_host, props, True, connect_func_mock ) assert result == writer_conn_mock @@ -842,13 +751,8 @@ def test_connect_non_rds_cluster_endpoint_with_verification( assert plugin_service_mock.get_host_role.call_count == 2 -def test_wrong_role_connection_writer_endpoint_to_reader( - plugin_service_mock, reader_conn_mock, plugin_config, plugin_instance -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_wrong_role_connection_writer_endpoint_to_reader_srw( + plugin_service_mock, reader_conn_mock, srw_plugin): plugin_service_mock.current_connection = reader_conn_mock plugin_service_mock.current_host_info = simple_reader_host plugin_service_mock.connect.return_value = reader_conn_mock @@ -857,20 +761,10 @@ def test_wrong_role_connection_writer_endpoint_to_reader( ) # Wrong role for writer with pytest.raises(ReadWriteSplittingError): - plugin_instance._switch_connection_if_required(False) + srw_plugin._switch_connection_if_required(False) -def test_get_verified_connection_wrong_role_retry_reader( - plugin_config, - plugin_instance, - plugin_service_mock, - reader_conn_mock, - writer_conn_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_get_verified_connection_wrong_role_retry_reader_srw(srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock,): plugin_service_mock.current_connection = writer_conn_mock plugin_service_mock.current_host_info = simple_writer_host @@ -880,23 +774,13 @@ def test_get_verified_connection_wrong_role_retry_reader( HostRole.READER if conn == reader_conn_mock else HostRole.WRITER ) - plugin_instance._switch_connection_if_required(True) + srw_plugin._switch_connection_if_required(True) assert plugin_service_mock.connect.call_count == 2 writer_conn_mock.close.assert_called_once() -def test_get_verified_connection_sql_exception_retry( - plugin_config, - plugin_instance, - plugin_service_mock, - reader_conn_mock, - writer_conn_mock, -): - plugin_name, _ = plugin_config - if plugin_name != "srw": - pytest.skip("Test only applies to Simple Read/Write Splitting Plugin") - +def test_get_verified_connection_sql_exception_retry_srw(srw_plugin, plugin_service_mock, reader_conn_mock, writer_conn_mock): plugin_service_mock.current_connection = writer_conn_mock plugin_service_mock.current_host_info = simple_writer_host @@ -907,7 +791,7 @@ def test_get_verified_connection_sql_exception_retry( ] plugin_service_mock.get_host_role.return_value = HostRole.READER - plugin_instance._switch_connection_if_required(True) + srw_plugin._switch_connection_if_required(True) assert plugin_service_mock.connect.call_count == 2 - assert plugin_instance._reader_connection == reader_conn_mock + assert srw_plugin._reader_connection == reader_conn_mock