diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 6436ad668e..5c6d91bc5d 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -57,3 +57,19 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False): result = runner.invoke(app, ["--force"]) assert result.exit_code == 0, result.output assert "Forcing operation" in result.output + + +def test_runner_can_use_an_async_method(): + app = typer.Typer() + + @app.command() + async def cmd(val: Annotated[int, typer.Argument()] = 0): + print(f"hello {val}") + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "hello 0" in result.output + + result = runner.invoke(app, ["42"]) + assert result.exit_code == 0, result.output + assert "hello 42" in result.output diff --git a/tests/test_completion/test_completion_install.py b/tests/test_completion/test_completion_install.py index 0856b9b565..1a5b5f3833 100644 --- a/tests/test_completion/test_completion_install.py +++ b/tests/test_completion/test_completion_install.py @@ -5,14 +5,12 @@ from unittest import mock import shellingham -import typer from typer.testing import CliRunner from docs_src.commands.index import tutorial001 as mod runner = CliRunner() -app = typer.Typer() -app.command()(mod.main) +app = mod.app def test_completion_install_no_shell(): @@ -145,11 +143,6 @@ def test_completion_install_fish(): assert "Completion will take effect once you restart the terminal" in result.stdout -runner = CliRunner() -app = typer.Typer() -app.command()(mod.main) - - def test_completion_install_powershell(): completion_path: Path = ( Path.home() / f".config/powershell/Microsoft.PowerShell_profile.ps1" diff --git a/typer/main.py b/typer/main.py index aa39e82849..851ed8a6be 100644 --- a/typer/main.py +++ b/typer/main.py @@ -1,10 +1,11 @@ +import asyncio import inspect import os import sys import traceback from datetime import datetime from enum import Enum -from functools import update_wrapper +from functools import update_wrapper, wraps from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType @@ -46,6 +47,7 @@ except ImportError: # pragma: nocover rich = None # type: ignore + _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -235,12 +237,27 @@ def command( cls = TyperCommand def decorator(f: CommandFunctionType) -> CommandFunctionType: + def add_runner(f: CommandFunctionType) -> CommandFunctionType: + @wraps(f) + def run_wrapper(*args: Any, **kwargs: Any) -> Any: + if sys.version_info >= (3, 7): + return asyncio.run(f(*args, **kwargs)) + else: + asyncio.get_event_loop().run_until_complete(f(*args, **kwargs)) + + return run_wrapper # type: ignore + + if inspect.iscoroutinefunction(f): + callback = add_runner(f) + else: + callback = f + self.registered_commands.append( CommandInfo( name=name, cls=cls, context_settings=context_settings, - callback=f, + callback=callback, help=help, epilog=epilog, short_help=short_help,