Skip to content

Commit c87a47c

Browse files
committed
Merge branch 'release/1.2.1'
2 parents 7ac0adf + 5d623c8 commit c87a47c

File tree

6 files changed

+254
-7
lines changed

6 files changed

+254
-7
lines changed

.flake8

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ ignore =
8888
N802,
8989
; Do not perform function calls in argument defaults.
9090
B008,
91+
; Found except `BaseException`
92+
WPS424,
9193

9294
; all init files
9395
__init__.py:

README.md

+51
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,54 @@ with graph.sync_ctx() as ctx:
125125
```
126126

127127
The ParamInfo has the information about name and parameters signature. It's useful if you want to create a dependency that changes based on parameter name, or signature.
128+
129+
130+
## Exception propagation
131+
132+
By default if error happens within the context, we send this error to the dependency,
133+
so you can close it properly. You can disable this functionality by setting `exception_propagation` parameter to `False`.
134+
135+
Let's imagine that you want to get a database session from pool and commit after the function is done.
136+
137+
138+
```python
139+
async def get_session():
140+
session = sessionmaker()
141+
142+
yield session
143+
144+
await session.commit()
145+
146+
```
147+
148+
But what if the error happened when the dependant function was called? In this case you want to rollback, instead of commit.
149+
To solve this problem, you can just wrap the `yield` statement in `try except` to handle the error.
150+
151+
```python
152+
async def get_session():
153+
session = sessionmaker()
154+
155+
try:
156+
yield session
157+
except Exception:
158+
await session.rollback()
159+
return
160+
161+
await session.commit()
162+
163+
```
164+
165+
**Also, as a library developer, you can disable exception propagation**. If you do so, then no exception will ever be propagated to dependencies and no such `try except` expression will ever work.
166+
167+
168+
Example of disabled propogation.
169+
170+
```python
171+
172+
graph = DependencyGraph(target_func)
173+
174+
with graph.sync_ctx(exception_propagation=False) as ctx:
175+
print(ctx.resolve_kwargs())
176+
177+
178+
```

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "taskiq-dependencies"
3-
version = "1.2.0"
3+
version = "1.2.1"
44
description = "FastAPI like dependency injection implementation"
55
authors = ["Pavel Kirilin <[email protected]>"]
66
readme = "README.md"

taskiq_dependencies/ctx.py

+55-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import inspect
33
from copy import copy
4+
from logging import getLogger
45
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
56

67
from taskiq_dependencies.utils import ParamInfo
@@ -9,18 +10,23 @@
910
from taskiq_dependencies.graph import DependencyGraph # pragma: no cover
1011

1112

13+
logger = getLogger("taskiq.dependencies.ctx")
14+
15+
1216
class BaseResolveContext:
1317
"""Base resolver context."""
1418

1519
def __init__(
1620
self,
1721
graph: "DependencyGraph",
1822
initial_cache: Optional[Dict[Any, Any]] = None,
23+
exception_propagation: bool = True,
1924
) -> None:
2025
self.graph = graph
2126
self.opened_dependencies: List[Any] = []
2227
self.sub_contexts: "List[Any]" = []
2328
self.initial_cache = initial_cache or {}
29+
self.propagate_excs = exception_propagation
2430

2531
def traverse_deps( # noqa: C901, WPS210
2632
self,
@@ -116,18 +122,34 @@ def __enter__(self) -> "SyncResolveContext":
116122
return self
117123

118124
def __exit__(self, *args: Any) -> None:
119-
self.close()
125+
self.close(*args)
120126

121-
def close(self) -> None:
127+
def close(self, *args: Any) -> None: # noqa: C901
122128
"""
123129
Close all opened dependencies.
124130
125131
This function runs teardown of all dependencies.
132+
133+
:param args: exception info if any.
126134
"""
135+
exception_found = False
136+
if args[1] is not None and self.propagate_excs:
137+
exception_found = True
127138
for ctx in self.sub_contexts:
128-
ctx.close()
139+
ctx.close(*args)
129140
for dep in reversed(self.opened_dependencies):
130141
if inspect.isgenerator(dep):
142+
if exception_found:
143+
try:
144+
dep.throw(*args)
145+
except BaseException as exc:
146+
logger.warning(
147+
"Exception found on dependency teardown %s",
148+
exc,
149+
exc_info=True,
150+
)
151+
continue
152+
continue
131153
for _ in dep: # noqa: WPS328
132154
pass # noqa: WPS420
133155

@@ -201,21 +223,48 @@ async def __aenter__(self) -> "AsyncResolveContext":
201223
return self
202224

203225
async def __aexit__(self, *args: Any) -> None:
204-
await self.close()
226+
await self.close(*args)
205227

206-
async def close(self) -> None: # noqa: C901
228+
async def close(self, *args: Any) -> None: # noqa: C901
207229
"""
208230
Close all opened dependencies.
209231
210232
This function runs teardown of all dependencies.
233+
234+
:param args: exception info if any.
211235
"""
236+
exception_found = False
237+
if args[1] is not None and self.propagate_excs:
238+
exception_found = True
212239
for ctx in self.sub_contexts:
213-
await ctx.close() # type: ignore
240+
await ctx.close(*args) # type: ignore
214241
for dep in reversed(self.opened_dependencies):
215242
if inspect.isgenerator(dep):
243+
if exception_found:
244+
try:
245+
dep.throw(*args)
246+
except BaseException as exc:
247+
logger.warning(
248+
"Exception found on dependency teardown %s",
249+
exc,
250+
exc_info=True,
251+
)
252+
continue
253+
continue
216254
for _ in dep: # noqa: WPS328
217255
pass # noqa: WPS420
218256
elif inspect.isasyncgen(dep):
257+
if exception_found:
258+
try:
259+
await dep.athrow(*args)
260+
except BaseException as exc:
261+
logger.warning(
262+
"Exception found on dependency teardown %s",
263+
exc,
264+
exc_info=True,
265+
)
266+
continue
267+
continue
219268
async for _ in dep: # noqa: WPS328
220269
pass # noqa: WPS420
221270

taskiq_dependencies/graph.py

+8
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,43 @@ def is_empty(self) -> bool:
4141
def async_ctx(
4242
self,
4343
initial_cache: Optional[Dict[Any, Any]] = None,
44+
exception_propagation: bool = True,
4445
) -> AsyncResolveContext:
4546
"""
4647
Create dependency resolver context.
4748
4849
This context is used to actually resolve dependencies.
4950
5051
:param initial_cache: initial cache dict.
52+
:param exception_propagation: If true, all found errors within
53+
context will be propagated to dependencies.
5154
:return: new resolver context.
5255
"""
5356
return AsyncResolveContext(
5457
self,
5558
initial_cache,
59+
exception_propagation,
5660
)
5761

5862
def sync_ctx(
5963
self,
6064
initial_cache: Optional[Dict[Any, Any]] = None,
65+
exception_propagation: bool = True,
6166
) -> SyncResolveContext:
6267
"""
6368
Create dependency resolver context.
6469
6570
This context is used to actually resolve dependencies.
6671
6772
:param initial_cache: initial cache dict.
73+
:param exception_propagation: If true, all found errors within
74+
context will be propagated to dependencies.
6875
:return: new resolver context.
6976
"""
7077
return SyncResolveContext(
7178
self,
7279
initial_cache,
80+
exception_propagation,
7381
)
7482

7583
def _build_graph(self) -> None: # noqa: C901, WPS210

tests/test_graph.py

+137
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,140 @@ def target(class_val: str = Depends(TeClass("tval"))) -> None:
310310

311311
info: str = kwargs["class_val"]
312312
assert info == "tval"
313+
314+
315+
def test_exception_generators() -> None:
316+
317+
errors_found = 0
318+
319+
def my_generator() -> Generator[int, None, None]:
320+
nonlocal errors_found
321+
try:
322+
yield 1
323+
except ValueError:
324+
errors_found += 1
325+
326+
def target(_: int = Depends(my_generator)) -> None:
327+
raise ValueError()
328+
329+
with pytest.raises(ValueError):
330+
with DependencyGraph(target=target).sync_ctx() as g:
331+
target(**g.resolve_kwargs())
332+
333+
assert errors_found == 1
334+
335+
336+
@pytest.mark.anyio
337+
async def test_async_exception_generators() -> None:
338+
339+
errors_found = 0
340+
341+
async def my_generator() -> AsyncGenerator[int, None]:
342+
nonlocal errors_found
343+
try:
344+
yield 1
345+
except ValueError:
346+
errors_found += 1
347+
348+
def target(_: int = Depends(my_generator)) -> None:
349+
raise ValueError()
350+
351+
with pytest.raises(ValueError):
352+
async with DependencyGraph(target=target).async_ctx() as g:
353+
target(**(await g.resolve_kwargs()))
354+
355+
assert errors_found == 1
356+
357+
358+
@pytest.mark.anyio
359+
async def test_async_exception_generators_multiple() -> None:
360+
361+
errors_found = 0
362+
363+
async def my_generator() -> AsyncGenerator[int, None]:
364+
nonlocal errors_found
365+
try:
366+
yield 1
367+
except ValueError:
368+
errors_found += 1
369+
370+
def target(
371+
_a: int = Depends(my_generator, use_cache=False),
372+
_b: int = Depends(my_generator, use_cache=False),
373+
_c: int = Depends(my_generator, use_cache=False),
374+
) -> None:
375+
raise ValueError()
376+
377+
with pytest.raises(ValueError):
378+
async with DependencyGraph(target=target).async_ctx() as g:
379+
target(**(await g.resolve_kwargs()))
380+
381+
assert errors_found == 3
382+
383+
384+
@pytest.mark.anyio
385+
async def test_async_exception_in_teardown() -> None:
386+
387+
errors_found = 0
388+
389+
async def my_generator() -> AsyncGenerator[int, None]:
390+
nonlocal errors_found
391+
try:
392+
yield 1
393+
except ValueError:
394+
errors_found += 1
395+
raise Exception()
396+
397+
def target(_: int = Depends(my_generator)) -> None:
398+
raise ValueError()
399+
400+
with pytest.raises(ValueError):
401+
async with DependencyGraph(target=target).async_ctx() as g:
402+
target(**(await g.resolve_kwargs()))
403+
404+
405+
@pytest.mark.anyio
406+
async def test_async_propagation_disabled() -> None:
407+
408+
errors_found = 0
409+
410+
async def my_generator() -> AsyncGenerator[int, None]:
411+
nonlocal errors_found
412+
try:
413+
yield 1
414+
except ValueError:
415+
errors_found += 1
416+
raise Exception()
417+
418+
def target(_: int = Depends(my_generator)) -> None:
419+
raise ValueError()
420+
421+
with pytest.raises(ValueError):
422+
async with DependencyGraph(target=target).async_ctx(
423+
exception_propagation=False,
424+
) as g:
425+
target(**(await g.resolve_kwargs()))
426+
427+
assert errors_found == 0
428+
429+
430+
def test_sync_propagation_disabled() -> None:
431+
432+
errors_found = 0
433+
434+
def my_generator() -> Generator[int, None, None]:
435+
nonlocal errors_found
436+
try:
437+
yield 1
438+
except ValueError:
439+
errors_found += 1
440+
raise Exception()
441+
442+
def target(_: int = Depends(my_generator)) -> None:
443+
raise ValueError()
444+
445+
with pytest.raises(ValueError):
446+
with DependencyGraph(target=target).sync_ctx(exception_propagation=False) as g:
447+
target(**(g.resolve_kwargs()))
448+
449+
assert errors_found == 0

0 commit comments

Comments
 (0)