Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def __init__(
oss_cluster_maint_notifications_handler,
parser,
)
self._processed_start_maint_notifications = set()
self._skipped_end_maint_notifications = set()

@abstractmethod
def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
Expand Down Expand Up @@ -667,6 +669,22 @@ def maintenance_state(self) -> MaintenanceState:
def maintenance_state(self, state: "MaintenanceState"):
self._maintenance_state = state

def add_maint_start_notification(self, id: int):
self._processed_start_maint_notifications.add(id)

def get_processed_start_notifications(self) -> set:
return self._processed_start_maint_notifications

def add_skipped_end_notification(self, id: int):
self._skipped_end_maint_notifications.add(id)

def get_skipped_end_notifications(self) -> set:
return self._skipped_end_maint_notifications

def reset_received_notifications(self):
self._processed_start_maint_notifications.clear()
self._skipped_end_maint_notifications.clear()

def getpeername(self):
"""
Returns the peer name of the connection.
Expand Down
80 changes: 50 additions & 30 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification):
return

if notification_type:
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
self.handle_maintenance_start_notification(
MaintenanceState.MAINTENANCE, notification
)
else:
self.handle_maintenance_completed_notification()

def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState
self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -913,6 +915,11 @@ def handle_maintenance_start_notification(
)
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
if isinstance(notification, OSSNodeMigratingNotification):
# add the notification id to the set of processed start maint notifications
# this is used to skip the unrelaxing of the timeouts if we have received more than
# one start notification before the the final end notification
self.connection.add_maint_start_notification(notification.id)

def handle_maintenance_completed_notification(self):
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
Expand All @@ -926,6 +933,9 @@ def handle_maintenance_completed_notification(self):
# timeouts by providing -1 as the relaxed timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE
# reset the sets that keep track of received start maint
# notifications and skipped end maint notifications
self.connection.reset_received_notifications()


class OSSMaintNotificationsHandler:
Expand Down Expand Up @@ -1004,35 +1014,45 @@ def handle_oss_maintenance_completed_notification(
disconnect_startup_nodes_pools=False,
additional_startup_nodes_info=additional_startup_nodes_info,
)
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in (
current_node.redis_connection.connection_pool._get_in_use_connections()
):
conn.mark_for_reconnect()
with current_node.redis_connection.connection_pool._lock:
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
conn.mark_for_reconnect()

if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()
else:
if self.config.is_relaxed_timeouts_enabled():
# reset the timeouts for the node to which the connection is connected
# TODO: add check if other maintenance ops are in progress for the same node - CAE-1038
# and if so, don't reset the timeouts
for conn in (
*current_node.redis_connection.connection_pool._get_in_use_connections(),
*current_node.redis_connection.connection_pool._get_free_connections(),
):
conn.reset_tmp_settings(reset_relaxed_timeout=True)
conn.update_current_socket_timeout(relaxed_timeout=-1)
conn.maintenance_state = MaintenanceState.NONE
if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()
else:
if self.config.is_relaxed_timeouts_enabled():
# reset the timeouts for the node to which the connection is connected
# Perform check if other maintenance ops are in progress for the same node
# and if so, don't reset the timeouts and wait for the last maintenance
# to complete
for conn in (
*current_node.redis_connection.connection_pool._get_in_use_connections(),
*current_node.redis_connection.connection_pool._get_free_connections(),
):
if (
len(conn.get_processed_start_notifications())
> len(conn.get_skipped_end_notifications()) + 1
):
# we have received more start notifications than end notifications
# for this connection - we should not reset the timeouts
# and add the notification id to the set of skipped end notifications
conn.add_skipped_end_notification(notification.id)
else:
conn.reset_tmp_settings(reset_relaxed_timeout=True)
conn.update_current_socket_timeout(relaxed_timeout=-1)
conn.maintenance_state = MaintenanceState.NONE
conn.reset_received_notifications()

# mark the notification as processed
self._processed_notifications.add(notification)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,17 +1006,17 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges(
self.cluster.set("anyprefix:{3}:k", "VAL")
# this functionality is part of CAE-1038 and will be added later
# validate the timeout is still relaxed
# self._validate_connections_states(
# self.cluster,
# [
# ConnectionStateExpectation(
# node_port=NODE_PORT_1,
# changed_connections_count=1,
# state=MaintenanceState.MAINTENANCE,
# relaxed_timeout=self.config.relaxed_timeout,
# ),
# ],
# )
self._validate_connections_states(
self.cluster,
[
ConnectionStateExpectation(
node_port=NODE_PORT_1,
changed_connections_count=1,
state=MaintenanceState.MAINTENANCE,
relaxed_timeout=self.config.relaxed_timeout,
),
],
)
smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp(
"SMIGRATED 15 0.0.0.0:15381 3000-4000"
)
Expand Down
16 changes: 11 additions & 5 deletions tests/maint_notifications/test_maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,9 @@ def test_handle_notification_migrating(self):
self.handler, "handle_maintenance_start_notification"
) as mock_handle:
self.handler.handle_notification(notification)
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
mock_handle.assert_called_once_with(
MaintenanceState.MAINTENANCE, notification
)

def test_handle_notification_migrated(self):
"""Test handling of NodeMigratedNotification."""
Expand All @@ -884,7 +886,9 @@ def test_handle_notification_failing_over(self):
self.handler, "handle_maintenance_start_notification"
) as mock_handle:
self.handler.handle_notification(notification)
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
mock_handle.assert_called_once_with(
MaintenanceState.MAINTENANCE, notification
)

def test_handle_notification_failed_over(self):
"""Test handling of NodeFailedOverNotification."""
Expand All @@ -911,7 +915,7 @@ def test_handle_maintenance_start_notification_disabled(self):
handler = MaintNotificationsConnectionHandler(self.mock_connection, config)

result = handler.handle_maintenance_start_notification(
MaintenanceState.MAINTENANCE
MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5)
)

assert result is None
Expand All @@ -922,7 +926,7 @@ def test_handle_maintenance_start_notification_moving_state(self):
self.mock_connection.maintenance_state = MaintenanceState.MOVING

result = self.handler.handle_maintenance_start_notification(
MaintenanceState.MAINTENANCE
MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5)
)
assert result is None
self.mock_connection.update_current_socket_timeout.assert_not_called()
Expand All @@ -931,7 +935,9 @@ def test_handle_maintenance_start_notification_success(self):
"""Test successful maintenance start notification handling for migrating."""
self.mock_connection.maintenance_state = MaintenanceState.NONE

self.handler.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
self.handler.handle_maintenance_start_notification(
MaintenanceState.MAINTENANCE, NodeMigratingNotification(id=1, ttl=5)
)

assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
Expand Down