diff --git a/redis/connection.py b/redis/connection.py index c9a3221b0b..e8dc39a0d6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -319,6 +319,8 @@ def __init__( oss_cluster_maint_notifications_handler, parser, ) + self._processed_start_maint_notifications = set() + self._skipped_end_maint_notifications = set() @abstractmethod def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: @@ -667,6 +669,22 @@ def maintenance_state(self) -> MaintenanceState: def maintenance_state(self, state: "MaintenanceState"): self._maintenance_state = state + def add_maint_start_notification(self, id: int): + self._processed_start_maint_notifications.add(id) + + def get_processed_start_notifications(self) -> set: + return self._processed_start_maint_notifications + + def add_skipped_end_notification(self, id: int): + self._skipped_end_maint_notifications.add(id) + + def get_skipped_end_notifications(self) -> set: + return self._skipped_end_maint_notifications + + def reset_received_notifications(self): + self._processed_start_maint_notifications.clear() + self._skipped_end_maint_notifications.clear() + def getpeername(self): """ Returns the peer name of the connection. diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index c1274bdad5..da5ac9c217 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -894,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification): return if notification_type: - self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) + self.handle_maintenance_start_notification( + MaintenanceState.MAINTENANCE, notification + ) else: self.handle_maintenance_completed_notification() def handle_maintenance_start_notification( - self, maintenance_state: MaintenanceState + self, maintenance_state: MaintenanceState, notification: MaintenanceNotification ): if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -913,6 +915,11 @@ def handle_maintenance_start_notification( ) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relaxed_timeout) + if isinstance(notification, OSSNodeMigratingNotification): + # add the notification id to the set of processed start maint notifications + # this is used to skip the unrelaxing of the timeouts if we have received more than + # one start notification before the the final end notification + self.connection.add_maint_start_notification(notification.id) def handle_maintenance_completed_notification(self): # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled @@ -926,6 +933,9 @@ def handle_maintenance_completed_notification(self): # timeouts by providing -1 as the relaxed timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE + # reset the sets that keep track of received start maint + # notifications and skipped end maint notifications + self.connection.reset_received_notifications() class OSSMaintNotificationsHandler: @@ -1004,35 +1014,45 @@ def handle_oss_maintenance_completed_notification( disconnect_startup_nodes_pools=False, additional_startup_nodes_info=additional_startup_nodes_info, ) - # mark for reconnect all in use connections to the node - this will force them to - # disconnect after they complete their current commands - # Some of them might be used by sub sub and we don't know which ones - so we disconnect - # all in flight connections after they are done with current command execution - for conn in ( - current_node.redis_connection.connection_pool._get_in_use_connections() - ): - conn.mark_for_reconnect() + with current_node.redis_connection.connection_pool._lock: + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + # Some of them might be used by sub sub and we don't know which ones - so we disconnect + # all in flight connections after they are done with current command execution + for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): + conn.mark_for_reconnect() - if ( - current_node - not in self.cluster_client.nodes_manager.nodes_cache.values() - ): - # disconnect all free connections to the node - this node will be dropped - # from the cluster, so we don't need to revert the timeouts - for conn in current_node.redis_connection.connection_pool._get_free_connections(): - conn.disconnect() - else: - if self.config.is_relaxed_timeouts_enabled(): - # reset the timeouts for the node to which the connection is connected - # TODO: add check if other maintenance ops are in progress for the same node - CAE-1038 - # and if so, don't reset the timeouts - for conn in ( - *current_node.redis_connection.connection_pool._get_in_use_connections(), - *current_node.redis_connection.connection_pool._get_free_connections(), - ): - conn.reset_tmp_settings(reset_relaxed_timeout=True) - conn.update_current_socket_timeout(relaxed_timeout=-1) - conn.maintenance_state = MaintenanceState.NONE + if ( + current_node + not in self.cluster_client.nodes_manager.nodes_cache.values() + ): + # disconnect all free connections to the node - this node will be dropped + # from the cluster, so we don't need to revert the timeouts + for conn in current_node.redis_connection.connection_pool._get_free_connections(): + conn.disconnect() + else: + if self.config.is_relaxed_timeouts_enabled(): + # reset the timeouts for the node to which the connection is connected + # Perform check if other maintenance ops are in progress for the same node + # and if so, don't reset the timeouts and wait for the last maintenance + # to complete + for conn in ( + *current_node.redis_connection.connection_pool._get_in_use_connections(), + *current_node.redis_connection.connection_pool._get_free_connections(), + ): + if ( + len(conn.get_processed_start_notifications()) + > len(conn.get_skipped_end_notifications()) + 1 + ): + # we have received more start notifications than end notifications + # for this connection - we should not reset the timeouts + # and add the notification id to the set of skipped end notifications + conn.add_skipped_end_notification(notification.id) + else: + conn.reset_tmp_settings(reset_relaxed_timeout=True) + conn.update_current_socket_timeout(relaxed_timeout=-1) + conn.maintenance_state = MaintenanceState.NONE + conn.reset_received_notifications() # mark the notification as processed self._processed_notifications.add(notification) diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index 4302d486f2..e49f5c6131 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -1006,17 +1006,17 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( self.cluster.set("anyprefix:{3}:k", "VAL") # this functionality is part of CAE-1038 and will be added later # validate the timeout is still relaxed - # self._validate_connections_states( - # self.cluster, - # [ - # ConnectionStateExpectation( - # node_port=NODE_PORT_1, - # changed_connections_count=1, - # state=MaintenanceState.MAINTENANCE, - # relaxed_timeout=self.config.relaxed_timeout, - # ), - # ], - # ) + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATED 15 0.0.0.0:15381 3000-4000" ) diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index 91126a8cb1..47a27a48cf 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -864,7 +864,9 @@ def test_handle_notification_migrating(self): self.handler, "handle_maintenance_start_notification" ) as mock_handle: self.handler.handle_notification(notification) - mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + mock_handle.assert_called_once_with( + MaintenanceState.MAINTENANCE, notification + ) def test_handle_notification_migrated(self): """Test handling of NodeMigratedNotification.""" @@ -884,7 +886,9 @@ def test_handle_notification_failing_over(self): self.handler, "handle_maintenance_start_notification" ) as mock_handle: self.handler.handle_notification(notification) - mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + mock_handle.assert_called_once_with( + MaintenanceState.MAINTENANCE, notification + ) def test_handle_notification_failed_over(self): """Test handling of NodeFailedOverNotification.""" @@ -911,7 +915,7 @@ def test_handle_maintenance_start_notification_disabled(self): handler = MaintNotificationsConnectionHandler(self.mock_connection, config) result = handler.handle_maintenance_start_notification( - MaintenanceState.MAINTENANCE + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) ) assert result is None @@ -922,7 +926,7 @@ def test_handle_maintenance_start_notification_moving_state(self): self.mock_connection.maintenance_state = MaintenanceState.MOVING result = self.handler.handle_maintenance_start_notification( - MaintenanceState.MAINTENANCE + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) ) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() @@ -931,7 +935,9 @@ def test_handle_maintenance_start_notification_success(self): """Test successful maintenance start notification handling for migrating.""" self.mock_connection.maintenance_state = MaintenanceState.NONE - self.handler.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE) + self.handler.handle_maintenance_start_notification( + MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5) + ) assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)