diff --git a/newsfragments/1235.feature.rst b/newsfragments/1235.feature.rst new file mode 100644 index 00000000..6c047099 --- /dev/null +++ b/newsfragments/1235.feature.rst @@ -0,0 +1,3 @@ +Added async PostgreSQL fixture support via ``postgresql_async`` factory and ``AsyncDatabaseJanitor``. +Added configurable fixture ``scope`` parameter to ``postgresql``, ``postgresql_async``, ``postgresql_proc``, and ``postgresql_noproc`` factories (defaults preserved: ``"function"`` for client fixtures, ``"session"`` for process fixtures). +Added optional ``async`` extra (``pip install pytest-postgresql[async]``) providing ``pytest-asyncio`` and ``aiofiles`` dependencies. diff --git a/pyproject.toml b/pyproject.toml index a9cb59a1..06e162b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,12 @@ dependencies = [ ] requires-python = ">= 3.10" +[project.optional-dependencies] +async = [ + "pytest-asyncio >= 0.21", + "aiofiles >= 23.0" +] + [project.urls] "Source" = "https://github.com/dbfixtures/pytest-postgresql" "Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues" diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py index d6bd2f64..002304cb 100644 --- a/pytest_postgresql/factories/__init__.py +++ b/pytest_postgresql/factories/__init__.py @@ -17,8 +17,8 @@ # along with pytest-postgresql. If not, see . """Fixture factories for postgresql fixtures.""" -from pytest_postgresql.factories.client import postgresql +from pytest_postgresql.factories.client import postgresql, postgresql_async from pytest_postgresql.factories.noprocess import postgresql_noproc from pytest_postgresql.factories.process import PortType, postgresql_proc -__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType") +__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType") diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index 5fb5a5be..26a701d1 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -17,23 +17,30 @@ # along with pytest-postgresql. If not, see . """Fixture factory for postgresql client.""" -from typing import Callable, Iterator +from typing import AsyncIterator, Callable, Iterator import psycopg import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection from pytest import FixtureRequest +try: + import pytest_asyncio +except ImportError: + pytest_asyncio = None # type: ignore[assignment] + from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor -from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT def postgresql( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, + scope: FixtureScopeT = "function", ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. @@ -41,12 +48,13 @@ def postgresql( :param dbname: database name :param isolation_level: optional postgresql isolation level defaults to server's default + :param scope: fixture scope; by default "function" which is recommended. :returns: function which makes a connection to postgresql """ - @pytest.fixture + @pytest.fixture(scope=scope) def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: - """Fixture factory for PostgreSQL. + """Fixture connection factory for PostgreSQL. :param request: fixture request object :returns: postgresql client @@ -85,3 +93,72 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: db_connection.close() return postgresql_factory + + +def postgresql_async( + process_fixture_name: str, + dbname: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, + scope: FixtureScopeT = "function", +) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]: + """Return async connection fixture factory for PostgreSQL. + + :param process_fixture_name: name of the process fixture + :param dbname: database name + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param scope: fixture scope; by default "function" which is recommended. + :returns: function which makes an async connection to postgresql + """ + if pytest_asyncio is None: + + @pytest.fixture(scope=scope) + def postgresql_async_factory(request: FixtureRequest) -> None: + """Sync stub that raises ImportError when pytest-asyncio is absent.""" + raise ImportError( + "pytest-asyncio is required for async fixtures. Install it with: pip install pytest-postgresql[async]" + ) + + return postgresql_async_factory # type: ignore[return-value] + + @pytest_asyncio.fixture(scope=scope, loop_scope=scope) + async def postgresql_async_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]: + """Async connection fixture factory for PostgreSQL. + + :param request: fixture request object + :returns: postgresql async client + """ + proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name) + config = get_config(request) + + pg_host = proc_fixture.host + pg_port = proc_fixture.port + pg_user = proc_fixture.user + pg_password = proc_fixture.password + pg_options = proc_fixture.options + pg_db = dbname or proc_fixture.dbname + janitor = AsyncDatabaseJanitor( + user=pg_user, + host=pg_host, + port=pg_port, + dbname=pg_db, + template_dbname=proc_fixture.template_dbname, + version=proc_fixture.version, + password=pg_password, + isolation_level=isolation_level, + ) + if config.drop_test_database: + await janitor.drop() + async with janitor: + db_connection: AsyncConnection = await AsyncConnection.connect( + dbname=pg_db, + user=pg_user, + password=pg_password, + host=pg_host, + port=pg_port, + options=pg_options, + ) + yield db_connection + await db_connection.close() + + return postgresql_async_factory diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index 2d7f8b49..8af27c37 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -27,6 +27,7 @@ from pytest_postgresql.config import get_config from pytest_postgresql.executor_noop import NoopExecutor from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT def xdistify_dbname(dbname: str) -> str: @@ -46,6 +47,7 @@ def postgresql_noproc( options: str = "", load: list[Callable | str | Path] | None = None, depends_on: str | None = None, + scope: FixtureScopeT = "session", ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. @@ -57,10 +59,11 @@ def postgresql_noproc( :param options: Postgresql connection options :param load: List of functions used to initialize database's template. :param depends_on: Optional name of the fixture to depend on. + :param scope: fixture scope; by default "session" which is recommended. :returns: function which makes a postgresql process """ - @pytest.fixture(scope="session") + @pytest.fixture(scope=scope) def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]: """Noop Process fixture for PostgreSQL. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 27fab57f..cbfc0111 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -32,6 +32,7 @@ from pytest_postgresql.exceptions import ExecutableMissingException from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT PortType = port_for.PortType # mypy requires explicit export @@ -81,6 +82,7 @@ def postgresql_proc( unixsocketdir: str | None = None, postgres_options: str | None = None, load: list[Callable | str | Path] | None = None, + scope: FixtureScopeT = "session", ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. @@ -101,10 +103,11 @@ def postgresql_proc( :param unixsocketdir: directory to create postgresql's unixsockets :param postgres_options: Postgres executable options for use by pg_ctl :param load: List of functions used to initialize database's template. + :param scope: fixture scope; by default "session" which is recommended. :returns: function which makes a postgresql process """ - @pytest.fixture(scope="session") + @pytest.fixture(scope=scope) def postgresql_proc_fixture( request: FixtureRequest, tmp_path_factory: TempPathFactory ) -> Iterator[PostgreSQLExecutor]: diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index f602372e..146f4dd0 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -1,21 +1,24 @@ """Database Janitor.""" -from contextlib import contextmanager +import inspect +from contextlib import asynccontextmanager, contextmanager from pathlib import Path from types import TracebackType -from typing import Callable, Iterator, Type, TypeVar +from typing import AsyncIterator, Callable, Iterator, Type, TypeVar import psycopg +import psycopg.sql as sql from packaging.version import parse -from psycopg import Connection, Cursor +from psycopg import AsyncCursor, Connection, Cursor -from pytest_postgresql.loader import build_loader -from pytest_postgresql.retry import retry +from pytest_postgresql.loader import build_loader, build_loader_async +from pytest_postgresql.retry import retry, retry_async Version = type(parse("1")) DatabaseJanitorType = TypeVar("DatabaseJanitorType", bound="DatabaseJanitor") +AsyncDatabaseJanitorType = TypeVar("AsyncDatabaseJanitorType", bound="AsyncDatabaseJanitor") class DatabaseJanitor: @@ -67,18 +70,17 @@ def __init__( def init(self) -> None: """Create database in postgresql.""" with self.cursor() as cur: + query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname)) if self.template_dbname: # And make sure no-one is left connected to the template database. # Otherwise, Creating database from template will fail self._terminate_connection(cur, self.template_dbname) - query = f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}"' - else: - query = f'CREATE DATABASE "{self.dbname}"' + query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname)) if self.as_template: - query += " IS_TEMPLATE = true" + query = query + sql.SQL(" IS_TEMPLATE = true") - cur.execute(f"{query};") + cur.execute(query) def is_template(self) -> bool: """Determine whether the DatabaseJanitor maintains template or database.""" @@ -92,17 +94,17 @@ def drop(self) -> None: self._dont_datallowconn(cur, self.dbname) self._terminate_connection(cur, self.dbname) if self.as_template: - cur.execute(f'ALTER DATABASE "{self.dbname}" with is_template false;') - cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";') + cur.execute(sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname))) + cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname))) @staticmethod def _dont_datallowconn(cur: Cursor, dbname: str) -> None: - cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;') + cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname))) @staticmethod def _terminate_connection(cur: Cursor, dbname: str) -> None: cur.execute( - "SELECT pg_terminate_backend(pg_stat_activity.pid)" + "SELECT pg_terminate_backend(pg_stat_activity.pid) " "FROM pg_stat_activity " "WHERE pg_stat_activity.datname = %s;", (dbname,), @@ -164,3 +166,153 @@ def __exit__( ) -> None: """Exit from Database janitor context cleaning after itself.""" self.drop() + + +class AsyncDatabaseJanitor: + """Manage database state asynchronously for specific tasks.""" + + def __init__( + self, + *, + user: str, + host: str, + port: str | int, + version: str | float | Version, # type: ignore[valid-type] + dbname: str, + template_dbname: str | None = None, + as_template: bool = False, + password: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, + connection_timeout: int = 60, + ) -> None: + """Initialize async janitor. + + :param user: postgresql username + :param host: postgresql host + :param port: postgresql port + :param dbname: database name + :param template_dbname: template database name to clone from + :param as_template: whether to mark the database as a template + :param version: postgresql version number + :param password: optional postgresql password + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param connection_timeout: how long to retry connection before + raising a TimeoutError + """ + self.user = user + self.password = password + self.host = host + self.port = port + self.dbname = dbname + self.template_dbname = template_dbname + self.as_template = as_template + self._connection_timeout = connection_timeout + self.isolation_level = isolation_level + if not isinstance(version, Version): + self.version = parse(str(version)) + else: + self.version = version + + async def init(self) -> None: + """Create database in postgresql.""" + async with self.cursor() as cur: + query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname)) + if self.template_dbname: + # And make sure no-one is left connected to the template database. + # Otherwise, Creating database from template will fail + await self._terminate_connection(cur, self.template_dbname) + query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname)) + + if self.as_template: + query = query + sql.SQL(" IS_TEMPLATE = true") + + await cur.execute(query) + + def is_template(self) -> bool: + """Determine whether the AsyncDatabaseJanitor maintains template or database.""" + return self.as_template + + async def drop(self) -> None: + """Drop database in postgresql.""" + # We cannot drop the database while there are connections to it, so we + # terminate all connections first while not allowing new connections. + async with self.cursor() as cur: + await self._dont_datallowconn(cur, self.dbname) + await self._terminate_connection(cur, self.dbname) + if self.as_template: + await cur.execute( + sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname)) + ) + await cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname))) + + @staticmethod + async def _dont_datallowconn(cur: AsyncCursor, dbname: str) -> None: # type: ignore[type-arg] + await cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname))) + + @staticmethod + async def _terminate_connection(cur: AsyncCursor, dbname: str) -> None: # type: ignore[type-arg] + await cur.execute( + "SELECT pg_terminate_backend(pg_stat_activity.pid) " + "FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %s;", + (dbname,), + ) + + async def load(self, load: Callable | str | Path) -> None: + """Load data into a database. + + Expects: + + * a Path to sql file, that'll be loaded + * an import path to import callable + * a callable that expects: host, port, user, dbname and password arguments. + + """ + _loader = build_loader_async(load) + result = _loader( + host=self.host, + port=self.port, + user=self.user, + dbname=self.dbname, + password=self.password, + ) + if inspect.isawaitable(result): + await result + + @asynccontextmanager + async def cursor(self, dbname: str = "postgres") -> AsyncIterator[AsyncCursor]: # type: ignore[type-arg] + """Return postgresql async cursor.""" + + async def connect() -> psycopg.AsyncConnection: + return await psycopg.AsyncConnection.connect( + dbname=dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + ) + + conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError) + try: + await conn.set_isolation_level(self.isolation_level) + await conn.set_autocommit(True) + # We must not run a transaction since we create a database. + async with conn.cursor() as cur: + yield cur + finally: + await conn.close() + + async def __aenter__(self: AsyncDatabaseJanitorType) -> AsyncDatabaseJanitorType: + """Initialize Async Database Janitor.""" + await self.init() + return self + + async def __aexit__( + self: AsyncDatabaseJanitorType, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit from Async Database Janitor context cleaning after itself.""" + await self.drop() diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index c9b28cbd..63f025ba 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -1,5 +1,6 @@ """Loader helper functions.""" +import importlib import re from functools import partial from pathlib import Path @@ -7,6 +8,11 @@ import psycopg +try: + import aiofiles +except ImportError: + aiofiles = None # type: ignore[assignment] + def build_loader(load: Callable | str | Path) -> Callable: """Build a loader callable.""" @@ -16,7 +22,7 @@ def build_loader(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name]) + _temp_import = importlib.import_module(import_path) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -30,3 +36,35 @@ def sql(sql_filename: Path, **kwargs: Any) -> None: with db_connection.cursor() as cur: cur.execute(_fd.read()) db_connection.commit() + + +def build_loader_async(load: Callable | str | Path) -> Callable: + """Build an async loader callable.""" + if isinstance(load, Path): + return partial(sql_async, load) + elif isinstance(load, str): + loader_parts = re.split("[.:]", load, maxsplit=2) + import_path = ".".join(loader_parts[:-1]) + loader_name = loader_parts[-1] + _temp_import = importlib.import_module(import_path) + _loader: Callable = getattr(_temp_import, loader_name) + return _loader + else: + return load + + +async def sql_async(sql_filename: Path, **kwargs: Any) -> None: + """Async database loader for sql files. + + Requires the optional ``async`` extra: ``pip install pytest-postgresql[async]``. + """ + if aiofiles is None: + raise ImportError( + "aiofiles is required for async SQL loading. Install it with: pip install pytest-postgresql[async]" + ) + + async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: + async with db_connection.cursor() as cur: + async with aiofiles.open(sql_filename, "r") as _fd: + await cur.execute(await _fd.read()) + await db_connection.commit() diff --git a/pytest_postgresql/plugin.py b/pytest_postgresql/plugin.py index 612e408a..5fa7b58c 100644 --- a/pytest_postgresql/plugin.py +++ b/pytest_postgresql/plugin.py @@ -135,3 +135,4 @@ def pytest_addoption(parser: Parser) -> None: postgresql_proc = factories.postgresql_proc() postgresql_noproc = factories.postgresql_noproc() postgresql = factories.postgresql("postgresql_proc") +postgresql_async = factories.postgresql_async("postgresql_proc") diff --git a/pytest_postgresql/retry.py b/pytest_postgresql/retry.py index ea25fa2e..078db5bc 100644 --- a/pytest_postgresql/retry.py +++ b/pytest_postgresql/retry.py @@ -1,9 +1,10 @@ """Small retry callable in case of specific error occurred.""" +import asyncio import datetime import sys from time import sleep -from typing import Callable, Type, TypeVar +from typing import Awaitable, Callable, Type, TypeVar T = TypeVar("T") @@ -29,11 +30,41 @@ def retry( i += 1 try: res = func() - return res except possible_exception as e: if time + timeout_diff < get_current_datetime(): raise TimeoutError(f"Failed after {i} attempts") from e sleep(1) + else: + return res + + +async def retry_async( + func: Callable[[], Awaitable[T]], + timeout: int = 60, + possible_exception: Type[Exception] = Exception, +) -> T: + """Attempt to retry the async function for timeout time. + + Most often used for connecting to postgresql database as, + especially on macos on github-actions, first few tries fails + with this message: + + ... :: + FATAL: the database system is starting up + """ + time: datetime.datetime = get_current_datetime() + timeout_diff: datetime.timedelta = datetime.timedelta(seconds=timeout) + i = 0 + while True: + i += 1 + try: + res = await func() + except possible_exception as e: + if time + timeout_diff < get_current_datetime(): + raise TimeoutError(f"Failed after {i} attempts") from e + await asyncio.sleep(1) + else: + return res def get_current_datetime() -> datetime.datetime: diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py new file mode 100644 index 00000000..e5f35043 --- /dev/null +++ b/pytest_postgresql/types.py @@ -0,0 +1,5 @@ +"""Pytest PostgreSQL types.""" + +from typing import Literal + +FixtureScopeT = Literal["session", "package", "module", "class", "function"] diff --git a/tests/conftest.py b/tests/conftest.py index 784b8905..483437af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,3 +17,5 @@ postgresql_proc2 = factories.postgresql_proc(port=None, load=[TEST_SQL_FILE, TEST_SQL_FILE2]) postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db") postgresql_load_1 = factories.postgresql("postgresql_proc2") +postgresql2_async = factories.postgresql_async("postgresql_proc2", dbname="test-db") +postgresql_load_1_async = factories.postgresql_async("postgresql_proc2") diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py index ae25307a..1d0fbb73 100644 --- a/tests/docker/test_noproc_docker.py +++ b/tests/docker/test_noproc_docker.py @@ -3,7 +3,7 @@ import pathlib import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection import pytest_postgresql.factories.client import pytest_postgresql.factories.noprocess @@ -14,12 +14,17 @@ ) postgres_with_schema = pytest_postgresql.factories.client.postgresql("postgresql_my_proc") +async_postgres_with_schema = pytest_postgresql.factories.client.postgresql_async("postgresql_my_proc") + postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc( dbname="stories_templated", load=[load_database] ) postgres_with_template = pytest_postgresql.factories.client.postgresql( "postgresql_my_proc_template", dbname="stories_templated" ) +async_postgres_with_template = pytest_postgresql.factories.client.postgresql_async( + "postgresql_my_proc_template", dbname="stories_templated" +) def test_postgres_docker_load(postgres_with_schema: Connection) -> None: @@ -32,6 +37,14 @@ def test_postgres_docker_load(postgres_with_schema: Connection) -> None: print(cur.fetchall()) +@pytest.mark.asyncio +async def test_postgres_docker_load_async(async_postgres_with_schema: AsyncConnection) -> None: + """Async check main postgres fixture.""" + async with async_postgres_with_schema.cursor() as cur: + await cur.execute("select * from public.tokens") + print(await cur.fetchall()) + + @pytest.mark.parametrize("_", range(5)) def test_template_database(postgres_with_template: Connection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" @@ -43,3 +56,17 @@ def test_template_database(postgres_with_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(5)) +async def test_template_database_async(async_postgres_with_template: AsyncConnection, _: int) -> None: + """Async check that the database structure gets recreated out of a template.""" + async with async_postgres_with_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 0 diff --git a/tests/test_factory_errors.py b/tests/test_factory_errors.py new file mode 100644 index 00000000..18216693 --- /dev/null +++ b/tests/test_factory_errors.py @@ -0,0 +1,33 @@ +"""Tests for factory error paths (missing optional dependencies).""" + +from unittest.mock import patch + +import pytest + +from pytest_postgresql.factories.client import postgresql_async + + +def test_postgresql_async_factory_creation_succeeds_without_pytest_asyncio() -> None: + """postgresql_async() must not raise at factory-creation time when pytest-asyncio is absent. + + The plugin registers ``postgresql_async`` at load time (plugin.py), so raising here + would break all users — including those who only use synchronous fixtures. + """ + with patch("pytest_postgresql.factories.client.pytest_asyncio", None): + fixture_func = postgresql_async("some_proc_fixture") + assert callable(fixture_func) + + +def test_postgresql_async_raises_on_use_without_pytest_asyncio() -> None: + """When pytest-asyncio is absent, the registered stub is synchronous and raises ImportError. + + A synchronous stub avoids the "coroutine was never awaited" warning that would + result from registering an async def with plain pytest.fixture. + """ + with patch("pytest_postgresql.factories.client.pytest_asyncio", None): + fixture_func = postgresql_async("some_proc_fixture") + # pytest 8+ wraps fixtures to prevent direct calls; unwrap first. + raw_func = getattr(fixture_func, "__wrapped__", fixture_func) + assert not hasattr(raw_func, "__await__"), "stub must be a sync function, not a coroutine" + with pytest.raises(ImportError, match="pytest-asyncio"): + raw_func(None) # type: ignore[arg-type] diff --git a/tests/test_janitor.py b/tests/test_janitor.py index fd1fca2a..9390d9e7 100644 --- a/tests/test_janitor.py +++ b/tests/test_janitor.py @@ -1,13 +1,16 @@ """Database Janitor tests.""" import sys -from typing import Any -from unittest.mock import MagicMock, patch +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator +from unittest.mock import AsyncMock, MagicMock, patch +import psycopg.sql as pgsql import pytest from packaging.version import parse +from psycopg import AsyncCursor -from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor VERSION = parse("10") @@ -19,6 +22,14 @@ def test_version_cast(version: Any) -> None: assert janitor.version == VERSION +@pytest.mark.parametrize("version", (VERSION, 10, "10")) +@pytest.mark.asyncio +async def test_version_cast_async(version: Any) -> None: + """Async test that version is cast to Version object.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version) + assert janitor.version == VERSION + + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" @@ -27,6 +38,19 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234") +@pytest.mark.asyncio +async def test_cursor_selects_postgres_database_async() -> None: + """Async test that the cursor requests the postgres database.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10) + async with janitor.cursor(): + connect_mock.assert_called_once_with( + dbname="postgres", user="user", password=None, host="host", port="1234" + ) + + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" @@ -36,7 +60,7 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: port="1234", dbname="database_name", version=10, - password="some_password", + password="some_password", # noqa: S106 ) with janitor.cursor(): connect_mock.assert_called_once_with( @@ -44,6 +68,39 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: ) +@pytest.mark.asyncio +async def test_cursor_connects_with_password_async() -> None: + """Async test that the cursor requests the postgres database with password.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor( + user="user", + host="host", + port="1234", + dbname="database_name", + version=10, + password="some_password", # noqa: S106 + ) + async with janitor.cursor(): + connect_mock.assert_called_once_with( + dbname="postgres", user="user", password="some_password", host="host", port="1234" + ) + + +@pytest.mark.asyncio +async def test_cursor_custom_dbname_async() -> None: + """Test that a custom dbname is forwarded to the connection in AsyncDatabaseJanitor.cursor.""" + conn_mock = _make_async_conn_mock() + connect_mock = AsyncMock(return_value=conn_mock) + with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock): + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10) + async with janitor.cursor(dbname="custom_db"): + connect_mock.assert_called_once_with( + dbname="custom_db", user="user", password=None, host="host", port="1234" + ) + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") @pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database")) @patch("pytest_postgresql.janitor.psycopg.connect") @@ -57,9 +114,193 @@ def test_janitor_populate(connect_mock: MagicMock, load_database: str) -> None: "port": "1234", "user": "user", "dbname": "database_name", - "password": "some_password", + "password": "some_password", # noqa: S106 } janitor = DatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] janitor.load(load_database) assert connect_mock.called assert connect_mock.call_args.kwargs == call_kwargs + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") +@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database")) +@patch("tests.loader.psycopg.connect") +@pytest.mark.asyncio +async def test_janitor_populate_async(connect_mock: MagicMock, load_database: str) -> None: + """Async test that the cursor requests the postgres database and populates. + + load_database (synchronous) uses psycopg.connect, so we mock that. + """ + call_kwargs = { + "host": "host", + "port": "1234", + "user": "user", + "dbname": "database_name", + "password": "some_password", # noqa: S106 + } + janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] + await janitor.load(load_database) + assert connect_mock.called + assert connect_mock.call_args.kwargs == call_kwargs + + +# --------------------------------------------------------------------------- +# AsyncDatabaseJanitor -- init() / drop() / helper method tests +# --------------------------------------------------------------------------- + + +def _render_sql(obj: object) -> str: + """Render a psycopg.sql Composable to its SQL text form for test assertions.""" + if isinstance(obj, pgsql.Composed): + return "".join(_render_sql(part) for part in obj) + if isinstance(obj, pgsql.SQL): + return obj._obj # type: ignore[attr-defined] + if isinstance(obj, pgsql.Identifier): + parts: tuple[str, ...] = obj._obj # type: ignore[attr-defined] + return ".".join('"' + s.replace('"', '""') + '"' for s in parts) + return str(obj) + + +def _make_cursor_mock() -> MagicMock: + """Create a mock async cursor that records execute() calls.""" + cur = AsyncMock(spec=AsyncCursor) + return cur + + +def _make_cursor_context(cur: AsyncMock) -> Any: + """Return an async context manager that yields the given cursor mock.""" + + @asynccontextmanager + async def _ctx(dbname: str = "postgres") -> AsyncIterator[AsyncMock]: + yield cur + + return _ctx + + +@pytest.mark.asyncio +async def test_async_janitor_init_creates_database() -> None: + """init() executes CREATE DATABASE with the configured dbname.""" + cur = _make_cursor_mock() + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)): + await janitor.init() + + executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list) + assert 'CREATE DATABASE "mydb"' in executed_sql + + +@pytest.mark.asyncio +async def test_async_janitor_init_with_template() -> None: + """init() uses TEMPLATE clause when template_dbname is set.""" + cur = _make_cursor_mock() + janitor = AsyncDatabaseJanitor( + user="user", host="host", port="1234", dbname="mydb", template_dbname="tmpl", version=10 + ) + with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)): + await janitor.init() + + executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list) + assert 'CREATE DATABASE "mydb" TEMPLATE "tmpl"' in executed_sql + + +@pytest.mark.asyncio +async def test_async_janitor_init_as_template() -> None: + """init() appends IS_TEMPLATE = true when as_template is True.""" + cur = _make_cursor_mock() + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10) + with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)): + await janitor.init() + + executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list) + assert "IS_TEMPLATE = true" in executed_sql + + +@pytest.mark.asyncio +async def test_async_janitor_drop_drops_database() -> None: + """drop() executes DROP DATABASE IF EXISTS for the configured dbname.""" + cur = _make_cursor_mock() + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)): + await janitor.drop() + + executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list) + assert 'DROP DATABASE IF EXISTS "mydb"' in executed_sql + + +@pytest.mark.asyncio +async def test_async_janitor_drop_as_template() -> None: + """drop() resets is_template before dropping when as_template is True.""" + cur = _make_cursor_mock() + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10) + with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)): + await janitor.drop() + + executed_sql = [_render_sql(c.args[0]) for c in cur.execute.call_args_list] + assert any("is_template false" in s for s in executed_sql) + assert any('DROP DATABASE IF EXISTS "mydb"' in s for s in executed_sql) + # is_template false must come before DROP + template_idx = next(i for i, s in enumerate(executed_sql) if "is_template false" in s) + drop_idx = next(i for i, s in enumerate(executed_sql) if "DROP DATABASE" in s) + assert template_idx < drop_idx + + +def test_async_janitor_is_template_false() -> None: + """is_template() returns False when as_template is not set.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + assert janitor.is_template() is False + + +def test_async_janitor_is_template_true() -> None: + """is_template() returns True when as_template=True.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10) + assert janitor.is_template() is True + + +@pytest.mark.asyncio +async def test_async_janitor_context_manager_calls_init_and_drop() -> None: + """__aenter__ calls init() and __aexit__ calls drop().""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10) + init_mock = AsyncMock() + drop_mock = AsyncMock() + with patch.object(AsyncDatabaseJanitor, "init", init_mock), patch.object(AsyncDatabaseJanitor, "drop", drop_mock): + async with janitor: + init_mock.assert_called_once() + drop_mock.assert_not_called() + drop_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_janitor_terminate_connection_sql() -> None: + """_terminate_connection() executes pg_terminate_backend query with correct dbname.""" + cur = AsyncMock(spec=AsyncCursor) + await AsyncDatabaseJanitor._terminate_connection(cur, "target_db") + + cur.execute.assert_called_once() + sql_str, params = cur.execute.call_args.args + assert "pg_terminate_backend" in sql_str + assert params == ("target_db",) + + +@pytest.mark.asyncio +async def test_async_janitor_dont_datallowconn_sql() -> None: + """_dont_datallowconn() executes ALTER DATABASE allow_connections false for the dbname.""" + cur = AsyncMock(spec=AsyncCursor) + await AsyncDatabaseJanitor._dont_datallowconn(cur, "target_db") + + cur.execute.assert_called_once() + sql_str = _render_sql(cur.execute.call_args.args[0]) + assert "allow_connections false" in sql_str + assert '"target_db"' in sql_str + + +def _make_async_conn_mock() -> MagicMock: + """Create a MagicMock that behaves like a psycopg3 AsyncConnection.""" + conn = MagicMock() + conn.close = AsyncMock() + conn.set_isolation_level = AsyncMock() + conn.set_autocommit = AsyncMock() + cursor_mock = MagicMock() + cursor_mock.__aenter__ = AsyncMock(return_value=MagicMock()) + cursor_mock.__aexit__ = AsyncMock(return_value=False) + conn.cursor = MagicMock(return_value=cursor_mock) + return conn diff --git a/tests/test_loader.py b/tests/test_loader.py index c03f8a55..5d69cdbb 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,8 +1,11 @@ """Tests for the `build_loader` function.""" from pathlib import Path +from unittest.mock import patch -from pytest_postgresql.loader import build_loader, sql +import pytest + +from pytest_postgresql.loader import build_loader, build_loader_async, sql, sql_async from tests.loader import load_database @@ -12,9 +15,49 @@ def test_loader_callables() -> None: assert load_database == build_loader("tests.loader:load_database") +def test_loader_callables_dot_separator() -> None: + """Test dot-separated import path resolves the same callable as colon-separated.""" + assert build_loader("tests.loader.load_database") == load_database + + +@pytest.mark.asyncio +async def test_loader_callables_async() -> None: + """Async test handling callables in build_loader_async.""" + assert load_database == build_loader_async(load_database) + assert load_database == build_loader_async("tests.loader:load_database") + + async def afun(*_args: object, **_kwargs: object) -> int: + return 0 + + assert afun == build_loader_async(afun) + + +@pytest.mark.asyncio +async def test_loader_callables_async_dot_separator() -> None: + """Dot-separated import path is resolved identically by build_loader_async.""" + assert build_loader_async("tests.loader.load_database") == load_database + + def test_loader_sql() -> None: """Test returning partial running sql for the sql file path.""" sql_path = Path("test_sql/eidastats.sql") loader_func = build_loader(sql_path) assert loader_func.args == (sql_path,) # type: ignore assert loader_func.func == sql # type: ignore + + +@pytest.mark.asyncio +async def test_loader_sql_async() -> None: + """Async test returning partial running sql_async for the sql file path.""" + sql_path = Path("test_sql/eidastats.sql") + loader_func = build_loader_async(sql_path) + assert loader_func.args == (sql_path,) # type: ignore + assert loader_func.func == sql_async # type: ignore + + +@pytest.mark.asyncio +async def test_sql_async_raises_without_aiofiles() -> None: + """sql_async raises ImportError with a helpful message when aiofiles is not installed.""" + with patch("pytest_postgresql.loader.aiofiles", None): + with pytest.raises(ImportError, match="aiofiles"): + await sql_async(Path("dummy.sql"), host="h", port=5432, user="u", dbname="d") diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 1b86beaf..1461694d 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,11 +1,14 @@ """All tests for pytest-postgresql.""" +import decimal + import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection from psycopg.pq import ConnStatus from pytest_postgresql.executor import PostgreSQLExecutor -from pytest_postgresql.retry import retry +from pytest_postgresql.retry import retry, retry_async +from tests.conftest import POSTGRESQL_VERSION MAKE_Q = "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);" SELECT_Q = "SELECT * FROM test_load;" @@ -66,3 +69,60 @@ def check_if_one_connection() -> None: assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) + + +@pytest.mark.asyncio +async def test_main_postgres_async(postgresql_async: AsyncConnection) -> None: + """Async check main postgresql fixture.""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + +@pytest.mark.asyncio +async def test_two_postgreses_async(postgresql_async: AsyncConnection, postgresql2_async: AsyncConnection) -> None: + """Async check two postgresql fixtures on one test.""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + async with postgresql2_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql2_async.commit() + + +@pytest.mark.asyncio +async def test_postgres_load_two_files_async(postgresql_load_1_async: AsyncConnection) -> None: + """Async check postgresql fixture can load two files.""" + async with postgresql_load_1_async.cursor() as cur: + await cur.execute(SELECT_Q) + results = await cur.fetchall() + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> None: + """Async check if postgres fixture can be started on random port.""" + assert postgresql2_async.info.status == ConnStatus.OK + + +@pytest.mark.skipif( + decimal.Decimal(POSTGRESQL_VERSION) < 10, + reason="Test query not supported in those postgresql versions, and soon will not be supported.", +) +@pytest.mark.xdist_group(name="terminate_connection") +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(2)) +async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None: + """Async test that connections are terminated between tests. + + And check that only one exists at a time. + """ + async with postgresql2_async.cursor() as cur: + + async def check_if_one_connection() -> None: + await cur.execute("SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend';") + existing_connections = await cur.fetchall() + assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" + + await retry_async(check_if_one_connection, timeout=120, possible_exception=AssertionError) diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 00000000..8581331d --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,75 @@ +"""Unit tests for retry and retry_async.""" + +import datetime +from unittest.mock import AsyncMock, patch + +import pytest + +from pytest_postgresql.retry import retry_async + + +@pytest.mark.asyncio +async def test_retry_async_immediate_success() -> None: + """Test that retry_async returns immediately when function succeeds on first call.""" + + async def ok() -> int: + return 42 + + assert await retry_async(ok, timeout=5) == 42 + + +@pytest.mark.asyncio +async def test_retry_async_succeeds_after_failures() -> None: + """Test that retry_async retries on the expected exception and returns on success.""" + attempts = 0 + + async def flaky() -> str: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise ConnectionError("transient") + return "ok" + + sleep_mock = AsyncMock() + with patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock): + result = await retry_async(flaky, timeout=10, possible_exception=ConnectionError) + + assert result == "ok" + assert attempts == 3 + assert sleep_mock.call_count == 2 + + +@pytest.mark.asyncio +async def test_retry_async_timeout() -> None: + """Test that retry_async raises TimeoutError after the timeout elapses.""" + + async def always_fail() -> None: + raise ValueError("boom") + + sleep_mock = AsyncMock() + base = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) + call_count = 0 + + def advancing_clock() -> datetime.datetime: + nonlocal call_count + call_count += 1 + # First call captures starting time; all subsequent calls report past the timeout. + return base if call_count == 1 else base + datetime.timedelta(seconds=10) + + with ( + patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock), + patch("pytest_postgresql.retry.get_current_datetime", advancing_clock), + ): + with pytest.raises(TimeoutError, match="Failed after"): + await retry_async(always_fail, timeout=1, possible_exception=ValueError) + + +@pytest.mark.asyncio +async def test_retry_async_unmatched_exception_propagates() -> None: + """Test that an exception not matching possible_exception propagates immediately.""" + + async def wrong_exc() -> None: + raise TypeError("unexpected") + + with pytest.raises(TypeError, match="unexpected"): + await retry_async(wrong_exc, timeout=5, possible_exception=ValueError) diff --git a/tests/test_template_database.py b/tests/test_template_database.py index 64631779..fc64442e 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -1,9 +1,9 @@ """Template database tests.""" import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection -from pytest_postgresql.factories import postgresql, postgresql_proc +from pytest_postgresql.factories import postgresql, postgresql_async, postgresql_proc from tests.loader import load_database postgresql_proc_with_template = postgresql_proc( @@ -17,6 +17,11 @@ dbname="stories_templated", ) +async_postgresql_template = postgresql_async( + "postgresql_proc_with_template", + dbname="stories_templated", +) + @pytest.mark.xdist_group(name="template_database") @pytest.mark.parametrize("_", range(5)) @@ -30,3 +35,18 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + + +@pytest.mark.xdist_group(name="template_database_async") +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(5)) +async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None: + """Async check that the database structure gets recreated out of a template.""" + async with async_postgresql_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + res = await cur.fetchall() + assert len(res) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + res = await cur.fetchall() + assert len(res) == 0