Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect when the firmware responds with the wrong response #241

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,12 @@ ignore = [
"W503", "E203", "E501", "D202",
"D103", "D102", "D101", # TODO: remove these once docstrings are added
]
per-file-ignores = ["tests/*:F811,F401,F403"]
per-file-ignores = ["tests/*:F811,F401,F403"]

[tool.coverage.report]
exclude_also = [
"raise AssertionError",
"raise NotImplementedError",
"if (typing\\.)?TYPE_CHECKING:",
"@(abc\\.)?abstractmethod",
]
48 changes: 43 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import zigpy.config
import zigpy.types as zigpy_t

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
else:
from asyncio import timeout as asyncio_timeout

from zigpy_deconz import api as deconz_api, types as t, uart
import zigpy_deconz.exception
import zigpy_deconz.zigbee.application
Expand Down Expand Up @@ -86,7 +91,7 @@ async def mock_connect(config, api):

@pytest.fixture
async def mock_command_rsp(gateway):
def inner(command_id, params, rsp, *, replace=False):
def inner(command_id, params, rsp, *, rsp_command=None, replace=False):
if (
getattr(getattr(gateway.send, "side_effect", None), "_handlers", None)
is None
Expand All @@ -107,15 +112,18 @@ def receiver(data):

kwargs, rest = t.deserialize_dict(command.payload, schema)

for params, mock in receiver._handlers[command.command_id]:
for params, rsp_command, mock in receiver._handlers[command.command_id]:
if rsp_command is None:
rsp_command = command.command_id

if all(kwargs[k] == v for k, v in params.items()):
_, rx_schema = deconz_api.COMMAND_SCHEMAS[command.command_id]
_, rx_schema = deconz_api.COMMAND_SCHEMAS[rsp_command]
ret = mock(**kwargs)

asyncio.get_running_loop().call_soon(
gateway._api.data_received,
deconz_api.Command(
command_id=command.command_id,
command_id=rsp_command,
seq=command.seq,
payload=t.serialize_dict(ret, rx_schema),
).serialize(),
Expand All @@ -128,7 +136,9 @@ def receiver(data):
gateway.send.side_effect._handlers[command_id].clear()

mock = MagicMock(return_value=rsp)
gateway.send.side_effect._handlers[command_id].append((params, mock))
gateway.send.side_effect._handlers[command_id].append(
(params, rsp_command, mock)
)

return mock

Expand Down Expand Up @@ -993,3 +1003,31 @@ async def test_cb3_device_state_callback_bug(api, mock_command_rsp):
await asyncio.sleep(0.01)

assert api._device_state == device_state


async def test_firmware_responding_with_wrong_type_with_correct_seq(
api, mock_command_rsp, caplog
):
await api.connect()

mock_command_rsp(
command_id=deconz_api.CommandId.aps_data_confirm,
params={},
# Completely different response
rsp_command=deconz_api.CommandId.version,
rsp={
"status": deconz_api.Status.SUCCESS,
"frame_length": t.uint16_t(9),
"version": deconz_api.FirmwareVersion(0x26450900),
},
)

with caplog.at_level(logging.DEBUG):
with pytest.raises(asyncio.TimeoutError):
async with asyncio_timeout(0.5):
await api.send_command(deconz_api.CommandId.aps_data_confirm)

assert (
"Firmware responded incorrectly (Response is mismatched! Sent"
" <CommandId.aps_data_confirm: 4>, received <CommandId.version: 13>), retrying"
) in caplog.text
70 changes: 51 additions & 19 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import collections
import itertools
import logging
import sys
Expand All @@ -25,7 +26,7 @@
)
from zigpy.zdo.types import SimpleDescriptor

from zigpy_deconz.exception import APIException, CommandError
from zigpy_deconz.exception import APIException, CommandError, MismatchedResponseError
import zigpy_deconz.types as t
import zigpy_deconz.uart
from zigpy_deconz.utils import restart_forever
Expand Down Expand Up @@ -415,7 +416,9 @@ class Deconz:
def __init__(self, app: Callable, device_config: dict[str, Any]):
"""Init instance."""
self._app = app
self._awaiting = {}

# [seq][cmd_id] = [fut1, fut2, ...]
self._awaiting = collections.defaultdict(lambda: collections.defaultdict(list))
self._command_lock = asyncio.Lock()
self._config = device_config
self._device_state = DeviceState(
Expand Down Expand Up @@ -459,7 +462,7 @@ async def connect(self) -> None:

await self.version()

device_state_rsp = await self._command(CommandId.device_state)
device_state_rsp = await self.send_command(CommandId.device_state)
self._device_state = device_state_rsp["device_state"]

self._data_poller_task = asyncio.create_task(self._data_poller())
Expand All @@ -486,6 +489,13 @@ def close(self):
self._uart.close()
self._uart = None

async def send_command(self, cmd, **kwargs) -> Any:
while True:
try:
return await self._command(cmd, **kwargs)
except MismatchedResponseError as exc:
LOGGER.debug("Firmware responded incorrectly (%s), retrying", exc)

async def _command(self, cmd, **kwargs):
payload = []
tx_schema, _ = COMMAND_SCHEMAS[cmd]
Expand Down Expand Up @@ -556,17 +566,16 @@ async def _command(self, cmd, **kwargs):
self._seq = (self._seq % 255) + 1

fut = asyncio.Future()
self._awaiting[seq, cmd] = fut
self._awaiting[seq][cmd].append(fut)

try:
async with asyncio_timeout(COMMAND_TIMEOUT):
return await fut
except asyncio.TimeoutError:
LOGGER.warning(
"No response to '%s' command with seq id '0x%02x'", cmd, seq
)
self._awaiting.pop((seq, cmd), None)
LOGGER.debug("No response to '%s' command with seq %d", cmd, seq)
raise
finally:
self._awaiting[seq][cmd].remove(fut)

def data_received(self, data: bytes) -> None:
command, _ = Command.deserialize(data)
Expand All @@ -577,7 +586,19 @@ def data_received(self, data: bytes) -> None:

_, rx_schema = COMMAND_SCHEMAS[command.command_id]

fut = self._awaiting.pop((command.seq, command.command_id), None)
fut = None
wrong_fut_cmd_id = None

try:
fut = self._awaiting[command.seq][command.command_id][0]
except IndexError:
# XXX: The firmware can sometimes respond with the wrong response. Find the
# future associated with it so we can throw an appropriate error.
for cmd_id, futs in self._awaiting[command.seq].items():
if futs:
fut = futs[0]
wrong_fut_cmd_id = cmd_id
break

try:
params, rest = t.deserialize_dict(command.payload, rx_schema)
Expand Down Expand Up @@ -614,7 +635,16 @@ def data_received(self, data: bytes) -> None:

exc = None

if status != Status.SUCCESS:
if wrong_fut_cmd_id is not None:
exc = MismatchedResponseError(
command.command_id,
params,
(
f"Response is mismatched! Sent {wrong_fut_cmd_id},"
f" received {command.command_id}"
),
)
elif status != Status.SUCCESS:
exc = CommandError(status, f"{command.command_id}, status: {status}")

if fut is not None:
Expand Down Expand Up @@ -665,7 +695,9 @@ async def _data_poller(self):
else:
flags = t.DataIndicationFlags.Always_Use_NWK_Source_Addr

rsp = await self._command(CommandId.aps_data_indication, flags=flags)
rsp = await self.send_command(
CommandId.aps_data_indication, flags=flags
)
self._handle_device_state_changed(
status=rsp["status"], device_state=rsp["device_state"]
)
Expand All @@ -687,7 +719,7 @@ async def _data_poller(self):

# Poll data confirm
if DeviceStateFlags.APSDE_DATA_CONFIRM in self._device_state.device_state:
rsp = await self._command(CommandId.aps_data_confirm)
rsp = await self.send_command(CommandId.aps_data_confirm)

self._app.handle_tx_confirm(rsp["request_id"], rsp["confirm_status"])
self._handle_device_state_changed(
Expand Down Expand Up @@ -738,7 +770,7 @@ async def version(self):
NetworkParameter.protocol_version
)

version_rsp = await self._command(CommandId.version, reserved=0)
version_rsp = await self.send_command(CommandId.version, reserved=0)
self._firmware_version = version_rsp["version"]

return self.firmware_version
Expand All @@ -753,7 +785,7 @@ async def read_parameter(
else:
value = read_param_type(parameter).serialize()

rsp = await self._command(
rsp = await self.send_command(
CommandId.read_parameter,
parameter_id=parameter_id,
parameter=value,
Expand All @@ -770,7 +802,7 @@ async def write_parameter(
self, parameter_id: NetworkParameter, parameter: Any
) -> None:
read_param_type, write_param_type = NETWORK_PARAMETER_TYPES[parameter_id]
await self._command(
await self.send_command(
CommandId.write_parameter,
parameter_id=parameter_id,
parameter=write_param_type(parameter).serialize(),
Expand Down Expand Up @@ -803,7 +835,7 @@ async def aps_data_request(
await self._free_slots_available_event.wait()

try:
rsp = await self._command(
rsp = await self.send_command(
CommandId.aps_data_request,
request_id=req_id,
flags=flags,
Expand All @@ -830,17 +862,17 @@ async def aps_data_request(
return

async def get_device_state(self) -> DeviceState:
rsp = await self._command(CommandId.device_state)
rsp = await self.send_command(CommandId.device_state)

return rsp["device_state"]

async def change_network_state(self, new_state: NetworkState) -> None:
await self._command(CommandId.change_network_state, network_state=new_state)
await self.send_command(CommandId.change_network_state, network_state=new_state)

async def add_neighbour(
self, nwk: t.NWK, ieee: t.EUI64, mac_capability_flags: t.uint8_t
) -> None:
await self._command(
await self.send_command(
CommandId.update_neighbor,
action=UpdateNeighborAction.ADD,
nwk=nwk,
Expand Down
17 changes: 17 additions & 0 deletions zigpy_deconz/exception.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
"""Zigpy-deconz exceptions."""

from __future__ import annotations

import typing

from zigpy.exceptions import APIException

if typing.TYPE_CHECKING:
from zigpy_deconz.api import CommandId

Check warning on line 10 in zigpy_deconz/exception.py

View check run for this annotation

Codecov / codecov/patch

zigpy_deconz/exception.py#L10

Added line #L10 was not covered by tests


class CommandError(APIException):
def __init__(self, status=1, *args, **kwargs):
Expand All @@ -12,3 +19,13 @@
@property
def status(self):
return self._status


class MismatchedResponseError(APIException):
def __init__(
self, command_id: CommandId, params: dict[str, typing.Any], *args, **kwargs
) -> None:
"""Initialize instance."""
super().__init__(*args, **kwargs)
self.command_id = command_id
self.params = params
Loading