Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,7 +2825,7 @@ async def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ async def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ async def select_servers(

async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ async def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ async def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ async def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = await self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1112,16 +1128,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
6 changes: 3 additions & 3 deletions pymongo/server_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class Selection:

@classmethod
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
known_servers = topology_description.known_servers
candidate_servers = topology_description.candidate_servers
primary = None
for sd in known_servers:
for sd in candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break

return Selection(
topology_description,
topology_description.known_servers,
topology_description.candidate_servers,
topology_description.common_wire_version,
primary,
)
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2815,7 +2815,7 @@ def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ def select_servers(

with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1110,16 +1126,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
38 changes: 37 additions & 1 deletion pymongo/topology_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self._server_descriptions = server_descriptions
self._max_set_version = max_set_version
self._max_election_id = max_election_id
self._candidate_servers = list(self._server_descriptions.values())

# The heartbeat_frequency is used in staleness estimates.
self._topology_settings = topology_settings
Expand Down Expand Up @@ -248,6 +249,11 @@ def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]

@property
def candidate_servers(self) -> list[ServerDescription]:
"""List of Servers excluding deprioritized servers."""
return self._candidate_servers

@property
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
Expand Down Expand Up @@ -283,11 +289,27 @@ def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerD
if (cast(float, s.round_trip_time) - fastest) <= threshold
]

def _filter_servers(
self, deprioritized_servers: Optional[list[ServerDescription]] = None
) -> None:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
self._candidate_servers = self.known_servers
else:
deprioritized_addresses = {sd.address for sd in deprioritized_servers}
filtered = [
server
for server in self.known_servers
if server.address not in deprioritized_addresses
]
self._candidate_servers = filtered or self.known_servers

def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
deprioritized_servers: Optional[list[ServerDescription]] = None,
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).

Expand Down Expand Up @@ -324,21 +346,35 @@ def apply_selector(
description = self.server_descriptions().get(address)
return [description] if description and description.is_server_type_known else []

self._filter_servers(deprioritized_servers)
# Primary selection fast path.
if self.topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary and type(selector) is Primary:
for sd in self._server_descriptions.values():
for sd in self._candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
sds = [sd]
if custom_selector:
sds = custom_selector(sds)
return sds
# All primaries are deprioritized
if deprioritized_servers:
for sd in deprioritized_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
sds = [sd]
if custom_selector:
sds = custom_selector(sds)
return sds
# No primary found, return an empty list.
return []

selection = Selection.from_topology_description(self)
# Ignore read preference for sharded clusters.
if self.topology_type != TOPOLOGY_TYPE.Sharded:
selection = selector(selection)
# No suitable servers found, apply preference again but include deprioritized servers.
if not selection and deprioritized_servers:
self._filter_servers(None)
selection = Selection.from_topology_description(self)
selection = selector(selection)

# Apply custom selector followed by localThresholdMS.
if custom_selector is not None and selection:
Expand Down
71 changes: 71 additions & 0 deletions test/asynchronous/test_retryable_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import threading
from test.asynchronous.utils import async_set_fail_point

from pymongo import ReadPreference
from pymongo.errors import OperationFailure

sys.path[0:0] = [""]
Expand Down Expand Up @@ -182,6 +183,44 @@ async def test_retryable_reads_are_retried_on_a_different_mongos_when_one_is_ava
# Assert that both events occurred on different mongos.
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id

@async_client_context.require_replica_set
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_a_different_replica_when_one_is_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 6},
}

replica_clients = []

for node in async_client_context.nodes:
client = await self.async_rs_or_single_client(*node, directConnection=True)
await async_set_fail_point(client, fail_command)
replica_clients.append(client)

listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
directConnection=False,
readPreference="secondaryPreferred",
)

with self.assertRaises(OperationFailure):
await client.t.t.find_one({})

# Disable failpoints on each node
for client in replica_clients:
fail_command["mode"] = "off"
await async_set_fail_point(client, fail_command)

self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)

# Assert that both events occurred on different nodes.
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id

@async_client_context.require_multiple_mongoses
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are_available(
Expand Down Expand Up @@ -218,6 +257,38 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are
# Assert that both events occurred on the same mongos.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id

@async_client_context.require_replica_set
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_replica_when_no_others_are_available(
self
):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 6},
}

node_client = await self.async_rs_or_single_client(*list(async_client_context.nodes)[0])
await async_set_fail_point(node_client, fail_command)

listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
)

await client.t.t.find_one({})

# Disable failpoints
fail_command["mode"] = "off"
await async_set_fail_point(node_client, fail_command)

self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)

# Assert that both events occurred on the same node.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id

@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
listener = OvertCommandListener()
Expand Down
Loading
Loading