From fab43f49c0896c62b60fa3139a2215d8865e67d6 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 09:01:10 -0500 Subject: [PATCH 1/7] PYTHON-5672 Refactor connection checkout to reduce layers Replace the three @asynccontextmanager layers on the connection checkout hot path with class-based async context managers, and eliminate _MongoClientErrorHandler by absorbing it into _ClientCheckout. - _PoolCheckout replaces Pool.checkout() generator CM - _ClientCheckout replaces _checkout() generator CM and absorbs all of _MongoClientErrorHandler (contribute_socket, handle, SDAM error logic) - _ClientReadCheckout extends _ClientCheckout to apply single-topology read preference adjustment (formerly _conn_from_server()) - active_contexts.add() consolidated into _get_conn(), avoiding a separate lock acquisition on the hot path; deduped so new connections (already tracked by connect()) are not double-added - Connection leak fixed: self._conn assigned before event publishing so checkin runs if a CMAP listener raises in __aenter__ - _ClientCheckout.for_existing_conn() classmethod handles the _run_operation() getMore path that needs SDAM handling around an already-checked-out connection --- pymongo/asynchronous/mongo_client.py | 185 ++++++++++++-------- pymongo/asynchronous/pool.py | 188 ++++++++++++--------- pymongo/asynchronous/server.py | 11 +- pymongo/synchronous/mongo_client.py | 185 ++++++++++++-------- pymongo/synchronous/pool.py | 188 ++++++++++++--------- pymongo/synchronous/server.py | 11 +- test/asynchronous/pymongo_mocks.py | 2 +- test/asynchronous/test_read_preferences.py | 2 +- test/pymongo_mocks.py | 2 +- test/test_read_preferences.py | 2 +- 10 files changed, 469 insertions(+), 307 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 6aeea53f4c..84595c5105 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -41,7 +41,6 @@ import weakref from collections import defaultdict from collections.abc import AsyncGenerator, Collection, Coroutine, Mapping, MutableMapping, Sequence -from contextlib import AbstractAsyncContextManager from typing import ( TYPE_CHECKING, Any, @@ -65,6 +64,7 @@ from pymongo.asynchronous.helpers import ( _RetryPolicy, ) +from pymongo.asynchronous.pool import _PoolCheckout from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -1780,39 +1780,8 @@ async def _get_topology(self) -> Topology: self._opened = True return self._topology - @contextlib.asynccontextmanager - async def _checkout( - self, server: Server, session: Optional[AsyncClientSession] - ) -> AsyncGenerator[AsyncConnection, None]: - in_txn = session and session.in_transaction - async with _MongoClientErrorHandler(self, server, session) as err_handler: - # Reuse the pinned connection, if it exists. - if in_txn and session and session._pinned_connection: - err_handler.contribute_socket(session._pinned_connection) - yield session._pinned_connection - return - async with await server.checkout(handler=err_handler) as conn: - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - err_handler.contribute_socket(conn) - if ( - self._encrypter - and not self._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError( - "Auto-encryption requires a minimum MongoDB version of 4.2" - ) - yield conn + def _checkout(self, server: Server, session: Optional[AsyncClientSession]) -> _ClientCheckout: + return _ClientCheckout(self, server, session) async def _select_server( self, @@ -1863,41 +1832,22 @@ async def _select_server( async def _conn_for_writes( self, session: Optional[AsyncClientSession], operation: str - ) -> AbstractAsyncContextManager[AsyncConnection]: + ) -> _ClientCheckout: server = await self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) - @contextlib.asynccontextmanager - async def _conn_from_server( + def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession] - ) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" - # Get a connection for a server matching the read preference, and yield - # conn with the effective read preference. The Server Selection - # Spec says not to send any $readPreference to standalones and to - # always send primaryPreferred when directly connected to a repl set - # member. - # Thread safe: if the type is single it cannot change. - # NOTE: We already opened the Topology when selecting a server so there's no need - # to call _get_topology() again. - single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single - async with self._checkout(server, session) as conn: - if single: - if conn.is_repl and not (session and session.in_transaction): - # Use primary preferred to ensure any repl set member - # can handle the request. - read_preference = ReadPreference.PRIMARY_PREFERRED - elif conn.is_standalone: - # Don't send read preference to standalones. - read_preference = ReadPreference.PRIMARY - yield conn, read_preference + return _ClientReadCheckout(self, server, session, read_preference) async def _conn_for_reads( self, read_preference: _ServerMode, session: Optional[AsyncClientSession], operation: str, - ) -> AbstractAsyncContextManager[tuple[AsyncConnection, _ServerMode]]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" server = await self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) @@ -1925,8 +1875,12 @@ async def _run_operation( ) async with operation.conn_mgr._lock: - async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] - err_handler.contribute_socket(operation.conn_mgr.conn) + async with _ClientCheckout.for_existing_conn( + self, + server, + operation.session, # type: ignore[arg-type] + operation.conn_mgr.conn, + ): return await server.run_operation( operation.conn_mgr.conn, operation, @@ -2667,10 +2621,18 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong exc_to_check._add_error_label("RetryableWriteError") -class _MongoClientErrorHandler: - """Handle errors raised when executing an operation.""" +class _ClientCheckout: + """Context manager for checking out a connection from the pool. + + Absorbs the former _MongoClientErrorHandler and the @asynccontextmanager + _checkout() method into a single class-based CM to eliminate generator + overhead on the hot path. + """ __slots__ = ( + "_existing_conn", + "_pool_checkout", + "_server", "client", "completed_handshake", "handled", @@ -2704,6 +2666,9 @@ def __init__( self.completed_handshake = False self.service_id: Optional[ObjectId] = None self.handled = False + self._existing_conn: Optional[AsyncConnection] = None + self._pool_checkout: Optional[_PoolCheckout] = None + self._server = server def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" @@ -2741,16 +2706,102 @@ async def handle( assert self.client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) - async def __aenter__(self) -> _MongoClientErrorHandler: - return self + async def __aenter__(self) -> AsyncConnection: + if self._existing_conn is not None: + return self._existing_conn + server = self._server + session = self.session + in_txn = session and session.in_transaction + # Reuse the pinned connection, if it exists. + if in_txn and session and session._pinned_connection: + self.contribute_socket(session._pinned_connection) + return session._pinned_connection + pool_checkout = _PoolCheckout(server.pool, self) + conn = await pool_checkout.__aenter__() + self._pool_checkout = pool_checkout + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError("Auto-encryption requires a minimum MongoDB version of 4.2") + return conn async def __aexit__( self, - exc_type: Optional[type[Exception]], - exc_val: Optional[Exception], + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - return await self.handle(exc_type, exc_val) + # Perform SDAM error handling while the connection is still checked out. + await self.handle(exc_type, exc_val) + if self._pool_checkout is not None: + await self._pool_checkout.__aexit__(exc_type, exc_val, exc_tb) + + @classmethod + def for_existing_conn( + cls, + client: AsyncMongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[AsyncClientSession], + conn: AsyncConnection, + ) -> _ClientCheckout: + """Return a _ClientCheckout for an already-checked-out connection. + + Used when SDAM error handling is needed around an existing connection + without performing a new pool checkout (e.g. re-running a getMore). + """ + checkout = cls(client, server, session) + checkout.contribute_socket(conn) + checkout._existing_conn = conn + return checkout + + +class _ClientReadCheckout(_ClientCheckout): + """Context manager for read operations. + + Extends _ClientCheckout to apply the single-topology read preference + adjustment (formerly in _conn_from_server()) and return the effective + read preference alongside the connection. + """ + + __slots__ = ("_effective_read_pref",) + + def __init__( + self, + client: AsyncMongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[AsyncClientSession], + read_preference: _ServerMode, + ) -> None: + super().__init__(client, server, session) + self._effective_read_pref: _ServerMode = read_preference + + async def __aenter__(self) -> tuple[AsyncConnection, _ServerMode]: # type: ignore[override] + conn = await super().__aenter__() + # The Server Selection Spec says not to send any $readPreference to + # standalones and to always send primaryPreferred when directly + # connected to a replica set member. + # Thread safe: topology type cannot change once set to Single. + single = self.client._topology.description.topology_type == TOPOLOGY_TYPE.Single + if single: + if conn.is_repl and not (self.session and self.session.in_transaction): + self._effective_read_pref = ReadPreference.PRIMARY_PREFERRED + elif conn.is_standalone: + self._effective_read_pref = ReadPreference.PRIMARY + return conn, self._effective_read_pref class _ClientConnectionRetryable(Generic[T]): diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 60acb93fcd..aeda81c12e 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -16,15 +16,13 @@ import asyncio import collections -import contextlib import logging import os import socket import ssl -import sys import time import weakref -from collections.abc import AsyncGenerator, Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -89,11 +87,13 @@ from pymongo.socket_checker import SocketChecker if TYPE_CHECKING: + from types import TracebackType + from bson import CodecOptions from bson.objectid import ObjectId from pymongo.asynchronous.auth import _AuthContext from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.asynchronous.mongo_client import AsyncMongoClient, _ClientCheckout from pymongo.compression_support import ( SnappyContext, ZlibContext, @@ -745,7 +745,7 @@ def __init__( # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[AsyncConnection] = set() + self._pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 @@ -981,7 +981,7 @@ def _handle_connection_error(self, error: BaseException) -> None: error._add_error_label("SystemOverloadedError") error._add_error_label("RetryableError") - async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: + async def connect(self, handler: Optional[_ClientCheckout] = None) -> AsyncConnection: """Connect to Mongo and return a new AsyncConnection. Can raise ConnectionFailure. @@ -1077,84 +1077,23 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A return conn - @contextlib.asynccontextmanager - async def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncGenerator[AsyncConnection, None]: - """Get a connection from the pool. Use with a "with" statement. + def checkout(self, handler: Optional[_ClientCheckout] = None) -> _PoolCheckout: + """Get a connection from the pool. Use with an "async with" statement. - Returns a :class:`AsyncConnection` object wrapping a connected - :class:`socket.socket`. + Returns a :class:`_PoolCheckout` context manager that yields a + :class:`AsyncConnection` object wrapping a connected socket. - This method should always be used in a with-statement:: + This method should always be used in an async-with-statement:: - with pool.get_conn() as connection: + async with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) Can raise ConnectionFailure or OperationFailure. - :param handler: A _MongoClientErrorHandler. + :param handler: A _ClientCheckout error handler. """ - listeners = self.opts._event_listeners - checkout_started_time = time.monotonic() - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_check_out_started(self.address) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_STARTED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - - conn = await self._get_conn(checkout_started_time, handler=handler) - - duration = time.monotonic() - checkout_started_time - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_checked_out(self.address, conn.id, duration) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - driverConnectionId=conn.id, - durationMS=duration, - ) - try: - async with self.lock: - self.active_contexts.add(conn.cancel_context) - yield conn - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: - # Exception in caller. Ensure the connection gets returned. - # Note that when pinned is True, the session owns the - # connection and it is responsible for checking the connection - # back into the pool. - pinned = conn.pinned_txn or conn.pinned_cursor - if handler: - # Perform SDAM error handling rules while the connection is - # still checked out. - exc_type, exc_val, _ = sys.exc_info() - await handler.handle(exc_type, exc_val) - if not pinned and conn.active: - await self.checkin(conn) - raise - if conn.pinned_txn: - async with self.lock: - self.__pinned_sockets.add(conn) - self.ntxns += 1 - elif conn.pinned_cursor: - async with self.lock: - self.__pinned_sockets.add(conn) - self.ncursors += 1 - elif conn.active: - await self.checkin(conn) + return _PoolCheckout(self, handler) def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> None: if self.state != PoolState.READY: @@ -1183,7 +1122,7 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> ) async def _get_conn( - self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None + self, checkout_started_time: float, handler: Optional[_ClientCheckout] = None ) -> AsyncConnection: """Get or create a AsyncConnection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. @@ -1242,6 +1181,7 @@ async def _get_conn( conn = None incremented = False emitted_event = False + is_new_conn = False try: async with self.lock: self.active_sockets += 1 @@ -1273,6 +1213,7 @@ async def _get_conn( else: # We need to create a new connection try: conn = await self.connect(handler=handler) + is_new_conn = True finally: async with self._max_connecting_cond: self._pending -= 1 @@ -1309,6 +1250,11 @@ async def _get_conn( raise conn.active = True + # connect() already adds cancel_context for new connections; only add + # here for reused connections taken from the idle pool. + if not is_new_conn: + async with self.lock: + self.active_contexts.add(conn.cancel_context) return conn async def checkin(self, conn: AsyncConnection) -> None: @@ -1321,7 +1267,7 @@ async def checkin(self, conn: AsyncConnection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) + self._pinned_sockets.discard(conn) listeners = self.opts._event_listeners async with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1463,3 +1409,91 @@ def __del__(self) -> None: if _IS_SYNC: for conn in self.conns: conn.close_conn(None) # type: ignore[unused-coroutine] + + +class _PoolCheckout: + """Class-based context manager for pool connection checkout.""" + + __slots__ = ("_checkout_started_time", "_conn", "_handler", "_pool") + + def __init__( + self, + pool: Pool, + handler: Optional[_ClientCheckout] = None, + ) -> None: + self._pool = pool + self._handler = handler + self._conn: Optional[AsyncConnection] = None + self._checkout_started_time: float = 0.0 + + async def __aenter__(self) -> AsyncConnection: + pool = self._pool + self._checkout_started_time = time.monotonic() + checkout_started_time = self._checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_check_out_started(pool.address) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + ) + + conn = await pool._get_conn(checkout_started_time, handler=self._handler) + self._conn = conn + try: + duration = time.monotonic() - checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_checked_out( + pool.address, conn.id, duration + ) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + driverConnectionId=conn.id, + durationMS=duration, + ) + except BaseException: + await pool.checkin(conn) + self._conn = None + raise + return conn + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + conn = self._conn + if conn is None: + return + pool = self._pool + if exc_type is not None: + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the connection + # and is responsible for checking it back into the pool. + # SDAM error handling is performed by _ClientCheckout.__aexit__ + # before this method is called. + pinned = conn.pinned_txn or conn.pinned_cursor + if not pinned and conn.active: + await pool.checkin(conn) + else: + if conn.pinned_txn: + async with pool.lock: + pool._pinned_sockets.add(conn) + pool.ntxns += 1 + elif conn.pinned_cursor: + async with pool.lock: + pool._pinned_sockets.add(conn) + pool.ncursors += 1 + elif conn.active: + await pool.checkin(conn) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 9a6984f486..9d78a4cba2 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging -from contextlib import AbstractAsyncContextManager from datetime import datetime from typing import ( TYPE_CHECKING, @@ -42,9 +41,9 @@ from weakref import ReferenceType from bson.objectid import ObjectId - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.monitor import Monitor - from pymongo.asynchronous.pool import AsyncConnection, Pool + from pymongo.asynchronous.pool import AsyncConnection, Pool, _PoolCheckout from pymongo.monitoring import _EventListeners from pymongo.read_preferences import _ServerMode from pymongo.server_description import ServerDescription @@ -227,10 +226,8 @@ async def run_operation( return response - async def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AbstractAsyncContextManager[AsyncConnection]: - return self.pool.checkout(handler) + def checkout(self) -> _PoolCheckout: + return self.pool.checkout() @property def description(self) -> ServerDescription: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6b7c5d9c98..6ad564902a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -41,7 +41,6 @@ import weakref from collections import defaultdict from collections.abc import Collection, Generator, Mapping, MutableMapping, Sequence -from contextlib import AbstractContextManager from typing import ( TYPE_CHECKING, Any, @@ -109,6 +108,7 @@ from pymongo.synchronous.helpers import ( _RetryPolicy, ) +from pymongo.synchronous.pool import _PoolCheckout from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -1777,39 +1777,8 @@ def _get_topology(self) -> Topology: self._opened = True return self._topology - @contextlib.contextmanager - def _checkout( - self, server: Server, session: Optional[ClientSession] - ) -> Generator[Connection, None]: - in_txn = session and session.in_transaction - with _MongoClientErrorHandler(self, server, session) as err_handler: - # Reuse the pinned connection, if it exists. - if in_txn and session and session._pinned_connection: - err_handler.contribute_socket(session._pinned_connection) - yield session._pinned_connection - return - with server.checkout(handler=err_handler) as conn: - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - err_handler.contribute_socket(conn) - if ( - self._encrypter - and not self._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError( - "Auto-encryption requires a minimum MongoDB version of 4.2" - ) - yield conn + def _checkout(self, server: Server, session: Optional[ClientSession]) -> _ClientCheckout: + return _ClientCheckout(self, server, session) def _select_server( self, @@ -1858,43 +1827,22 @@ def _select_server( session._unpin() raise - def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AbstractContextManager[Connection]: + def _conn_for_writes(self, session: Optional[ClientSession], operation: str) -> _ClientCheckout: server = self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) - @contextlib.contextmanager def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> Generator[tuple[Connection, _ServerMode], None]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" - # Get a connection for a server matching the read preference, and yield - # conn with the effective read preference. The Server Selection - # Spec says not to send any $readPreference to standalones and to - # always send primaryPreferred when directly connected to a repl set - # member. - # Thread safe: if the type is single it cannot change. - # NOTE: We already opened the Topology when selecting a server so there's no need - # to call _get_topology() again. - single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._checkout(server, session) as conn: - if single: - if conn.is_repl and not (session and session.in_transaction): - # Use primary preferred to ensure any repl set member - # can handle the request. - read_preference = ReadPreference.PRIMARY_PREFERRED - elif conn.is_standalone: - # Don't send read preference to standalones. - read_preference = ReadPreference.PRIMARY - yield conn, read_preference + return _ClientReadCheckout(self, server, session, read_preference) def _conn_for_reads( self, read_preference: _ServerMode, session: Optional[ClientSession], operation: str, - ) -> AbstractContextManager[tuple[Connection, _ServerMode]]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" server = self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) @@ -1922,8 +1870,12 @@ def _run_operation( ) with operation.conn_mgr._lock: - with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] - err_handler.contribute_socket(operation.conn_mgr.conn) + with _ClientCheckout.for_existing_conn( + self, + server, + operation.session, # type: ignore[arg-type] + operation.conn_mgr.conn, + ): return server.run_operation( operation.conn_mgr.conn, operation, @@ -2658,10 +2610,18 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong exc_to_check._add_error_label("RetryableWriteError") -class _MongoClientErrorHandler: - """Handle errors raised when executing an operation.""" +class _ClientCheckout: + """Context manager for checking out a connection from the pool. + + Absorbs the former _MongoClientErrorHandler and the @asynccontextmanager + _checkout() method into a single class-based CM to eliminate generator + overhead on the hot path. + """ __slots__ = ( + "_existing_conn", + "_pool_checkout", + "_server", "client", "completed_handshake", "handled", @@ -2695,6 +2655,9 @@ def __init__( self.completed_handshake = False self.service_id: Optional[ObjectId] = None self.handled = False + self._existing_conn: Optional[Connection] = None + self._pool_checkout: Optional[_PoolCheckout] = None + self._server = server def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" @@ -2732,16 +2695,102 @@ def handle( assert self.client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) - def __enter__(self) -> _MongoClientErrorHandler: - return self + def __enter__(self) -> Connection: + if self._existing_conn is not None: + return self._existing_conn + server = self._server + session = self.session + in_txn = session and session.in_transaction + # Reuse the pinned connection, if it exists. + if in_txn and session and session._pinned_connection: + self.contribute_socket(session._pinned_connection) + return session._pinned_connection + pool_checkout = _PoolCheckout(server.pool, self) + conn = pool_checkout.__enter__() + self._pool_checkout = pool_checkout + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError("Auto-encryption requires a minimum MongoDB version of 4.2") + return conn def __exit__( self, - exc_type: Optional[type[Exception]], - exc_val: Optional[Exception], + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - return self.handle(exc_type, exc_val) + # Perform SDAM error handling while the connection is still checked out. + self.handle(exc_type, exc_val) + if self._pool_checkout is not None: + self._pool_checkout.__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def for_existing_conn( + cls, + client: MongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[ClientSession], + conn: Connection, + ) -> _ClientCheckout: + """Return a _ClientCheckout for an already-checked-out connection. + + Used when SDAM error handling is needed around an existing connection + without performing a new pool checkout (e.g. re-running a getMore). + """ + checkout = cls(client, server, session) + checkout.contribute_socket(conn) + checkout._existing_conn = conn + return checkout + + +class _ClientReadCheckout(_ClientCheckout): + """Context manager for read operations. + + Extends _ClientCheckout to apply the single-topology read preference + adjustment (formerly in _conn_from_server()) and return the effective + read preference alongside the connection. + """ + + __slots__ = ("_effective_read_pref",) + + def __init__( + self, + client: MongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[ClientSession], + read_preference: _ServerMode, + ) -> None: + super().__init__(client, server, session) + self._effective_read_pref: _ServerMode = read_preference + + def __enter__(self) -> tuple[Connection, _ServerMode]: # type: ignore[override] + conn = super().__enter__() + # The Server Selection Spec says not to send any $readPreference to + # standalones and to always send primaryPreferred when directly + # connected to a replica set member. + # Thread safe: topology type cannot change once set to Single. + single = self.client._topology.description.topology_type == TOPOLOGY_TYPE.Single + if single: + if conn.is_repl and not (self.session and self.session.in_transaction): + self._effective_read_pref = ReadPreference.PRIMARY_PREFERRED + elif conn.is_standalone: + self._effective_read_pref = ReadPreference.PRIMARY + return conn, self._effective_read_pref class _ClientConnectionRetryable(Generic[T]): diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index b3929b674a..84d09742a1 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -16,15 +16,13 @@ import asyncio import collections -import contextlib import logging import os import socket import ssl -import sys import time import weakref -from collections.abc import Generator, Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -89,6 +87,8 @@ from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: + from types import TracebackType + from bson import CodecOptions from bson.objectid import ObjectId from pymongo.compression_support import ( @@ -101,7 +101,7 @@ from pymongo.read_preferences import _ServerMode from pymongo.synchronous.auth import _AuthContext from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.synchronous.mongo_client import MongoClient, _ClientCheckout from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern @@ -743,7 +743,7 @@ def __init__( # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[Connection] = set() + self._pinned_sockets: set[Connection] = set() self.ncursors = 0 self.ntxns = 0 @@ -977,7 +977,7 @@ def _handle_connection_error(self, error: BaseException) -> None: error._add_error_label("SystemOverloadedError") error._add_error_label("RetryableError") - def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: + def connect(self, handler: Optional[_ClientCheckout] = None) -> Connection: """Connect to Mongo and return a new Connection. Can raise ConnectionFailure. @@ -1073,84 +1073,23 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect return conn - @contextlib.contextmanager - def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> Generator[Connection, None]: - """Get a connection from the pool. Use with a "with" statement. + def checkout(self, handler: Optional[_ClientCheckout] = None) -> _PoolCheckout: + """Get a connection from the pool. Use with an "with" statement. - Returns a :class:`Connection` object wrapping a connected - :class:`socket.socket`. + Returns a :class:`_PoolCheckout` context manager that yields a + :class:`Connection` object wrapping a connected socket. - This method should always be used in a with-statement:: + This method should always be used in an async-with-statement:: - with pool.get_conn() as connection: + with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) Can raise ConnectionFailure or OperationFailure. - :param handler: A _MongoClientErrorHandler. + :param handler: A _ClientCheckout error handler. """ - listeners = self.opts._event_listeners - checkout_started_time = time.monotonic() - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_check_out_started(self.address) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_STARTED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - - conn = self._get_conn(checkout_started_time, handler=handler) - - duration = time.monotonic() - checkout_started_time - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_checked_out(self.address, conn.id, duration) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - driverConnectionId=conn.id, - durationMS=duration, - ) - try: - with self.lock: - self.active_contexts.add(conn.cancel_context) - yield conn - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: - # Exception in caller. Ensure the connection gets returned. - # Note that when pinned is True, the session owns the - # connection and it is responsible for checking the connection - # back into the pool. - pinned = conn.pinned_txn or conn.pinned_cursor - if handler: - # Perform SDAM error handling rules while the connection is - # still checked out. - exc_type, exc_val, _ = sys.exc_info() - handler.handle(exc_type, exc_val) - if not pinned and conn.active: - self.checkin(conn) - raise - if conn.pinned_txn: - with self.lock: - self.__pinned_sockets.add(conn) - self.ntxns += 1 - elif conn.pinned_cursor: - with self.lock: - self.__pinned_sockets.add(conn) - self.ncursors += 1 - elif conn.active: - self.checkin(conn) + return _PoolCheckout(self, handler) def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> None: if self.state != PoolState.READY: @@ -1179,7 +1118,7 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> ) def _get_conn( - self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None + self, checkout_started_time: float, handler: Optional[_ClientCheckout] = None ) -> Connection: """Get or create a Connection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. @@ -1238,6 +1177,7 @@ def _get_conn( conn = None incremented = False emitted_event = False + is_new_conn = False try: with self.lock: self.active_sockets += 1 @@ -1269,6 +1209,7 @@ def _get_conn( else: # We need to create a new connection try: conn = self.connect(handler=handler) + is_new_conn = True finally: with self._max_connecting_cond: self._pending -= 1 @@ -1305,6 +1246,11 @@ def _get_conn( raise conn.active = True + # connect() already adds cancel_context for new connections; only add + # here for reused connections taken from the idle pool. + if not is_new_conn: + with self.lock: + self.active_contexts.add(conn.cancel_context) return conn def checkin(self, conn: Connection) -> None: @@ -1317,7 +1263,7 @@ def checkin(self, conn: Connection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) + self._pinned_sockets.discard(conn) listeners = self.opts._event_listeners with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1459,3 +1405,91 @@ def __del__(self) -> None: if _IS_SYNC: for conn in self.conns: conn.close_conn(None) # type: ignore[unused-coroutine] + + +class _PoolCheckout: + """Class-based context manager for pool connection checkout.""" + + __slots__ = ("_checkout_started_time", "_conn", "_handler", "_pool") + + def __init__( + self, + pool: Pool, + handler: Optional[_ClientCheckout] = None, + ) -> None: + self._pool = pool + self._handler = handler + self._conn: Optional[Connection] = None + self._checkout_started_time: float = 0.0 + + def __enter__(self) -> Connection: + pool = self._pool + self._checkout_started_time = time.monotonic() + checkout_started_time = self._checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_check_out_started(pool.address) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + ) + + conn = pool._get_conn(checkout_started_time, handler=self._handler) + self._conn = conn + try: + duration = time.monotonic() - checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_checked_out( + pool.address, conn.id, duration + ) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + driverConnectionId=conn.id, + durationMS=duration, + ) + except BaseException: + pool.checkin(conn) + self._conn = None + raise + return conn + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + conn = self._conn + if conn is None: + return + pool = self._pool + if exc_type is not None: + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the connection + # and is responsible for checking it back into the pool. + # SDAM error handling is performed by _ClientCheckout.__aexit__ + # before this method is called. + pinned = conn.pinned_txn or conn.pinned_cursor + if not pinned and conn.active: + pool.checkin(conn) + else: + if conn.pinned_txn: + with pool.lock: + pool._pinned_sockets.add(conn) + pool.ntxns += 1 + elif conn.pinned_cursor: + with pool.lock: + pool._pinned_sockets.add(conn) + pool.ncursors += 1 + elif conn.active: + pool.checkin(conn) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 7aa017134a..dfe28f3490 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging -from contextlib import AbstractContextManager from datetime import datetime from typing import ( TYPE_CHECKING, @@ -45,9 +44,9 @@ from pymongo.monitoring import _EventListeners from pymongo.read_preferences import _ServerMode from pymongo.server_description import ServerDescription - from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor - from pymongo.synchronous.pool import Connection, Pool + from pymongo.synchronous.pool import Connection, Pool, _PoolCheckout from pymongo.typings import _DocumentOut _IS_SYNC = True @@ -227,10 +226,8 @@ def run_operation( return response - def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AbstractContextManager[Connection]: - return self.pool.checkout(handler) + def checkout(self) -> _PoolCheckout: + return self.pool.checkout() @property def description(self) -> ServerDescription: diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index 4413cbe43b..e71be6a502 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -43,7 +43,7 @@ def __init__(self, client, pair, *args, **kwargs): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.asynccontextmanager - async def checkout(self, handler=None): + async def checkout(self, handler=None): # type: ignore[override] client = self.client host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py index 9f92a39920..7801926550 100644 --- a/test/asynchronous/test_read_preferences.py +++ b/test/asynchronous/test_read_preferences.py @@ -334,7 +334,7 @@ async def _conn_for_reads(self, read_preference, session, operation): return context @contextlib.asynccontextmanager - async def _conn_from_server(self, read_preference, server, session): + async def _conn_from_server(self, read_preference, server, session): # type: ignore[override] context = super()._conn_from_server(read_preference, server, session) async with context as (conn, read_preference): await self.record_a_read(conn.address) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 206c1a61b7..a3092bffe9 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -42,7 +42,7 @@ def __init__(self, client, pair, *args, **kwargs): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def checkout(self, handler=None): + def checkout(self, handler=None): # type: ignore[override] client = self.client host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 27c8d0704a..65c25735f5 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -314,7 +314,7 @@ def _conn_for_reads(self, read_preference, session, operation): return context @contextlib.contextmanager - def _conn_from_server(self, read_preference, server, session): + def _conn_from_server(self, read_preference, server, session): # type: ignore[override] context = super()._conn_from_server(read_preference, server, session) with context as (conn, read_preference): self.record_a_read(conn.address) From 9b09a4e781cfff36bee62363ba651252398992e5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 09:22:12 -0500 Subject: [PATCH 2/7] PYTHON-5672 Add __slots__ and lazy _deprioritized_servers to _ClientConnectionRetryable Eliminates per-instance __dict__ allocation and defers the _deprioritized_servers list creation until it is actually needed (only on sharded retry paths). --- pymongo/asynchronous/mongo_client.py | 28 +++++++++++++++++++++++++++- pymongo/synchronous/mongo_client.py | 28 +++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 84595c5105..d85ac7f46b 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2807,6 +2807,30 @@ async def __aenter__(self) -> tuple[AsyncConnection, _ServerMode]: # type: igno class _ClientConnectionRetryable(Generic[T]): """Responsible for executing retryable connections on read or write operations""" + __slots__ = ( + "_address", + "_always_retryable", + "_attempt_number", + "_bulk", + "_client", + "_deprioritized_servers", + "_func", + "_is_aggregate_write", + "_is_read", + "_is_run_command", + "_last_error", + "_max_retries", + "_operation", + "_operation_id", + "_read_pref", + "_retry_policy", + "_retryable", + "_retrying", + "_server", + "_server_selector", + "_session", + ) + def __init__( self, mongo_client: AsyncMongoClient, # type: ignore[type-arg] @@ -2839,7 +2863,7 @@ def __init__( ) self._address = address self._server: Server = None # type: ignore - self._deprioritized_servers: list[Server] = [] + self._deprioritized_servers: Optional[list[Server]] = None self._operation = operation self._operation_id = operation_id self._attempt_number = 0 @@ -2979,6 +3003,8 @@ async def run(self) -> T: self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting) ): + if self._deprioritized_servers is None: + self._deprioritized_servers = [] self._deprioritized_servers.append(self._server) self._always_retryable = always_retryable diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6ad564902a..d4a156d143 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2796,6 +2796,30 @@ def __enter__(self) -> tuple[Connection, _ServerMode]: # type: ignore[override] class _ClientConnectionRetryable(Generic[T]): """Responsible for executing retryable connections on read or write operations""" + __slots__ = ( + "_address", + "_always_retryable", + "_attempt_number", + "_bulk", + "_client", + "_deprioritized_servers", + "_func", + "_is_aggregate_write", + "_is_read", + "_is_run_command", + "_last_error", + "_max_retries", + "_operation", + "_operation_id", + "_read_pref", + "_retry_policy", + "_retryable", + "_retrying", + "_server", + "_server_selector", + "_session", + ) + def __init__( self, mongo_client: MongoClient, # type: ignore[type-arg] @@ -2828,7 +2852,7 @@ def __init__( ) self._address = address self._server: Server = None # type: ignore - self._deprioritized_servers: list[Server] = [] + self._deprioritized_servers: Optional[list[Server]] = None self._operation = operation self._operation_id = operation_id self._attempt_number = 0 @@ -2968,6 +2992,8 @@ def run(self) -> T: self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting) ): + if self._deprioritized_servers is None: + self._deprioritized_servers = [] self._deprioritized_servers.append(self._server) self._always_retryable = always_retryable From 9e5dcb2d9f712fc29d150eaa501d95af8c70cac2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 13:03:35 -0500 Subject: [PATCH 3/7] PYTHON-5672 Fix SDAM error handling when pool checkout raises in __aenter__ When _PoolCheckout.__aenter__() raises (e.g. connect() fails), Python does not call _ClientCheckout.__aexit__(), so handle() was never invoked and the topology never learned about the failure. This caused test_5_check_out_fails_ connection_error to fail (missing PoolClearedEvent) and broke failover tests. Fix: wrap both the pool checkout call and the post-checkout setup in try/except BaseException blocks inside _ClientCheckout.__aenter__(), calling handle() and (for post-checkout failures) checking the connection back in before re-raising. --- pymongo/asynchronous/mongo_client.py | 52 ++++++++++++++++++---------- pymongo/synchronous/mongo_client.py | 52 ++++++++++++++++++---------- 2 files changed, 66 insertions(+), 38 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index d85ac7f46b..e0b5b7fcad 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2717,26 +2717,40 @@ async def __aenter__(self) -> AsyncConnection: self.contribute_socket(session._pinned_connection) return session._pinned_connection pool_checkout = _PoolCheckout(server.pool, self) - conn = await pool_checkout.__aenter__() + try: + conn = await pool_checkout.__aenter__() + except BaseException as exc: + # __aenter__ raised — pool already cleaned up internally. + # Run SDAM error handling so the topology learns about the failure. + await self.handle(type(exc), exc) + raise self._pool_checkout = pool_checkout - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - self.contribute_socket(conn) - if ( - self.client._encrypter - and not self.client._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError("Auto-encryption requires a minimum MongoDB version of 4.2") + try: + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError( + "Auto-encryption requires a minimum MongoDB version of 4.2" + ) + except BaseException as exc: + await self.handle(type(exc), exc) + await pool_checkout.__aexit__(type(exc), exc, None) + self._pool_checkout = None + raise return conn async def __aexit__( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index d4a156d143..74a588f8d1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2706,26 +2706,40 @@ def __enter__(self) -> Connection: self.contribute_socket(session._pinned_connection) return session._pinned_connection pool_checkout = _PoolCheckout(server.pool, self) - conn = pool_checkout.__enter__() + try: + conn = pool_checkout.__enter__() + except BaseException as exc: + # __aenter__ raised — pool already cleaned up internally. + # Run SDAM error handling so the topology learns about the failure. + self.handle(type(exc), exc) + raise self._pool_checkout = pool_checkout - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - self.contribute_socket(conn) - if ( - self.client._encrypter - and not self.client._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError("Auto-encryption requires a minimum MongoDB version of 4.2") + try: + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError( + "Auto-encryption requires a minimum MongoDB version of 4.2" + ) + except BaseException as exc: + self.handle(type(exc), exc) + pool_checkout.__exit__(type(exc), exc, None) + self._pool_checkout = None + raise return conn def __exit__( From 65d0cbe8ebae9dab7ea6a6dfabf5592b96c17553 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 13:52:22 -0500 Subject: [PATCH 4/7] PYTHON-5672 Fix MockPool.checkout() bypass in _ClientCheckout Constructing _PoolCheckout directly bypassed pool.checkout() overrides used in MockPool, causing mock_down_hosts checks to be skipped and network-error tests to fail. --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index e0b5b7fcad..942b6ae61f 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2716,7 +2716,7 @@ async def __aenter__(self) -> AsyncConnection: if in_txn and session and session._pinned_connection: self.contribute_socket(session._pinned_connection) return session._pinned_connection - pool_checkout = _PoolCheckout(server.pool, self) + pool_checkout = server.pool.checkout(self) try: conn = await pool_checkout.__aenter__() except BaseException as exc: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 74a588f8d1..3eba75a16b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2705,7 +2705,7 @@ def __enter__(self) -> Connection: if in_txn and session and session._pinned_connection: self.contribute_socket(session._pinned_connection) return session._pinned_connection - pool_checkout = _PoolCheckout(server.pool, self) + pool_checkout = server.pool.checkout(self) try: conn = pool_checkout.__enter__() except BaseException as exc: From be245734c0dde0a2f0b7685dffd144061ed8360c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 15:28:43 -0500 Subject: [PATCH 5/7] PYTHON-5672 Add coverage for _PoolCheckout leak guard; remove dead server.checkout() server.checkout() was a thin wrapper around pool.checkout() that nothing called after _ClientCheckout was changed to call server.pool.checkout() directly. Add a test that mocks publish_connection_checked_out to raise and verifies the connection is returned to the pool rather than leaked. --- pymongo/asynchronous/server.py | 5 +---- pymongo/synchronous/server.py | 5 +---- test/asynchronous/test_pooling.py | 28 ++++++++++++++++++++++++++++ test/test_pooling.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 9d78a4cba2..1bc40ae9b4 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -43,7 +43,7 @@ from bson.objectid import ObjectId from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.monitor import Monitor - from pymongo.asynchronous.pool import AsyncConnection, Pool, _PoolCheckout + from pymongo.asynchronous.pool import AsyncConnection, Pool from pymongo.monitoring import _EventListeners from pymongo.read_preferences import _ServerMode from pymongo.server_description import ServerDescription @@ -226,9 +226,6 @@ async def run_operation( return response - def checkout(self) -> _PoolCheckout: - return self.pool.checkout() - @property def description(self) -> ServerDescription: return self._description diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index dfe28f3490..362c80cf10 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -46,7 +46,7 @@ from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor - from pymongo.synchronous.pool import Connection, Pool, _PoolCheckout + from pymongo.synchronous.pool import Connection, Pool from pymongo.typings import _DocumentOut _IS_SYNC = True @@ -226,9 +226,6 @@ def run_operation( return response - def checkout(self) -> _PoolCheckout: - return self.pool.checkout() - @property def description(self) -> ServerDescription: return self._description diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 96b603ec10..398c2f5ec0 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -215,6 +215,34 @@ async def test_get_socket_and_exception(self): self.assertEqual(1, len(cx_pool.conns)) + async def test_checkout_event_listener_failure_no_leak(self): + # Connection is returned to the pool when publish_connection_checked_out raises. + from unittest.mock import patch + + from pymongo.monitoring import _EventListeners + from test.utils_shared import CMAPListener + + cx_pool = await self.create_pool( + max_pool_size=1, event_listeners=_EventListeners([CMAPListener()]) + ) + + with patch.object( + cx_pool.opts._event_listeners, + "publish_connection_checked_out", + side_effect=RuntimeError("simulated failure"), + ): + with self.assertRaises(RuntimeError): + async with cx_pool.checkout(): + pass + + # Connection was returned to the pool — not leaked. + self.assertEqual(1, len(cx_pool.conns)) + self.assertEqual(0, cx_pool.active_sockets) + + # Pool is still functional. + async with cx_pool.checkout(): + pass + async def test_pool_removes_closed_socket(self): # Test that Pool removes explicitly closed socket. cx_pool = await self.create_pool() diff --git a/test/test_pooling.py b/test/test_pooling.py index 47266dd166..529ab7e82b 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -215,6 +215,34 @@ def test_get_socket_and_exception(self): self.assertEqual(1, len(cx_pool.conns)) + def test_checkout_event_listener_failure_no_leak(self): + # Connection is returned to the pool when publish_connection_checked_out raises. + from unittest.mock import patch + + from pymongo.monitoring import _EventListeners + from test.utils_shared import CMAPListener + + cx_pool = self.create_pool( + max_pool_size=1, event_listeners=_EventListeners([CMAPListener()]) + ) + + with patch.object( + cx_pool.opts._event_listeners, + "publish_connection_checked_out", + side_effect=RuntimeError("simulated failure"), + ): + with self.assertRaises(RuntimeError): + with cx_pool.checkout(): + pass + + # Connection was returned to the pool — not leaked. + self.assertEqual(1, len(cx_pool.conns)) + self.assertEqual(0, cx_pool.active_sockets) + + # Pool is still functional. + with cx_pool.checkout(): + pass + def test_pool_removes_closed_socket(self): # Test that Pool removes explicitly closed socket. cx_pool = self.create_pool() From b3ef09cb370dfb9245edf49f0e2217880230d7c7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 16:42:13 -0500 Subject: [PATCH 6/7] PYTHON-5672 Test _ClientCheckout post-checkout setup failure returns connection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Covers the except BaseException block in _ClientCheckout.__aenter__() that fires when post-checkout work (session pinning, ConfigurationError check) raises — verifying the connection is returned to the pool and not leaked. --- test/asynchronous/test_client.py | 23 +++++++++++++++++++++++ test/test_client.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5da186931a..bc46b5f9bb 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -818,6 +818,29 @@ async def test_max_idle_time_checkout(self): self.assertEqual(conn, new_con) self.assertEqual(1, len(server._pool.conns)) + async def test_client_checkout_setup_failure_returns_connection(self): + # Verify that the connection is returned to the pool when an exception + # is raised during _ClientCheckout.__aenter__ post-checkout setup + # (e.g. session pinning or the auto-encryption wire-version check). + from unittest.mock import patch + + from pymongo.asynchronous.mongo_client import _ClientCheckout + + client = await self.async_rs_or_single_client() + server = await (await client._get_topology()).select_server( + writable_server_selector, _Op.TEST + ) + pool = server.pool + + checkout = _ClientCheckout(client, server, None) + with patch.object(checkout, "contribute_socket", side_effect=RuntimeError("simulated")): + with self.assertRaises(RuntimeError): + async with checkout: + pass + + # Connection was returned to pool, not leaked. + self.assertEqual(0, pool.active_sockets) + async def test_constants(self): """This test uses AsyncMongoClient explicitly to make sure that host and port are not overloaded. diff --git a/test/test_client.py b/test/test_client.py index b37b5e57ac..11052b2375 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -793,6 +793,27 @@ def test_max_idle_time_checkout(self): self.assertEqual(conn, new_con) self.assertEqual(1, len(server._pool.conns)) + def test_client_checkout_setup_failure_returns_connection(self): + # Verify that the connection is returned to the pool when an exception + # is raised during _ClientCheckout.__aenter__ post-checkout setup + # (e.g. session pinning or the auto-encryption wire-version check). + from unittest.mock import patch + + from pymongo.synchronous.mongo_client import _ClientCheckout + + client = self.rs_or_single_client() + server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST) + pool = server.pool + + checkout = _ClientCheckout(client, server, None) + with patch.object(checkout, "contribute_socket", side_effect=RuntimeError("simulated")): + with self.assertRaises(RuntimeError): + with checkout: + pass + + # Connection was returned to pool, not leaked. + self.assertEqual(0, pool.active_sockets) + def test_constants(self): """This test uses MongoClient explicitly to make sure that host and port are not overloaded. From 389fc8ec8350f2ac49235074b25ce8cdba1973c2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 25 Jun 2026 17:13:04 -0500 Subject: [PATCH 7/7] PYTHON-5672 Fix test_client_checkout_setup_failure on freethreaded Python MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit patch.object cannot shadow a method on an instance when __slots__ is defined — the attribute is read-only. Use a subclass override instead. --- test/asynchronous/test_client.py | 16 +++++++++------- test/test_client.py | 16 +++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index bc46b5f9bb..44206788ac 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -822,21 +822,23 @@ async def test_client_checkout_setup_failure_returns_connection(self): # Verify that the connection is returned to the pool when an exception # is raised during _ClientCheckout.__aenter__ post-checkout setup # (e.g. session pinning or the auto-encryption wire-version check). - from unittest.mock import patch - + # Use a subclass to override contribute_socket because __slots__ prevents + # instance-level patching of methods. from pymongo.asynchronous.mongo_client import _ClientCheckout + class _BrokenSetupCheckout(_ClientCheckout): + def contribute_socket(self, conn, completed_handshake=True): + raise RuntimeError("simulated failure in post-checkout setup") + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( writable_server_selector, _Op.TEST ) pool = server.pool - checkout = _ClientCheckout(client, server, None) - with patch.object(checkout, "contribute_socket", side_effect=RuntimeError("simulated")): - with self.assertRaises(RuntimeError): - async with checkout: - pass + with self.assertRaises(RuntimeError): + async with _BrokenSetupCheckout(client, server, None): + pass # Connection was returned to pool, not leaked. self.assertEqual(0, pool.active_sockets) diff --git a/test/test_client.py b/test/test_client.py index 11052b2375..b6e2191c5f 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -797,19 +797,21 @@ def test_client_checkout_setup_failure_returns_connection(self): # Verify that the connection is returned to the pool when an exception # is raised during _ClientCheckout.__aenter__ post-checkout setup # (e.g. session pinning or the auto-encryption wire-version check). - from unittest.mock import patch - + # Use a subclass to override contribute_socket because __slots__ prevents + # instance-level patching of methods. from pymongo.synchronous.mongo_client import _ClientCheckout + class _BrokenSetupCheckout(_ClientCheckout): + def contribute_socket(self, conn, completed_handshake=True): + raise RuntimeError("simulated failure in post-checkout setup") + client = self.rs_or_single_client() server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST) pool = server.pool - checkout = _ClientCheckout(client, server, None) - with patch.object(checkout, "contribute_socket", side_effect=RuntimeError("simulated")): - with self.assertRaises(RuntimeError): - with checkout: - pass + with self.assertRaises(RuntimeError): + with _BrokenSetupCheckout(client, server, None): + pass # Connection was returned to pool, not leaked. self.assertEqual(0, pool.active_sockets)