Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymongo/asynchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def _send_message(self, operation: _GetMore) -> None:
client = self._collection.database.client
try:
response = await client._run_operation(
operation, self._unpack_response, address=self._address
operation, self._run_with_conn, address=self._address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:

try:
response = await client._run_operation(
operation, self._unpack_response, address=self._address
operation, self._run_with_conn, address=self._address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS or self._exhaust:
Expand Down
119 changes: 117 additions & 2 deletions pymongo/asynchronous/cursor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,57 @@

from __future__ import annotations

import datetime
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union

from pymongo import _csot
from pymongo.asynchronous.command_runner import run_cursor_command
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.cursor_shared import _AgnosticCursorBase
from pymongo.lock import _async_create_lock
from pymongo.typings import _DocumentType
from pymongo.message import _GetMore, _OpMsg, _Query
from pymongo.response import PinnedResponse, Response
from pymongo.typings import _DocumentOut, _DocumentType

if TYPE_CHECKING:
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.read_preferences import _ServerMode

_IS_SYNC = False

_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}


def _split_message(
message: Union[tuple[int, Any], tuple[int, Any, int]],
) -> tuple[int, Any, int]:
"""Return request_id, data, max_doc_size.

:param message: (request_id, data, max_doc_size) or (request_id, data)
"""
if len(message) == 3:
return message # type: ignore[return-value]
# get_more and kill_cursors messages don't include BSON documents.
request_id, data = message # type: ignore[misc]
return request_id, data, 0


async def _operation_to_command(
operation: Union[_Query, _GetMore],
conn: AsyncConnection,
use_cmd: bool,
) -> tuple[dict[str, Any], str]:
cmd, db = operation.as_command(conn, use_cmd)
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
operation.db, cmd, operation.codec_options
)
operation.update_command(cmd)
return cmd, db


class _ConnectionManager:
"""Used with exhaust cursors to ensure the connection is returned."""
Expand Down Expand Up @@ -66,6 +103,84 @@ def session(self) -> Optional[AsyncClientSession]:
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
...

@abstractmethod
def _unpack_response(
self,
response: _OpMsg,
cursor_id: Optional[int],
codec_options: Any,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]: ...

@_handle_reauth
async def _run_with_conn(
self,
conn: AsyncConnection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
) -> Response:
"""Execute a cursor operation on the given connection and return a Response."""
client = self._collection.database.client
use_cmd = operation.use_command(conn)
more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come)
cmd, dbn = await _operation_to_command(operation, conn, use_cmd)
if more_to_come:
request_id, data, max_doc_size = 0, b"", 0
else:
message = operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = _split_message(message)
user_fields = _CURSOR_DOC_FIELDS if use_cmd else None
docs, reply, duration = await run_cursor_command(
conn,
cmd,
dbn,
request_id,
data,
client=client,
session=operation.session, # type: ignore[arg-type]
listeners=client._event_listeners,
address=conn.address,
start=datetime.datetime.now(),
codec_options=operation.codec_options,
user_fields=user_fields,
command_name=operation.name,
pool_opts=conn.opts,
max_doc_size=max_doc_size,
more_to_come=more_to_come,
unpack_res=self._unpack_response,
cursor_id=operation.cursor_id,
)
assert reply is not None
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the more_to_come flag is set.
more_to_come = reply.more_to_come
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.conn_mgr:
operation.conn_mgr.update_exhaust(more_to_come)
return PinnedResponse(
data=reply,
address=conn.address,
conn=conn,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs, # type: ignore[arg-type]
more_to_come=more_to_come,
)
return Response(
data=reply,
address=conn.address,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs, # type: ignore[arg-type]
)

async def _die_lock(self) -> None:
"""Closes this cursor."""
try:
Expand Down
25 changes: 7 additions & 18 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,13 +1906,14 @@ async def _conn_for_reads(
async def _run_operation(
self,
operation: Union[_Query, _GetMore],
unpack_res: Callable, # type: ignore[type-arg]
execute_fn: Callable, # type: ignore[type-arg]
address: Optional[_Address] = None,
) -> Response:
"""Run a _Query/_GetMore operation and return a Response.

:param operation: a _Query or _GetMore object.
:param unpack_res: A callable that decodes the wire protocol response.
:param execute_fn: A callable ``(conn, operation, read_preference) -> Response``
that executes the operation on a given connection.
:param address: Optional address when sending a message
to a specific server, used for getMore.
"""
Expand All @@ -1927,30 +1928,18 @@ async def _run_operation(
async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation(
operation.conn_mgr.conn,
operation,
operation.read_preference,
self._event_listeners,
unpack_res,
self,
return await execute_fn(
operation.conn_mgr.conn, operation, operation.read_preference
)

async def _cmd(
_session: Optional[AsyncClientSession],
server: Server,
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> Response:
operation.reset() # Reset op in case of retry.
return await server.run_operation(
conn,
operation,
read_preference,
self._event_listeners,
unpack_res,
self,
)
return await execute_fn(conn, operation, read_preference)

return await self._retryable_read(
_cmd,
Expand Down
136 changes: 1 addition & 135 deletions pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,31 @@

import logging
from contextlib import AbstractAsyncContextManager
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)

from pymongo.asynchronous.command_runner import run_cursor_command
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.logger import (
_SDAM_LOGGER,
_debug_log,
_SDAMStatusMessage,
)
from pymongo.message import _GetMore, _OpMsg, _Query
from pymongo.response import PinnedResponse, Response

if TYPE_CHECKING:
from queue import Queue
from weakref import ReferenceType

from bson.objectid import ObjectId
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.asynchronous.mongo_client import _MongoClientErrorHandler
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.pool import AsyncConnection, Pool
from pymongo.monitoring import _EventListeners
from pymongo.read_preferences import _ServerMode
from pymongo.server_description import ServerDescription
from pymongo.typings import _DocumentOut

_IS_SYNC = False

_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}


class Server:
def __init__(
Expand Down Expand Up @@ -118,115 +107,6 @@ def request_check(self) -> None:
"""Check the server's state soon."""
self._monitor.request_check()

async def operation_to_command(
self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False
) -> tuple[dict[str, Any], str]:
cmd, db = operation.as_command(conn, apply_timeout)
# Support auto encryption
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
operation.db, cmd, operation.codec_options
)
operation.update_command(cmd)

return cmd, db

@_handle_reauth
async def run_operation(
self,
conn: AsyncConnection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., list[_DocumentOut]],
client: AsyncMongoClient[Any],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.

This method is used only to run _Query/_GetMore operations from
cursors.
Can raise ConnectionFailure, OperationFailure, etc.

:param conn: An AsyncConnection instance.
:param operation: A _Query or _GetMore object.
:param read_preference: The read preference to use.
:param listeners: Instance of _EventListeners or None.
:param unpack_res: A callable that decodes the wire protocol response.
:param client: An AsyncMongoClient instance.
"""
assert listeners is not None
start = datetime.now()

use_cmd = operation.use_command(conn)
more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come)
cmd, dbn = await self.operation_to_command(operation, conn, use_cmd)
if more_to_come:
request_id = 0
data = b""
max_doc_size = 0
else:
message = operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = self._split_message(message)

user_fields = _CURSOR_DOC_FIELDS if use_cmd else None

docs, reply, duration = await run_cursor_command(
conn,
cmd,
dbn,
request_id,
data,
client=client,
session=operation.session, # type: ignore[arg-type]
listeners=listeners,
address=conn.address,
start=start,
codec_options=operation.codec_options,
user_fields=user_fields,
command_name=operation.name,
pool_opts=conn.opts,
max_doc_size=max_doc_size,
more_to_come=more_to_come,
unpack_res=unpack_res,
cursor_id=operation.cursor_id,
)
assert reply is not None

response: Response

if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
# more_to_come flag is set.
more_to_come = reply.more_to_come
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.conn_mgr:
operation.conn_mgr.update_exhaust(more_to_come)
response = PinnedResponse(
data=reply,
address=self._description.address,
conn=conn,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs, # type: ignore[arg-type]
more_to_come=more_to_come,
)
else:
response = Response(
data=reply,
address=self._description.address,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs, # type: ignore[arg-type]
)

return response

async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> AbstractAsyncContextManager[AsyncConnection]:
Expand All @@ -245,19 +125,5 @@ def description(self, server_description: ServerDescription) -> None:
def pool(self) -> Pool:
return self._pool

def _split_message(
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
) -> tuple[int, Any, int]:
"""Return request_id, data, max_doc_size.

:param message: (request_id, data, max_doc_size) or (request_id, data)
"""
if len(message) == 3:
return message # type: ignore[return-value]
else:
# get_more and kill_cursors messages don't include BSON documents.
request_id, data = message # type: ignore[misc]
return request_id, data, 0

def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._description!r}>"
Loading
Loading