Skip to content

Commit

Permalink
#88: Add asyncio support
Browse files Browse the repository at this point in the history
  • Loading branch information
borissmidt committed Jul 6, 2023
1 parent f83c66d commit ca719d1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
14 changes: 14 additions & 0 deletions tests/test_annotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,17 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False):
result = runner.invoke(app, ["--force"])
assert result.exit_code == 0, result.output
assert "Forcing operation" in result.output

def test_runner_can_use_an_async_method():
app = typer.Typer()
@app.command()
async def cmd(val: Annotated[int, typer.Argument()] = 0):
print(f"hello {val}")

result = runner.invoke(app)
assert result.exit_code == 0, result.output
assert "hello 0" in result.output

result = runner.invoke(app, ["42"])
assert result.exit_code == 0, result.output
assert "hello 42" in result.output
32 changes: 28 additions & 4 deletions typer/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import inspect
import os
import sys
import traceback
from datetime import datetime
from enum import Enum
from functools import update_wrapper
from functools import update_wrapper, wraps
from pathlib import Path
from traceback import FrameSummary, StackSummary
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, ParamSpec
from uuid import UUID

import click
Expand Down Expand Up @@ -46,6 +47,11 @@
except ImportError: # pragma: nocover
rich = None # type: ignore

try:
from asyncio import Loop
except ImportError: # pragma: nocover
Loop = Any # type: ignore

_original_except_hook = sys.excepthook
_typer_developer_exception_attr_name = "__typer_developer_exception__"

Expand Down Expand Up @@ -139,6 +145,7 @@ def __init__(
pretty_exceptions_enable: bool = True,
pretty_exceptions_show_locals: bool = True,
pretty_exceptions_short: bool = True,
loop_factory: Optional[Callable[[], Loop]] = None,
):
self._add_completion = add_completion
self.rich_markup_mode: MarkupMode = rich_markup_mode
Expand Down Expand Up @@ -167,6 +174,7 @@ def __init__(
self.registered_groups: List[TyperInfo] = []
self.registered_commands: List[CommandInfo] = []
self.registered_callback: Optional[TyperInfo] = None
self.loop_factory = loop_factory

def callback(
self,
Expand Down Expand Up @@ -235,12 +243,28 @@ def command(
cls = TyperCommand

def decorator(f: CommandFunctionType) -> CommandFunctionType:
def add_runner(f: CommandFunctionType) -> CommandFunctionType:
@wraps(f)
def runner(*args, **kwargs) -> Any:
if sys.version_info >= (3, 11) and self.loop_factory:
with asyncio.Runner(loop_factory=self.loop_factory) as runner:
return runner.run(f(*args, **kwargs))
else:
return asyncio.run(f(*args, **kwargs))

return runner

if inspect.iscoroutinefunction(f):
callback = add_runner(f)
else:
callback = f

self.registered_commands.append(
CommandInfo(
name=name,
cls=cls,
context_settings=context_settings,
callback=f,
callback=callback,
help=help,
epilog=epilog,
short_help=short_help,
Expand All @@ -253,7 +277,7 @@ def decorator(f: CommandFunctionType) -> CommandFunctionType:
rich_help_panel=rich_help_panel,
)
)
return f


return decorator

Expand Down

0 comments on commit ca719d1

Please sign in to comment.