Skip to content
Draft
3 changes: 3 additions & 0 deletions .evergreen/resync-specs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ do
change-streams|change_streams)
cpjson change-streams/tests/ change_streams/
;;
client-backpressure|client_backpressure)
cpjson client-backpressure/tests client-backpressure
;;
client-side-encryption|csfle|fle)
cpjson client-side-encryption/tests/ client-side-encryption/spec
cpjson client-side-encryption/corpus/ client-side-encryption/corpus
Expand Down
9 changes: 4 additions & 5 deletions .evergreen/scripts/setup_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import base64
import io
import os
import platform
import shutil
import stat
import tarfile
from pathlib import Path
from urllib import request

Expand Down Expand Up @@ -117,9 +115,10 @@ def setup_libmongocrypt():
LOGGER.info(f"Fetching {url}...")
with request.urlopen(request.Request(url), timeout=15.0) as response: # noqa: S310
if response.status == 200:
fileobj = io.BytesIO(response.read())
with tarfile.open("libmongocrypt.tar.gz", fileobj=fileobj) as fid:
fid.extractall(Path.cwd() / "libmongocrypt")
with Path("libmongocrypt.tar.gz").open("wb") as f:
f.write(response.read())
Path("libmongocrypt").mkdir()
run_command("tar -xzf libmongocrypt.tar.gz -C libmongocrypt")
LOGGER.info(f"Fetching {url}... done.")

run_command("ls -la libmongocrypt")
Expand Down
16 changes: 9 additions & 7 deletions justfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# See https://just.systems/man/en/ for instructions
set shell := ["bash", "-c"]
# Do not modify the lock file when running justfile commands.
export UV_FROZEN := "1"

# Commonly used command segments.
typing_run := "uv run --group typing --extra aws --extra encryption --extra ocsp --extra snappy --extra test --extra zstd"
Expand All @@ -16,7 +14,7 @@ default:

[private]
resync:
@uv sync --quiet --frozen
@uv sync --quiet

install:
bash .evergreen/scripts/setup-dev-env.sh
Expand Down Expand Up @@ -50,12 +48,12 @@ typing-pyright: && resync
{{typing_run}} pyright -p strict_pyrightconfig.json test/test_typing_strict.py

[group('lint')]
lint: && resync
uv run pre-commit run --all-files
lint *args="": && resync
uvx pre-commit run --all-files {{args}}

[group('lint')]
lint-manual: && resync
uv run pre-commit run --all-files --hook-stage manual
lint-manual *args="": && resync
uvx pre-commit run --all-files --hook-stage manual {{args}}

[group('test')]
test *args="-v --durations=5 --maxfail=10": && resync
Expand All @@ -73,6 +71,10 @@ setup-tests *args="":
teardown-tests:
bash .evergreen/scripts/teardown-tests.sh

[group('test')]
integration-tests:
bash integration_tests/run.sh

[group('server')]
run-server *args="":
bash .evergreen/scripts/run-server.sh {{args}}
Expand Down
14 changes: 13 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,21 @@ async def _execute_command(
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))

retryable_label_error = (
hasattr(error, "details")
and isinstance(error.details, dict)
and "errorLabels" in error.details
and isinstance(error.details["errorLabels"], list)
and "RetryableError" in error.details["errorLabels"]
)

# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)
Expand Down
86 changes: 57 additions & 29 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Coroutine,
Generic,
Expand Down Expand Up @@ -58,7 +57,6 @@
AsyncCursor,
AsyncRawBatchCursor,
)
from pymongo.asynchronous.helpers import _retry_overload
from pymongo.collation import validate_collation_or_none
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.errors import (
Expand Down Expand Up @@ -573,11 +571,6 @@ async def watch(
await change_stream._initialize_cursor()
return change_stream

async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)

async def _command(
self,
conn: AsyncConnection,
Expand Down Expand Up @@ -654,7 +647,10 @@ async def _create_helper(
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn:

async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
Expand All @@ -671,6 +667,8 @@ async def _create_helper(
session=session,
)

await self.database.client._retryable_write(False, inner, session, _Op.CREATE)

async def _create(
self,
options: MutableMapping[str, Any],
Expand Down Expand Up @@ -2229,7 +2227,6 @@ async def create_indexes(
return await self._create_indexes(indexes, session, **kwargs)

@_csot.apply
@_retry_overload
async def _create_indexes(
self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any
) -> list[str]:
Expand All @@ -2243,7 +2240,10 @@ async def _create_indexes(
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:

async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9

def gen_indexes() -> Iterator[Mapping[str, Any]]:
Expand Down Expand Up @@ -2272,7 +2272,11 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]:
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names

return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_INDEXES
)

async def create_index(
self,
Expand Down Expand Up @@ -2474,7 +2478,6 @@ async def drop_index(
await self._drop_index(index_or_name, session, comment, **kwargs)

@_csot.apply
@_retry_overload
async def _drop_index(
self,
index_or_name: _IndexKeyHint,
Expand All @@ -2493,7 +2496,10 @@ async def _drop_index(
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:

async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
Expand All @@ -2503,6 +2509,8 @@ async def _drop_index(
session=session,
)

await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)

async def list_indexes(
self,
session: Optional[AsyncClientSession] = None,
Expand Down Expand Up @@ -2766,17 +2774,22 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]:
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)

async with await self._conn_for_writes(
session, operation=_Op.CREATE_SEARCH_INDEXES
) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
resp = await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]

return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)

async def drop_search_index(
self,
name: str,
Expand All @@ -2802,15 +2815,21 @@ async def drop_search_index(
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:

async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)

await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)

async def update_search_index(
self,
name: str,
Expand Down Expand Up @@ -2838,15 +2857,21 @@ async def update_search_index(
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:

async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)

await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)

async def options(
self,
session: Optional[AsyncClientSession] = None,
Expand Down Expand Up @@ -3075,7 +3100,6 @@ async def aggregate_raw_batches(
)

@_csot.apply
@_retry_overload
async def rename(
self,
new_name: str,
Expand Down Expand Up @@ -3127,17 +3151,21 @@ async def rename(
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client

async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn:
async with self._database.client._tmp_session(session) as s:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)

return await client._retryable_write(False, inner, session, _Op.RENAME)

async def distinct(
self,
Expand Down
Loading
Loading