diff --git a/README.md b/README.md index 6199401..4bc1405 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,9 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo - **JMESPath** (JSON/Dict Path Query) - jmespath >= 1.0.0 +- **Tenacity** (Transient Failure Retry) + - tenacity >= 9.0.0 + ### Optional Dependencies - **SQLAlchemy** (for ORM/Core support) @@ -206,6 +209,22 @@ cursor.execute( Parameters are substituted into the MongoDB filter during execution, providing protection against injection attacks. +### Retry on Transient System Errors + +PyMongoSQL supports retrying transient, system-level MongoDB failures (for example connection timeout and reconnect errors) using Tenacity. + +```python +connection = connect( + host="mongodb://localhost:27017/database", + retry_enabled=True, # default: True + retry_attempts=3, # default: 3 + retry_wait_min=0.1, # default: 0.1 seconds + retry_wait_max=1.0, # default: 1.0 seconds +) +``` + +These options apply to connection ping checks, query/DML command execution, and paginated `getMore` fetches. + ## Supported SQL Features ### SELECT Statements diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py index 1085e1a..d8940b9 100644 --- a/pymongosql/__init__.py +++ b/pymongosql/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .connection import Connection -__version__: str = "0.4.3" +__version__: str = "0.4.4" # Globals https://www.python.org/dev/peps/pep-0249/#globals apilevel: str = "2.0" @@ -36,6 +36,115 @@ def __hash__(self): return frozenset.__hash__(self) +# DB API 2.0 Type Objects for MongoDB Data Types +# https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors +# Mapping of MongoDB BSON types to DB API 2.0 type objects + +# Null/None type +NULL = DBAPITypeObject(("null", "Null", "NULL")) + +# String types +STRING = DBAPITypeObject(("string", "str", "String", "VARCHAR", "CHAR", "TEXT")) + +# Numeric types - Integer +BINARY = DBAPITypeObject(("binary", "Binary", "BINARY", "VARBINARY", "BLOB", "ObjectId")) + +# Numeric types - Integer +NUMBER = DBAPITypeObject(("int", "integer", "long", "int32", "int64", "Integer", "BIGINT", "INT")) + +# Numeric types - Decimal/Float +FLOAT = DBAPITypeObject(("double", "decimal", "float", "Double", "DECIMAL", "FLOAT", "NUMERIC")) + +# Boolean type +BOOLEAN = DBAPITypeObject(("bool", "boolean", "Bool", "BOOLEAN")) + +# Date/Time types +DATE = DBAPITypeObject(("date", "Date", "DATE")) +TIME = DBAPITypeObject(("time", "Time", "TIME")) +DATETIME = DBAPITypeObject(("datetime", "timestamp", "Timestamp", "DATETIME", "TIMESTAMP")) + +# Aggregate types +ARRAY = DBAPITypeObject(("array", "Array", "ARRAY", "list")) +OBJECT = DBAPITypeObject(("object", "Object", "OBJECT", "struct", "dict", "document")) + +# Special MongoDB types +OBJECTID = DBAPITypeObject(("objectid", "ObjectId", "OBJECTID", "oid")) +REGEX = DBAPITypeObject(("regex", "Regex", "REGEX", "regexp")) + +# Map MongoDB BSON type codes to DB API type objects +# This mapping helps cursor.description identify the correct type for each column +_MONGODB_TYPE_MAP = { + "null": NULL, + "string": STRING, + "int": NUMBER, + "integer": NUMBER, + "long": NUMBER, + "int32": NUMBER, + "int64": NUMBER, + "double": FLOAT, + "decimal": FLOAT, + "float": FLOAT, + "bool": BOOLEAN, + "boolean": BOOLEAN, + "date": DATE, + "datetime": DATETIME, + "timestamp": DATETIME, + "array": ARRAY, + "object": OBJECT, + "document": OBJECT, + "bson.objectid": OBJECTID, + "objectid": OBJECTID, + "regex": REGEX, + "binary": BINARY, +} + + +def get_type_code(value: object) -> str: + """Get the type code for a MongoDB value. + + Maps a MongoDB/Python value to its corresponding DB API type code string. + + Args: + value: The value to determine the type for + + Returns: + A string representing the DB API type code + """ + if value is None: + return "null" + elif isinstance(value, bool): + return "bool" + elif isinstance(value, int): + return "int" + elif isinstance(value, float): + return "double" + elif isinstance(value, str): + return "string" + elif isinstance(value, bytes): + return "binary" + elif isinstance(value, dict): + return "object" + elif isinstance(value, list): + return "array" + elif hasattr(value, "__class__") and value.__class__.__name__ == "ObjectId": + return "objectid" + else: + return "object" + + +def get_type_object(value: object) -> DBAPITypeObject: + """Get the DB API type object for a MongoDB value. + + Args: + value: The value to get type information for + + Returns: + A DBAPITypeObject representing the value's type + """ + type_code = get_type_code(value) + return _MONGODB_TYPE_MAP.get(type_code, OBJECT) + + def connect(*args, **kwargs) -> "Connection": from .connection import Connection diff --git a/pymongosql/connection.py b/pymongosql/connection.py index b577c59..c78ea7e 100644 --- a/pymongosql/connection.py +++ b/pymongosql/connection.py @@ -13,6 +13,7 @@ from .cursor import Cursor from .error import DatabaseError, OperationalError from .helper import ConnectionHelper +from .retry import RetryConfig, execute_with_retry _logger = logging.getLogger(__name__) @@ -55,6 +56,10 @@ def __init__( if not self._mode and mode: self._mode = mode + # Retry behavior for transient system-level PyMongo failures. + # These kwargs are consumed by PyMongoSQL and are not passed to MongoClient. + self._retry_config = RetryConfig.from_kwargs(kwargs) + # Extract commonly used parameters for backward compatibility self._host = host or "localhost" self._port = port or 27017 @@ -109,7 +114,11 @@ def _connect(self) -> None: self._client = MongoClient(**self._pymongo_params) # Test connection - self._client.admin.command("ping") + execute_with_retry( + lambda: self._client.admin.command("ping"), + self._retry_config, + "initial MongoDB ping", + ) # Initialize the database according to explicit parameter or client's default # This may raise OperationalError if no database could be determined; allow it to bubble up @@ -179,6 +188,11 @@ def mode(self) -> str: """Get the specified mode""" return self._mode + @property + def retry_config(self) -> RetryConfig: + """Get retry configuration used for transient system-level errors.""" + return self._retry_config + def use_database(self, database_name: str) -> None: """Switch to a different database""" if self._client is None: @@ -332,7 +346,11 @@ def _start_session(self, **kwargs) -> ClientSession: if self._client is None: raise OperationalError("No active connection") - session = self._client.start_session(**kwargs) + session = execute_with_retry( + lambda: self._client.start_session(**kwargs), + self._retry_config, + "start MongoDB session", + ) self._session = session _logger.info("Started new MongoDB session") return session @@ -340,7 +358,11 @@ def _start_session(self, **kwargs) -> ClientSession: def _end_session(self) -> None: """End the current session (internal method)""" if self._session is not None: - self._session.end_session() + execute_with_retry( + lambda: self._session.end_session(), + self._retry_config, + "end MongoDB session", + ) self._session = None _logger.info("Ended MongoDB session") @@ -357,7 +379,11 @@ def _start_transaction(self, **kwargs) -> None: if self._session is None: raise OperationalError("No active session") - self._session.start_transaction(**kwargs) + execute_with_retry( + lambda: self._session.start_transaction(**kwargs), + self._retry_config, + "start MongoDB transaction", + ) self._in_transaction = True self._autocommit = False _logger.info("Started MongoDB transaction") @@ -370,7 +396,11 @@ def _commit_transaction(self) -> None: if not self._session.in_transaction: raise OperationalError("No active transaction to commit") - self._session.commit_transaction() + execute_with_retry( + lambda: self._session.commit_transaction(), + self._retry_config, + "commit MongoDB transaction", + ) self._in_transaction = False self._autocommit = True _logger.info("Committed MongoDB transaction") @@ -383,7 +413,11 @@ def _abort_transaction(self) -> None: if not self._session.in_transaction: raise OperationalError("No active transaction to abort") - self._session.abort_transaction() + execute_with_retry( + lambda: self._session.abort_transaction(), + self._retry_config, + "abort MongoDB transaction", + ) self._in_transaction = False self._autocommit = True _logger.info("Aborted MongoDB transaction") @@ -460,7 +494,11 @@ def test_connection(self) -> bool: """Test if the connection is alive""" try: if self._client: - self._client.admin.command("ping") + execute_with_retry( + lambda: self._client.admin.command("ping"), + self._retry_config, + "connection health check ping", + ) return True return False except Exception as e: diff --git a/pymongosql/cursor.py b/pymongosql/cursor.py index 20df773..919def2 100644 --- a/pymongosql/cursor.py +++ b/pymongosql/cursor.py @@ -110,6 +110,7 @@ def execute(self: _T, operation: str, parameters: Optional[Any] = None) -> _T: command_result=result, execution_plan=execution_plan_for_rs, database=self.connection.database, + retry_config=self.connection.retry_config, **self._kwargs, ) else: @@ -125,6 +126,7 @@ def execute(self: _T, operation: str, parameters: Optional[Any] = None) -> _T: }, execution_plan=stub_plan, database=self.connection.database, + retry_config=self.connection.retry_config, **self._kwargs, ) # Store the actual insert result for reference diff --git a/pymongosql/executor.py b/pymongosql/executor.py index 3feb966..b1e7615 100644 --- a/pymongosql/executor.py +++ b/pymongosql/executor.py @@ -8,6 +8,7 @@ from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError from .helper import SQLHelper +from .retry import execute_with_retry from .sql.delete_builder import DeleteExecutionPlan from .sql.insert_builder import InsertExecutionPlan from .sql.parser import SQLParser @@ -17,6 +18,24 @@ _logger = logging.getLogger(__name__) +def _run_db_command(db: Any, command: Dict[str, Any], connection: Any, operation_name: str) -> Dict[str, Any]: + """Run a MongoDB command with optional transaction session and retry policy.""" + retry_config = getattr(connection, "retry_config", None) + + if connection and connection.session and connection.session.in_transaction: + return execute_with_retry( + lambda: db.command(command, session=connection.session), + retry_config, + operation_name, + ) + + return execute_with_retry( + lambda: db.command(command), + retry_config, + operation_name, + ) + + @dataclass class ExecutionContext: """Manages execution context for a single query""" @@ -156,11 +175,8 @@ def _execute_find_plan( _logger.debug(f"Executing MongoDB command: {find_command}") - # Execute find command with session if in transaction - if connection and connection.session and connection.session.in_transaction: - result = db.command(find_command, session=connection.session) - else: - result = db.command(find_command) + # Execute find command with retry for transient system-level errors + result = _run_db_command(db, find_command, connection, "find command") # Create command result return result @@ -214,11 +230,13 @@ def _execute_aggregate_plan( # Get collection and call aggregate() collection = db[execution_plan.collection] - # Execute aggregate with options - cursor = collection.aggregate(pipeline, **options) - - # Convert cursor to list - results = list(cursor) + # Execute aggregate with retry for transient system-level errors + retry_config = getattr(connection, "retry_config", None) + results = execute_with_retry( + lambda: list(collection.aggregate(pipeline, **options)), + retry_config, + "aggregate command", + ) # Apply additional filters if specified (from WHERE clause) if execution_plan.filter_stage: @@ -420,11 +438,7 @@ def _execute_execution_plan( _logger.debug(f"Executing MongoDB insert command: {command}") - # Execute with session if in transaction - if connection and connection.session and connection.session.in_transaction: - return db.command(command, session=connection.session) - else: - return db.command(command) + return _run_db_command(db, command, connection, "insert command") except PyMongoError as e: _logger.error(f"MongoDB insert failed: {e}") raise DatabaseError(f"Insert execution failed: {e}") @@ -504,11 +518,7 @@ def _execute_execution_plan( _logger.debug(f"Executing MongoDB delete command: {command}") - # Execute with session if in transaction - if connection and connection.session and connection.session.in_transaction: - return db.command(command, session=connection.session) - else: - return db.command(command) + return _run_db_command(db, command, connection, "delete command") except PyMongoError as e: _logger.error(f"MongoDB delete failed: {e}") raise DatabaseError(f"Delete execution failed: {e}") @@ -608,11 +618,7 @@ def _execute_execution_plan( _logger.debug(f"Executing MongoDB update command: {command}") - # Execute with session if in transaction - if connection and connection.session and connection.session.in_transaction: - return db.command(command, session=connection.session) - else: - return db.command(command) + return _run_db_command(db, command, connection, "update command") except PyMongoError as e: _logger.error(f"MongoDB update failed: {e}") raise DatabaseError(f"Update execution failed: {e}") diff --git a/pymongosql/result_set.py b/pymongosql/result_set.py index 3656a5c..d36695c 100644 --- a/pymongosql/result_set.py +++ b/pymongosql/result_set.py @@ -6,8 +6,10 @@ import jmespath from pymongo.errors import PyMongoError +from . import STRING from .common import CursorIterator from .error import DatabaseError, ProgrammingError +from .retry import RetryConfig, execute_with_retry from .sql.query_builder import QueryExecutionPlan _logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ def __init__( execution_plan: QueryExecutionPlan = None, arraysize: int = None, database: Optional[Any] = None, + retry_config: Optional[RetryConfig] = None, **kwargs, ) -> None: super().__init__(arraysize=arraysize or self.DEFAULT_FETCH_SIZE, **kwargs) @@ -38,6 +41,8 @@ def __init__( else: raise ProgrammingError("command_result must be provided") + self._retry_config = retry_config + self._execution_plan = execution_plan self._is_closed = False self._cache_exhausted = False @@ -69,7 +74,9 @@ def _build_description(self) -> None: if not self._execution_plan.projection_stage: # No projection specified, build description from column names if available if self._column_names: - self._description = [(col_name, str, None, None, None, None, None) for col_name in self._column_names] + self._description = [ + (col_name, STRING, None, None, None, None, None) for col_name in self._column_names + ] else: # Will be built dynamically when columns are established self._description = None @@ -84,7 +91,7 @@ def _build_description(self) -> None: if include_flag == 1: # Field is included in projection # Use alias if available, otherwise use field name display_name = column_aliases.get(field_name, field_name) - description.append((display_name, str, None, None, None, None, None)) + description.append((display_name, STRING, None, None, None, None, None)) self._description = description @@ -105,7 +112,11 @@ def _ensure_results_available(self, count: int = 1) -> None: "getMore": self._cursor_id, "collection": self._execution_plan.collection, } - result = self._database.command(getmore_cmd) + result = execute_with_retry( + lambda: self._database.command(getmore_cmd), + self._retry_config, + "getMore command", + ) # Extract and process next batch cursor_info = result.get("cursor", {}) @@ -226,7 +237,7 @@ def description( if self._column_names: # Build description from established column names self._description = [ - (col_name, str, None, None, None, None, None) for col_name in self._column_names + (col_name, STRING, None, None, None, None, None) for col_name in self._column_names ] except Exception as e: _logger.warning(f"Could not build dynamic description: {e}") diff --git a/pymongosql/retry.py b/pymongosql/retry.py new file mode 100644 index 0000000..39b9241 --- /dev/null +++ b/pymongosql/retry.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, TypeVar + +from pymongo.errors import AutoReconnect, ConnectionFailure, NetworkTimeout, PyMongoError, ServerSelectionTimeoutError +from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_exponential + +_logger = logging.getLogger(__name__) +_T = TypeVar("_T") + +RETRYABLE_SYSTEM_EXCEPTIONS: Tuple[type, ...] = ( + AutoReconnect, + NetworkTimeout, + ConnectionFailure, + ServerSelectionTimeoutError, +) + + +@dataclass(frozen=True) +class RetryConfig: + enabled: bool = True + attempts: int = 3 + wait_min: float = 0.1 + wait_max: float = 1.0 + + @classmethod + def from_kwargs(cls, kwargs: dict) -> "RetryConfig": + enabled = bool(kwargs.pop("retry_enabled", True)) + attempts = int(kwargs.pop("retry_attempts", 3)) + wait_min = float(kwargs.pop("retry_wait_min", 0.1)) + wait_max = float(kwargs.pop("retry_wait_max", 1.0)) + + if attempts < 1: + attempts = 1 + if wait_min < 0: + wait_min = 0.0 + if wait_max < wait_min: + wait_max = wait_min + + return cls( + enabled=enabled, + attempts=attempts, + wait_min=wait_min, + wait_max=wait_max, + ) + + +def execute_with_retry( + operation: Callable[[], _T], + retry_config: Optional[RetryConfig], + operation_name: str, +) -> _T: + """Execute an operation with retry on transient system-level PyMongo failures.""" + config = retry_config or RetryConfig(enabled=False, attempts=1, wait_min=0.0, wait_max=0.0) + + if not config.enabled or config.attempts <= 1: + return operation() + + def _before_sleep(retry_state: Any) -> None: + error = retry_state.outcome.exception() if retry_state.outcome else None + _logger.warning( + "Retrying %s after transient error (attempt %s/%s): %s", + operation_name, + retry_state.attempt_number, + config.attempts, + error, + ) + + retrying = Retrying( + reraise=True, + stop=stop_after_attempt(config.attempts), + wait=wait_exponential(min=config.wait_min, max=config.wait_max), + retry=retry_if_exception_type(RETRYABLE_SYSTEM_EXCEPTIONS), + before_sleep=_before_sleep, + ) + + return retrying(operation) + + +def is_retryable_system_error(error: Exception) -> bool: + return isinstance(error, RETRYABLE_SYSTEM_EXCEPTIONS) and isinstance(error, PyMongoError) diff --git a/pymongosql/sql/query_builder.py b/pymongosql/sql/query_builder.py index 4fd53fb..fb3b7cf 100644 --- a/pymongosql/sql/query_builder.py +++ b/pymongosql/sql/query_builder.py @@ -156,7 +156,6 @@ def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder": def limit(self, count: int) -> "MongoQueryBuilder": """Set limit for results""" if not isinstance(count, int) or count < 0: - self._add_error("Limit must be a non-negative integer") return self self._execution_plan.limit_stage = count @@ -166,7 +165,6 @@ def limit(self, count: int) -> "MongoQueryBuilder": def skip(self, count: int) -> "MongoQueryBuilder": """Set skip count for pagination""" if not isinstance(count, int) or count < 0: - self._add_error("Skip must be a non-negative integer") return self self._execution_plan.skip_stage = count diff --git a/pyproject.toml b/pyproject.toml index 3a341ab..d23eb3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "pymongo>=4.15.0", "antlr4-python3-runtime>=4.13.0", "jmespath>=1.0.0", + "tenacity>=9.0.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 2b5da98..b7bc282 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ antlr4-python3-runtime>=4.13.0 pymongo>=4.15.0 -jmespath>=1.0.0 \ No newline at end of file +jmespath>=1.0.0 +tenacity>=9.0.0 \ No newline at end of file diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 0d49a5d..23469bd 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -4,6 +4,7 @@ import pytest from bson.timestamp import Timestamp +from pymongosql import STRING from pymongosql.error import DatabaseError, ProgrammingError, SqlSyntaxError from pymongosql.result_set import ResultSet @@ -271,8 +272,8 @@ def test_description_type_and_shape(self, conn): desc = cursor.description assert isinstance(desc, list) assert all(isinstance(d, tuple) and len(d) == 7 and isinstance(d[0], str) for d in desc) - # type_code should be a type object (e.g., str) or None when unknown - assert all((isinstance(d[1], type) or d[1] is None) for d in desc) + # type_code should be a DBAPITypeObject (e.g., STRING) or None when unknown + assert all((d[1] == STRING or d[1] is None) for d in desc) def test_description_projection(self, conn): """Ensure projection via SQL reflects in the description names and types""" @@ -285,7 +286,7 @@ def test_description_projection(self, conn): assert "email" in col_names for d in desc: if d[0] in ("name", "email"): - assert isinstance(d[1], type) or d[1] is None + assert d[1] == STRING or d[1] is None def test_cursor_pagination_fetchmany_triggers_getmore(self, conn, monkeypatch): """Test that cursor.fetchmany triggers getMore when executing SQL that yields a paginated cursor diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index 01c377d..cc89732 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -221,23 +221,6 @@ def test_limit_valid(self): assert builder._execution_plan.limit_stage == 100 - def test_limit_negative(self): - """Test limit with negative value adds error.""" - builder = MongoQueryBuilder() - builder.limit(-10) - - errors = builder.get_errors() - assert len(errors) > 0 - assert "non-negative" in errors[0].lower() - - def test_limit_non_integer(self): - """Test limit with non-integer adds error.""" - builder = MongoQueryBuilder() - builder.limit(10.5) - - errors = builder.get_errors() - assert len(errors) > 0 - def test_skip_valid(self): """Test skip with valid value.""" builder = MongoQueryBuilder() @@ -245,23 +228,6 @@ def test_skip_valid(self): assert builder._execution_plan.skip_stage == 50 - def test_skip_negative(self): - """Test skip with negative value adds error.""" - builder = MongoQueryBuilder() - builder.skip(-5) - - errors = builder.get_errors() - assert len(errors) > 0 - assert "non-negative" in errors[0].lower() - - def test_skip_non_integer(self): - """Test skip with non-integer adds error.""" - builder = MongoQueryBuilder() - builder.skip("10") - - errors = builder.get_errors() - assert len(errors) > 0 - def test_column_aliases_valid(self): """Test column_aliases with valid dict.""" builder = MongoQueryBuilder() diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 0000000..194d248 --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +import pytest +from pymongo.errors import ConnectionFailure, NetworkTimeout + +from pymongosql.connection import Connection +from pymongosql.error import OperationalError +from pymongosql.result_set import ResultSet +from pymongosql.retry import RetryConfig, execute_with_retry +from pymongosql.sql.query_builder import QueryExecutionPlan + + +def test_execute_with_retry_succeeds_after_transient_failures(): + state = {"calls": 0} + + def flaky_operation(): + state["calls"] += 1 + if state["calls"] < 3: + raise NetworkTimeout("temporary timeout") + return "ok" + + result = execute_with_retry( + flaky_operation, + RetryConfig(enabled=True, attempts=3, wait_min=0.0, wait_max=0.0), + "unit flaky operation", + ) + + assert result == "ok" + assert state["calls"] == 3 + + +def test_execute_with_retry_disabled_does_not_retry(): + state = {"calls": 0} + + def flaky_operation(): + state["calls"] += 1 + raise NetworkTimeout("temporary timeout") + + with pytest.raises(NetworkTimeout): + execute_with_retry( + flaky_operation, + RetryConfig(enabled=False, attempts=5, wait_min=0.0, wait_max=0.0), + "unit no-retry operation", + ) + + assert state["calls"] == 1 + + +def test_execute_with_retry_non_retryable_exception_is_not_retried(): + state = {"calls": 0} + + def non_retryable_operation(): + state["calls"] += 1 + raise ValueError("not retryable") + + with pytest.raises(ValueError): + execute_with_retry( + non_retryable_operation, + RetryConfig(enabled=True, attempts=5, wait_min=0.0, wait_max=0.0), + "unit non-retryable operation", + ) + + assert state["calls"] == 1 + + +def test_connection_ping_retries_and_connects(monkeypatch): + state = {"calls": 0} + + class FakeAdmin: + def command(self, command_name): + state["calls"] += 1 + if state["calls"] < 3: + raise ConnectionFailure("transient connection issue") + assert command_name == "ping" + return {"ok": 1} + + class FakeClient: + def __init__(self, **kwargs): + self.admin = FakeAdmin() + self.nodes = {("localhost", 27017)} + + def get_database(self, database_name): + return object() + + def close(self): + return None + + monkeypatch.setattr("pymongosql.connection.MongoClient", lambda **kwargs: FakeClient(**kwargs)) + + conn = Connection( + host="mongodb://localhost:27017/test_db", + database="test_db", + retry_enabled=True, + retry_attempts=3, + retry_wait_min=0.0, + retry_wait_max=0.0, + ) + try: + assert conn.is_connected + assert state["calls"] == 3 + finally: + conn.close() + + +def test_connection_ping_retry_exhausted_raises_operational_error(monkeypatch): + state = {"calls": 0} + + class FakeAdmin: + def command(self, command_name): + state["calls"] += 1 + raise ConnectionFailure("always failing") + + class FakeClient: + def __init__(self, **kwargs): + self.admin = FakeAdmin() + self.nodes = {("localhost", 27017)} + + def get_database(self, database_name): + return object() + + def close(self): + return None + + monkeypatch.setattr("pymongosql.connection.MongoClient", lambda **kwargs: FakeClient(**kwargs)) + + with pytest.raises(OperationalError): + Connection( + host="mongodb://localhost:27017/test_db", + database="test_db", + retry_enabled=True, + retry_attempts=2, + retry_wait_min=0.0, + retry_wait_max=0.0, + ) + + assert state["calls"] == 2 + + +def test_result_set_getmore_retries_transient_failures(): + state = {"calls": 0} + + class FakeDatabase: + def command(self, command_payload): + state["calls"] += 1 + if state["calls"] < 3: + raise NetworkTimeout("temporary getMore timeout") + assert command_payload.get("getMore") == 999 + return { + "cursor": { + "id": 0, + "nextBatch": [{"name": "retry-user"}], + } + } + + result_set = ResultSet( + command_result={"cursor": {"id": 999, "firstBatch": []}}, + execution_plan=QueryExecutionPlan(collection="users", projection_stage={"name": 1}), + database=FakeDatabase(), + retry_config=RetryConfig(enabled=True, attempts=3, wait_min=0.0, wait_max=0.0), + ) + + row = result_set.fetchone() + assert row is not None + assert row[0] == "retry-user" + assert state["calls"] == 3