Skip to content

Commit

Permalink
Properly handle NAK frames and implement retries (#610)
Browse files Browse the repository at this point in the history
* Retry transmits

* Handle retransmitted NCP frames

* Use named constants

* Drop log level

* Use `ConnectionResetError` instead of `RuntimeError`

* Implement dynamic ACK timeout

* Fix _rec_seq in tests

* Reduce UART concurrency to 1

* Add unit tests

---------

Co-authored-by: David Mulcahey <[email protected]>
  • Loading branch information
puddly and dmulcahey authored Feb 5, 2024
1 parent 1c43536 commit 2fe6e86
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
NETWORK_OPS_TIMEOUT = 10
NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1

MAX_COMMAND_CONCURRENCY = 4
MAX_COMMAND_CONCURRENCY = 1


class EZSP:
Expand Down
2 changes: 1 addition & 1 deletion bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from bellows.typing import GatewayType

LOGGER = logging.getLogger(__name__)
EZSP_CMD_TIMEOUT = 5
EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2


class ProtocolHandler(abc.ABC):
Expand Down
88 changes: 79 additions & 9 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import binascii
import logging
import sys
import time

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
Expand All @@ -17,6 +18,12 @@
LOGGER = logging.getLogger(__name__)
RESET_TIMEOUT = 5

ASH_ACK_RETRIES = 4

ASH_RX_ACK_INIT = 1.6
ASH_RX_ACK_MIN = 0.4
ASH_RX_ACK_MAX = 3.2


class Gateway(asyncio.Protocol):
FLAG = b"\x7E" # Marks end of frame
Expand Down Expand Up @@ -47,6 +54,7 @@ def __init__(self, application, connected_future=None, connection_done_future=No
self._connection_done_future = connection_done_future

self._send_task = None
self._ack_timeout = ASH_RX_ACK_INIT

def connection_made(self, transport):
"""Callback when the uart is connected"""
Expand Down Expand Up @@ -118,10 +126,18 @@ def data_frame_received(self, data):
"""Data frame receive handler"""
LOGGER.debug("Data frame: %s", binascii.hexlify(data))
seq = (data[0] & 0b01110000) >> 4
self._rec_seq = (seq + 1) % 8
self.write(self._ack_frame())
self._handle_ack(data[0])
self._application.frame_received(self._randomize(data[1:-3]))
re_tx = (data[0] & 0b00001000) >> 3

if seq == self._rec_seq:
self._rec_seq = (seq + 1) % 8
self.write(self._ack_frame())

self._handle_ack(data[0])
self._application.frame_received(self._randomize(data[1:-3]))
elif re_tx:
self.write(self._ack_frame())
else:
self.write(self._nak_frame())

def ack_frame_received(self, data):
"""Acknowledgement frame receive handler"""
Expand Down Expand Up @@ -268,13 +284,67 @@ async def _send_loop(self):
if item is self.Terminator:
break
data, seq = item
success = False
rxmit = 0
while not success:

for attempt in range(ASH_ACK_RETRIES + 1):
self._pending = (seq, asyncio.get_event_loop().create_future())

send_time = time.monotonic()
rxmit = attempt > 0
self.write(self._data_frame(data, seq, rxmit))
rxmit = 1
success = await self._pending[1]

try:
async with asyncio_timeout(self._ack_timeout):
success = await self._pending[1]
except asyncio.TimeoutError:
success = None
LOGGER.debug(
"Frame %s (seq %s) timed out on attempt %d, retrying",
data,
seq,
attempt,
)
else:
if success:
break

LOGGER.debug(
"Frame %s (seq %s) failed to transmit on attempt %d, retrying",
data,
seq,
attempt,
)
finally:
delta = time.monotonic() - send_time

if success is not None:
new_ack_timeout = max(
ASH_RX_ACK_MIN,
min(
ASH_RX_ACK_MAX,
(7 / 8) * self._ack_timeout + 0.5 * delta,
),
)
else:
new_ack_timeout = max(
ASH_RX_ACK_MIN, min(ASH_RX_ACK_MAX, 2 * self._ack_timeout)
)

if abs(self._ack_timeout - new_ack_timeout) > 0.01:
LOGGER.debug(
"Adjusting ACK timeout from %.2f to %.2f",
self._ack_timeout,
new_ack_timeout,
)

self._ack_timeout = new_ack_timeout
self._pending = (-1, None)
else:
self.connection_lost(
ConnectionResetError(
f"Failed to transmit ASH frame after {ASH_ACK_RETRIES} retries"
)
)
return

def _handle_ack(self, control):
"""Handle an acknowledgement frame"""
Expand Down
111 changes: 111 additions & 0 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def test_substitute_received(gw):

def test_partial_data_received(gw):
gw.write = MagicMock()
gw._rec_seq = 5
gw.data_received(b"\x54\x79\xa1\xb0")
gw.data_received(b"\x50\xf2\x6e\x7e")
assert gw.write.call_count == 1
Expand All @@ -209,6 +210,7 @@ def test_crc_error(gw):

def test_crc_error_and_valid_frame(gw):
gw.write = MagicMock()
gw._rec_seq = 5
gw.data_received(
b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~\x54\x79\xa1\xb0\x50\xf2\x6e\x7e"
)
Expand All @@ -218,6 +220,7 @@ def test_crc_error_and_valid_frame(gw):

def test_data_frame_received(gw):
gw.write = MagicMock()
gw._rec_seq = 5
gw.data_received(b"\x54\x79\xa1\xb0\x50\xf2\x6e\x7e")
assert gw.write.call_count == 1
assert gw._application.frame_received.call_count == 1
Expand Down Expand Up @@ -416,3 +419,111 @@ async def test_wait_for_startup_reset_failure(gw):
await asyncio.wait_for(gw.wait_for_startup_reset(), 0.01)

assert gw._startup_reset_future is None


ASH_ACK_MIN = 0.01


@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0)
@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2)
@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3)
async def test_retry_success():
app = MagicMock()
transport = MagicMock()
connected_future = asyncio.get_running_loop().create_future()

gw = uart.Gateway(app, connected_future)
gw.connection_made(transport)

old_timeout = gw._ack_timeout
gw.data(b"TX 1")
await asyncio.sleep(0)

# Wait more than one ACK cycle to reply
assert len(transport.write.mock_calls) == 1
await asyncio.sleep(ASH_ACK_MIN * 5)

# The gateway has retried once by now
assert len(transport.write.mock_calls) == 2

gw.frame_received(
# ash.DataFrame(frm_num=0, re_tx=0, ack_num=1, ezsp_frame=b"RX 1").to_bytes()
bytes.fromhex("01107988654851")
)

# An ACK has been received and the pending frame has been acknowledged
await asyncio.sleep(0)
assert gw._pending == (-1, None)

assert gw._ack_timeout > old_timeout

gw.close()


@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0)
@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2)
@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3)
async def test_retry_nak_then_success():
app = MagicMock()
transport = MagicMock()
connected_future = asyncio.get_running_loop().create_future()

gw = uart.Gateway(app, connected_future)
gw.connection_made(transport)

old_timeout = gw._ack_timeout
gw.data(b"TX 1")
await asyncio.sleep(0)
assert len(transport.write.mock_calls) == 1

# Wait less than one ACK cycle so that we can NAK the frame during the RX window
await asyncio.sleep(ASH_ACK_MIN)
# NAK the frame
gw.frame_received(
# ash.NakFrame(res=0, ncp_ready=0, ack_num=0).to_bytes()
bytes.fromhex("a0541a")
)

# The gateway has retried once more, instantly
await asyncio.sleep(0)
assert len(transport.write.mock_calls) == 2

# Send a proper ACK
gw.frame_received(
# ash.AckFrame(res=0, ncp_ready=0, ack_num=1).to_bytes()
bytes.fromhex("816059")
)
await asyncio.sleep(0)
assert gw._pending == (-1, None)
assert gw._ack_timeout < old_timeout

gw.close()


@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0)
@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2)
@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3)
async def test_retry_failure():
app = MagicMock()
transport = MagicMock()
connected_future = asyncio.get_running_loop().create_future()

gw = uart.Gateway(app, connected_future)
gw.connection_made(transport)

old_timeout = gw._ack_timeout
gw.data(b"TX 1")
await asyncio.sleep(0)

# Wait more than one ACK cycle to reply
assert len(transport.write.mock_calls) == 1
await asyncio.sleep(ASH_ACK_MIN * 40)

# The gateway has exhausted retries
assert len(transport.write.mock_calls) == 5

assert gw._pending == (-1, None)
assert gw._ack_timeout > old_timeout
assert gw._ack_timeout == ASH_ACK_MIN * 2**3 # max timeout

gw.close()

0 comments on commit 2fe6e86

Please sign in to comment.