diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 6aeea53f4c..942b6ae61f 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -41,7 +41,6 @@ import weakref from collections import defaultdict from collections.abc import AsyncGenerator, Collection, Coroutine, Mapping, MutableMapping, Sequence -from contextlib import AbstractAsyncContextManager from typing import ( TYPE_CHECKING, Any, @@ -65,6 +64,7 @@ from pymongo.asynchronous.helpers import ( _RetryPolicy, ) +from pymongo.asynchronous.pool import _PoolCheckout from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -1780,39 +1780,8 @@ async def _get_topology(self) -> Topology: self._opened = True return self._topology - @contextlib.asynccontextmanager - async def _checkout( - self, server: Server, session: Optional[AsyncClientSession] - ) -> AsyncGenerator[AsyncConnection, None]: - in_txn = session and session.in_transaction - async with _MongoClientErrorHandler(self, server, session) as err_handler: - # Reuse the pinned connection, if it exists. - if in_txn and session and session._pinned_connection: - err_handler.contribute_socket(session._pinned_connection) - yield session._pinned_connection - return - async with await server.checkout(handler=err_handler) as conn: - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - err_handler.contribute_socket(conn) - if ( - self._encrypter - and not self._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError( - "Auto-encryption requires a minimum MongoDB version of 4.2" - ) - yield conn + def _checkout(self, server: Server, session: Optional[AsyncClientSession]) -> _ClientCheckout: + return _ClientCheckout(self, server, session) async def _select_server( self, @@ -1863,41 +1832,22 @@ async def _select_server( async def _conn_for_writes( self, session: Optional[AsyncClientSession], operation: str - ) -> AbstractAsyncContextManager[AsyncConnection]: + ) -> _ClientCheckout: server = await self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) - @contextlib.asynccontextmanager - async def _conn_from_server( + def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession] - ) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" - # Get a connection for a server matching the read preference, and yield - # conn with the effective read preference. The Server Selection - # Spec says not to send any $readPreference to standalones and to - # always send primaryPreferred when directly connected to a repl set - # member. - # Thread safe: if the type is single it cannot change. - # NOTE: We already opened the Topology when selecting a server so there's no need - # to call _get_topology() again. - single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single - async with self._checkout(server, session) as conn: - if single: - if conn.is_repl and not (session and session.in_transaction): - # Use primary preferred to ensure any repl set member - # can handle the request. - read_preference = ReadPreference.PRIMARY_PREFERRED - elif conn.is_standalone: - # Don't send read preference to standalones. - read_preference = ReadPreference.PRIMARY - yield conn, read_preference + return _ClientReadCheckout(self, server, session, read_preference) async def _conn_for_reads( self, read_preference: _ServerMode, session: Optional[AsyncClientSession], operation: str, - ) -> AbstractAsyncContextManager[tuple[AsyncConnection, _ServerMode]]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" server = await self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) @@ -1925,8 +1875,12 @@ async def _run_operation( ) async with operation.conn_mgr._lock: - async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] - err_handler.contribute_socket(operation.conn_mgr.conn) + async with _ClientCheckout.for_existing_conn( + self, + server, + operation.session, # type: ignore[arg-type] + operation.conn_mgr.conn, + ): return await server.run_operation( operation.conn_mgr.conn, operation, @@ -2667,10 +2621,18 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong exc_to_check._add_error_label("RetryableWriteError") -class _MongoClientErrorHandler: - """Handle errors raised when executing an operation.""" +class _ClientCheckout: + """Context manager for checking out a connection from the pool. + + Absorbs the former _MongoClientErrorHandler and the @asynccontextmanager + _checkout() method into a single class-based CM to eliminate generator + overhead on the hot path. + """ __slots__ = ( + "_existing_conn", + "_pool_checkout", + "_server", "client", "completed_handshake", "handled", @@ -2704,6 +2666,9 @@ def __init__( self.completed_handshake = False self.service_id: Optional[ObjectId] = None self.handled = False + self._existing_conn: Optional[AsyncConnection] = None + self._pool_checkout: Optional[_PoolCheckout] = None + self._server = server def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" @@ -2741,21 +2706,145 @@ async def handle( assert self.client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) - async def __aenter__(self) -> _MongoClientErrorHandler: - return self + async def __aenter__(self) -> AsyncConnection: + if self._existing_conn is not None: + return self._existing_conn + server = self._server + session = self.session + in_txn = session and session.in_transaction + # Reuse the pinned connection, if it exists. + if in_txn and session and session._pinned_connection: + self.contribute_socket(session._pinned_connection) + return session._pinned_connection + pool_checkout = server.pool.checkout(self) + try: + conn = await pool_checkout.__aenter__() + except BaseException as exc: + # __aenter__ raised — pool already cleaned up internally. + # Run SDAM error handling so the topology learns about the failure. + await self.handle(type(exc), exc) + raise + self._pool_checkout = pool_checkout + try: + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError( + "Auto-encryption requires a minimum MongoDB version of 4.2" + ) + except BaseException as exc: + await self.handle(type(exc), exc) + await pool_checkout.__aexit__(type(exc), exc, None) + self._pool_checkout = None + raise + return conn async def __aexit__( self, - exc_type: Optional[type[Exception]], - exc_val: Optional[Exception], + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - return await self.handle(exc_type, exc_val) + # Perform SDAM error handling while the connection is still checked out. + await self.handle(exc_type, exc_val) + if self._pool_checkout is not None: + await self._pool_checkout.__aexit__(exc_type, exc_val, exc_tb) + + @classmethod + def for_existing_conn( + cls, + client: AsyncMongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[AsyncClientSession], + conn: AsyncConnection, + ) -> _ClientCheckout: + """Return a _ClientCheckout for an already-checked-out connection. + + Used when SDAM error handling is needed around an existing connection + without performing a new pool checkout (e.g. re-running a getMore). + """ + checkout = cls(client, server, session) + checkout.contribute_socket(conn) + checkout._existing_conn = conn + return checkout + + +class _ClientReadCheckout(_ClientCheckout): + """Context manager for read operations. + + Extends _ClientCheckout to apply the single-topology read preference + adjustment (formerly in _conn_from_server()) and return the effective + read preference alongside the connection. + """ + + __slots__ = ("_effective_read_pref",) + + def __init__( + self, + client: AsyncMongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[AsyncClientSession], + read_preference: _ServerMode, + ) -> None: + super().__init__(client, server, session) + self._effective_read_pref: _ServerMode = read_preference + + async def __aenter__(self) -> tuple[AsyncConnection, _ServerMode]: # type: ignore[override] + conn = await super().__aenter__() + # The Server Selection Spec says not to send any $readPreference to + # standalones and to always send primaryPreferred when directly + # connected to a replica set member. + # Thread safe: topology type cannot change once set to Single. + single = self.client._topology.description.topology_type == TOPOLOGY_TYPE.Single + if single: + if conn.is_repl and not (self.session and self.session.in_transaction): + self._effective_read_pref = ReadPreference.PRIMARY_PREFERRED + elif conn.is_standalone: + self._effective_read_pref = ReadPreference.PRIMARY + return conn, self._effective_read_pref class _ClientConnectionRetryable(Generic[T]): """Responsible for executing retryable connections on read or write operations""" + __slots__ = ( + "_address", + "_always_retryable", + "_attempt_number", + "_bulk", + "_client", + "_deprioritized_servers", + "_func", + "_is_aggregate_write", + "_is_read", + "_is_run_command", + "_last_error", + "_max_retries", + "_operation", + "_operation_id", + "_read_pref", + "_retry_policy", + "_retryable", + "_retrying", + "_server", + "_server_selector", + "_session", + ) + def __init__( self, mongo_client: AsyncMongoClient, # type: ignore[type-arg] @@ -2788,7 +2877,7 @@ def __init__( ) self._address = address self._server: Server = None # type: ignore - self._deprioritized_servers: list[Server] = [] + self._deprioritized_servers: Optional[list[Server]] = None self._operation = operation self._operation_id = operation_id self._attempt_number = 0 @@ -2928,6 +3017,8 @@ async def run(self) -> T: self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting) ): + if self._deprioritized_servers is None: + self._deprioritized_servers = [] self._deprioritized_servers.append(self._server) self._always_retryable = always_retryable diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 60acb93fcd..aeda81c12e 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -16,15 +16,13 @@ import asyncio import collections -import contextlib import logging import os import socket import ssl -import sys import time import weakref -from collections.abc import AsyncGenerator, Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -89,11 +87,13 @@ from pymongo.socket_checker import SocketChecker if TYPE_CHECKING: + from types import TracebackType + from bson import CodecOptions from bson.objectid import ObjectId from pymongo.asynchronous.auth import _AuthContext from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.asynchronous.mongo_client import AsyncMongoClient, _ClientCheckout from pymongo.compression_support import ( SnappyContext, ZlibContext, @@ -745,7 +745,7 @@ def __init__( # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[AsyncConnection] = set() + self._pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 @@ -981,7 +981,7 @@ def _handle_connection_error(self, error: BaseException) -> None: error._add_error_label("SystemOverloadedError") error._add_error_label("RetryableError") - async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: + async def connect(self, handler: Optional[_ClientCheckout] = None) -> AsyncConnection: """Connect to Mongo and return a new AsyncConnection. Can raise ConnectionFailure. @@ -1077,84 +1077,23 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A return conn - @contextlib.asynccontextmanager - async def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncGenerator[AsyncConnection, None]: - """Get a connection from the pool. Use with a "with" statement. + def checkout(self, handler: Optional[_ClientCheckout] = None) -> _PoolCheckout: + """Get a connection from the pool. Use with an "async with" statement. - Returns a :class:`AsyncConnection` object wrapping a connected - :class:`socket.socket`. + Returns a :class:`_PoolCheckout` context manager that yields a + :class:`AsyncConnection` object wrapping a connected socket. - This method should always be used in a with-statement:: + This method should always be used in an async-with-statement:: - with pool.get_conn() as connection: + async with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) Can raise ConnectionFailure or OperationFailure. - :param handler: A _MongoClientErrorHandler. + :param handler: A _ClientCheckout error handler. """ - listeners = self.opts._event_listeners - checkout_started_time = time.monotonic() - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_check_out_started(self.address) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_STARTED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - - conn = await self._get_conn(checkout_started_time, handler=handler) - - duration = time.monotonic() - checkout_started_time - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_checked_out(self.address, conn.id, duration) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - driverConnectionId=conn.id, - durationMS=duration, - ) - try: - async with self.lock: - self.active_contexts.add(conn.cancel_context) - yield conn - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: - # Exception in caller. Ensure the connection gets returned. - # Note that when pinned is True, the session owns the - # connection and it is responsible for checking the connection - # back into the pool. - pinned = conn.pinned_txn or conn.pinned_cursor - if handler: - # Perform SDAM error handling rules while the connection is - # still checked out. - exc_type, exc_val, _ = sys.exc_info() - await handler.handle(exc_type, exc_val) - if not pinned and conn.active: - await self.checkin(conn) - raise - if conn.pinned_txn: - async with self.lock: - self.__pinned_sockets.add(conn) - self.ntxns += 1 - elif conn.pinned_cursor: - async with self.lock: - self.__pinned_sockets.add(conn) - self.ncursors += 1 - elif conn.active: - await self.checkin(conn) + return _PoolCheckout(self, handler) def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> None: if self.state != PoolState.READY: @@ -1183,7 +1122,7 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> ) async def _get_conn( - self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None + self, checkout_started_time: float, handler: Optional[_ClientCheckout] = None ) -> AsyncConnection: """Get or create a AsyncConnection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. @@ -1242,6 +1181,7 @@ async def _get_conn( conn = None incremented = False emitted_event = False + is_new_conn = False try: async with self.lock: self.active_sockets += 1 @@ -1273,6 +1213,7 @@ async def _get_conn( else: # We need to create a new connection try: conn = await self.connect(handler=handler) + is_new_conn = True finally: async with self._max_connecting_cond: self._pending -= 1 @@ -1309,6 +1250,11 @@ async def _get_conn( raise conn.active = True + # connect() already adds cancel_context for new connections; only add + # here for reused connections taken from the idle pool. + if not is_new_conn: + async with self.lock: + self.active_contexts.add(conn.cancel_context) return conn async def checkin(self, conn: AsyncConnection) -> None: @@ -1321,7 +1267,7 @@ async def checkin(self, conn: AsyncConnection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) + self._pinned_sockets.discard(conn) listeners = self.opts._event_listeners async with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1463,3 +1409,91 @@ def __del__(self) -> None: if _IS_SYNC: for conn in self.conns: conn.close_conn(None) # type: ignore[unused-coroutine] + + +class _PoolCheckout: + """Class-based context manager for pool connection checkout.""" + + __slots__ = ("_checkout_started_time", "_conn", "_handler", "_pool") + + def __init__( + self, + pool: Pool, + handler: Optional[_ClientCheckout] = None, + ) -> None: + self._pool = pool + self._handler = handler + self._conn: Optional[AsyncConnection] = None + self._checkout_started_time: float = 0.0 + + async def __aenter__(self) -> AsyncConnection: + pool = self._pool + self._checkout_started_time = time.monotonic() + checkout_started_time = self._checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_check_out_started(pool.address) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + ) + + conn = await pool._get_conn(checkout_started_time, handler=self._handler) + self._conn = conn + try: + duration = time.monotonic() - checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_checked_out( + pool.address, conn.id, duration + ) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + driverConnectionId=conn.id, + durationMS=duration, + ) + except BaseException: + await pool.checkin(conn) + self._conn = None + raise + return conn + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + conn = self._conn + if conn is None: + return + pool = self._pool + if exc_type is not None: + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the connection + # and is responsible for checking it back into the pool. + # SDAM error handling is performed by _ClientCheckout.__aexit__ + # before this method is called. + pinned = conn.pinned_txn or conn.pinned_cursor + if not pinned and conn.active: + await pool.checkin(conn) + else: + if conn.pinned_txn: + async with pool.lock: + pool._pinned_sockets.add(conn) + pool.ntxns += 1 + elif conn.pinned_cursor: + async with pool.lock: + pool._pinned_sockets.add(conn) + pool.ncursors += 1 + elif conn.active: + await pool.checkin(conn) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 9a6984f486..1bc40ae9b4 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging -from contextlib import AbstractAsyncContextManager from datetime import datetime from typing import ( TYPE_CHECKING, @@ -42,7 +41,7 @@ from weakref import ReferenceType from bson.objectid import ObjectId - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.monitor import Monitor from pymongo.asynchronous.pool import AsyncConnection, Pool from pymongo.monitoring import _EventListeners @@ -227,11 +226,6 @@ async def run_operation( return response - async def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AbstractAsyncContextManager[AsyncConnection]: - return self.pool.checkout(handler) - @property def description(self) -> ServerDescription: return self._description diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6b7c5d9c98..3eba75a16b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -41,7 +41,6 @@ import weakref from collections import defaultdict from collections.abc import Collection, Generator, Mapping, MutableMapping, Sequence -from contextlib import AbstractContextManager from typing import ( TYPE_CHECKING, Any, @@ -109,6 +108,7 @@ from pymongo.synchronous.helpers import ( _RetryPolicy, ) +from pymongo.synchronous.pool import _PoolCheckout from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -1777,39 +1777,8 @@ def _get_topology(self) -> Topology: self._opened = True return self._topology - @contextlib.contextmanager - def _checkout( - self, server: Server, session: Optional[ClientSession] - ) -> Generator[Connection, None]: - in_txn = session and session.in_transaction - with _MongoClientErrorHandler(self, server, session) as err_handler: - # Reuse the pinned connection, if it exists. - if in_txn and session and session._pinned_connection: - err_handler.contribute_socket(session._pinned_connection) - yield session._pinned_connection - return - with server.checkout(handler=err_handler) as conn: - # Pin this session to the selected server or connection. - if ( - in_txn - and session - and server.description.server_type - in ( - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - ): - session._pin(server, conn) - err_handler.contribute_socket(conn) - if ( - self._encrypter - and not self._encrypter._bypass_auto_encryption - and conn.max_wire_version < 8 - ): - raise ConfigurationError( - "Auto-encryption requires a minimum MongoDB version of 4.2" - ) - yield conn + def _checkout(self, server: Server, session: Optional[ClientSession]) -> _ClientCheckout: + return _ClientCheckout(self, server, session) def _select_server( self, @@ -1858,43 +1827,22 @@ def _select_server( session._unpin() raise - def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AbstractContextManager[Connection]: + def _conn_for_writes(self, session: Optional[ClientSession], operation: str) -> _ClientCheckout: server = self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) - @contextlib.contextmanager def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> Generator[tuple[Connection, _ServerMode], None]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" - # Get a connection for a server matching the read preference, and yield - # conn with the effective read preference. The Server Selection - # Spec says not to send any $readPreference to standalones and to - # always send primaryPreferred when directly connected to a repl set - # member. - # Thread safe: if the type is single it cannot change. - # NOTE: We already opened the Topology when selecting a server so there's no need - # to call _get_topology() again. - single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._checkout(server, session) as conn: - if single: - if conn.is_repl and not (session and session.in_transaction): - # Use primary preferred to ensure any repl set member - # can handle the request. - read_preference = ReadPreference.PRIMARY_PREFERRED - elif conn.is_standalone: - # Don't send read preference to standalones. - read_preference = ReadPreference.PRIMARY - yield conn, read_preference + return _ClientReadCheckout(self, server, session, read_preference) def _conn_for_reads( self, read_preference: _ServerMode, session: Optional[ClientSession], operation: str, - ) -> AbstractContextManager[tuple[Connection, _ServerMode]]: + ) -> _ClientReadCheckout: assert read_preference is not None, "read_preference must not be None" server = self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) @@ -1922,8 +1870,12 @@ def _run_operation( ) with operation.conn_mgr._lock: - with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] - err_handler.contribute_socket(operation.conn_mgr.conn) + with _ClientCheckout.for_existing_conn( + self, + server, + operation.session, # type: ignore[arg-type] + operation.conn_mgr.conn, + ): return server.run_operation( operation.conn_mgr.conn, operation, @@ -2658,10 +2610,18 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong exc_to_check._add_error_label("RetryableWriteError") -class _MongoClientErrorHandler: - """Handle errors raised when executing an operation.""" +class _ClientCheckout: + """Context manager for checking out a connection from the pool. + + Absorbs the former _MongoClientErrorHandler and the @asynccontextmanager + _checkout() method into a single class-based CM to eliminate generator + overhead on the hot path. + """ __slots__ = ( + "_existing_conn", + "_pool_checkout", + "_server", "client", "completed_handshake", "handled", @@ -2695,6 +2655,9 @@ def __init__( self.completed_handshake = False self.service_id: Optional[ObjectId] = None self.handled = False + self._existing_conn: Optional[Connection] = None + self._pool_checkout: Optional[_PoolCheckout] = None + self._server = server def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" @@ -2732,21 +2695,145 @@ def handle( assert self.client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) - def __enter__(self) -> _MongoClientErrorHandler: - return self + def __enter__(self) -> Connection: + if self._existing_conn is not None: + return self._existing_conn + server = self._server + session = self.session + in_txn = session and session.in_transaction + # Reuse the pinned connection, if it exists. + if in_txn and session and session._pinned_connection: + self.contribute_socket(session._pinned_connection) + return session._pinned_connection + pool_checkout = server.pool.checkout(self) + try: + conn = pool_checkout.__enter__() + except BaseException as exc: + # __aenter__ raised — pool already cleaned up internally. + # Run SDAM error handling so the topology learns about the failure. + self.handle(type(exc), exc) + raise + self._pool_checkout = pool_checkout + try: + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + self.contribute_socket(conn) + if ( + self.client._encrypter + and not self.client._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError( + "Auto-encryption requires a minimum MongoDB version of 4.2" + ) + except BaseException as exc: + self.handle(type(exc), exc) + pool_checkout.__exit__(type(exc), exc, None) + self._pool_checkout = None + raise + return conn def __exit__( self, - exc_type: Optional[type[Exception]], - exc_val: Optional[Exception], + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - return self.handle(exc_type, exc_val) + # Perform SDAM error handling while the connection is still checked out. + self.handle(exc_type, exc_val) + if self._pool_checkout is not None: + self._pool_checkout.__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def for_existing_conn( + cls, + client: MongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[ClientSession], + conn: Connection, + ) -> _ClientCheckout: + """Return a _ClientCheckout for an already-checked-out connection. + + Used when SDAM error handling is needed around an existing connection + without performing a new pool checkout (e.g. re-running a getMore). + """ + checkout = cls(client, server, session) + checkout.contribute_socket(conn) + checkout._existing_conn = conn + return checkout + + +class _ClientReadCheckout(_ClientCheckout): + """Context manager for read operations. + + Extends _ClientCheckout to apply the single-topology read preference + adjustment (formerly in _conn_from_server()) and return the effective + read preference alongside the connection. + """ + + __slots__ = ("_effective_read_pref",) + + def __init__( + self, + client: MongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[ClientSession], + read_preference: _ServerMode, + ) -> None: + super().__init__(client, server, session) + self._effective_read_pref: _ServerMode = read_preference + + def __enter__(self) -> tuple[Connection, _ServerMode]: # type: ignore[override] + conn = super().__enter__() + # The Server Selection Spec says not to send any $readPreference to + # standalones and to always send primaryPreferred when directly + # connected to a replica set member. + # Thread safe: topology type cannot change once set to Single. + single = self.client._topology.description.topology_type == TOPOLOGY_TYPE.Single + if single: + if conn.is_repl and not (self.session and self.session.in_transaction): + self._effective_read_pref = ReadPreference.PRIMARY_PREFERRED + elif conn.is_standalone: + self._effective_read_pref = ReadPreference.PRIMARY + return conn, self._effective_read_pref class _ClientConnectionRetryable(Generic[T]): """Responsible for executing retryable connections on read or write operations""" + __slots__ = ( + "_address", + "_always_retryable", + "_attempt_number", + "_bulk", + "_client", + "_deprioritized_servers", + "_func", + "_is_aggregate_write", + "_is_read", + "_is_run_command", + "_last_error", + "_max_retries", + "_operation", + "_operation_id", + "_read_pref", + "_retry_policy", + "_retryable", + "_retrying", + "_server", + "_server_selector", + "_session", + ) + def __init__( self, mongo_client: MongoClient, # type: ignore[type-arg] @@ -2779,7 +2866,7 @@ def __init__( ) self._address = address self._server: Server = None # type: ignore - self._deprioritized_servers: list[Server] = [] + self._deprioritized_servers: Optional[list[Server]] = None self._operation = operation self._operation_id = operation_id self._attempt_number = 0 @@ -2919,6 +3006,8 @@ def run(self) -> T: self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting) ): + if self._deprioritized_servers is None: + self._deprioritized_servers = [] self._deprioritized_servers.append(self._server) self._always_retryable = always_retryable diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index b3929b674a..84d09742a1 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -16,15 +16,13 @@ import asyncio import collections -import contextlib import logging import os import socket import ssl -import sys import time import weakref -from collections.abc import Generator, Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -89,6 +87,8 @@ from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: + from types import TracebackType + from bson import CodecOptions from bson.objectid import ObjectId from pymongo.compression_support import ( @@ -101,7 +101,7 @@ from pymongo.read_preferences import _ServerMode from pymongo.synchronous.auth import _AuthContext from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.synchronous.mongo_client import MongoClient, _ClientCheckout from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern @@ -743,7 +743,7 @@ def __init__( # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[Connection] = set() + self._pinned_sockets: set[Connection] = set() self.ncursors = 0 self.ntxns = 0 @@ -977,7 +977,7 @@ def _handle_connection_error(self, error: BaseException) -> None: error._add_error_label("SystemOverloadedError") error._add_error_label("RetryableError") - def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: + def connect(self, handler: Optional[_ClientCheckout] = None) -> Connection: """Connect to Mongo and return a new Connection. Can raise ConnectionFailure. @@ -1073,84 +1073,23 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect return conn - @contextlib.contextmanager - def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> Generator[Connection, None]: - """Get a connection from the pool. Use with a "with" statement. + def checkout(self, handler: Optional[_ClientCheckout] = None) -> _PoolCheckout: + """Get a connection from the pool. Use with an "with" statement. - Returns a :class:`Connection` object wrapping a connected - :class:`socket.socket`. + Returns a :class:`_PoolCheckout` context manager that yields a + :class:`Connection` object wrapping a connected socket. - This method should always be used in a with-statement:: + This method should always be used in an async-with-statement:: - with pool.get_conn() as connection: + with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) Can raise ConnectionFailure or OperationFailure. - :param handler: A _MongoClientErrorHandler. + :param handler: A _ClientCheckout error handler. """ - listeners = self.opts._event_listeners - checkout_started_time = time.monotonic() - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_check_out_started(self.address) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_STARTED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - - conn = self._get_conn(checkout_started_time, handler=handler) - - duration = time.monotonic() - checkout_started_time - if self.enabled_for_cmap: - assert listeners is not None - listeners.publish_connection_checked_out(self.address, conn.id, duration) - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - driverConnectionId=conn.id, - durationMS=duration, - ) - try: - with self.lock: - self.active_contexts.add(conn.cancel_context) - yield conn - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: - # Exception in caller. Ensure the connection gets returned. - # Note that when pinned is True, the session owns the - # connection and it is responsible for checking the connection - # back into the pool. - pinned = conn.pinned_txn or conn.pinned_cursor - if handler: - # Perform SDAM error handling rules while the connection is - # still checked out. - exc_type, exc_val, _ = sys.exc_info() - handler.handle(exc_type, exc_val) - if not pinned and conn.active: - self.checkin(conn) - raise - if conn.pinned_txn: - with self.lock: - self.__pinned_sockets.add(conn) - self.ntxns += 1 - elif conn.pinned_cursor: - with self.lock: - self.__pinned_sockets.add(conn) - self.ncursors += 1 - elif conn.active: - self.checkin(conn) + return _PoolCheckout(self, handler) def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> None: if self.state != PoolState.READY: @@ -1179,7 +1118,7 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> ) def _get_conn( - self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None + self, checkout_started_time: float, handler: Optional[_ClientCheckout] = None ) -> Connection: """Get or create a Connection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. @@ -1238,6 +1177,7 @@ def _get_conn( conn = None incremented = False emitted_event = False + is_new_conn = False try: with self.lock: self.active_sockets += 1 @@ -1269,6 +1209,7 @@ def _get_conn( else: # We need to create a new connection try: conn = self.connect(handler=handler) + is_new_conn = True finally: with self._max_connecting_cond: self._pending -= 1 @@ -1305,6 +1246,11 @@ def _get_conn( raise conn.active = True + # connect() already adds cancel_context for new connections; only add + # here for reused connections taken from the idle pool. + if not is_new_conn: + with self.lock: + self.active_contexts.add(conn.cancel_context) return conn def checkin(self, conn: Connection) -> None: @@ -1317,7 +1263,7 @@ def checkin(self, conn: Connection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) + self._pinned_sockets.discard(conn) listeners = self.opts._event_listeners with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1459,3 +1405,91 @@ def __del__(self) -> None: if _IS_SYNC: for conn in self.conns: conn.close_conn(None) # type: ignore[unused-coroutine] + + +class _PoolCheckout: + """Class-based context manager for pool connection checkout.""" + + __slots__ = ("_checkout_started_time", "_conn", "_handler", "_pool") + + def __init__( + self, + pool: Pool, + handler: Optional[_ClientCheckout] = None, + ) -> None: + self._pool = pool + self._handler = handler + self._conn: Optional[Connection] = None + self._checkout_started_time: float = 0.0 + + def __enter__(self) -> Connection: + pool = self._pool + self._checkout_started_time = time.monotonic() + checkout_started_time = self._checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_check_out_started(pool.address) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + ) + + conn = pool._get_conn(checkout_started_time, handler=self._handler) + self._conn = conn + try: + duration = time.monotonic() - checkout_started_time + if pool.enabled_for_cmap: + assert pool.opts._event_listeners is not None + pool.opts._event_listeners.publish_connection_checked_out( + pool.address, conn.id, duration + ) + if pool.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=pool._client_id, + serverHost=pool.address[0], + serverPort=pool.address[1], + driverConnectionId=conn.id, + durationMS=duration, + ) + except BaseException: + pool.checkin(conn) + self._conn = None + raise + return conn + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + conn = self._conn + if conn is None: + return + pool = self._pool + if exc_type is not None: + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the connection + # and is responsible for checking it back into the pool. + # SDAM error handling is performed by _ClientCheckout.__aexit__ + # before this method is called. + pinned = conn.pinned_txn or conn.pinned_cursor + if not pinned and conn.active: + pool.checkin(conn) + else: + if conn.pinned_txn: + with pool.lock: + pool._pinned_sockets.add(conn) + pool.ntxns += 1 + elif conn.pinned_cursor: + with pool.lock: + pool._pinned_sockets.add(conn) + pool.ncursors += 1 + elif conn.active: + pool.checkin(conn) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 7aa017134a..362c80cf10 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging -from contextlib import AbstractContextManager from datetime import datetime from typing import ( TYPE_CHECKING, @@ -45,7 +44,7 @@ from pymongo.monitoring import _EventListeners from pymongo.read_preferences import _ServerMode from pymongo.server_description import ServerDescription - from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import Connection, Pool from pymongo.typings import _DocumentOut @@ -227,11 +226,6 @@ def run_operation( return response - def checkout( - self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AbstractContextManager[Connection]: - return self.pool.checkout(handler) - @property def description(self) -> ServerDescription: return self._description diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index 4413cbe43b..e71be6a502 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -43,7 +43,7 @@ def __init__(self, client, pair, *args, **kwargs): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.asynccontextmanager - async def checkout(self, handler=None): + async def checkout(self, handler=None): # type: ignore[override] client = self.client host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5da186931a..44206788ac 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -818,6 +818,31 @@ async def test_max_idle_time_checkout(self): self.assertEqual(conn, new_con) self.assertEqual(1, len(server._pool.conns)) + async def test_client_checkout_setup_failure_returns_connection(self): + # Verify that the connection is returned to the pool when an exception + # is raised during _ClientCheckout.__aenter__ post-checkout setup + # (e.g. session pinning or the auto-encryption wire-version check). + # Use a subclass to override contribute_socket because __slots__ prevents + # instance-level patching of methods. + from pymongo.asynchronous.mongo_client import _ClientCheckout + + class _BrokenSetupCheckout(_ClientCheckout): + def contribute_socket(self, conn, completed_handshake=True): + raise RuntimeError("simulated failure in post-checkout setup") + + client = await self.async_rs_or_single_client() + server = await (await client._get_topology()).select_server( + writable_server_selector, _Op.TEST + ) + pool = server.pool + + with self.assertRaises(RuntimeError): + async with _BrokenSetupCheckout(client, server, None): + pass + + # Connection was returned to pool, not leaked. + self.assertEqual(0, pool.active_sockets) + async def test_constants(self): """This test uses AsyncMongoClient explicitly to make sure that host and port are not overloaded. diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 96b603ec10..398c2f5ec0 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -215,6 +215,34 @@ async def test_get_socket_and_exception(self): self.assertEqual(1, len(cx_pool.conns)) + async def test_checkout_event_listener_failure_no_leak(self): + # Connection is returned to the pool when publish_connection_checked_out raises. + from unittest.mock import patch + + from pymongo.monitoring import _EventListeners + from test.utils_shared import CMAPListener + + cx_pool = await self.create_pool( + max_pool_size=1, event_listeners=_EventListeners([CMAPListener()]) + ) + + with patch.object( + cx_pool.opts._event_listeners, + "publish_connection_checked_out", + side_effect=RuntimeError("simulated failure"), + ): + with self.assertRaises(RuntimeError): + async with cx_pool.checkout(): + pass + + # Connection was returned to the pool — not leaked. + self.assertEqual(1, len(cx_pool.conns)) + self.assertEqual(0, cx_pool.active_sockets) + + # Pool is still functional. + async with cx_pool.checkout(): + pass + async def test_pool_removes_closed_socket(self): # Test that Pool removes explicitly closed socket. cx_pool = await self.create_pool() diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py index 9f92a39920..7801926550 100644 --- a/test/asynchronous/test_read_preferences.py +++ b/test/asynchronous/test_read_preferences.py @@ -334,7 +334,7 @@ async def _conn_for_reads(self, read_preference, session, operation): return context @contextlib.asynccontextmanager - async def _conn_from_server(self, read_preference, server, session): + async def _conn_from_server(self, read_preference, server, session): # type: ignore[override] context = super()._conn_from_server(read_preference, server, session) async with context as (conn, read_preference): await self.record_a_read(conn.address) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 206c1a61b7..a3092bffe9 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -42,7 +42,7 @@ def __init__(self, client, pair, *args, **kwargs): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def checkout(self, handler=None): + def checkout(self, handler=None): # type: ignore[override] client = self.client host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: diff --git a/test/test_client.py b/test/test_client.py index b37b5e57ac..b6e2191c5f 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -793,6 +793,29 @@ def test_max_idle_time_checkout(self): self.assertEqual(conn, new_con) self.assertEqual(1, len(server._pool.conns)) + def test_client_checkout_setup_failure_returns_connection(self): + # Verify that the connection is returned to the pool when an exception + # is raised during _ClientCheckout.__aenter__ post-checkout setup + # (e.g. session pinning or the auto-encryption wire-version check). + # Use a subclass to override contribute_socket because __slots__ prevents + # instance-level patching of methods. + from pymongo.synchronous.mongo_client import _ClientCheckout + + class _BrokenSetupCheckout(_ClientCheckout): + def contribute_socket(self, conn, completed_handshake=True): + raise RuntimeError("simulated failure in post-checkout setup") + + client = self.rs_or_single_client() + server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST) + pool = server.pool + + with self.assertRaises(RuntimeError): + with _BrokenSetupCheckout(client, server, None): + pass + + # Connection was returned to pool, not leaked. + self.assertEqual(0, pool.active_sockets) + def test_constants(self): """This test uses MongoClient explicitly to make sure that host and port are not overloaded. diff --git a/test/test_pooling.py b/test/test_pooling.py index 47266dd166..529ab7e82b 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -215,6 +215,34 @@ def test_get_socket_and_exception(self): self.assertEqual(1, len(cx_pool.conns)) + def test_checkout_event_listener_failure_no_leak(self): + # Connection is returned to the pool when publish_connection_checked_out raises. + from unittest.mock import patch + + from pymongo.monitoring import _EventListeners + from test.utils_shared import CMAPListener + + cx_pool = self.create_pool( + max_pool_size=1, event_listeners=_EventListeners([CMAPListener()]) + ) + + with patch.object( + cx_pool.opts._event_listeners, + "publish_connection_checked_out", + side_effect=RuntimeError("simulated failure"), + ): + with self.assertRaises(RuntimeError): + with cx_pool.checkout(): + pass + + # Connection was returned to the pool — not leaked. + self.assertEqual(1, len(cx_pool.conns)) + self.assertEqual(0, cx_pool.active_sockets) + + # Pool is still functional. + with cx_pool.checkout(): + pass + def test_pool_removes_closed_socket(self): # Test that Pool removes explicitly closed socket. cx_pool = self.create_pool() diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 27c8d0704a..65c25735f5 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -314,7 +314,7 @@ def _conn_for_reads(self, read_preference, session, operation): return context @contextlib.contextmanager - def _conn_from_server(self, read_preference, server, session): + def _conn_from_server(self, read_preference, server, session): # type: ignore[override] context = super()._conn_from_server(read_preference, server, session) with context as (conn, read_preference): self.record_a_read(conn.address)