Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cancel_callback in dispatch #1948

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion channels/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ async def __call__(self, scope, receive, send):
self.channel_receive = functools.partial(
self.channel_layer.receive, self.channel_name
)
# Handler to call when dispatch task is cancelled
cancel_callback = None
try:
if callable(self.channel_layer.clean_channel):
cancel_callback = functools.partial(
self.channel_layer.clean_channel, self.channel_name
)
except AttributeError:
pass
# Store send function
if self._sync:
self.base_send = async_to_sync(send)
Expand All @@ -56,7 +65,9 @@ async def __call__(self, scope, receive, send):
try:
if self.channel_layer is not None:
await await_many_dispatch(
[receive, self.channel_receive], self.dispatch
[receive, self.channel_receive],
self.dispatch,
cancel_callback=cancel_callback,
)
else:
await await_many_dispatch([receive], self.dispatch)
Expand Down
5 changes: 3 additions & 2 deletions channels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def name_that_thing(thing):
return repr(thing)


async def await_many_dispatch(consumer_callables, dispatch):
async def await_many_dispatch(consumer_callables, dispatch, cancel_callback=None):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
Expand All @@ -56,4 +56,5 @@ async def await_many_dispatch(consumer_callables, dispatch):
try:
await task
except asyncio.CancelledError:
pass
if cancel_callback:
await cancel_callback()
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio
from unittest import mock

import async_timeout
import pytest

from channels.utils import await_many_dispatch


async def sleep_task(*args):
await asyncio.sleep(10)


@pytest.mark.asyncio
async def test_cancel_callback_called():
cancel_callback = mock.AsyncMock()
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(0):
await await_many_dispatch([sleep_task], sleep_task, cancel_callback)
assert cancel_callback.called
Loading