Skip to content

Commit

Permalink
Merge pull request #122 from simonsobs/util
Browse files Browse the repository at this point in the history
Add a util func `until_true()`
  • Loading branch information
TaiSakuma authored Nov 5, 2024
2 parents 182c9ef + 21cad50 commit 831a941
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 0 deletions.
3 changes: 3 additions & 0 deletions nextline/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
'ExcThread',
'ThreadTaskIdComposer',
'Timer',
'UntilNotNoneTimeout',
'until_true',
'is_timezone_aware',
'utc_timestamp',
]
Expand All @@ -45,4 +47,5 @@
from .thread_exception import ExcThread
from .thread_task_id import ThreadTaskIdComposer
from .timer import Timer
from .until import UntilNotNoneTimeout, until_true
from .utc import is_timezone_aware, utc_timestamp
106 changes: 106 additions & 0 deletions nextline/utils/until.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import asyncio
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from typing import Optional


class UntilNotNoneTimeout(Exception):
pass


async def until_true(
func: Callable[[], bool] | Callable[[], Awaitable[bool]],
/,
*,
timeout: Optional[float] = None,
interval: float = 0,
) -> None:
'''Return when `func` returns `True` or a truthy value.
Parameters:
-----------
func
A callable that returns either a boolean or an awaitable that returns a
boolean.
timeout
The maximum number of seconds to wait for `func` to return `True`.
If `None`, wait indefinitely.
interval
The number of seconds to wait before checking `func` again.
Examples
--------
The `func` returns `True` when the third time it is called:
>>> def gen():
... print('Once')
... yield False
... print('Twice')
... yield False
... print('Thrice')
... yield True
... print('Never reached')
>>> g = gen()
>>> func = g.__next__
>>> asyncio.run(until_true(func))
Once
Twice
Thrice
The `afunc` is an async version of `func`:
>>> async def agen():
... print('Once')
... yield False
... print('Twice')
... yield False
... print('Thrice')
... yield True
... print('Never reached')
>>> g = agen()
>>> afunc = g.__anext__
>>> asyncio.run(until_true(afunc))
Once
Twice
Thrice
An exception will be raised if `timeout` has passed before `True` is
returned:
>>> async def gen_none():
... while True:
... yield False
>>> g = gen_none()
>>> afunc = g.__anext__
>>> asyncio.run(until_true(afunc, timeout=0.001)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
UntilNotNoneTimeout: Timed out after 0.001 seconds.
'''

async def call_func() -> bool:
maybe_awaitable = func()
if isawaitable(maybe_awaitable):
return await maybe_awaitable
return maybe_awaitable

async def _until_true() -> None:
while not await call_func():
await asyncio.sleep(interval)
return

# NOTE: For Python 3.11+, `asyncio.timeout` can be used.

try:
return await asyncio.wait_for(_until_true(), timeout)
except asyncio.TimeoutError:
raise UntilNotNoneTimeout(
f'Timed out after {timeout} seconds. '
f'The function has not returned a non-None value: {func!r}'
)
78 changes: 78 additions & 0 deletions tests/utils/test_until.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio
from collections.abc import Awaitable, Callable
from inspect import iscoroutinefunction
from typing import NoReturn, TypeGuard, cast
from unittest.mock import Mock

import pytest
from hypothesis import given
from hypothesis import strategies as st

from nextline.utils import UntilNotNoneTimeout, until_true


def func_factory(
counts: int, sync: bool = False
) -> Callable[[], bool] | Callable[[], Awaitable[bool]]:
assert counts

def func() -> bool:
nonlocal counts
counts -= 1
return counts == 0

async def afunc() -> bool:
return func()

return func if sync else afunc


def is_async_func(
f: Callable[[], bool] | Callable[[], Awaitable[bool]],
) -> TypeGuard[Callable[[], Awaitable[bool]]]:
return iscoroutinefunction(f)


@given(counts=st.integers(min_value=1, max_value=10))
def test_func_factory_sync(counts: int) -> None:
func = func_factory(counts, sync=True)
for _ in range(counts - 1):
assert not func()
assert func()


@given(counts=st.integers(min_value=1, max_value=10))
async def test_func_factory_async(counts: int) -> None:
func = func_factory(counts, sync=False)
assert is_async_func(func)
for _ in range(counts - 1):
assert not await func()
assert await func()


@given(counts=st.integers(min_value=1, max_value=10), sync=st.booleans())
async def test_counts(counts: int, sync: bool) -> None:
wrapped = func_factory(counts, sync=sync)
func = Mock(wraps=wrapped)
await until_true(func)
assert func.call_count == counts


@given(sync=st.booleans())
async def test_timeout(sync: bool) -> None:
counts = cast(int, float('inf'))
assert counts == counts - 1
wrapped = func_factory(counts, sync=sync)
func = Mock(wraps=wrapped)
with pytest.raises(UntilNotNoneTimeout):
await until_true(func, timeout=0.001)


@pytest.mark.timeout(5)
async def test_timeout_never_return() -> None:
async def func() -> NoReturn:
while True:
await asyncio.sleep(0)

with pytest.raises(UntilNotNoneTimeout):
await until_true(func, timeout=0.001)

0 comments on commit 831a941

Please sign in to comment.