Skip to content

Commit d4f558a

Browse files
committed
PYTHON-5662 - Add support for server selection's deprioritized servers to all topologies
1 parent 3093a7c commit d4f558a

33 files changed

+1547
-28
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,8 +2825,7 @@ async def run(self) -> T:
28252825
if self._last_error is None:
28262826
self._last_error = exc
28272827

2828-
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
2829-
self._deprioritized_servers.append(self._server)
2828+
self._deprioritized_servers.append(self._server)
28302829

28312830
def _is_not_eligible_for_retry(self) -> bool:
28322831
"""Checks if the exchange is not eligible for retry"""

pymongo/asynchronous/topology.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ async def select_servers(
265265
server_selection_timeout: Optional[float] = None,
266266
address: Optional[_Address] = None,
267267
operation_id: Optional[int] = None,
268+
deprioritized_servers: Optional[list[Server]] = None,
268269
) -> list[Server]:
269270
"""Return a list of Servers matching selector, or time out.
270271
@@ -292,7 +293,12 @@ async def select_servers(
292293

293294
async with self._lock:
294295
server_descriptions = await self._select_servers_loop(
295-
selector, server_timeout, operation, operation_id, address
296+
selector,
297+
server_timeout,
298+
operation,
299+
operation_id,
300+
address,
301+
deprioritized_servers=deprioritized_servers,
296302
)
297303

298304
return [
@@ -306,6 +312,7 @@ async def _select_servers_loop(
306312
operation: str,
307313
operation_id: Optional[int],
308314
address: Optional[_Address],
315+
deprioritized_servers: Optional[list[Server]] = None,
309316
) -> list[ServerDescription]:
310317
"""select_servers() guts. Hold the lock when calling this."""
311318
now = time.monotonic()
@@ -324,7 +331,12 @@ async def _select_servers_loop(
324331
)
325332

326333
server_descriptions = self._description.apply_selector(
327-
selector, address, custom_selector=self._settings.server_selector
334+
selector,
335+
address,
336+
custom_selector=self._settings.server_selector,
337+
deprioritized_servers=[server.description for server in deprioritized_servers]
338+
if deprioritized_servers
339+
else None,
328340
)
329341

330342
while not server_descriptions:
@@ -385,7 +397,12 @@ async def _select_server(
385397
operation_id: Optional[int] = None,
386398
) -> Server:
387399
servers = await self.select_servers(
388-
selector, operation, server_selection_timeout, address, operation_id
400+
selector,
401+
operation,
402+
server_selection_timeout,
403+
address,
404+
operation_id,
405+
deprioritized_servers,
389406
)
390407
servers = _filter_servers(servers, deprioritized_servers)
391408
if len(servers) == 1:

pymongo/server_selectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ class Selection:
3434

3535
@classmethod
3636
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
37-
known_servers = topology_description.known_servers
37+
candidate_servers = topology_description.candidate_servers
3838
primary = None
39-
for sd in known_servers:
39+
for sd in candidate_servers:
4040
if sd.server_type == SERVER_TYPE.RSPrimary:
4141
primary = sd
4242
break
4343

4444
return Selection(
4545
topology_description,
46-
topology_description.known_servers,
46+
topology_description.candidate_servers,
4747
topology_description.common_wire_version,
4848
primary,
4949
)

pymongo/synchronous/mongo_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,8 +2815,7 @@ def run(self) -> T:
28152815
if self._last_error is None:
28162816
self._last_error = exc
28172817

2818-
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
2819-
self._deprioritized_servers.append(self._server)
2818+
self._deprioritized_servers.append(self._server)
28202819

28212820
def _is_not_eligible_for_retry(self) -> bool:
28222821
"""Checks if the exchange is not eligible for retry"""

pymongo/synchronous/topology.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def select_servers(
265265
server_selection_timeout: Optional[float] = None,
266266
address: Optional[_Address] = None,
267267
operation_id: Optional[int] = None,
268+
deprioritized_servers: Optional[list[Server]] = None,
268269
) -> list[Server]:
269270
"""Return a list of Servers matching selector, or time out.
270271
@@ -292,7 +293,12 @@ def select_servers(
292293

293294
with self._lock:
294295
server_descriptions = self._select_servers_loop(
295-
selector, server_timeout, operation, operation_id, address
296+
selector,
297+
server_timeout,
298+
operation,
299+
operation_id,
300+
address,
301+
deprioritized_servers=deprioritized_servers,
296302
)
297303

298304
return [
@@ -306,6 +312,7 @@ def _select_servers_loop(
306312
operation: str,
307313
operation_id: Optional[int],
308314
address: Optional[_Address],
315+
deprioritized_servers: Optional[list[Server]] = None,
309316
) -> list[ServerDescription]:
310317
"""select_servers() guts. Hold the lock when calling this."""
311318
now = time.monotonic()
@@ -324,7 +331,12 @@ def _select_servers_loop(
324331
)
325332

326333
server_descriptions = self._description.apply_selector(
327-
selector, address, custom_selector=self._settings.server_selector
334+
selector,
335+
address,
336+
custom_selector=self._settings.server_selector,
337+
deprioritized_servers=[server.description for server in deprioritized_servers]
338+
if deprioritized_servers
339+
else None,
328340
)
329341

330342
while not server_descriptions:
@@ -385,7 +397,12 @@ def _select_server(
385397
operation_id: Optional[int] = None,
386398
) -> Server:
387399
servers = self.select_servers(
388-
selector, operation, server_selection_timeout, address, operation_id
400+
selector,
401+
operation,
402+
server_selection_timeout,
403+
address,
404+
operation_id,
405+
deprioritized_servers,
389406
)
390407
servers = _filter_servers(servers, deprioritized_servers)
391408
if len(servers) == 1:

pymongo/topology_description.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
self._server_descriptions = server_descriptions
8686
self._max_set_version = max_set_version
8787
self._max_election_id = max_election_id
88+
self._candidate_servers = list(self._server_descriptions.values())
8889

8990
# The heartbeat_frequency is used in staleness estimates.
9091
self._topology_settings = topology_settings
@@ -248,6 +249,11 @@ def readable_servers(self) -> list[ServerDescription]:
248249
"""List of readable Servers."""
249250
return [s for s in self._server_descriptions.values() if s.is_readable]
250251

252+
@property
253+
def candidate_servers(self) -> list[ServerDescription]:
254+
"""List of Servers excluding deprioritized servers."""
255+
return self._candidate_servers
256+
251257
@property
252258
def common_wire_version(self) -> Optional[int]:
253259
"""Minimum of all servers' max wire versions, or None."""
@@ -283,11 +289,24 @@ def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerD
283289
if (cast(float, s.round_trip_time) - fastest) <= threshold
284290
]
285291

292+
def _filter_servers(
293+
self, deprioritized_servers: Optional[list[ServerDescription]] = None
294+
) -> None:
295+
"""Filter out deprioritized servers from a list of server candidates."""
296+
if not deprioritized_servers:
297+
self._candidate_servers = self.known_servers
298+
else:
299+
filtered = [
300+
server for server in self.known_servers if server not in deprioritized_servers
301+
]
302+
self._candidate_servers = filtered or self.known_servers
303+
286304
def apply_selector(
287305
self,
288306
selector: Any,
289307
address: Optional[_Address] = None,
290308
custom_selector: Optional[_ServerSelector] = None,
309+
deprioritized_servers: Optional[list[ServerDescription]] = None,
291310
) -> list[ServerDescription]:
292311
"""List of servers matching the provided selector(s).
293312
@@ -324,9 +343,10 @@ def apply_selector(
324343
description = self.server_descriptions().get(address)
325344
return [description] if description and description.is_server_type_known else []
326345

346+
self._filter_servers(deprioritized_servers)
327347
# Primary selection fast path.
328348
if self.topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary and type(selector) is Primary:
329-
for sd in self._server_descriptions.values():
349+
for sd in self._candidate_servers:
330350
if sd.server_type == SERVER_TYPE.RSPrimary:
331351
sds = [sd]
332352
if custom_selector:

test/asynchronous/test_retryable_reads.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
from test.asynchronous.utils import async_set_fail_point
2323

24+
from pymongo import ReadPreference
2425
from pymongo.errors import OperationFailure
2526

2627
sys.path[0:0] = [""]
@@ -182,6 +183,44 @@ async def test_retryable_reads_are_retried_on_a_different_mongos_when_one_is_ava
182183
# Assert that both events occurred on different mongos.
183184
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id
184185

186+
@async_client_context.require_replica_set
187+
@async_client_context.require_failCommand_fail_point
188+
async def test_retryable_reads_are_retried_on_a_different_replica_when_one_is_available(self):
189+
fail_command = {
190+
"configureFailPoint": "failCommand",
191+
"mode": {"times": 1},
192+
"data": {"failCommands": ["find"], "errorCode": 6},
193+
}
194+
195+
replica_clients = []
196+
197+
for node in async_client_context.nodes:
198+
client = await self.async_rs_or_single_client(*node, directConnection=True)
199+
await async_set_fail_point(client, fail_command)
200+
replica_clients.append(client)
201+
202+
listener = OvertCommandListener()
203+
client = await self.async_rs_or_single_client(
204+
event_listeners=[listener],
205+
retryReads=True,
206+
directConnection=False,
207+
readPreference="secondaryPreferred",
208+
)
209+
210+
with self.assertRaises(OperationFailure):
211+
await client.t.t.find_one({})
212+
213+
# Disable failpoints on each node
214+
for client in replica_clients:
215+
fail_command["mode"] = "off"
216+
await async_set_fail_point(client, fail_command)
217+
218+
self.assertEqual(len(listener.failed_events), 2)
219+
self.assertEqual(len(listener.succeeded_events), 0)
220+
221+
# Assert that both events occurred on different nodes.
222+
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id
223+
185224
@async_client_context.require_multiple_mongoses
186225
@async_client_context.require_failCommand_fail_point
187226
async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are_available(
@@ -218,6 +257,38 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are
218257
# Assert that both events occurred on the same mongos.
219258
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id
220259

260+
@async_client_context.require_replica_set
261+
@async_client_context.require_failCommand_fail_point
262+
async def test_retryable_reads_are_retried_on_the_same_replica_when_no_others_are_available(
263+
self
264+
):
265+
fail_command = {
266+
"configureFailPoint": "failCommand",
267+
"mode": {"times": 1},
268+
"data": {"failCommands": ["find"], "errorCode": 6},
269+
}
270+
271+
node_client = await self.async_rs_or_single_client(*list(async_client_context.nodes)[0])
272+
await async_set_fail_point(node_client, fail_command)
273+
274+
listener = OvertCommandListener()
275+
client = await self.async_rs_or_single_client(
276+
event_listeners=[listener],
277+
retryReads=True,
278+
)
279+
280+
await client.t.t.find_one({})
281+
282+
# Disable failpoints
283+
fail_command["mode"] = "off"
284+
await async_set_fail_point(node_client, fail_command)
285+
286+
self.assertEqual(len(listener.failed_events), 1)
287+
self.assertEqual(len(listener.succeeded_events), 1)
288+
289+
# Assert that both events occurred on the same node.
290+
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id
291+
221292
@async_client_context.require_failCommand_fail_point
222293
async def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
223294
listener = OvertCommandListener()

test/asynchronous/utils_selection_tests.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from bson import json_util
3636
from pymongo.asynchronous.settings import TopologySettings
3737
from pymongo.asynchronous.topology import Topology
38-
from pymongo.common import HEARTBEAT_FREQUENCY
38+
from pymongo.common import HEARTBEAT_FREQUENCY, clean_node
3939
from pymongo.errors import AutoReconnect, ConfigurationError
4040
from pymongo.operations import _Op
4141
from pymongo.server_selectors import writable_server_selector
@@ -95,12 +95,21 @@ async def run_scenario(self):
9595
# "Eligible servers" is defined in the server selection spec as
9696
# the set of servers matching both the ReadPreference's mode
9797
# and tag sets.
98-
top_latency = await create_topology(scenario_def)
98+
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
9999

100100
# "In latency window" is defined in the server selection
101101
# spec as the subset of suitable_servers that falls within the
102102
# allowable latency window.
103-
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
103+
top_latency = await create_topology(scenario_def)
104+
105+
top_suitable_deprioritized_servers = [
106+
top_suitable.get_server_by_address(clean_node(server["address"]))
107+
for server in scenario_def.get("deprioritized_servers", [])
108+
]
109+
top_latency_deprioritized_servers = [
110+
top_latency.get_server_by_address(clean_node(server["address"]))
111+
for server in scenario_def.get("deprioritized_servers", [])
112+
]
104113

105114
# Create server selector.
106115
if scenario_def.get("operation") == "write":
@@ -120,21 +129,37 @@ async def run_scenario(self):
120129
# Select servers.
121130
if not scenario_def.get("suitable_servers"):
122131
with self.assertRaises(AutoReconnect):
123-
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
132+
await top_suitable.select_server(
133+
pref,
134+
_Op.TEST,
135+
server_selection_timeout=0,
136+
deprioritized_servers=top_suitable_deprioritized_servers,
137+
)
124138

125139
return
126140

127141
if not scenario_def["in_latency_window"]:
128142
with self.assertRaises(AutoReconnect):
129-
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
143+
await top_latency.select_server(
144+
pref,
145+
_Op.TEST,
146+
server_selection_timeout=0,
147+
deprioritized_servers=top_latency_deprioritized_servers,
148+
)
130149

131150
return
132151

133152
actual_suitable_s = await top_suitable.select_servers(
134-
pref, _Op.TEST, server_selection_timeout=0
153+
pref,
154+
_Op.TEST,
155+
server_selection_timeout=0,
156+
deprioritized_servers=top_suitable_deprioritized_servers,
135157
)
136158
actual_latency_s = await top_latency.select_servers(
137-
pref, _Op.TEST, server_selection_timeout=0
159+
pref,
160+
_Op.TEST,
161+
server_selection_timeout=0,
162+
deprioritized_servers=top_latency_deprioritized_servers,
138163
)
139164

140165
expected_suitable_servers = {}

0 commit comments

Comments
 (0)