Skip to content

Commit

Permalink
Allow typer to execute async commands by wrapping them in asyncio.run()
Browse files Browse the repository at this point in the history
  • Loading branch information
paulo-raca committed Jul 13, 2023
1 parent be20b5d commit 1f3d7d1
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 5 deletions.
40 changes: 40 additions & 0 deletions docs/tutorial/async-cmd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Typer allows you to use [async](https://docs.python.org/3/library/asyncio.html) functions.

```Python
{!../docs_src/async_cmd/async001.py!}
```

<div class="termy">

```console
$ python main.py

Hello Async World
```

</div>

It also works with commands, and you can mix regular and async commands:

```Python
{!../docs_src/async_cmd/async002.py!}
```

<div class="termy">

```console
$ python main.py sync

Hello Sync World

$ python main.py async

Hello Async World
```
</div>

!!! info
Under the hood, Typer is running your async functions with [asyncio.run()](https://docs.python.org/3/library/asyncio-runner.html#asyncio.run)

!!! warning
Typer only supports async functions on Python 3.7+
14 changes: 14 additions & 0 deletions docs_src/async_cmd/async001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio

import typer

app = typer.Typer()


async def main():
await asyncio.sleep(0)
print("Hello Async World")


if __name__ == "__main__":
typer.run(main)
20 changes: 20 additions & 0 deletions docs_src/async_cmd/async002.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

import typer

app = typer.Typer()


@app.command("sync")
def command_sync():
print("Hello Sync World")


@app.command("async")
async def command_async():
await asyncio.sleep(0)
print("Hello Async World")


if __name__ == "__main__":
app()
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ nav:
- Typer Callback: tutorial/commands/callback.md
- One or Multiple Commands: tutorial/commands/one-or-multiple.md
- Using the Context: tutorial/commands/context.md
- Async functions: tutorial/async-cmd.md
- CLI Option autocompletion: tutorial/options-autocompletion.md
- CLI Parameter Types:
- CLI Parameter Types Intro: tutorial/parameter-types/index.md
Expand Down
Empty file.
45 changes: 45 additions & 0 deletions tests/test_tutorial/test_async_cmd/test_async001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import subprocess
import sys

import pytest
import typer
from typer.testing import CliRunner

from docs_src.async_cmd import async001 as mod

runner = CliRunner()


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="typer support for async functions requires python3.7 or higher",
)
def test_cli():
app = typer.Typer()
app.command()(mod.main)
result = runner.invoke(app, [])
assert result.output == "Hello Async World\n"


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="typer support for async functions requires python3.7 or higher",
)
def test_execute():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert result.stdout == "Hello Async World\n"


def test_script():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "Usage" in result.stdout
59 changes: 59 additions & 0 deletions tests/test_tutorial/test_async_cmd/test_async002.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import subprocess
import sys

import pytest
from typer.testing import CliRunner

from docs_src.async_cmd import async002 as mod

app = mod.app

runner = CliRunner()


def test_command_sync():
result = runner.invoke(app, ["sync"])
assert result.output == "Hello Sync World\n"


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="typer support for async functions requires python3.7 or higher",
)
def test_command_async():
result = runner.invoke(app, ["async"])
assert result.output == "Hello Async World\n"


def test_execute_sync():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "sync"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert result.stdout == "Hello Sync World\n"


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="typer support for async functions requires python3.7 or higher",
)
def test_execute_async():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "async"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert result.stdout == "Hello Async World\n"


def test_script():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "Usage" in result.stdout
17 changes: 15 additions & 2 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
from pathlib import Path
from traceback import FrameSummary, StackSummary
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from uuid import UUID

import click
Expand All @@ -34,7 +45,7 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .utils import ensure_sync, get_params_from_function

try:
import rich
Expand Down Expand Up @@ -235,6 +246,8 @@ def command(
cls = TyperCommand

def decorator(f: CommandFunctionType) -> CommandFunctionType:
f = cast(CommandFunctionType, ensure_sync(f))

self.registered_commands.append(
CommandInfo(
name=name,
Expand Down
39 changes: 36 additions & 3 deletions typer/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import inspect
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints

from typing_extensions import Annotated
from functools import wraps
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)

from typing_extensions import Annotated, ParamSpec

from ._typing import get_args, get_origin
from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta
Expand Down Expand Up @@ -185,3 +198,23 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
name=param.name, default=default, annotation=annotation
)
return params


P = ParamSpec("P")
R = TypeVar("R")


def ensure_sync(f: Callable[P, Union[R, Coroutine[Any, Any, R]]]) -> Callable[P, R]:
# If `f` is an async function, wrap it into asyncio.run(f)
if not inspect.iscoroutinefunction(f):
f_sync = cast(Callable[P, R], f)
return f_sync

@wraps(f)
def run_f(*args: P.args, **kwargs: P.kwargs) -> R:
import asyncio

f_async = cast(Callable[P, Coroutine[Any, Any, R]], f)
return asyncio.run(f_async(*args, **kwargs))

return run_f

0 comments on commit 1f3d7d1

Please sign in to comment.