diff --git a/tests/test_completion/test_completion_install.py b/tests/test_completion/test_completion_install.py index 0856b9b565..830f5901c9 100644 --- a/tests/test_completion/test_completion_install.py +++ b/tests/test_completion/test_completion_install.py @@ -5,14 +5,12 @@ from unittest import mock import shellingham -import typer -from typer.testing import CliRunner from docs_src.commands.index import tutorial001 as mod +from typer.testing import CliRunner runner = CliRunner() -app = typer.Typer() -app.command()(mod.main) +app = mod.app def test_completion_install_no_shell(): @@ -144,12 +142,6 @@ def test_completion_install_fish(): assert "completion installed in" in result.stdout assert "Completion will take effect once you restart the terminal" in result.stdout - -runner = CliRunner() -app = typer.Typer() -app.command()(mod.main) - - def test_completion_install_powershell(): completion_path: Path = ( Path.home() / f".config/powershell/Microsoft.PowerShell_profile.ps1" diff --git a/typer/main.py b/typer/main.py index 2e644ccc4e..a685149e1f 100644 --- a/typer/main.py +++ b/typer/main.py @@ -47,10 +47,7 @@ except ImportError: # pragma: nocover rich = None # type: ignore -try: - from asyncio import Loop -except ImportError: # pragma: nocover - Loop = Any # type: ignore +from asyncio import AbstractEventLoop _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -145,7 +142,7 @@ def __init__( pretty_exceptions_enable: bool = True, pretty_exceptions_show_locals: bool = True, pretty_exceptions_short: bool = True, - loop_factory: Optional[Callable[[], Loop]] = None, + loop_factory: Optional[Callable[[], AbstractEventLoop]] = None, ): self._add_completion = add_completion self.rich_markup_mode: MarkupMode = rich_markup_mode @@ -245,7 +242,7 @@ def command( def decorator(f: CommandFunctionType) -> CommandFunctionType: def add_runner(f: CommandFunctionType) -> CommandFunctionType: @wraps(f) - def runner(*args, **kwargs) -> Any: + def run_wrapper(*args, **kwargs) -> Any: if sys.version_info >= (3, 11) and self.loop_factory: with asyncio.Runner(loop_factory=self.loop_factory) as runner: return runner.run(f(*args, **kwargs)) @@ -254,7 +251,7 @@ def runner(*args, **kwargs) -> Any: else: asyncio.get_event_loop().run_until_complete(asyncio.wait(f(*args, **kwargs))) - return runner + return run_wrapper if inspect.iscoroutinefunction(f): callback = add_runner(f) @@ -279,6 +276,7 @@ def runner(*args, **kwargs) -> Any: rich_help_panel=rich_help_panel, ) ) + return f return decorator