Skip to content

Commit

Permalink
Move EZSP send lock from EZSP to individual protocol handlers (#649)
Browse files Browse the repository at this point in the history
* Log frames that are not ACKed

* Move command locking and prioritization into the protocol handler

* Rename `delay_time` to `send_time`

* Cancel all pending futures when the connection is lost

* Increase ACK_TIMEOUTS from 4 to 5

* Increase the EZSP command timeout to 10s

* Do not count the ASH send time in the EZSP command timeout

* Set the NCP state to `FAILED` when we soft fail

* Always handle ACK information, even if the frame is invalid

* Remove stale constants from `Gateway`

* Guard to make sure we can't send data while the transport is closing

* Fix unit tests

* Send a NAK frame on any parsing error

* Reset the random seed every ASH test invocation

* Remove unnecessary `asyncio.get_running_loop()`

* Add a few more unit tests for coverage

* Null out the transport when we are done with it

* Fix typo when setting ncp_state

* Fix typo with buffer truncation

* Fix unit test to account for retries after NCP failure
  • Loading branch information
puddly authored Sep 13, 2024
1 parent e160be2 commit 1a0b8a7
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 81 deletions.
68 changes: 45 additions & 23 deletions bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import binascii
from collections.abc import Coroutine
import contextlib
import dataclasses
import enum
import logging
Expand Down Expand Up @@ -62,7 +63,7 @@ class Reserved(enum.IntEnum):
# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
# due to timeouts.
ACK_TIMEOUTS = 4
ACK_TIMEOUTS = 5


def generate_random_sequence(length: int) -> bytes:
Expand Down Expand Up @@ -368,14 +369,26 @@ def connection_made(self, transport):
self._ezsp_protocol.connection_made(self)

def connection_lost(self, exc):
self._transport = None
self._cancel_pending_data_frames()
self._ezsp_protocol.connection_lost(exc)

def eof_received(self):
self._ezsp_protocol.eof_received()

def _cancel_pending_data_frames(
self, exc: BaseException = RuntimeError("Connection has been closed")
):
for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

def close(self):
self._cancel_pending_data_frames()

if self._transport is not None:
self._transport.close()
self._transport = None

@staticmethod
def _stuff_bytes(data: bytes) -> bytes:
Expand All @@ -399,7 +412,9 @@ def _unstuff_bytes(data: bytes) -> bytes:
for c in data:
if escaped:
byte = c ^ 0b00100000
assert byte in RESERVED_BYTES
if byte not in RESERVED_BYTES:
raise ParsingError(f"Invalid escaped byte: 0x{byte:02X}")

out.append(byte)
escaped = False
elif c == Reserved.ESCAPE:
Expand All @@ -417,7 +432,7 @@ def data_received(self, data: bytes) -> None:
_LOGGER.debug(
"Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE
)
self._buffer = self._buffer[:MAX_BUFFER_SIZE]
self._buffer = self._buffer[-MAX_BUFFER_SIZE:]

while self._buffer:
if self._discarding_until_next_flag:
Expand Down Expand Up @@ -447,14 +462,19 @@ def data_received(self, data: bytes) -> None:
if not frame_bytes:
continue

data = self._unstuff_bytes(frame_bytes)

try:
data = self._unstuff_bytes(frame_bytes)
frame = parse_frame(data)
except Exception:
_LOGGER.debug(
"Failed to parse frame %r", frame_bytes, exc_info=True
)

with contextlib.suppress(NcpFailure):
self._write_frame(
NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq),
prefix=(Reserved.CANCEL,),
)
else:
self.frame_received(frame)
elif reserved_byte == Reserved.CANCEL:
Expand All @@ -479,7 +499,7 @@ def data_received(self, data: bytes) -> None:
f"Unexpected reserved byte found: 0x{reserved_byte:02X}"
) # pragma: no cover

def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
def _handle_ack(self, frame: DataFrame | AckFrame | NakFrame) -> None:
# Note that ackNum is the number of the next frame the receiver expects and it
# is one greater than the last frame received.
for ack_num_offset in range(-TX_K, 0):
Expand All @@ -494,14 +514,19 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
def frame_received(self, frame: AshFrame) -> None:
_LOGGER.debug("Received frame %r", frame)

# If a frame has ACK information (DATA, ACK, or NAK), it should be used even if
# the frame is out of sequence or invalid
if isinstance(frame, DataFrame):
self._handle_ack(frame)
self.data_frame_received(frame)
elif isinstance(frame, RStackFrame):
self.rstack_frame_received(frame)
elif isinstance(frame, AckFrame):
self._handle_ack(frame)
self.ack_frame_received(frame)
elif isinstance(frame, NakFrame):
self._handle_ack(frame)
self.nak_frame_received(frame)
elif isinstance(frame, RStackFrame):
self.rstack_frame_received(frame)
elif isinstance(frame, RstFrame):
self.rst_frame_received(frame)
elif isinstance(frame, ErrorFrame):
Expand All @@ -513,7 +538,6 @@ def data_frame_received(self, frame: DataFrame) -> None:
# The Host may not piggyback acknowledgments and should promptly send an ACK
# frame when it receives a DATA frame.
if frame.frm_num == self._rx_seq:
self._handle_ack(frame)
self._rx_seq = (frame.frm_num + 1) % 8
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))

Expand All @@ -536,14 +560,10 @@ def rstack_frame_received(self, frame: RStackFrame) -> None:
self._ezsp_protocol.reset_received(frame.reset_code)

def ack_frame_received(self, frame: AckFrame) -> None:
self._handle_ack(frame)
pass

def nak_frame_received(self, frame: NakFrame) -> None:
err = NotAcked(frame=frame)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(err)
self._cancel_pending_data_frames(NotAcked(frame=frame))

def rst_frame_received(self, frame: RstFrame) -> None:
self._ncp_reset_code = None
Expand All @@ -558,12 +578,8 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
self._enter_failed_state(self._ncp_reset_code)

def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
exc = NcpFailure(code=reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

self._ncp_state = NcpState.FAILED
self._cancel_pending_data_frames(NcpFailure(code=reset_code))
self._ezsp_protocol.reset_received(reset_code)

def _write_frame(
Expand All @@ -573,6 +589,9 @@ def _write_frame(
prefix: tuple[Reserved] = (),
suffix: tuple[Reserved] = (Reserved.FLAG,),
) -> None:
if self._transport is None or self._transport.is_closing():
raise NcpFailure("Transport is closed, cannot send frame")

if _LOGGER.isEnabledFor(logging.DEBUG):
prefix_str = "".join([f"{r.name} + " for r in prefix])
suffix_str = "".join([f" + {r.name}" for r in suffix])
Expand Down Expand Up @@ -631,7 +650,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
await ack_future
except NotAcked:
_LOGGER.debug(
"NCP responded with NAK. Retrying (attempt %d)", attempt + 1
"NCP responded with NAK to %r. Retrying (attempt %d)",
frame,
attempt + 1,
)

# For timing purposes, NAK can be treated as an ACK
Expand All @@ -650,9 +671,10 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
raise
except asyncio.TimeoutError:
_LOGGER.debug(
"No ACK received in %0.2fs (attempt %d)",
"No ACK received in %0.2fs (attempt %d) for %r",
self._t_rx_ack,
attempt + 1,
frame,
)
# If a DATA frame acknowledgement is not received within the
# current timeout value, then t_rx_ack is doubled.
Expand Down
23 changes: 1 addition & 22 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from typing import Any, Callable, Generator
import urllib.parse

from zigpy.datastructures import PriorityDynamicBoundedSemaphore

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
else:
Expand Down Expand Up @@ -41,8 +39,6 @@
NETWORK_OPS_TIMEOUT = 10
NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1

MAX_COMMAND_CONCURRENCY = 1


class EZSP:
_BY_VERSION = {
Expand All @@ -66,7 +62,6 @@ def __init__(self, device_config: dict):
self._ezsp_version = v4.EZSPv4.VERSION
self._gw = None
self._protocol = None
self._send_sem = PriorityDynamicBoundedSemaphore(value=MAX_COMMAND_CONCURRENCY)

self._stack_status_listeners: collections.defaultdict[
t.sl_Status, list[asyncio.Future]
Expand Down Expand Up @@ -190,21 +185,6 @@ def close(self):
self._gw.close()
self._gw = None

def _get_command_priority(self, name: str) -> int:
return {
# Deprioritize any commands that send packets
"set_source_route": -1,
"setExtendedTimeout": -1,
"send_unicast": -1,
"send_multicast": -1,
"send_broadcast": -1,
# Prioritize watchdog commands
"nop": 999,
"readCounters": 999,
"readAndClearCounters": 999,
"getValue": 999,
}.get(name, 0)

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
command = getattr(self._protocol, name)

Expand All @@ -217,8 +197,7 @@ async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
)
raise EzspError("EZSP is not running")

async with self._send_sem(priority=self._get_command_priority(name)):
return await command(*args, **kwargs)
return await command(*args, **kwargs)

async def _list_command(
self, name, item_frames, completion_frame, spos, *args, **kwargs
Expand Down
70 changes: 60 additions & 10 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import logging
import sys
import time
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable

import zigpy.state
Expand All @@ -15,6 +16,8 @@
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover

from zigpy.datastructures import PriorityDynamicBoundedSemaphore

from bellows.config import CONF_EZSP_POLICIES
from bellows.exception import InvalidCommandError
import bellows.types as t
Expand All @@ -23,7 +26,9 @@
from bellows.uart import Gateway

LOGGER = logging.getLogger(__name__)
EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2

EZSP_CMD_TIMEOUT = 10
MAX_COMMAND_CONCURRENCY = 1


class ProtocolHandler(abc.ABC):
Expand All @@ -42,6 +47,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
for name, (cmd_id, tx_schema, rx_schema) in self.COMMANDS.items()
}
self.tc_policy = 0
self._send_semaphore = PriorityDynamicBoundedSemaphore(
value=MAX_COMMAND_CONCURRENCY
)

# Cached by `set_extended_timeout` so subsequent calls are a little faster
self._address_table_size: int | None = None
Expand All @@ -65,18 +73,60 @@ def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]:
def _ezsp_frame_tx(self, name: str) -> bytes:
"""Serialize the named frame."""

def _get_command_priority(self, name: str) -> int:
return {
# Deprioritize any commands that send packets
"setSourceRoute": -1,
"setExtendedTimeout": -1,
"sendUnicast": -1,
"sendMulticast": -1,
"sendBroadcast": -1,
# Prioritize watchdog commands
"nop": 999,
"readCounters": 999,
"readAndClearCounters": 999,
"getValue": 999,
}.get(name, 0)

async def command(self, name, *args, **kwargs) -> Any:
"""Serialize command and send it."""
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)
data = self._ezsp_frame(name, *args, **kwargs)
cmd_id, _, rx_schema = self.COMMANDS[name]
future = asyncio.get_running_loop().create_future()
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
self._seq = (self._seq + 1) % 256

async with asyncio_timeout(EZSP_CMD_TIMEOUT):
delayed = False
send_time = None

if self._send_semaphore.locked():
delayed = True
send_time = time.monotonic()

LOGGER.debug(
"Send semaphore is locked, delaying before sending %s(%r, %r)",
name,
args,
kwargs,
)

async with self._send_semaphore(priority=self._get_command_priority(name)):
if delayed:
LOGGER.debug(
"Sending command %s: %s %s after %0.2fs delay",
name,
args,
kwargs,
time.monotonic() - send_time,
)
else:
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)

data = self._ezsp_frame(name, *args, **kwargs)
cmd_id, _, rx_schema = self.COMMANDS[name]

future = asyncio.get_running_loop().create_future()
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
self._seq = (self._seq + 1) % 256

await self._gw.send_data(data)
return await future

async with asyncio_timeout(EZSP_CMD_TIMEOUT):
return await future

async def update_policies(self, policy_config: dict) -> None:
"""Set up the policies for what the NCP should do."""
Expand Down
15 changes: 0 additions & 15 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,6 @@


class Gateway(asyncio.Protocol):
FLAG = b"\x7E" # Marks end of frame
ESCAPE = b"\x7D"
XON = b"\x11" # Resume transmission
XOFF = b"\x13" # Stop transmission
SUBSTITUTE = b"\x18"
CANCEL = b"\x1A" # Terminates a frame in progress
STUFF = 0x20
RANDOMIZE_START = 0x42
RANDOMIZE_SEQ = 0xB8

RESERVED = FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL

class Terminator:
pass

def __init__(self, application, connected_future=None, connection_done_future=None):
self._application = application

Expand Down
Loading

0 comments on commit 1a0b8a7

Please sign in to comment.