diff --git a/CHANGES.rst b/CHANGES.rst index 77f4b4b4b..ed82a6351 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,11 @@ Version 8.4.0 Unreleased +- Command callbacks may be ``async def`` functions. Click runs the returned + awaitable to completion using :func:`asyncio.run`, including when the + coroutine function is only reached through wrappers such as + :func:`pass_context`. If an event loop is already running, Click raises + :exc:`RuntimeError` instead of scheduling the coroutine. :issue:`2033` - :class:`ParamType` typing improvements. :pr:`3371` - :class:`ParamType` is now a generic abstract base class, diff --git a/docs/quickstart.md b/docs/quickstart.md index 523a8072d..da5b21bcc 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -71,6 +71,31 @@ And the corresponding help page: invoke(hello, args=['--help'], prog_name='python hello.py') ``` +## Async command callbacks + +```{versionadded} 8.4 +``` + +Command bodies may be defined with `async def`. Click detects the awaitable +returned from the callback (including when you combine `async def` with +decorators such as {func}`pass_context`) and runs it to completion with +{func}`asyncio.run` before returning to the caller. + +```{eval-rst} +.. click:example:: + import asyncio + import click + + @click.command() + async def hello(): + await asyncio.sleep(0) + click.echo("Hello World!") +``` + +If an asyncio event loop is already running (for example when embedding a +Click CLI inside `asyncio.run` or another async framework), Click cannot use +`asyncio.run` again and raises {exc}`RuntimeError` instead. + ## Echoing Why does this example use {func}`echo` instead of the regular {func}`print` function? The answer to this question is diff --git a/src/click/core.py b/src/click/core.py index c5cab15c0..be1b7e602 100644 --- a/src/click/core.py +++ b/src/click/core.py @@ -54,6 +54,36 @@ V = t.TypeVar("V") +async def _await_any(awaitable: t.Awaitable[t.Any]) -> t.Any: + return await awaitable + + +def _invoke_command_callback( + callback: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any +) -> t.Any: + """Run *callback* and drive it with :func:`asyncio.run` if it returns an + awaitable (for example when the user defined an ``async def`` command + body, including when that coroutine function is wrapped by decorators such + as :func:`pass_context`). See :issue:`2033`. + """ + import asyncio + + result = callback(*args, **kwargs) + if inspect.isawaitable(result): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(_await_any(t.cast("t.Awaitable[t.Any]", result))) + if inspect.iscoroutine(result): + result.close() + raise RuntimeError( + "Click cannot run this command because an asyncio event loop is" + " already running. Use a synchronous callback, or invoke the CLI" + " from a context without an active event loop." + ) from None + return result + + def _complete_visible_commands( ctx: Context, incomplete: str ) -> cabc.Iterator[tuple[str, Command]]: @@ -816,6 +846,12 @@ def invoke( .. versionchanged:: 3.2 A new context is created, and missing arguments use default values. + + .. versionchanged:: 8.4 + If the callable returns an awaitable (for example an ``async def`` + callback, including when only the wrapped function is async, as with + :func:`pass_context`), Click runs it to completion using + :func:`asyncio.run`. """ if isinstance(callback, Command): other_cmd = callback @@ -851,7 +887,7 @@ def invoke( with augment_usage_errors(self): with ctx: - return callback(*args, **kwargs) + return _invoke_command_callback(callback, *args, **kwargs) def forward(self, cmd: Command, /, *args: t.Any, **kwargs: t.Any) -> t.Any: """Similar to :meth:`invoke` but fills in default keyword diff --git a/tests/test_async_commands.py b/tests/test_async_commands.py new file mode 100644 index 000000000..03d81f710 --- /dev/null +++ b/tests/test_async_commands.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import asyncio + +import pytest + +import click +from click.testing import CliRunner + + +def test_async_command_callback(runner: CliRunner) -> None: + @click.command() + async def cli() -> None: + await asyncio.sleep(0) + click.echo("done") + + result = runner.invoke(cli, []) + assert result.exit_code == 0 + assert result.output == "done\n" + + +def test_async_command_with_pass_context(runner: CliRunner) -> None: + @click.command() + @click.pass_context + async def cli(ctx: click.Context) -> None: + assert ctx.info_name == "cli" + click.echo("ok") + + result = runner.invoke(cli, []) + assert result.exit_code == 0 + assert result.output == "ok\n" + + +def test_async_subcommand(runner: CliRunner) -> None: + @click.group() + def grp() -> None: + pass + + @grp.command() + async def sub() -> None: + click.echo("sub") + + result = runner.invoke(grp, ["sub"]) + assert result.exit_code == 0 + assert result.output == "sub\n" + + +def test_async_command_return_value(runner: CliRunner) -> None: + @click.command() + async def cli() -> int: + await asyncio.sleep(0) + return 42 + + result = runner.invoke(cli, [], standalone_mode=False) + assert result.exit_code == 0 + assert result.return_value == 42 + + +def test_context_invoke_async_callback(runner: CliRunner) -> None: + @click.command() + @click.pass_context + def cli(ctx: click.Context) -> int: + async def helper() -> int: + return 99 + + rv = ctx.invoke(helper) + assert isinstance(rv, int) + return rv + + result = runner.invoke(cli, [], standalone_mode=False) + assert result.exit_code == 0 + assert result.return_value == 99 + + +def test_async_group_callback(runner: CliRunner) -> None: + @click.group() + @click.pass_context + async def grp(ctx: click.Context) -> None: + click.echo("grp") + + @grp.command() + def sub() -> None: + click.echo("sub") + + result = runner.invoke(grp, ["sub"]) + assert result.exit_code == 0 + assert result.output.splitlines() == ["grp", "sub"] + + +def test_async_command_rejects_when_loop_already_running( + runner: CliRunner, +) -> None: + @click.command() + async def cli() -> None: + pass + + async def invoke_inside_loop() -> None: + runner.invoke(cli, catch_exceptions=False) + + with pytest.raises(RuntimeError, match="asyncio event loop"): + asyncio.run(invoke_inside_loop())