2020from typing import (
2121 TYPE_CHECKING ,
2222 Any ,
23- AsyncContextManager ,
2423 Callable ,
2524 Coroutine ,
2625 Generic ,
5857 AsyncCursor ,
5958 AsyncRawBatchCursor ,
6059)
61- from pymongo .asynchronous .helpers import _retry_overload
6260from pymongo .collation import validate_collation_or_none
6361from pymongo .common import _ecoc_coll_name , _esc_coll_name
6462from pymongo .errors import (
@@ -573,11 +571,6 @@ async def watch(
573571 await change_stream ._initialize_cursor ()
574572 return change_stream
575573
576- async def _conn_for_writes (
577- self , session : Optional [AsyncClientSession ], operation : str
578- ) -> AsyncContextManager [AsyncConnection ]:
579- return await self ._database .client ._conn_for_writes (session , operation )
580-
581574 async def _command (
582575 self ,
583576 conn : AsyncConnection ,
@@ -654,7 +647,10 @@ async def _create_helper(
654647 if "size" in options :
655648 options ["size" ] = float (options ["size" ])
656649 cmd .update (options )
657- async with await self ._conn_for_writes (session , operation = _Op .CREATE ) as conn :
650+
651+ async def inner (
652+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
653+ ) -> None :
658654 if qev2_required and conn .max_wire_version < 21 :
659655 raise ConfigurationError (
660656 "Driver support of Queryable Encryption is incompatible with server. "
@@ -671,6 +667,8 @@ async def _create_helper(
671667 session = session ,
672668 )
673669
670+ await self .database .client ._retryable_write (False , inner , session , _Op .CREATE )
671+
674672 async def _create (
675673 self ,
676674 options : MutableMapping [str , Any ],
@@ -2229,7 +2227,6 @@ async def create_indexes(
22292227 return await self ._create_indexes (indexes , session , ** kwargs )
22302228
22312229 @_csot .apply
2232- @_retry_overload
22332230 async def _create_indexes (
22342231 self , indexes : Sequence [IndexModel ], session : Optional [AsyncClientSession ], ** kwargs : Any
22352232 ) -> list [str ]:
@@ -2243,7 +2240,10 @@ async def _create_indexes(
22432240 command (like maxTimeMS) can be passed as keyword arguments.
22442241 """
22452242 names = []
2246- async with await self ._conn_for_writes (session , operation = _Op .CREATE_INDEXES ) as conn :
2243+
2244+ async def inner (
2245+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
2246+ ) -> list [str ]:
22472247 supports_quorum = conn .max_wire_version >= 9
22482248
22492249 def gen_indexes () -> Iterator [Mapping [str , Any ]]:
@@ -2272,7 +2272,11 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]:
22722272 write_concern = self ._write_concern_for (session ),
22732273 session = session ,
22742274 )
2275- return names
2275+ return names
2276+
2277+ return await self .database .client ._retryable_write (
2278+ False , inner , session , _Op .CREATE_INDEXES
2279+ )
22762280
22772281 async def create_index (
22782282 self ,
@@ -2474,7 +2478,6 @@ async def drop_index(
24742478 await self ._drop_index (index_or_name , session , comment , ** kwargs )
24752479
24762480 @_csot .apply
2477- @_retry_overload
24782481 async def _drop_index (
24792482 self ,
24802483 index_or_name : _IndexKeyHint ,
@@ -2493,7 +2496,10 @@ async def _drop_index(
24932496 cmd .update (kwargs )
24942497 if comment is not None :
24952498 cmd ["comment" ] = comment
2496- async with await self ._conn_for_writes (session , operation = _Op .DROP_INDEXES ) as conn :
2499+
2500+ async def inner (
2501+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
2502+ ) -> None :
24972503 await self ._command (
24982504 conn ,
24992505 cmd ,
@@ -2503,6 +2509,8 @@ async def _drop_index(
25032509 session = session ,
25042510 )
25052511
2512+ await self .database .client ._retryable_write (False , inner , session , _Op .DROP_INDEXES )
2513+
25062514 async def list_indexes (
25072515 self ,
25082516 session : Optional [AsyncClientSession ] = None ,
@@ -2766,17 +2774,22 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]:
27662774 cmd = {"createSearchIndexes" : self .name , "indexes" : list (gen_indexes ())}
27672775 cmd .update (kwargs )
27682776
2769- async with await self . _conn_for_writes (
2770- session , operation = _Op . CREATE_SEARCH_INDEXES
2771- ) as conn :
2777+ async def inner (
2778+ session : Optional [ AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
2779+ ) -> list [ str ] :
27722780 resp = await self ._command (
27732781 conn ,
27742782 cmd ,
27752783 read_preference = ReadPreference .PRIMARY ,
27762784 codec_options = _UNICODE_REPLACE_CODEC_OPTIONS ,
2785+ session = session ,
27772786 )
27782787 return [index ["name" ] for index in resp ["indexesCreated" ]]
27792788
2789+ return self .database .client ._retryable_write (
2790+ False , inner , session , _Op .CREATE_SEARCH_INDEXES
2791+ )
2792+
27802793 async def drop_search_index (
27812794 self ,
27822795 name : str ,
@@ -2802,15 +2815,21 @@ async def drop_search_index(
28022815 cmd .update (kwargs )
28032816 if comment is not None :
28042817 cmd ["comment" ] = comment
2805- async with await self ._conn_for_writes (session , operation = _Op .DROP_SEARCH_INDEXES ) as conn :
2818+
2819+ async def inner (
2820+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
2821+ ) -> None :
28062822 await self ._command (
28072823 conn ,
28082824 cmd ,
28092825 read_preference = ReadPreference .PRIMARY ,
28102826 allowable_errors = ["ns not found" , 26 ],
28112827 codec_options = _UNICODE_REPLACE_CODEC_OPTIONS ,
2828+ session = session ,
28122829 )
28132830
2831+ return self .database .client ._retryable_write (False , inner , session , _Op .DROP_SEARCH_INDEXES )
2832+
28142833 async def update_search_index (
28152834 self ,
28162835 name : str ,
@@ -2838,15 +2857,21 @@ async def update_search_index(
28382857 cmd .update (kwargs )
28392858 if comment is not None :
28402859 cmd ["comment" ] = comment
2841- async with await self ._conn_for_writes (session , operation = _Op .UPDATE_SEARCH_INDEX ) as conn :
2860+
2861+ async def inner (
2862+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
2863+ ) -> None :
28422864 await self ._command (
28432865 conn ,
28442866 cmd ,
28452867 read_preference = ReadPreference .PRIMARY ,
28462868 allowable_errors = ["ns not found" , 26 ],
28472869 codec_options = _UNICODE_REPLACE_CODEC_OPTIONS ,
2870+ session = session ,
28482871 )
28492872
2873+ return self .database .client ._retryable_write (False , inner , session , _Op .UPDATE_SEARCH_INDEX )
2874+
28502875 async def options (
28512876 self ,
28522877 session : Optional [AsyncClientSession ] = None ,
@@ -3075,7 +3100,6 @@ async def aggregate_raw_batches(
30753100 )
30763101
30773102 @_csot .apply
3078- @_retry_overload
30793103 async def rename (
30803104 self ,
30813105 new_name : str ,
@@ -3127,17 +3151,21 @@ async def rename(
31273151 if comment is not None :
31283152 cmd ["comment" ] = comment
31293153 write_concern = self ._write_concern_for_cmd (cmd , session )
3154+ client = self ._database .client
31303155
3131- async with await self ._conn_for_writes (session , operation = _Op .RENAME ) as conn :
3132- async with self ._database .client ._tmp_session (session ) as s :
3133- return await conn .command (
3134- "admin" ,
3135- cmd ,
3136- write_concern = write_concern ,
3137- parse_write_concern_error = True ,
3138- session = s ,
3139- client = self ._database .client ,
3140- )
3156+ async def inner (
3157+ session : Optional [AsyncClientSession ], conn : AsyncConnection , _retryable_write : bool
3158+ ) -> MutableMapping [str , Any ]:
3159+ return await conn .command (
3160+ "admin" ,
3161+ cmd ,
3162+ write_concern = write_concern ,
3163+ parse_write_concern_error = True ,
3164+ session = session ,
3165+ client = client ,
3166+ )
3167+
3168+ return client ._retryable_write (False , inner , session , _Op .RENAME )
31413169
31423170 async def distinct (
31433171 self ,
0 commit comments