diff --git a/newsfragments/1235.feature.rst b/newsfragments/1235.feature.rst
new file mode 100644
index 00000000..6c047099
--- /dev/null
+++ b/newsfragments/1235.feature.rst
@@ -0,0 +1,3 @@
+Added async PostgreSQL fixture support via ``postgresql_async`` factory and ``AsyncDatabaseJanitor``.
+Added configurable fixture ``scope`` parameter to ``postgresql``, ``postgresql_async``, ``postgresql_proc``, and ``postgresql_noproc`` factories (defaults preserved: ``"function"`` for client fixtures, ``"session"`` for process fixtures).
+Added optional ``async`` extra (``pip install pytest-postgresql[async]``) providing ``pytest-asyncio`` and ``aiofiles`` dependencies.
diff --git a/pyproject.toml b/pyproject.toml
index a9cb59a1..06e162b0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,12 @@ dependencies = [
]
requires-python = ">= 3.10"
+[project.optional-dependencies]
+async = [
+ "pytest-asyncio >= 0.21",
+ "aiofiles >= 23.0"
+]
+
[project.urls]
"Source" = "https://github.com/dbfixtures/pytest-postgresql"
"Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues"
diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py
index d6bd2f64..002304cb 100644
--- a/pytest_postgresql/factories/__init__.py
+++ b/pytest_postgresql/factories/__init__.py
@@ -17,8 +17,8 @@
# along with pytest-postgresql. If not, see .
"""Fixture factories for postgresql fixtures."""
-from pytest_postgresql.factories.client import postgresql
+from pytest_postgresql.factories.client import postgresql, postgresql_async
from pytest_postgresql.factories.noprocess import postgresql_noproc
from pytest_postgresql.factories.process import PortType, postgresql_proc
-__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType")
+__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType")
diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py
index 5fb5a5be..26a701d1 100644
--- a/pytest_postgresql/factories/client.py
+++ b/pytest_postgresql/factories/client.py
@@ -17,23 +17,30 @@
# along with pytest-postgresql. If not, see .
"""Fixture factory for postgresql client."""
-from typing import Callable, Iterator
+from typing import AsyncIterator, Callable, Iterator
import psycopg
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
from pytest import FixtureRequest
+try:
+ import pytest_asyncio
+except ImportError:
+ pytest_asyncio = None # type: ignore[assignment]
+
from pytest_postgresql.config import get_config
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.executor_noop import NoopExecutor
-from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
def postgresql(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
+ scope: FixtureScopeT = "function",
) -> Callable[[FixtureRequest], Iterator[Connection]]:
"""Return connection fixture factory for PostgreSQL.
@@ -41,12 +48,13 @@ def postgresql(
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
+ :param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""
- @pytest.fixture
+ @pytest.fixture(scope=scope)
def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
- """Fixture factory for PostgreSQL.
+ """Fixture connection factory for PostgreSQL.
:param request: fixture request object
:returns: postgresql client
@@ -85,3 +93,72 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
db_connection.close()
return postgresql_factory
+
+
+def postgresql_async(
+ process_fixture_name: str,
+ dbname: str | None = None,
+ isolation_level: "psycopg.IsolationLevel | None" = None,
+ scope: FixtureScopeT = "function",
+) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]:
+ """Return async connection fixture factory for PostgreSQL.
+
+ :param process_fixture_name: name of the process fixture
+ :param dbname: database name
+ :param isolation_level: optional postgresql isolation level
+ defaults to server's default
+ :param scope: fixture scope; by default "function" which is recommended.
+ :returns: function which makes an async connection to postgresql
+ """
+ if pytest_asyncio is None:
+
+ @pytest.fixture(scope=scope)
+ def postgresql_async_factory(request: FixtureRequest) -> None:
+ """Sync stub that raises ImportError when pytest-asyncio is absent."""
+ raise ImportError(
+ "pytest-asyncio is required for async fixtures. Install it with: pip install pytest-postgresql[async]"
+ )
+
+ return postgresql_async_factory # type: ignore[return-value]
+
+ @pytest_asyncio.fixture(scope=scope, loop_scope=scope)
+ async def postgresql_async_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]:
+ """Async connection fixture factory for PostgreSQL.
+
+ :param request: fixture request object
+ :returns: postgresql async client
+ """
+ proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name)
+ config = get_config(request)
+
+ pg_host = proc_fixture.host
+ pg_port = proc_fixture.port
+ pg_user = proc_fixture.user
+ pg_password = proc_fixture.password
+ pg_options = proc_fixture.options
+ pg_db = dbname or proc_fixture.dbname
+ janitor = AsyncDatabaseJanitor(
+ user=pg_user,
+ host=pg_host,
+ port=pg_port,
+ dbname=pg_db,
+ template_dbname=proc_fixture.template_dbname,
+ version=proc_fixture.version,
+ password=pg_password,
+ isolation_level=isolation_level,
+ )
+ if config.drop_test_database:
+ await janitor.drop()
+ async with janitor:
+ db_connection: AsyncConnection = await AsyncConnection.connect(
+ dbname=pg_db,
+ user=pg_user,
+ password=pg_password,
+ host=pg_host,
+ port=pg_port,
+ options=pg_options,
+ )
+ yield db_connection
+ await db_connection.close()
+
+ return postgresql_async_factory
diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py
index 2d7f8b49..8af27c37 100644
--- a/pytest_postgresql/factories/noprocess.py
+++ b/pytest_postgresql/factories/noprocess.py
@@ -27,6 +27,7 @@
from pytest_postgresql.config import get_config
from pytest_postgresql.executor_noop import NoopExecutor
from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
def xdistify_dbname(dbname: str) -> str:
@@ -46,6 +47,7 @@ def postgresql_noproc(
options: str = "",
load: list[Callable | str | Path] | None = None,
depends_on: str | None = None,
+ scope: FixtureScopeT = "session",
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
"""Postgresql noprocess factory.
@@ -57,10 +59,11 @@ def postgresql_noproc(
:param options: Postgresql connection options
:param load: List of functions used to initialize database's template.
:param depends_on: Optional name of the fixture to depend on.
+ :param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""
- @pytest.fixture(scope="session")
+ @pytest.fixture(scope=scope)
def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]:
"""Noop Process fixture for PostgreSQL.
diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py
index 27fab57f..cbfc0111 100644
--- a/pytest_postgresql/factories/process.py
+++ b/pytest_postgresql/factories/process.py
@@ -32,6 +32,7 @@
from pytest_postgresql.exceptions import ExecutableMissingException
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
PortType = port_for.PortType # mypy requires explicit export
@@ -81,6 +82,7 @@ def postgresql_proc(
unixsocketdir: str | None = None,
postgres_options: str | None = None,
load: list[Callable | str | Path] | None = None,
+ scope: FixtureScopeT = "session",
) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
"""Postgresql process factory.
@@ -101,10 +103,11 @@ def postgresql_proc(
:param unixsocketdir: directory to create postgresql's unixsockets
:param postgres_options: Postgres executable options for use by pg_ctl
:param load: List of functions used to initialize database's template.
+ :param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""
- @pytest.fixture(scope="session")
+ @pytest.fixture(scope=scope)
def postgresql_proc_fixture(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Iterator[PostgreSQLExecutor]:
diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py
index f602372e..146f4dd0 100644
--- a/pytest_postgresql/janitor.py
+++ b/pytest_postgresql/janitor.py
@@ -1,21 +1,24 @@
"""Database Janitor."""
-from contextlib import contextmanager
+import inspect
+from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from types import TracebackType
-from typing import Callable, Iterator, Type, TypeVar
+from typing import AsyncIterator, Callable, Iterator, Type, TypeVar
import psycopg
+import psycopg.sql as sql
from packaging.version import parse
-from psycopg import Connection, Cursor
+from psycopg import AsyncCursor, Connection, Cursor
-from pytest_postgresql.loader import build_loader
-from pytest_postgresql.retry import retry
+from pytest_postgresql.loader import build_loader, build_loader_async
+from pytest_postgresql.retry import retry, retry_async
Version = type(parse("1"))
DatabaseJanitorType = TypeVar("DatabaseJanitorType", bound="DatabaseJanitor")
+AsyncDatabaseJanitorType = TypeVar("AsyncDatabaseJanitorType", bound="AsyncDatabaseJanitor")
class DatabaseJanitor:
@@ -67,18 +70,17 @@ def __init__(
def init(self) -> None:
"""Create database in postgresql."""
with self.cursor() as cur:
+ query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname))
if self.template_dbname:
# And make sure no-one is left connected to the template database.
# Otherwise, Creating database from template will fail
self._terminate_connection(cur, self.template_dbname)
- query = f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}"'
- else:
- query = f'CREATE DATABASE "{self.dbname}"'
+ query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname))
if self.as_template:
- query += " IS_TEMPLATE = true"
+ query = query + sql.SQL(" IS_TEMPLATE = true")
- cur.execute(f"{query};")
+ cur.execute(query)
def is_template(self) -> bool:
"""Determine whether the DatabaseJanitor maintains template or database."""
@@ -92,17 +94,17 @@ def drop(self) -> None:
self._dont_datallowconn(cur, self.dbname)
self._terminate_connection(cur, self.dbname)
if self.as_template:
- cur.execute(f'ALTER DATABASE "{self.dbname}" with is_template false;')
- cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";')
+ cur.execute(sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname)))
+ cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname)))
@staticmethod
def _dont_datallowconn(cur: Cursor, dbname: str) -> None:
- cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;')
+ cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname)))
@staticmethod
def _terminate_connection(cur: Cursor, dbname: str) -> None:
cur.execute(
- "SELECT pg_terminate_backend(pg_stat_activity.pid)"
+ "SELECT pg_terminate_backend(pg_stat_activity.pid) "
"FROM pg_stat_activity "
"WHERE pg_stat_activity.datname = %s;",
(dbname,),
@@ -164,3 +166,153 @@ def __exit__(
) -> None:
"""Exit from Database janitor context cleaning after itself."""
self.drop()
+
+
+class AsyncDatabaseJanitor:
+ """Manage database state asynchronously for specific tasks."""
+
+ def __init__(
+ self,
+ *,
+ user: str,
+ host: str,
+ port: str | int,
+ version: str | float | Version, # type: ignore[valid-type]
+ dbname: str,
+ template_dbname: str | None = None,
+ as_template: bool = False,
+ password: str | None = None,
+ isolation_level: "psycopg.IsolationLevel | None" = None,
+ connection_timeout: int = 60,
+ ) -> None:
+ """Initialize async janitor.
+
+ :param user: postgresql username
+ :param host: postgresql host
+ :param port: postgresql port
+ :param dbname: database name
+ :param template_dbname: template database name to clone from
+ :param as_template: whether to mark the database as a template
+ :param version: postgresql version number
+ :param password: optional postgresql password
+ :param isolation_level: optional postgresql isolation level
+ defaults to server's default
+ :param connection_timeout: how long to retry connection before
+ raising a TimeoutError
+ """
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.dbname = dbname
+ self.template_dbname = template_dbname
+ self.as_template = as_template
+ self._connection_timeout = connection_timeout
+ self.isolation_level = isolation_level
+ if not isinstance(version, Version):
+ self.version = parse(str(version))
+ else:
+ self.version = version
+
+ async def init(self) -> None:
+ """Create database in postgresql."""
+ async with self.cursor() as cur:
+ query = sql.SQL("CREATE DATABASE {}").format(sql.Identifier(self.dbname))
+ if self.template_dbname:
+ # And make sure no-one is left connected to the template database.
+ # Otherwise, Creating database from template will fail
+ await self._terminate_connection(cur, self.template_dbname)
+ query = query + sql.SQL(" TEMPLATE {}").format(sql.Identifier(self.template_dbname))
+
+ if self.as_template:
+ query = query + sql.SQL(" IS_TEMPLATE = true")
+
+ await cur.execute(query)
+
+ def is_template(self) -> bool:
+ """Determine whether the AsyncDatabaseJanitor maintains template or database."""
+ return self.as_template
+
+ async def drop(self) -> None:
+ """Drop database in postgresql."""
+ # We cannot drop the database while there are connections to it, so we
+ # terminate all connections first while not allowing new connections.
+ async with self.cursor() as cur:
+ await self._dont_datallowconn(cur, self.dbname)
+ await self._terminate_connection(cur, self.dbname)
+ if self.as_template:
+ await cur.execute(
+ sql.SQL("ALTER DATABASE {} WITH is_template false").format(sql.Identifier(self.dbname))
+ )
+ await cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(self.dbname)))
+
+ @staticmethod
+ async def _dont_datallowconn(cur: AsyncCursor, dbname: str) -> None: # type: ignore[type-arg]
+ await cur.execute(sql.SQL("ALTER DATABASE {} WITH allow_connections false").format(sql.Identifier(dbname)))
+
+ @staticmethod
+ async def _terminate_connection(cur: AsyncCursor, dbname: str) -> None: # type: ignore[type-arg]
+ await cur.execute(
+ "SELECT pg_terminate_backend(pg_stat_activity.pid) "
+ "FROM pg_stat_activity "
+ "WHERE pg_stat_activity.datname = %s;",
+ (dbname,),
+ )
+
+ async def load(self, load: Callable | str | Path) -> None:
+ """Load data into a database.
+
+ Expects:
+
+ * a Path to sql file, that'll be loaded
+ * an import path to import callable
+ * a callable that expects: host, port, user, dbname and password arguments.
+
+ """
+ _loader = build_loader_async(load)
+ result = _loader(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ dbname=self.dbname,
+ password=self.password,
+ )
+ if inspect.isawaitable(result):
+ await result
+
+ @asynccontextmanager
+ async def cursor(self, dbname: str = "postgres") -> AsyncIterator[AsyncCursor]: # type: ignore[type-arg]
+ """Return postgresql async cursor."""
+
+ async def connect() -> psycopg.AsyncConnection:
+ return await psycopg.AsyncConnection.connect(
+ dbname=dbname,
+ user=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ )
+
+ conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError)
+ try:
+ await conn.set_isolation_level(self.isolation_level)
+ await conn.set_autocommit(True)
+ # We must not run a transaction since we create a database.
+ async with conn.cursor() as cur:
+ yield cur
+ finally:
+ await conn.close()
+
+ async def __aenter__(self: AsyncDatabaseJanitorType) -> AsyncDatabaseJanitorType:
+ """Initialize Async Database Janitor."""
+ await self.init()
+ return self
+
+ async def __aexit__(
+ self: AsyncDatabaseJanitorType,
+ exc_type: Type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ """Exit from Async Database Janitor context cleaning after itself."""
+ await self.drop()
diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py
index c9b28cbd..63f025ba 100644
--- a/pytest_postgresql/loader.py
+++ b/pytest_postgresql/loader.py
@@ -1,5 +1,6 @@
"""Loader helper functions."""
+import importlib
import re
from functools import partial
from pathlib import Path
@@ -7,6 +8,11 @@
import psycopg
+try:
+ import aiofiles
+except ImportError:
+ aiofiles = None # type: ignore[assignment]
+
def build_loader(load: Callable | str | Path) -> Callable:
"""Build a loader callable."""
@@ -16,7 +22,7 @@ def build_loader(load: Callable | str | Path) -> Callable:
loader_parts = re.split("[.:]", load, maxsplit=2)
import_path = ".".join(loader_parts[:-1])
loader_name = loader_parts[-1]
- _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
+ _temp_import = importlib.import_module(import_path)
_loader: Callable = getattr(_temp_import, loader_name)
return _loader
else:
@@ -30,3 +36,35 @@ def sql(sql_filename: Path, **kwargs: Any) -> None:
with db_connection.cursor() as cur:
cur.execute(_fd.read())
db_connection.commit()
+
+
+def build_loader_async(load: Callable | str | Path) -> Callable:
+ """Build an async loader callable."""
+ if isinstance(load, Path):
+ return partial(sql_async, load)
+ elif isinstance(load, str):
+ loader_parts = re.split("[.:]", load, maxsplit=2)
+ import_path = ".".join(loader_parts[:-1])
+ loader_name = loader_parts[-1]
+ _temp_import = importlib.import_module(import_path)
+ _loader: Callable = getattr(_temp_import, loader_name)
+ return _loader
+ else:
+ return load
+
+
+async def sql_async(sql_filename: Path, **kwargs: Any) -> None:
+ """Async database loader for sql files.
+
+ Requires the optional ``async`` extra: ``pip install pytest-postgresql[async]``.
+ """
+ if aiofiles is None:
+ raise ImportError(
+ "aiofiles is required for async SQL loading. Install it with: pip install pytest-postgresql[async]"
+ )
+
+ async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection:
+ async with db_connection.cursor() as cur:
+ async with aiofiles.open(sql_filename, "r") as _fd:
+ await cur.execute(await _fd.read())
+ await db_connection.commit()
diff --git a/pytest_postgresql/plugin.py b/pytest_postgresql/plugin.py
index 612e408a..5fa7b58c 100644
--- a/pytest_postgresql/plugin.py
+++ b/pytest_postgresql/plugin.py
@@ -135,3 +135,4 @@ def pytest_addoption(parser: Parser) -> None:
postgresql_proc = factories.postgresql_proc()
postgresql_noproc = factories.postgresql_noproc()
postgresql = factories.postgresql("postgresql_proc")
+postgresql_async = factories.postgresql_async("postgresql_proc")
diff --git a/pytest_postgresql/retry.py b/pytest_postgresql/retry.py
index ea25fa2e..078db5bc 100644
--- a/pytest_postgresql/retry.py
+++ b/pytest_postgresql/retry.py
@@ -1,9 +1,10 @@
"""Small retry callable in case of specific error occurred."""
+import asyncio
import datetime
import sys
from time import sleep
-from typing import Callable, Type, TypeVar
+from typing import Awaitable, Callable, Type, TypeVar
T = TypeVar("T")
@@ -29,11 +30,41 @@ def retry(
i += 1
try:
res = func()
- return res
except possible_exception as e:
if time + timeout_diff < get_current_datetime():
raise TimeoutError(f"Failed after {i} attempts") from e
sleep(1)
+ else:
+ return res
+
+
+async def retry_async(
+ func: Callable[[], Awaitable[T]],
+ timeout: int = 60,
+ possible_exception: Type[Exception] = Exception,
+) -> T:
+ """Attempt to retry the async function for timeout time.
+
+ Most often used for connecting to postgresql database as,
+ especially on macos on github-actions, first few tries fails
+ with this message:
+
+ ... ::
+ FATAL: the database system is starting up
+ """
+ time: datetime.datetime = get_current_datetime()
+ timeout_diff: datetime.timedelta = datetime.timedelta(seconds=timeout)
+ i = 0
+ while True:
+ i += 1
+ try:
+ res = await func()
+ except possible_exception as e:
+ if time + timeout_diff < get_current_datetime():
+ raise TimeoutError(f"Failed after {i} attempts") from e
+ await asyncio.sleep(1)
+ else:
+ return res
def get_current_datetime() -> datetime.datetime:
diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py
new file mode 100644
index 00000000..e5f35043
--- /dev/null
+++ b/pytest_postgresql/types.py
@@ -0,0 +1,5 @@
+"""Pytest PostgreSQL types."""
+
+from typing import Literal
+
+FixtureScopeT = Literal["session", "package", "module", "class", "function"]
diff --git a/tests/conftest.py b/tests/conftest.py
index 784b8905..483437af 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,3 +17,5 @@
postgresql_proc2 = factories.postgresql_proc(port=None, load=[TEST_SQL_FILE, TEST_SQL_FILE2])
postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db")
postgresql_load_1 = factories.postgresql("postgresql_proc2")
+postgresql2_async = factories.postgresql_async("postgresql_proc2", dbname="test-db")
+postgresql_load_1_async = factories.postgresql_async("postgresql_proc2")
diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py
index ae25307a..1d0fbb73 100644
--- a/tests/docker/test_noproc_docker.py
+++ b/tests/docker/test_noproc_docker.py
@@ -3,7 +3,7 @@
import pathlib
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
import pytest_postgresql.factories.client
import pytest_postgresql.factories.noprocess
@@ -14,12 +14,17 @@
)
postgres_with_schema = pytest_postgresql.factories.client.postgresql("postgresql_my_proc")
+async_postgres_with_schema = pytest_postgresql.factories.client.postgresql_async("postgresql_my_proc")
+
postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc(
dbname="stories_templated", load=[load_database]
)
postgres_with_template = pytest_postgresql.factories.client.postgresql(
"postgresql_my_proc_template", dbname="stories_templated"
)
+async_postgres_with_template = pytest_postgresql.factories.client.postgresql_async(
+ "postgresql_my_proc_template", dbname="stories_templated"
+)
def test_postgres_docker_load(postgres_with_schema: Connection) -> None:
@@ -32,6 +37,14 @@ def test_postgres_docker_load(postgres_with_schema: Connection) -> None:
print(cur.fetchall())
+@pytest.mark.asyncio
+async def test_postgres_docker_load_async(async_postgres_with_schema: AsyncConnection) -> None:
+ """Async check main postgres fixture."""
+ async with async_postgres_with_schema.cursor() as cur:
+ await cur.execute("select * from public.tokens")
+ print(await cur.fetchall())
+
+
@pytest.mark.parametrize("_", range(5))
def test_template_database(postgres_with_template: Connection, _: int) -> None:
"""Check that the database structure gets recreated out of a template."""
@@ -43,3 +56,17 @@ def test_template_database(postgres_with_template: Connection, _: int) -> None:
cur.execute("SELECT * FROM stories")
res = cur.fetchall()
assert len(res) == 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("_", range(5))
+async def test_template_database_async(async_postgres_with_template: AsyncConnection, _: int) -> None:
+ """Async check that the database structure gets recreated out of a template."""
+ async with async_postgres_with_template.cursor() as cur:
+ await cur.execute("SELECT * FROM stories")
+ rows = await cur.fetchall()
+ assert len(rows) == 4
+ await cur.execute("TRUNCATE stories")
+ await cur.execute("SELECT * FROM stories")
+ rows = await cur.fetchall()
+ assert len(rows) == 0
diff --git a/tests/test_factory_errors.py b/tests/test_factory_errors.py
new file mode 100644
index 00000000..18216693
--- /dev/null
+++ b/tests/test_factory_errors.py
@@ -0,0 +1,33 @@
+"""Tests for factory error paths (missing optional dependencies)."""
+
+from unittest.mock import patch
+
+import pytest
+
+from pytest_postgresql.factories.client import postgresql_async
+
+
+def test_postgresql_async_factory_creation_succeeds_without_pytest_asyncio() -> None:
+ """postgresql_async() must not raise at factory-creation time when pytest-asyncio is absent.
+
+ The plugin registers ``postgresql_async`` at load time (plugin.py), so raising here
+ would break all users — including those who only use synchronous fixtures.
+ """
+ with patch("pytest_postgresql.factories.client.pytest_asyncio", None):
+ fixture_func = postgresql_async("some_proc_fixture")
+ assert callable(fixture_func)
+
+
+def test_postgresql_async_raises_on_use_without_pytest_asyncio() -> None:
+ """When pytest-asyncio is absent, the registered stub is synchronous and raises ImportError.
+
+ A synchronous stub avoids the "coroutine was never awaited" warning that would
+ result from registering an async def with plain pytest.fixture.
+ """
+ with patch("pytest_postgresql.factories.client.pytest_asyncio", None):
+ fixture_func = postgresql_async("some_proc_fixture")
+ # pytest 8+ wraps fixtures to prevent direct calls; unwrap first.
+ raw_func = getattr(fixture_func, "__wrapped__", fixture_func)
+ assert not hasattr(raw_func, "__await__"), "stub must be a sync function, not a coroutine"
+ with pytest.raises(ImportError, match="pytest-asyncio"):
+ raw_func(None) # type: ignore[arg-type]
diff --git a/tests/test_janitor.py b/tests/test_janitor.py
index fd1fca2a..9390d9e7 100644
--- a/tests/test_janitor.py
+++ b/tests/test_janitor.py
@@ -1,13 +1,16 @@
"""Database Janitor tests."""
import sys
-from typing import Any
-from unittest.mock import MagicMock, patch
+from contextlib import asynccontextmanager
+from typing import Any, AsyncIterator
+from unittest.mock import AsyncMock, MagicMock, patch
+import psycopg.sql as pgsql
import pytest
from packaging.version import parse
+from psycopg import AsyncCursor
-from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor
VERSION = parse("10")
@@ -19,6 +22,14 @@ def test_version_cast(version: Any) -> None:
assert janitor.version == VERSION
+@pytest.mark.parametrize("version", (VERSION, 10, "10"))
+@pytest.mark.asyncio
+async def test_version_cast_async(version: Any) -> None:
+ """Async test that version is cast to Version object."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version)
+ assert janitor.version == VERSION
+
+
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
@@ -27,6 +38,19 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234")
+@pytest.mark.asyncio
+async def test_cursor_selects_postgres_database_async() -> None:
+ """Async test that the cursor requests the postgres database."""
+ conn_mock = _make_async_conn_mock()
+ connect_mock = AsyncMock(return_value=conn_mock)
+ with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock):
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10)
+ async with janitor.cursor():
+ connect_mock.assert_called_once_with(
+ dbname="postgres", user="user", password=None, host="host", port="1234"
+ )
+
+
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
@@ -36,7 +60,7 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
port="1234",
dbname="database_name",
version=10,
- password="some_password",
+ password="some_password", # noqa: S106
)
with janitor.cursor():
connect_mock.assert_called_once_with(
@@ -44,6 +68,39 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
)
+@pytest.mark.asyncio
+async def test_cursor_connects_with_password_async() -> None:
+ """Async test that the cursor requests the postgres database with password."""
+ conn_mock = _make_async_conn_mock()
+ connect_mock = AsyncMock(return_value=conn_mock)
+ with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock):
+ janitor = AsyncDatabaseJanitor(
+ user="user",
+ host="host",
+ port="1234",
+ dbname="database_name",
+ version=10,
+ password="some_password", # noqa: S106
+ )
+ async with janitor.cursor():
+ connect_mock.assert_called_once_with(
+ dbname="postgres", user="user", password="some_password", host="host", port="1234"
+ )
+
+
+@pytest.mark.asyncio
+async def test_cursor_custom_dbname_async() -> None:
+ """Test that a custom dbname is forwarded to the connection in AsyncDatabaseJanitor.cursor."""
+ conn_mock = _make_async_conn_mock()
+ connect_mock = AsyncMock(return_value=conn_mock)
+ with patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect", connect_mock):
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10)
+ async with janitor.cursor(dbname="custom_db"):
+ connect_mock.assert_called_once_with(
+ dbname="custom_db", user="user", password=None, host="host", port="1234"
+ )
+
+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8")
@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database"))
@patch("pytest_postgresql.janitor.psycopg.connect")
@@ -57,9 +114,193 @@ def test_janitor_populate(connect_mock: MagicMock, load_database: str) -> None:
"port": "1234",
"user": "user",
"dbname": "database_name",
- "password": "some_password",
+ "password": "some_password", # noqa: S106
}
janitor = DatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type]
janitor.load(load_database)
assert connect_mock.called
assert connect_mock.call_args.kwargs == call_kwargs
+
+
+@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8")
+@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database"))
+@patch("tests.loader.psycopg.connect")
+@pytest.mark.asyncio
+async def test_janitor_populate_async(connect_mock: MagicMock, load_database: str) -> None:
+ """Async test that the cursor requests the postgres database and populates.
+
+ load_database (synchronous) uses psycopg.connect, so we mock that.
+ """
+ call_kwargs = {
+ "host": "host",
+ "port": "1234",
+ "user": "user",
+ "dbname": "database_name",
+ "password": "some_password", # noqa: S106
+ }
+ janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type]
+ await janitor.load(load_database)
+ assert connect_mock.called
+ assert connect_mock.call_args.kwargs == call_kwargs
+
+
+# ---------------------------------------------------------------------------
+# AsyncDatabaseJanitor -- init() / drop() / helper method tests
+# ---------------------------------------------------------------------------
+
+
+def _render_sql(obj: object) -> str:
+ """Render a psycopg.sql Composable to its SQL text form for test assertions."""
+ if isinstance(obj, pgsql.Composed):
+ return "".join(_render_sql(part) for part in obj)
+ if isinstance(obj, pgsql.SQL):
+ return obj._obj # type: ignore[attr-defined]
+ if isinstance(obj, pgsql.Identifier):
+ parts: tuple[str, ...] = obj._obj # type: ignore[attr-defined]
+ return ".".join('"' + s.replace('"', '""') + '"' for s in parts)
+ return str(obj)
+
+
+def _make_cursor_mock() -> MagicMock:
+ """Create a mock async cursor that records execute() calls."""
+ cur = AsyncMock(spec=AsyncCursor)
+ return cur
+
+
+def _make_cursor_context(cur: AsyncMock) -> Any:
+ """Return an async context manager that yields the given cursor mock."""
+
+ @asynccontextmanager
+ async def _ctx(dbname: str = "postgres") -> AsyncIterator[AsyncMock]:
+ yield cur
+
+ return _ctx
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_init_creates_database() -> None:
+ """init() executes CREATE DATABASE with the configured dbname."""
+ cur = _make_cursor_mock()
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10)
+ with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)):
+ await janitor.init()
+
+ executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list)
+ assert 'CREATE DATABASE "mydb"' in executed_sql
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_init_with_template() -> None:
+ """init() uses TEMPLATE clause when template_dbname is set."""
+ cur = _make_cursor_mock()
+ janitor = AsyncDatabaseJanitor(
+ user="user", host="host", port="1234", dbname="mydb", template_dbname="tmpl", version=10
+ )
+ with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)):
+ await janitor.init()
+
+ executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list)
+ assert 'CREATE DATABASE "mydb" TEMPLATE "tmpl"' in executed_sql
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_init_as_template() -> None:
+ """init() appends IS_TEMPLATE = true when as_template is True."""
+ cur = _make_cursor_mock()
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10)
+ with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)):
+ await janitor.init()
+
+ executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list)
+ assert "IS_TEMPLATE = true" in executed_sql
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_drop_drops_database() -> None:
+ """drop() executes DROP DATABASE IF EXISTS for the configured dbname."""
+ cur = _make_cursor_mock()
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10)
+ with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)):
+ await janitor.drop()
+
+ executed_sql = " ".join(_render_sql(c.args[0]) for c in cur.execute.call_args_list)
+ assert 'DROP DATABASE IF EXISTS "mydb"' in executed_sql
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_drop_as_template() -> None:
+ """drop() resets is_template before dropping when as_template is True."""
+ cur = _make_cursor_mock()
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10)
+ with patch.object(AsyncDatabaseJanitor, "cursor", _make_cursor_context(cur)):
+ await janitor.drop()
+
+ executed_sql = [_render_sql(c.args[0]) for c in cur.execute.call_args_list]
+ assert any("is_template false" in s for s in executed_sql)
+ assert any('DROP DATABASE IF EXISTS "mydb"' in s for s in executed_sql)
+ # is_template false must come before DROP
+ template_idx = next(i for i, s in enumerate(executed_sql) if "is_template false" in s)
+ drop_idx = next(i for i, s in enumerate(executed_sql) if "DROP DATABASE" in s)
+ assert template_idx < drop_idx
+
+
+def test_async_janitor_is_template_false() -> None:
+ """is_template() returns False when as_template is not set."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10)
+ assert janitor.is_template() is False
+
+
+def test_async_janitor_is_template_true() -> None:
+ """is_template() returns True when as_template=True."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", as_template=True, version=10)
+ assert janitor.is_template() is True
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_context_manager_calls_init_and_drop() -> None:
+ """__aenter__ calls init() and __aexit__ calls drop()."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="mydb", version=10)
+ init_mock = AsyncMock()
+ drop_mock = AsyncMock()
+ with patch.object(AsyncDatabaseJanitor, "init", init_mock), patch.object(AsyncDatabaseJanitor, "drop", drop_mock):
+ async with janitor:
+ init_mock.assert_called_once()
+ drop_mock.assert_not_called()
+ drop_mock.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_terminate_connection_sql() -> None:
+ """_terminate_connection() executes pg_terminate_backend query with correct dbname."""
+ cur = AsyncMock(spec=AsyncCursor)
+ await AsyncDatabaseJanitor._terminate_connection(cur, "target_db")
+
+ cur.execute.assert_called_once()
+ sql_str, params = cur.execute.call_args.args
+ assert "pg_terminate_backend" in sql_str
+ assert params == ("target_db",)
+
+
+@pytest.mark.asyncio
+async def test_async_janitor_dont_datallowconn_sql() -> None:
+ """_dont_datallowconn() executes ALTER DATABASE allow_connections false for the dbname."""
+ cur = AsyncMock(spec=AsyncCursor)
+ await AsyncDatabaseJanitor._dont_datallowconn(cur, "target_db")
+
+ cur.execute.assert_called_once()
+ sql_str = _render_sql(cur.execute.call_args.args[0])
+ assert "allow_connections false" in sql_str
+ assert '"target_db"' in sql_str
+
+
+def _make_async_conn_mock() -> MagicMock:
+ """Create a MagicMock that behaves like a psycopg3 AsyncConnection."""
+ conn = MagicMock()
+ conn.close = AsyncMock()
+ conn.set_isolation_level = AsyncMock()
+ conn.set_autocommit = AsyncMock()
+ cursor_mock = MagicMock()
+ cursor_mock.__aenter__ = AsyncMock(return_value=MagicMock())
+ cursor_mock.__aexit__ = AsyncMock(return_value=False)
+ conn.cursor = MagicMock(return_value=cursor_mock)
+ return conn
diff --git a/tests/test_loader.py b/tests/test_loader.py
index c03f8a55..5d69cdbb 100644
--- a/tests/test_loader.py
+++ b/tests/test_loader.py
@@ -1,8 +1,11 @@
"""Tests for the `build_loader` function."""
from pathlib import Path
+from unittest.mock import patch
-from pytest_postgresql.loader import build_loader, sql
+import pytest
+
+from pytest_postgresql.loader import build_loader, build_loader_async, sql, sql_async
from tests.loader import load_database
@@ -12,9 +15,49 @@ def test_loader_callables() -> None:
assert load_database == build_loader("tests.loader:load_database")
+def test_loader_callables_dot_separator() -> None:
+ """Test dot-separated import path resolves the same callable as colon-separated."""
+ assert build_loader("tests.loader.load_database") == load_database
+
+
+@pytest.mark.asyncio
+async def test_loader_callables_async() -> None:
+ """Async test handling callables in build_loader_async."""
+ assert load_database == build_loader_async(load_database)
+ assert load_database == build_loader_async("tests.loader:load_database")
+
+ async def afun(*_args: object, **_kwargs: object) -> int:
+ return 0
+
+ assert afun == build_loader_async(afun)
+
+
+@pytest.mark.asyncio
+async def test_loader_callables_async_dot_separator() -> None:
+ """Dot-separated import path is resolved identically by build_loader_async."""
+ assert build_loader_async("tests.loader.load_database") == load_database
+
+
def test_loader_sql() -> None:
"""Test returning partial running sql for the sql file path."""
sql_path = Path("test_sql/eidastats.sql")
loader_func = build_loader(sql_path)
assert loader_func.args == (sql_path,) # type: ignore
assert loader_func.func == sql # type: ignore
+
+
+@pytest.mark.asyncio
+async def test_loader_sql_async() -> None:
+ """Async test returning partial running sql_async for the sql file path."""
+ sql_path = Path("test_sql/eidastats.sql")
+ loader_func = build_loader_async(sql_path)
+ assert loader_func.args == (sql_path,) # type: ignore
+ assert loader_func.func == sql_async # type: ignore
+
+
+@pytest.mark.asyncio
+async def test_sql_async_raises_without_aiofiles() -> None:
+ """sql_async raises ImportError with a helpful message when aiofiles is not installed."""
+ with patch("pytest_postgresql.loader.aiofiles", None):
+ with pytest.raises(ImportError, match="aiofiles"):
+ await sql_async(Path("dummy.sql"), host="h", port=5432, user="u", dbname="d")
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 1b86beaf..1461694d 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -1,11 +1,14 @@
"""All tests for pytest-postgresql."""
+import decimal
+
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
from psycopg.pq import ConnStatus
from pytest_postgresql.executor import PostgreSQLExecutor
-from pytest_postgresql.retry import retry
+from pytest_postgresql.retry import retry, retry_async
+from tests.conftest import POSTGRESQL_VERSION
MAKE_Q = "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);"
SELECT_Q = "SELECT * FROM test_load;"
@@ -66,3 +69,60 @@ def check_if_one_connection() -> None:
assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}"
retry(check_if_one_connection, timeout=120, possible_exception=AssertionError)
+
+
+@pytest.mark.asyncio
+async def test_main_postgres_async(postgresql_async: AsyncConnection) -> None:
+ """Async check main postgresql fixture."""
+ async with postgresql_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql_async.commit()
+
+
+@pytest.mark.asyncio
+async def test_two_postgreses_async(postgresql_async: AsyncConnection, postgresql2_async: AsyncConnection) -> None:
+ """Async check two postgresql fixtures on one test."""
+ async with postgresql_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql_async.commit()
+
+ async with postgresql2_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql2_async.commit()
+
+
+@pytest.mark.asyncio
+async def test_postgres_load_two_files_async(postgresql_load_1_async: AsyncConnection) -> None:
+ """Async check postgresql fixture can load two files."""
+ async with postgresql_load_1_async.cursor() as cur:
+ await cur.execute(SELECT_Q)
+ results = await cur.fetchall()
+ assert len(results) == 2
+
+
+@pytest.mark.asyncio
+async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> None:
+ """Async check if postgres fixture can be started on random port."""
+ assert postgresql2_async.info.status == ConnStatus.OK
+
+
+@pytest.mark.skipif(
+ decimal.Decimal(POSTGRESQL_VERSION) < 10,
+ reason="Test query not supported in those postgresql versions, and soon will not be supported.",
+)
+@pytest.mark.xdist_group(name="terminate_connection")
+@pytest.mark.asyncio
+@pytest.mark.parametrize("_", range(2))
+async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None:
+ """Async test that connections are terminated between tests.
+
+ And check that only one exists at a time.
+ """
+ async with postgresql2_async.cursor() as cur:
+
+ async def check_if_one_connection() -> None:
+ await cur.execute("SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend';")
+ existing_connections = await cur.fetchall()
+ assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}"
+
+ await retry_async(check_if_one_connection, timeout=120, possible_exception=AssertionError)
diff --git a/tests/test_retry.py b/tests/test_retry.py
new file mode 100644
index 00000000..8581331d
--- /dev/null
+++ b/tests/test_retry.py
@@ -0,0 +1,75 @@
+"""Unit tests for retry and retry_async."""
+
+import datetime
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from pytest_postgresql.retry import retry_async
+
+
+@pytest.mark.asyncio
+async def test_retry_async_immediate_success() -> None:
+ """Test that retry_async returns immediately when function succeeds on first call."""
+
+ async def ok() -> int:
+ return 42
+
+ assert await retry_async(ok, timeout=5) == 42
+
+
+@pytest.mark.asyncio
+async def test_retry_async_succeeds_after_failures() -> None:
+ """Test that retry_async retries on the expected exception and returns on success."""
+ attempts = 0
+
+ async def flaky() -> str:
+ nonlocal attempts
+ attempts += 1
+ if attempts < 3:
+ raise ConnectionError("transient")
+ return "ok"
+
+ sleep_mock = AsyncMock()
+ with patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock):
+ result = await retry_async(flaky, timeout=10, possible_exception=ConnectionError)
+
+ assert result == "ok"
+ assert attempts == 3
+ assert sleep_mock.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_retry_async_timeout() -> None:
+ """Test that retry_async raises TimeoutError after the timeout elapses."""
+
+ async def always_fail() -> None:
+ raise ValueError("boom")
+
+ sleep_mock = AsyncMock()
+ base = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc)
+ call_count = 0
+
+ def advancing_clock() -> datetime.datetime:
+ nonlocal call_count
+ call_count += 1
+ # First call captures starting time; all subsequent calls report past the timeout.
+ return base if call_count == 1 else base + datetime.timedelta(seconds=10)
+
+ with (
+ patch("pytest_postgresql.retry.asyncio.sleep", sleep_mock),
+ patch("pytest_postgresql.retry.get_current_datetime", advancing_clock),
+ ):
+ with pytest.raises(TimeoutError, match="Failed after"):
+ await retry_async(always_fail, timeout=1, possible_exception=ValueError)
+
+
+@pytest.mark.asyncio
+async def test_retry_async_unmatched_exception_propagates() -> None:
+ """Test that an exception not matching possible_exception propagates immediately."""
+
+ async def wrong_exc() -> None:
+ raise TypeError("unexpected")
+
+ with pytest.raises(TypeError, match="unexpected"):
+ await retry_async(wrong_exc, timeout=5, possible_exception=ValueError)
diff --git a/tests/test_template_database.py b/tests/test_template_database.py
index 64631779..fc64442e 100644
--- a/tests/test_template_database.py
+++ b/tests/test_template_database.py
@@ -1,9 +1,9 @@
"""Template database tests."""
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
-from pytest_postgresql.factories import postgresql, postgresql_proc
+from pytest_postgresql.factories import postgresql, postgresql_async, postgresql_proc
from tests.loader import load_database
postgresql_proc_with_template = postgresql_proc(
@@ -17,6 +17,11 @@
dbname="stories_templated",
)
+async_postgresql_template = postgresql_async(
+ "postgresql_proc_with_template",
+ dbname="stories_templated",
+)
+
@pytest.mark.xdist_group(name="template_database")
@pytest.mark.parametrize("_", range(5))
@@ -30,3 +35,18 @@ def test_template_database(postgresql_template: Connection, _: int) -> None:
cur.execute("SELECT * FROM stories")
res = cur.fetchall()
assert len(res) == 0
+
+
+@pytest.mark.xdist_group(name="template_database_async")
+@pytest.mark.asyncio
+@pytest.mark.parametrize("_", range(5))
+async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None:
+ """Async check that the database structure gets recreated out of a template."""
+ async with async_postgresql_template.cursor() as cur:
+ await cur.execute("SELECT * FROM stories")
+ res = await cur.fetchall()
+ assert len(res) == 4
+ await cur.execute("TRUNCATE stories")
+ await cur.execute("SELECT * FROM stories")
+ res = await cur.fetchall()
+ assert len(res) == 0