From 80eec9687fcb0b39007013ba6a376b50091f7981 Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:56:13 +0200 Subject: [PATCH] Add PosArgT typing to trio.run (#3022) * add PosArgT typing to run() * add type tests --- src/trio/_core/_run.py | 4 +-- src/trio/_core/_tests/test_run.py | 2 +- src/trio/_core/_tests/type_tests/run.py | 46 +++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 src/trio/_core/_tests/type_tests/run.py diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 1512fdf954..c6faabaf72 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2184,8 +2184,8 @@ def setup_runner( def run( - async_fn: Callable[..., Awaitable[RetT]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + *args: Unpack[PosArgT], clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index d54d9f1813..c4639c4342 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -76,7 +76,7 @@ async def trivial(x: T) -> T: with pytest.raises(TypeError): # Missing an argument - _core.run(trivial) + _core.run(trivial) # type: ignore[arg-type] with pytest.raises(TypeError): # Not an async function diff --git a/src/trio/_core/_tests/type_tests/run.py b/src/trio/_core/_tests/type_tests/run.py new file mode 100644 index 0000000000..c121ce6c7a --- /dev/null +++ b/src/trio/_core/_tests/type_tests/run.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Sequence, overload + +import trio +from typing_extensions import assert_type + + +async def sleep_sort(values: Sequence[float]) -> list[float]: + return [1] + + +async def has_optional(arg: int | None = None) -> int: + return 5 + + +@overload +async def foo_overloaded(arg: int) -> str: ... + + +@overload +async def foo_overloaded(arg: str) -> int: ... + + +async def foo_overloaded(arg: int | str) -> int | str: + if isinstance(arg, str): + return 5 + return "hello" + + +v = trio.run( + sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0) +) +assert_type(v, "list[float]") +trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type] +trio.run(sleep_sort) # type: ignore[arg-type] + +r = trio.run(has_optional) +assert_type(r, int) +r = trio.run(has_optional, 5) +trio.run(has_optional, 7, 8) # type: ignore[arg-type] +trio.run(has_optional, "hello") # type: ignore[arg-type] + + +assert_type(trio.run(foo_overloaded, 5), str) +assert_type(trio.run(foo_overloaded, ""), int)