From e5a3cb5c298d0215935af3af2abece7389089d24 Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Thu, 6 Jul 2023 11:28:17 +0200 Subject: [PATCH] #88: Add asyncio support --- typer/main.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/typer/main.py b/typer/main.py index aa39e82849..894923513d 100644 --- a/typer/main.py +++ b/typer/main.py @@ -1,14 +1,15 @@ +import asyncio import inspect import os import sys import traceback from datetime import datetime from enum import Enum -from functools import update_wrapper +from functools import update_wrapper, wraps from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, ParamSpec from uuid import UUID import click @@ -46,6 +47,11 @@ except ImportError: # pragma: nocover rich = None # type: ignore +if sys.version_info >= (3, 11): + from asyncio import Loop +else: + Loop = Any # type: ignore + _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -139,6 +145,7 @@ def __init__( pretty_exceptions_enable: bool = True, pretty_exceptions_show_locals: bool = True, pretty_exceptions_short: bool = True, + loop_factory: Optional[Callable[[], Loop]] = None, ): self._add_completion = add_completion self.rich_markup_mode: MarkupMode = rich_markup_mode @@ -167,6 +174,7 @@ def __init__( self.registered_groups: List[TyperInfo] = [] self.registered_commands: List[CommandInfo] = [] self.registered_callback: Optional[TyperInfo] = None + self.loop_factory = loop_factory def callback( self, @@ -235,12 +243,30 @@ def command( cls = TyperCommand def decorator(f: CommandFunctionType) -> CommandFunctionType: + if inspect.iscoroutinefunction(f): + return f + def add_runner(f: CommandFunctionType) -> CommandFunctionType: + @wraps(f) + def runner(*args, **kwargs): + if sys.version_info >= (3, 11) and self.loop_factory: + with asyncio.Runner(loop_factory=self.loop_factory) as runner: + runner.run(f(*args, **kwargs)) + else: + asyncio.run(f(*args, **kwargs)) + + return runner + + if inspect.iscoroutinefunction(f): + callback = add_runner(f) + else: + callback = f + self.registered_commands.append( CommandInfo( name=name, cls=cls, context_settings=context_settings, - callback=f, + callback=callback, help=help, epilog=epilog, short_help=short_help, @@ -253,7 +279,7 @@ def decorator(f: CommandFunctionType) -> CommandFunctionType: rich_help_panel=rich_help_panel, ) ) - return f + return decorator