Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion changelog.d/1164.added.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Added the ``pytest_asyncio_loop_factories`` hook to parametrize asyncio tests with custom event loop factories.

The hook now returns a mapping of factory names to loop factories, and ``pytest.mark.asyncio(loop_factories=[...])`` can be used to select a subset of configured factories per test.
The hook returns a mapping of factory names to loop factories, and ``pytest.mark.asyncio(loop_factories=[...])`` selects a subset of configured factories per test. When a single factory is configured, test names are unchanged on pytest 8.4+.

Synchronous ``@pytest_asyncio.fixture`` functions now see the correct event loop when custom loop factories are configured, even when test code disrupts the current event loop (e.g., via ``asyncio.run()`` or ``asyncio.set_event_loop(None)``).
93 changes: 88 additions & 5 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,50 @@ def _fixture_synchronizer(
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(fixturedef.func):
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
elif inspect.isgeneratorfunction(fixturedef.func):
return _wrap_syncgen_fixture(fixture_function, runner) # type: ignore[arg-type]
else:
return fixturedef.func
return _wrap_sync_fixture(fixture_function, runner) # type: ignore[arg-type]


SyncGenFixtureParams = ParamSpec("SyncGenFixtureParams")
SyncGenFixtureYieldType = TypeVar("SyncGenFixtureYieldType")


def _wrap_syncgen_fixture(
fixture_function: Callable[
SyncGenFixtureParams, Generator[SyncGenFixtureYieldType]
],
runner: Runner,
) -> Callable[SyncGenFixtureParams, Generator[SyncGenFixtureYieldType]]:
@functools.wraps(fixture_function)
def _syncgen_fixture_wrapper(
*args: SyncGenFixtureParams.args,
**kwargs: SyncGenFixtureParams.kwargs,
) -> Generator[SyncGenFixtureYieldType]:
with _temporary_event_loop(runner.get_loop()):
yield from fixture_function(*args, **kwargs)

return _syncgen_fixture_wrapper


SyncFixtureParams = ParamSpec("SyncFixtureParams")
SyncFixtureReturnType = TypeVar("SyncFixtureReturnType")


def _wrap_sync_fixture(
fixture_function: Callable[SyncFixtureParams, SyncFixtureReturnType],
runner: Runner,
) -> Callable[SyncFixtureParams, SyncFixtureReturnType]:
@functools.wraps(fixture_function)
def _sync_fixture_wrapper(
*args: SyncFixtureParams.args,
**kwargs: SyncFixtureParams.kwargs,
) -> SyncFixtureReturnType:
with _temporary_event_loop(runner.get_loop()):
return fixture_function(*args, **kwargs)

return _sync_fixture_wrapper


AsyncGenFixtureParams = ParamSpec("AsyncGenFixtureParams")
Expand Down Expand Up @@ -500,6 +542,12 @@ def setup(self) -> None:
runner_fixture_id = f"_{self._loop_scope}_scoped_runner"
if runner_fixture_id not in self.fixturenames:
self.fixturenames.append(runner_fixture_id)
# When loop factories are configured, resolve the loop factory
# fixture early so that a factory variant change cascades cache
# invalidation before any async fixture checks its cache.
hook_caller = self.config.hook.pytest_asyncio_loop_factories
if hook_caller.get_hookimpls():
_ = self._request.getfixturevalue(_asyncio_loop_factory.__name__)
return super().setup()

def runtest(self) -> None:
Expand Down Expand Up @@ -712,22 +760,47 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
metafunc.fixturenames.append(_asyncio_loop_factory.__name__)
default_loop_scope = _get_default_test_loop_scope(metafunc.config)
loop_scope = marker_loop_scope or default_loop_scope
# pytest.HIDDEN_PARAM was added in pytest 8.4
hide_id = len(effective_factories) == 1 and hasattr(pytest, "HIDDEN_PARAM")
metafunc.parametrize(
_asyncio_loop_factory.__name__,
effective_factories.values(),
ids=effective_factories.keys(),
ids=(pytest.HIDDEN_PARAM,) if hide_id else effective_factories.keys(),
indirect=True,
scope=loop_scope,
)


@contextlib.contextmanager
def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[None]:
old_loop_policy = _get_event_loop_policy()
def _temporary_event_loop(loop: AbstractEventLoop) -> Iterator[None]:
try:
old_loop = _get_event_loop_no_warn()
except RuntimeError:
old_loop = None
if old_loop is loop:
yield
return
_set_event_loop(loop)
try:
yield
finally:
_set_event_loop(old_loop)


@contextlib.contextmanager
def _temporary_event_loop_policy(
policy: AbstractEventLoopPolicy,
*,
has_custom_factory: bool,
) -> Iterator[None]:
old_loop_policy = _get_event_loop_policy()
if has_custom_factory:
old_loop = None
else:
try:
old_loop = _get_event_loop_no_warn()
except RuntimeError:
old_loop = None
_set_event_loop_policy(policy)
try:
yield
Expand Down Expand Up @@ -846,6 +919,11 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
)
runner_fixture_id = f"_{loop_scope}_scoped_runner"
runner = request.getfixturevalue(runner_fixture_id)
# Prevent the runner closing before the fixture's async teardown.
runner_fixturedef = request._get_active_fixturedef(runner_fixture_id)
runner_fixturedef.addfinalizer(
functools.partial(fixturedef.finish, request=request)
)
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
_make_asyncio_fixture_function(synchronizer, loop_scope)
with MonkeyPatch.context() as c:
Expand Down Expand Up @@ -935,11 +1013,16 @@ def _scoped_runner(
) -> Iterator[Runner]:
new_loop_policy = event_loop_policy
debug_mode = _get_asyncio_debug(request.config)
with _temporary_event_loop_policy(new_loop_policy):
with _temporary_event_loop_policy(
new_loop_policy,
has_custom_factory=_asyncio_loop_factory is not None,
):
runner = Runner(
debug=debug_mode,
loop_factory=_asyncio_loop_factory,
).__enter__()
if _asyncio_loop_factory is not None:
_set_event_loop(runner.get_loop())
try:
yield runner
except Exception as e:
Expand Down
Loading
Loading