Skip to content

Commit 650b7a5

Browse files
committed
feat(autorun): add auto_await to AutorunOptions so that one can define an autorun/view as a decorator of a function without automatically awaiting its result, when auto_await is set to False, which activates the new behavior, the decorated function passes asyncio.iscoroutinefunction test, useful for certain libraries like quart
1 parent 4527ef3 commit 650b7a5

File tree

8 files changed

+137
-22
lines changed

8 files changed

+137
-22
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## Version 0.17.2
4+
5+
- feat(autorun): add `auto_await` to `AutorunOptions` so that one can define an autorun/view as a decorator of a function without automatically awaiting its result, when `auto_await` is set to `False`, which activates the new behavior, the decorated function passes `asyncio.iscoroutinefunction` test, useful for certain libraries like quart
6+
37
## Version 0.17.1
48

59
- refactor(core): allow `None` type for state, action and event types in `ReducerResult` and `CompleteReducerResult`

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "python-redux"
3-
version = "0.17.1"
3+
version = "0.17.2"
44
description = "Redux implementation for Python"
55
authors = ["Sassan Haradji <[email protected]>"]
66
license = "Apache-2.0"

redux/autorun.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
# ruff: noqa: D100, D101, D102, D103, D104, D105, D107
22
from __future__ import annotations
33

4+
import asyncio
45
import functools
56
import inspect
67
import weakref
78
from asyncio import Future, Task, iscoroutine, iscoroutinefunction
8-
from typing import TYPE_CHECKING, Any, Callable, Concatenate, Generic, cast
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
Callable,
13+
Concatenate,
14+
Coroutine,
15+
Generator,
16+
Generic,
17+
TypeVar,
18+
cast,
19+
)
920

1021
from redux.basic_types import (
1122
Action,
@@ -22,6 +33,25 @@
2233
from redux.main import Store
2334

2435

36+
T = TypeVar('T')
37+
38+
39+
class AwaitableWrapper(Generic[T]):
40+
def __init__(self, coro: Coroutine[None, None, T]) -> None:
41+
self.coro = coro
42+
self.awaited = False
43+
44+
def __await__(self) -> Generator[None, None, T]:
45+
self.awaited = True
46+
return self.coro.__await__()
47+
48+
def close(self) -> None:
49+
self.coro.close()
50+
51+
def __repr__(self) -> str:
52+
return f'AwaitableWrapper({self.coro}, awaited={self.awaited})'
53+
54+
2555
class Autorun(
2656
Generic[
2757
State,
@@ -45,6 +75,7 @@ def __init__(
4575
],
4676
options: AutorunOptions[AutorunOriginalReturnType],
4777
) -> None:
78+
self.__name__ = func.__name__
4879
self._store = store
4980
self._selector = selector
5081
self._comparator = comparator
@@ -55,6 +86,11 @@ def __init__(
5586
self._func = weakref.WeakMethod(func, self.unsubscribe)
5687
else:
5788
self._func = weakref.ref(func, self.unsubscribe)
89+
self._is_coroutine = (
90+
asyncio.coroutines._is_coroutine # pyright: ignore [reportAttributeAccessIssue] # noqa: SLF001
91+
if asyncio.iscoroutinefunction(func) and not options.auto_await
92+
else None
93+
)
5894
self._options = options
5995

6096
self._last_selector_result: SelectorOutput | None = None
@@ -120,11 +156,11 @@ def _task_callback(
120156
],
121157
task: Task,
122158
*,
123-
future: Future | None,
159+
future: Future,
124160
) -> None:
125161
task.add_done_callback(
126162
lambda result: (
127-
future.set_result(result.result()) if future else None,
163+
future.set_result(result.result()),
128164
self.inform_subscribers(),
129165
),
130166
)
@@ -184,15 +220,27 @@ def _call(
184220
)
185221
create_task = self._store._create_task # noqa: SLF001
186222
if iscoroutine(value) and create_task:
187-
future = Future()
188-
self._latest_value = cast(AutorunOriginalReturnType, future)
189-
create_task(
190-
value,
191-
callback=functools.partial(
192-
self._task_callback,
193-
future=future,
194-
),
195-
)
223+
if self._options.auto_await:
224+
future = Future()
225+
self._latest_value = cast(AutorunOriginalReturnType, future)
226+
create_task(
227+
value,
228+
callback=functools.partial(
229+
self._task_callback,
230+
future=future,
231+
),
232+
)
233+
else:
234+
if (
235+
self._latest_value is not None
236+
and isinstance(self._latest_value, AwaitableWrapper)
237+
and not self._latest_value.awaited
238+
):
239+
self._latest_value.close()
240+
self._latest_value = cast(
241+
AutorunOriginalReturnType,
242+
AwaitableWrapper(value),
243+
)
196244
else:
197245
self._latest_value = value
198246
self.inform_subscribers()

redux/basic_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class CreateStoreOptions(Immutable, Generic[Action, Event]):
133133

134134
class AutorunOptions(Immutable, Generic[AutorunOriginalReturnType]):
135135
default_value: AutorunOriginalReturnType | None = None
136+
auto_await: bool = True
136137
initial_call: bool = True
137138
reactive: bool = True
138139
keep_ref: bool = True
@@ -167,6 +168,8 @@ def subscribe(
167168

168169
def unsubscribe(self: AutorunReturnType) -> None: ...
169170

171+
__name__: str
172+
170173

171174
class AutorunDecorator(
172175
Protocol,

redux/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def decorator(
388388
func=cast(Callable, func),
389389
options=AutorunOptions(
390390
default_value=_options.default_value,
391+
auto_await=False,
391392
initial_call=False,
392393
reactive=False,
393394
keep_ref=_options.keep_ref,

redux_pytest/fixtures/snapshot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def monitor(self: StoreSnapshot[State], selector: Callable[[State], Any]) -> Non
128128
"""Monitor the state of the store and take snapshots."""
129129

130130
@self.store.autorun(selector=selector)
131-
def _(state: object | None) -> None:
132-
if state is None:
133-
return
131+
def _(state: object) -> None:
134132
self.take(selector=lambda _: state)
135133

136134
def close(self: StoreSnapshot[State]) -> None:

tests/test_async.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
if TYPE_CHECKING:
2525
from redux_pytest.fixtures.event_loop import LoopThread
2626

27-
INCREMENTS = 2
27+
INCREMENTS = 20
2828

2929

3030
class StateType(Immutable):
@@ -91,16 +91,18 @@ def test_autorun(
9191
event_loop: LoopThread,
9292
) -> None:
9393
@store.autorun(lambda state: state.value)
94-
async def _(value: int) -> int:
94+
async def sync_mirror(value: int) -> int:
9595
await asyncio.sleep(value / 10)
9696
store.dispatch(SetMirroredValueAction(value=value))
9797
return value
9898

99+
assert not asyncio.iscoroutinefunction(sync_mirror)
100+
99101
@store.autorun(
100102
lambda state: state.mirrored_value,
101103
lambda state: state.mirrored_value >= INCREMENTS,
102104
)
103-
async def _(mirrored_value: int) -> None:
105+
def _(mirrored_value: int) -> None:
104106
if mirrored_value < INCREMENTS:
105107
return
106108
event_loop.stop()
@@ -109,6 +111,36 @@ async def _(mirrored_value: int) -> None:
109111
dispatch_actions(store)
110112

111113

114+
def test_autorun_autoawait(
115+
store: StoreType,
116+
event_loop: LoopThread,
117+
) -> None:
118+
@store.autorun(lambda state: state.value, options=AutorunOptions(auto_await=False))
119+
async def sync_mirror(value: int) -> int:
120+
store.dispatch(SetMirroredValueAction(value=value))
121+
return value * 2
122+
123+
assert asyncio.iscoroutinefunction(sync_mirror)
124+
125+
@store.autorun(lambda state: (state.value, state.mirrored_value))
126+
async def _(values: tuple[int, int]) -> None:
127+
value, mirrored_value = values
128+
if mirrored_value != value:
129+
assert 'awaited=False' in str(sync_mirror())
130+
await sync_mirror()
131+
assert 'awaited=True' in str(sync_mirror())
132+
with pytest.raises(
133+
RuntimeError,
134+
match=r'^cannot reuse already awaited coroutine$',
135+
):
136+
await sync_mirror()
137+
elif value < INCREMENTS:
138+
store.dispatch(IncrementAction())
139+
else:
140+
event_loop.stop()
141+
store.dispatch(FinishAction())
142+
143+
112144
def test_autorun_default_value(
113145
store: StoreType,
114146
event_loop: LoopThread,
@@ -122,7 +154,7 @@ async def _(value: int) -> int:
122154
lambda state: state.mirrored_value,
123155
lambda state: state.mirrored_value >= INCREMENTS,
124156
)
125-
async def _(mirrored_value: int) -> None:
157+
def _(mirrored_value: int) -> None:
126158
if mirrored_value < INCREMENTS:
127159
return
128160
event_loop.stop()
@@ -145,7 +177,10 @@ async def doubled(value: int) -> int:
145177
@store.autorun(lambda state: state.value)
146178
async def _(value: int) -> None:
147179
assert await doubled() == value * 2
148-
for _ in range(10):
180+
with pytest.raises(
181+
RuntimeError,
182+
match=r'^cannot reuse already awaited coroutine$',
183+
):
149184
await doubled()
150185
if value < INCREMENTS:
151186
store.dispatch(IncrementAction())
@@ -155,6 +190,30 @@ async def _(value: int) -> None:
155190
assert calls == list(range(INCREMENTS + 1))
156191

157192

193+
def test_view_await(store: StoreType, event_loop: LoopThread) -> None:
194+
calls = []
195+
196+
@store.view(lambda state: state.value)
197+
async def doubled(value: int) -> int:
198+
calls.append(value)
199+
return value * 2
200+
201+
assert asyncio.iscoroutinefunction(doubled)
202+
203+
@store.autorun(lambda state: state.value)
204+
async def _(value: int) -> None:
205+
calls_length = len(calls)
206+
assert await doubled() == value * 2
207+
assert len(calls) == calls_length + 1
208+
209+
if value < INCREMENTS:
210+
store.dispatch(IncrementAction())
211+
else:
212+
event_loop.stop()
213+
store.dispatch(FinishAction())
214+
assert calls == list(range(INCREMENTS + 1))
215+
216+
158217
def test_view_with_args(
159218
store: StoreType,
160219
event_loop: LoopThread,

tests/test_autorun.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ def store() -> Generator[StoreType, None, None]:
8282

8383
def test_general(store_snapshot: StoreSnapshot, store: StoreType) -> None:
8484
@store.autorun(lambda state: state.value)
85-
def _(value: int) -> int:
85+
def decorated(value: int) -> int:
8686
store_snapshot.take()
8787
return value
8888

89+
assert decorated.__name__ == 'decorated'
90+
8991

9092
def test_ignore_attribute_error_in_selector(store: StoreType) -> None:
9193
@store.autorun(lambda state: cast(Any, state).non_existing)

0 commit comments

Comments
 (0)