Skip to content

Commit 3657cf3

Browse files
authored
Prevent task cancellation from propagating to ASH (#628)
* Do not allow task cancellation to propagate to ASH * Enter a failed state if we cannot send a frame * Support resolving multiple frames at once (we still limit TX_K=1)
1 parent 3875d11 commit 3657cf3

File tree

1 file changed

+46
-14
lines changed

1 file changed

+46
-14
lines changed

bellows/ash.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import abc
44
import asyncio
55
import binascii
6+
from collections.abc import Coroutine
67
import dataclasses
78
import enum
89
import logging
910
import sys
1011
import time
12+
import typing
1113

1214
if sys.version_info[:2] < (3, 11):
1315
from async_timeout import timeout as asyncio_timeout # pragma: no cover
@@ -55,7 +57,7 @@ class Reserved(enum.IntEnum):
5557

5658
# Maximum number of DATA frames the NCP can transmit without having received
5759
# acknowledgements
58-
TX_K = 1
60+
TX_K = 1 # TODO: investigate why this cannot be raised without causing a firmware crash
5961

6062
# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
6163
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
@@ -81,6 +83,23 @@ def generate_random_sequence(length: int) -> bytes:
8183
# Since the sequence is static for every frame, we only need to generate it once
8284
PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256)
8385

86+
if sys.version_info[:2] < (3, 12):
87+
create_eager_task = asyncio.create_task
88+
else:
89+
_T = typing.TypeVar("T")
90+
91+
def create_eager_task(
92+
coro: Coroutine[typing.Any, typing.Any, _T],
93+
*,
94+
name: str | None = None,
95+
loop: asyncio.AbstractEventLoop | None = None,
96+
) -> asyncio.Task[_T]:
97+
"""Create a task from a coroutine and schedule it to run immediately."""
98+
if loop is None:
99+
loop = asyncio.get_running_loop()
100+
101+
return asyncio.Task(coro, loop=loop, name=name, eager_start=True)
102+
84103

85104
class NcpState(enum.Enum):
86105
CONNECTED = "connected"
@@ -463,15 +482,14 @@ def data_received(self, data: bytes) -> None:
463482
def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
464483
# Note that ackNum is the number of the next frame the receiver expects and it
465484
# is one greater than the last frame received.
466-
ack_num = (frame.ack_num - 1) % 8
485+
for ack_num_offset in range(-TX_K, 0):
486+
ack_num = (frame.ack_num + ack_num_offset) % 8
487+
fut = self._pending_data_frames.get(ack_num)
467488

468-
fut = self._pending_data_frames.get(ack_num)
489+
if fut is None or fut.done():
490+
continue
469491

470-
if fut is None or fut.done():
471-
return
472-
473-
# _LOGGER.debug("Resolving frame %d", ack_num)
474-
self._pending_data_frames[ack_num].set_result(True)
492+
self._pending_data_frames[ack_num].set_result(True)
475493

476494
def frame_received(self, frame: AshFrame) -> None:
477495
_LOGGER.debug("Received frame %r", frame)
@@ -537,13 +555,16 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
537555
self._ncp_state = NcpState.FAILED
538556

539557
# Cancel all pending requests
540-
exc = NcpFailure(code=self._ncp_reset_code)
558+
self._enter_failed_state(self._ncp_reset_code)
559+
560+
def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
561+
exc = NcpFailure(code=reset_code)
541562

542563
for fut in self._pending_data_frames.values():
543564
if not fut.done():
544565
fut.set_exception(exc)
545566

546-
self._ezsp_protocol.reset_received(frame.reset_code)
567+
self._ezsp_protocol.reset_received(reset_code)
547568

548569
def _write_frame(
549570
self,
@@ -582,7 +603,7 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
582603
for attempt in range(ACK_TIMEOUTS):
583604
if self._ncp_state == NcpState.FAILED:
584605
_LOGGER.debug(
585-
"NCP is in a failed state, not re-sending: %r", frame
606+
"NCP is in a failed state, not sending: %r", frame
586607
)
587608
raise NcpFailure(
588609
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
@@ -618,6 +639,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
618639
self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta)
619640

620641
if attempt >= ACK_TIMEOUTS - 1:
642+
self._enter_failed_state(
643+
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
644+
)
621645
raise
622646
except NcpFailure:
623647
_LOGGER.debug(
@@ -635,6 +659,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
635659
self._change_ack_timeout(2 * self._t_rx_ack)
636660

637661
if attempt >= ACK_TIMEOUTS - 1:
662+
self._enter_failed_state(
663+
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
664+
)
638665
raise
639666
else:
640667
# Whenever an acknowledgement is received, t_rx_ack is set to
@@ -649,9 +676,14 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
649676
self._pending_data_frames.pop(frm_num)
650677

651678
async def send_data(self, data: bytes) -> None:
652-
await self._send_data_frame(
653-
# All of the other fields will be set during transmission/retries
654-
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
679+
# Sending data is a critical operation and cannot really be cancelled
680+
await asyncio.shield(
681+
create_eager_task(
682+
self._send_data_frame(
683+
# All of the other fields will be set during transmission/retries
684+
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
685+
)
686+
)
655687
)
656688

657689
def send_reset(self) -> None:

0 commit comments

Comments
 (0)