Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 48 additions & 36 deletions winloop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio as __asyncio
import collections.abc as _collections_abc
import typing as _typing
import sys as _sys
import warnings as _warnings
Expand All @@ -9,7 +8,7 @@
from ._version import __version__ # NOQA


__all__: tuple[str, ...] = ("new_event_loop", "run")
__all__: _typing.Tuple[str, ...] = ('new_event_loop', 'run')
_AbstractEventLoop = __asyncio.AbstractEventLoop


Expand All @@ -26,16 +25,16 @@ def new_event_loop() -> Loop:


if _typing.TYPE_CHECKING:

def run(
main: _collections_abc.Coroutine[_typing.Any, _typing.Any, _T],
main: _typing.Coroutine[_typing.Any, _typing.Any, _T],
*,
loop_factory: _collections_abc.Callable[[], Loop] | None = new_event_loop,
debug: bool | None = None,
loop_factory: _typing.Optional[
_typing.Callable[[], Loop]
] = new_event_loop,
debug: _typing.Optional[bool]=None,
) -> _T:
"""The preferred way of running a coroutine with winloop."""
else:

def run(main, *, loop_factory=new_event_loop, debug=None, **run_kwargs):
"""The preferred way of running a coroutine with winloop."""

Expand All @@ -45,7 +44,7 @@ async def wrapper():
# is using `winloop.run()` intentionally.
loop = __asyncio._get_running_loop()
if not isinstance(loop, Loop):
raise TypeError("winloop.run() uses a non-winloop event loop")
raise TypeError('winloop.run() uses a non-winloop event loop')
return await main

vi = _sys.version_info[:2]
Expand All @@ -55,11 +54,12 @@ async def wrapper():

if __asyncio._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)
"asyncio.run() cannot be called from a running event loop")

if not __asyncio.iscoroutine(main):
raise ValueError("a coroutine was expected, got {!r}".format(main))
raise ValueError(
"a coroutine was expected, got {!r}".format(main)
)

loop = loop_factory()
try:
Expand All @@ -71,27 +71,33 @@ async def wrapper():
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, "shutdown_default_executor"):
loop.run_until_complete(loop.shutdown_default_executor())
if hasattr(loop, 'shutdown_default_executor'):
loop.run_until_complete(
loop.shutdown_default_executor()
)
finally:
__asyncio.set_event_loop(None)
loop.close()

elif vi == (3, 11):
if __asyncio._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)
"asyncio.run() cannot be called from a running event loop")

with __asyncio.Runner(
loop_factory=loop_factory, debug=debug, **run_kwargs
loop_factory=loop_factory,
debug=debug,
**run_kwargs
) as runner:
return runner.run(wrapper())

else:
assert vi >= (3, 12)
return __asyncio.run(
wrapper(), loop_factory=loop_factory, debug=debug, **run_kwargs
wrapper(),
loop_factory=loop_factory,
debug=debug,
**run_kwargs
)


Expand All @@ -105,22 +111,22 @@ def _cancel_all_tasks(loop: _AbstractEventLoop) -> None:
for task in to_cancel:
task.cancel()

loop.run_until_complete(__asyncio.gather(*to_cancel, return_exceptions=True))
loop.run_until_complete(
__asyncio.gather(*to_cancel, return_exceptions=True)
)

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})


_deprecated_names = ("install", "EventLoopPolicy")
_deprecated_names = ('install', 'EventLoopPolicy')


if _sys.version_info[:2] < (3, 16):
Expand All @@ -146,16 +152,16 @@ def install() -> None:
"""
if _sys.version_info[:2] >= (3, 12):
_warnings.warn(
"winloop.install() is deprecated in favor of winloop.run() "
"starting with Python 3.12.",
'winloop.install() is deprecated in favor of winloop.run() '
'starting with Python 3.12.',
DeprecationWarning,
stacklevel=1,
)
__asyncio.set_event_loop_policy(EventLoopPolicy())

class EventLoopPolicy(
# This is to avoid a mypy error about AbstractEventLoopPolicy
getattr(__asyncio, "AbstractEventLoopPolicy") # type: ignore[misc]
getattr(__asyncio, 'AbstractEventLoopPolicy') # type: ignore[misc]
):
"""Event loop policy for winloop.

Expand All @@ -177,12 +183,16 @@ def _loop_factory(self) -> Loop:
# marked as abstract in typeshed, we have to put them in so mypy
# thinks the base methods are overridden. This is the same approach
# taken for the Windows event loop policy classes in typeshed.
def get_child_watcher(self) -> _typing.NoReturn: ...
def get_child_watcher(self) -> _typing.NoReturn:
...

def set_child_watcher(self, watcher: _typing.Any) -> _typing.NoReturn: ...
def set_child_watcher(
self, watcher: _typing.Any
) -> _typing.NoReturn:
...

class _Local(threading.local):
_loop: _AbstractEventLoop | None = None
_loop: _typing.Optional[_AbstractEventLoop] = None

def __init__(self) -> None:
self._local = self._Local()
Expand All @@ -194,13 +204,15 @@ def get_event_loop(self) -> _AbstractEventLoop:
"""
if self._local._loop is None:
raise RuntimeError(
"There is no current event loop in thread %r."
'There is no current event loop in thread %r.'
% threading.current_thread().name
)

return self._local._loop

def set_event_loop(self, loop: _AbstractEventLoop | None) -> None:
def set_event_loop(
self, loop: _typing.Optional[_AbstractEventLoop]
) -> None:
"""Set the event loop."""
if loop is not None and not isinstance(loop, _AbstractEventLoop):
raise TypeError(
Expand All @@ -216,6 +228,6 @@ def new_event_loop(self) -> Loop:
"""
return self._loop_factory()

globals()["install"] = install
globals()["EventLoopPolicy"] = EventLoopPolicy
globals()['install'] = install
globals()['EventLoopPolicy'] = EventLoopPolicy
return globals()[name]
Loading
Loading