From ca719d1d0ababbbbf300a615447bb6835d8ab235 Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Thu, 6 Jul 2023 11:28:17 +0200 Subject: [PATCH] #88: Add asyncio support --- tests/test_annotated.py | 14 ++++++++++++++ typer/main.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 6436ad668e..b4a3780386 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -57,3 +57,17 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False): result = runner.invoke(app, ["--force"]) assert result.exit_code == 0, result.output assert "Forcing operation" in result.output + +def test_runner_can_use_an_async_method(): + app = typer.Typer() + @app.command() + async def cmd(val: Annotated[int, typer.Argument()] = 0): + print(f"hello {val}") + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "hello 0" in result.output + + result = runner.invoke(app, ["42"]) + assert result.exit_code == 0, result.output + assert "hello 42" in result.output \ No newline at end of file diff --git a/typer/main.py b/typer/main.py index aa39e82849..07174461da 100644 --- a/typer/main.py +++ b/typer/main.py @@ -1,14 +1,15 @@ +import asyncio import inspect import os import sys import traceback from datetime import datetime from enum import Enum -from functools import update_wrapper +from functools import update_wrapper, wraps from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, ParamSpec from uuid import UUID import click @@ -46,6 +47,11 @@ except ImportError: # pragma: nocover rich = None # type: ignore +try: + from asyncio import Loop +except ImportError: # pragma: nocover + Loop = Any # type: ignore + _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -139,6 +145,7 @@ def __init__( pretty_exceptions_enable: bool = True, pretty_exceptions_show_locals: bool = True, pretty_exceptions_short: bool = True, + loop_factory: Optional[Callable[[], Loop]] = None, ): self._add_completion = add_completion self.rich_markup_mode: MarkupMode = rich_markup_mode @@ -167,6 +174,7 @@ def __init__( self.registered_groups: List[TyperInfo] = [] self.registered_commands: List[CommandInfo] = [] self.registered_callback: Optional[TyperInfo] = None + self.loop_factory = loop_factory def callback( self, @@ -235,12 +243,28 @@ def command( cls = TyperCommand def decorator(f: CommandFunctionType) -> CommandFunctionType: + def add_runner(f: CommandFunctionType) -> CommandFunctionType: + @wraps(f) + def runner(*args, **kwargs) -> Any: + if sys.version_info >= (3, 11) and self.loop_factory: + with asyncio.Runner(loop_factory=self.loop_factory) as runner: + return runner.run(f(*args, **kwargs)) + else: + return asyncio.run(f(*args, **kwargs)) + + return runner + + if inspect.iscoroutinefunction(f): + callback = add_runner(f) + else: + callback = f + self.registered_commands.append( CommandInfo( name=name, cls=cls, context_settings=context_settings, - callback=f, + callback=callback, help=help, epilog=epilog, short_help=short_help, @@ -253,7 +277,7 @@ def decorator(f: CommandFunctionType) -> CommandFunctionType: rich_help_panel=rich_help_panel, ) ) - return f + return decorator