Skip to content
Draft
6 changes: 3 additions & 3 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
cast,
)

from pymongo import _csot
from pymongo import _csot, common
from pymongo.errors import (
OperationFailure,
)
Expand Down Expand Up @@ -76,9 +76,9 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


_MAX_RETRIES = 5
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10

DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1

Expand Down Expand Up @@ -128,7 +128,7 @@ class _RetryPolicy:
def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
attempts: int = common._MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
adaptive_retry: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,9 @@ def __init__(
)

self._retry_policy = _RetryPolicy(
_TokenBucket(), adaptive_retry=self._options.adaptive_retries
_TokenBucket(),
attempts=self._options.max_retries,
adaptive_retry=self._options.adaptive_retries,
)

self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
Expand Down Expand Up @@ -2930,7 +2932,7 @@ async def run(self) -> T:
transaction.set_starting()
transaction.attempt = 0

if (
if self._client.options.enable_overload_retargeting and (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
Expand Down
26 changes: 26 additions & 0 deletions pymongo/client_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ def __init__(
if "adaptive_retries" in options
else options.get("adaptiveretries", common.ADAPTIVE_RETRIES)
)
self.__max_retries = (
options.get("max_retries", common._MAX_RETRIES)
if "max_retries" in options
else options.get("maxretries", common._MAX_RETRIES)
)
self.__enable_overload_retargeting = (
options.get("enable_overload_retargeting", common.ENABLE_OVERLOAD_RETARGETING)
if "enable_overload_retargeting" in options
else options.get("enableoverloadretargeting", common.ENABLE_OVERLOAD_RETARGETING)
)

@property
def _options(self) -> Mapping[str, Any]:
Expand Down Expand Up @@ -359,3 +369,19 @@ def adaptive_retries(self) -> bool:
.. versionadded:: 4.XX
"""
return self.__adaptive_retries

@property
def max_retries(self) -> int:
"""The configured maxRetries option.

.. versionadded:: 4.XX
"""
return self.__max_retries

@property
def enable_overload_retargeting(self) -> bool:
"""The configured enableOverloadRetargeting option.

.. versionadded:: 4.XX
"""
return self.__enable_overload_retargeting
8 changes: 8 additions & 0 deletions pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@
# Default value for adaptiveRetries
ADAPTIVE_RETRIES = False

# Default value for max retries
_MAX_RETRIES = 2

# Default value for enableOverloadRetargeting
ENABLE_OVERLOAD_RETARGETING = False

# Auth mechanism properties that must raise an error instead of warning if they invalidate.
_MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"]

Expand Down Expand Up @@ -776,6 +782,8 @@ def validate_server_monitoring_mode(option: str, value: str) -> str:
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list,
"adaptive_retries": validate_boolean_or_string,
"max_retries": validate_non_negative_integer,
"enable_overload_retargeting": validate_boolean_or_string,
}

# Dictionary where keys are any URI option name, and values are the
Expand Down
6 changes: 3 additions & 3 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
cast,
)

from pymongo import _csot
from pymongo import _csot, common
from pymongo.errors import (
OperationFailure,
)
Expand Down Expand Up @@ -76,9 +76,9 @@ def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


_MAX_RETRIES = 5
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10

DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1

Expand Down Expand Up @@ -128,7 +128,7 @@ class _RetryPolicy:
def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
attempts: int = common._MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
adaptive_retry: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,9 @@ def __init__(
)

self._retry_policy = _RetryPolicy(
_TokenBucket(), adaptive_retry=self._options.adaptive_retries
_TokenBucket(),
attempts=self._options.max_retries,
adaptive_retry=self._options.adaptive_retries,
)

self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
Expand Down Expand Up @@ -2920,7 +2922,7 @@ def run(self) -> T:
transaction.set_starting()
transaction.attempt = 0

if (
if self._client.options.enable_overload_retargeting and (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
Expand Down
230 changes: 230 additions & 0 deletions test/asynchronous/test_backpressure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test Client Backpressure spec."""
from __future__ import annotations

import asyncio
import sys

import pymongo

sys.path[0:0] = [""]

from test.asynchronous import (
AsyncIntegrationTest,
AsyncPyMongoTestCase,
async_client_context,
unittest,
)

from pymongo.asynchronous import helpers
from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket
from pymongo.errors import PyMongoError

_IS_SYNC = False

# Mock an system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError"],
},
}


class TestBackpressure(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True

@async_client_context.require_failCommand_appName
async def test_retry_overload_error_command(self):
await self.db.t.insert_one({"x": 1})

# Ensure command is retried on overload error.
fail_many = mock_overload_error.copy()
fail_many["mode"] = {"times": _MAX_RETRIES}
async with self.fail_point(fail_many):
await self.db.command("find", "t")

# Ensure command stops retrying after _MAX_RETRIES.
fail_too_many = mock_overload_error.copy()
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.command("find", "t")

self.assertIn("RetryableError", str(error.exception))

@async_client_context.require_failCommand_appName
async def test_retry_overload_error_find(self):
await self.db.t.insert_one({"x": 1})

# Ensure command is retried on overload error.
fail_many = mock_overload_error.copy()
fail_many["mode"] = {"times": _MAX_RETRIES}
async with self.fail_point(fail_many):
await self.db.t.find_one()

# Ensure command stops retrying after _MAX_RETRIES.
fail_too_many = mock_overload_error.copy()
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.find_one()

self.assertIn("RetryableError", str(error.exception))

@async_client_context.require_failCommand_appName
async def test_retry_overload_error_insert_one(self):
await self.db.t.insert_one({"x": 1})

# Ensure command is retried on overload error.
fail_many = mock_overload_error.copy()
fail_many["mode"] = {"times": _MAX_RETRIES}
async with self.fail_point(fail_many):
await self.db.t.find_one()

# Ensure command stops retrying after _MAX_RETRIES.
fail_too_many = mock_overload_error.copy()
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.find_one()

self.assertIn("RetryableError", str(error.exception))

@async_client_context.require_failCommand_appName
async def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
await self.db.t.insert_one({"x": 1})

# Ensure command is retried on overload error.
fail_many = mock_overload_error.copy()
fail_many["mode"] = {"times": _MAX_RETRIES}
async with self.fail_point(fail_many):
await self.db.t.update_many({}, {"$set": {"x": 2}})

# Ensure command stops retrying after _MAX_RETRIES.
fail_too_many = mock_overload_error.copy()
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.update_many({}, {"$set": {"x": 2}})

self.assertIn("RetryableError", str(error.exception))

@async_client_context.require_failCommand_appName
async def test_retry_overload_error_getMore(self):
coll = self.db.t
await coll.insert_many([{"x": 1} for _ in range(10)])

# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": _MAX_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError"],
},
}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_many):
await cursor.to_list()

# Ensure command stops retrying after _MAX_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await cursor.to_list()

self.assertIn("RetryableError", str(error.exception))

@async_client_context.require_failCommand_appName
async def test_limit_retry_command(self):
client = await self.async_rs_or_single_client()
client._retry_policy.token_bucket.tokens = 1
db = client.pymongo_test
await db.t.insert_one({"x": 1})

# Ensure command is retried once overload error.
fail_many = mock_overload_error.copy()
fail_many["mode"] = {"times": 1}
async with self.fail_point(fail_many):
await db.command("find", "t")

# Ensure command stops retrying when there are no tokens left.
fail_too_many = mock_overload_error.copy()
fail_too_many["mode"] = {"times": 2}
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await db.command("find", "t")

self.assertIn("RetryableError", str(error.exception))


class TestRetryPolicy(AsyncPyMongoTestCase):
async def test_retry_policy(self):
capacity = 10
retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity))
self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES)
self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL)
self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX)
for i in range(1, helpers._MAX_RETRIES + 1):
self.assertTrue(await retry_policy.should_retry(i, 0))
self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0))
for i in range(capacity - helpers._MAX_RETRIES):
self.assertTrue(await retry_policy.should_retry(1, 0))
# No tokens left, should not retry.
self.assertFalse(await retry_policy.should_retry(1, 0))
self.assertEqual(retry_policy.token_bucket.tokens, 0)

# record_success should generate tokens.
for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)):
await retry_policy.record_success(retry=False)
self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2)
for i in range(2):
self.assertTrue(await retry_policy.should_retry(1, 0))
self.assertFalse(await retry_policy.should_retry(1, 0))

# Recording a successful retry should return 1 additional token.
await retry_policy.record_success(retry=True)
self.assertAlmostEqual(
retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN
)
self.assertTrue(await retry_policy.should_retry(1, 0))
self.assertFalse(await retry_policy.should_retry(1, 0))
self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN)

async def test_retry_policy_csot(self):
retry_policy = _RetryPolicy(_TokenBucket())
self.assertTrue(await retry_policy.should_retry(1, 0.5))
with pymongo.timeout(0.5):
self.assertTrue(await retry_policy.should_retry(1, 0))
self.assertTrue(await retry_policy.should_retry(1, 0.1))
# Would exceed the timeout, should not retry.
self.assertFalse(await retry_policy.should_retry(1, 1.0))
self.assertTrue(await retry_policy.should_retry(1, 1.0))


if __name__ == "__main__":
unittest.main()
Loading
Loading