Skip to content

Commit

Permalink
Detect when the firmware responds with the wrong response (#241)
Browse files Browse the repository at this point in the history
* Handle mismatched `seq` commands

* Globally retry all commands in case of `MismatchedResponseError`

* Add a unit test and make sure command futures are always removed

* Exclude `if TYPE_CHECKING` from coverage
  • Loading branch information
puddly authored Dec 18, 2023
1 parent cc1f018 commit bc7e960
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 25 deletions.
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,12 @@ ignore = [
"W503", "E203", "E501", "D202",
"D103", "D102", "D101", # TODO: remove these once docstrings are added
]
per-file-ignores = ["tests/*:F811,F401,F403"]
per-file-ignores = ["tests/*:F811,F401,F403"]

[tool.coverage.report]
exclude_also = [
"raise AssertionError",
"raise NotImplementedError",
"if (typing\\.)?TYPE_CHECKING:",
"@(abc\\.)?abstractmethod",
]
48 changes: 43 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import zigpy.config
import zigpy.types as zigpy_t

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
else:
from asyncio import timeout as asyncio_timeout

from zigpy_deconz import api as deconz_api, types as t, uart
import zigpy_deconz.exception
import zigpy_deconz.zigbee.application
Expand Down Expand Up @@ -86,7 +91,7 @@ async def mock_connect(config, api):

@pytest.fixture
async def mock_command_rsp(gateway):
def inner(command_id, params, rsp, *, replace=False):
def inner(command_id, params, rsp, *, rsp_command=None, replace=False):
if (
getattr(getattr(gateway.send, "side_effect", None), "_handlers", None)
is None
Expand All @@ -107,15 +112,18 @@ def receiver(data):

kwargs, rest = t.deserialize_dict(command.payload, schema)

for params, mock in receiver._handlers[command.command_id]:
for params, rsp_command, mock in receiver._handlers[command.command_id]:
if rsp_command is None:
rsp_command = command.command_id

if all(kwargs[k] == v for k, v in params.items()):
_, rx_schema = deconz_api.COMMAND_SCHEMAS[command.command_id]
_, rx_schema = deconz_api.COMMAND_SCHEMAS[rsp_command]
ret = mock(**kwargs)

asyncio.get_running_loop().call_soon(
gateway._api.data_received,
deconz_api.Command(
command_id=command.command_id,
command_id=rsp_command,
seq=command.seq,
payload=t.serialize_dict(ret, rx_schema),
).serialize(),
Expand All @@ -128,7 +136,9 @@ def receiver(data):
gateway.send.side_effect._handlers[command_id].clear()

mock = MagicMock(return_value=rsp)
gateway.send.side_effect._handlers[command_id].append((params, mock))
gateway.send.side_effect._handlers[command_id].append(
(params, rsp_command, mock)
)

return mock

Expand Down Expand Up @@ -993,3 +1003,31 @@ async def test_cb3_device_state_callback_bug(api, mock_command_rsp):
await asyncio.sleep(0.01)

assert api._device_state == device_state


async def test_firmware_responding_with_wrong_type_with_correct_seq(
api, mock_command_rsp, caplog
):
await api.connect()

mock_command_rsp(
command_id=deconz_api.CommandId.aps_data_confirm,
params={},
# Completely different response
rsp_command=deconz_api.CommandId.version,
rsp={
"status": deconz_api.Status.SUCCESS,
"frame_length": t.uint16_t(9),
"version": deconz_api.FirmwareVersion(0x26450900),
},
)

with caplog.at_level(logging.DEBUG):
with pytest.raises(asyncio.TimeoutError):
async with asyncio_timeout(0.5):
await api.send_command(deconz_api.CommandId.aps_data_confirm)

assert (
"Firmware responded incorrectly (Response is mismatched! Sent"
" <CommandId.aps_data_confirm: 4>, received <CommandId.version: 13>), retrying"
) in caplog.text
70 changes: 51 additions & 19 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import collections
import itertools
import logging
import sys
Expand All @@ -25,7 +26,7 @@
)
from zigpy.zdo.types import SimpleDescriptor

from zigpy_deconz.exception import APIException, CommandError
from zigpy_deconz.exception import APIException, CommandError, MismatchedResponseError
import zigpy_deconz.types as t
import zigpy_deconz.uart
from zigpy_deconz.utils import restart_forever
Expand Down Expand Up @@ -415,7 +416,9 @@ class Deconz:
def __init__(self, app: Callable, device_config: dict[str, Any]):
"""Init instance."""
self._app = app
self._awaiting = {}

# [seq][cmd_id] = [fut1, fut2, ...]
self._awaiting = collections.defaultdict(lambda: collections.defaultdict(list))
self._command_lock = asyncio.Lock()
self._config = device_config
self._device_state = DeviceState(
Expand Down Expand Up @@ -459,7 +462,7 @@ async def connect(self) -> None:

await self.version()

device_state_rsp = await self._command(CommandId.device_state)
device_state_rsp = await self.send_command(CommandId.device_state)
self._device_state = device_state_rsp["device_state"]

self._data_poller_task = asyncio.create_task(self._data_poller())
Expand All @@ -486,6 +489,13 @@ def close(self):
self._uart.close()
self._uart = None

async def send_command(self, cmd, **kwargs) -> Any:
while True:
try:
return await self._command(cmd, **kwargs)
except MismatchedResponseError as exc:
LOGGER.debug("Firmware responded incorrectly (%s), retrying", exc)

async def _command(self, cmd, **kwargs):
payload = []
tx_schema, _ = COMMAND_SCHEMAS[cmd]
Expand Down Expand Up @@ -556,17 +566,16 @@ async def _command(self, cmd, **kwargs):
self._seq = (self._seq % 255) + 1

fut = asyncio.Future()
self._awaiting[seq, cmd] = fut
self._awaiting[seq][cmd].append(fut)

try:
async with asyncio_timeout(COMMAND_TIMEOUT):
return await fut
except asyncio.TimeoutError:
LOGGER.warning(
"No response to '%s' command with seq id '0x%02x'", cmd, seq
)
self._awaiting.pop((seq, cmd), None)
LOGGER.debug("No response to '%s' command with seq %d", cmd, seq)
raise
finally:
self._awaiting[seq][cmd].remove(fut)

def data_received(self, data: bytes) -> None:
command, _ = Command.deserialize(data)
Expand All @@ -577,7 +586,19 @@ def data_received(self, data: bytes) -> None:

_, rx_schema = COMMAND_SCHEMAS[command.command_id]

fut = self._awaiting.pop((command.seq, command.command_id), None)
fut = None
wrong_fut_cmd_id = None

try:
fut = self._awaiting[command.seq][command.command_id][0]
except IndexError:
# XXX: The firmware can sometimes respond with the wrong response. Find the
# future associated with it so we can throw an appropriate error.
for cmd_id, futs in self._awaiting[command.seq].items():
if futs:
fut = futs[0]
wrong_fut_cmd_id = cmd_id
break

try:
params, rest = t.deserialize_dict(command.payload, rx_schema)
Expand Down Expand Up @@ -614,7 +635,16 @@ def data_received(self, data: bytes) -> None:

exc = None

if status != Status.SUCCESS:
if wrong_fut_cmd_id is not None:
exc = MismatchedResponseError(
command.command_id,
params,
(
f"Response is mismatched! Sent {wrong_fut_cmd_id},"
f" received {command.command_id}"
),
)
elif status != Status.SUCCESS:
exc = CommandError(status, f"{command.command_id}, status: {status}")

if fut is not None:
Expand Down Expand Up @@ -665,7 +695,9 @@ async def _data_poller(self):
else:
flags = t.DataIndicationFlags.Always_Use_NWK_Source_Addr

rsp = await self._command(CommandId.aps_data_indication, flags=flags)
rsp = await self.send_command(
CommandId.aps_data_indication, flags=flags
)
self._handle_device_state_changed(
status=rsp["status"], device_state=rsp["device_state"]
)
Expand All @@ -687,7 +719,7 @@ async def _data_poller(self):

# Poll data confirm
if DeviceStateFlags.APSDE_DATA_CONFIRM in self._device_state.device_state:
rsp = await self._command(CommandId.aps_data_confirm)
rsp = await self.send_command(CommandId.aps_data_confirm)

self._app.handle_tx_confirm(rsp["request_id"], rsp["confirm_status"])
self._handle_device_state_changed(
Expand Down Expand Up @@ -738,7 +770,7 @@ async def version(self):
NetworkParameter.protocol_version
)

version_rsp = await self._command(CommandId.version, reserved=0)
version_rsp = await self.send_command(CommandId.version, reserved=0)
self._firmware_version = version_rsp["version"]

return self.firmware_version
Expand All @@ -753,7 +785,7 @@ async def read_parameter(
else:
value = read_param_type(parameter).serialize()

rsp = await self._command(
rsp = await self.send_command(
CommandId.read_parameter,
parameter_id=parameter_id,
parameter=value,
Expand All @@ -770,7 +802,7 @@ async def write_parameter(
self, parameter_id: NetworkParameter, parameter: Any
) -> None:
read_param_type, write_param_type = NETWORK_PARAMETER_TYPES[parameter_id]
await self._command(
await self.send_command(
CommandId.write_parameter,
parameter_id=parameter_id,
parameter=write_param_type(parameter).serialize(),
Expand Down Expand Up @@ -803,7 +835,7 @@ async def aps_data_request(
await self._free_slots_available_event.wait()

try:
rsp = await self._command(
rsp = await self.send_command(
CommandId.aps_data_request,
request_id=req_id,
flags=flags,
Expand All @@ -830,17 +862,17 @@ async def aps_data_request(
return

async def get_device_state(self) -> DeviceState:
rsp = await self._command(CommandId.device_state)
rsp = await self.send_command(CommandId.device_state)

return rsp["device_state"]

async def change_network_state(self, new_state: NetworkState) -> None:
await self._command(CommandId.change_network_state, network_state=new_state)
await self.send_command(CommandId.change_network_state, network_state=new_state)

async def add_neighbour(
self, nwk: t.NWK, ieee: t.EUI64, mac_capability_flags: t.uint8_t
) -> None:
await self._command(
await self.send_command(
CommandId.update_neighbor,
action=UpdateNeighborAction.ADD,
nwk=nwk,
Expand Down
17 changes: 17 additions & 0 deletions zigpy_deconz/exception.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
"""Zigpy-deconz exceptions."""

from __future__ import annotations

import typing

from zigpy.exceptions import APIException

if typing.TYPE_CHECKING:
from zigpy_deconz.api import CommandId


class CommandError(APIException):
def __init__(self, status=1, *args, **kwargs):
Expand All @@ -12,3 +19,13 @@ def __init__(self, status=1, *args, **kwargs):
@property
def status(self):
return self._status


class MismatchedResponseError(APIException):
def __init__(
self, command_id: CommandId, params: dict[str, typing.Any], *args, **kwargs
) -> None:
"""Initialize instance."""
super().__init__(*args, **kwargs)
self.command_id = command_id
self.params = params

0 comments on commit bc7e960

Please sign in to comment.