diff --git a/pyproject.toml b/pyproject.toml index e322a90..ffb6361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,4 +60,12 @@ ignore = [ "W503", "E203", "E501", "D202", "D103", "D102", "D101", # TODO: remove these once docstrings are added ] -per-file-ignores = ["tests/*:F811,F401,F403"] \ No newline at end of file +per-file-ignores = ["tests/*:F811,F401,F403"] + +[tool.coverage.report] +exclude_also = [ + "raise AssertionError", + "raise NotImplementedError", + "if (typing\\.)?TYPE_CHECKING:", + "@(abc\\.)?abstractmethod", +] diff --git a/tests/test_api.py b/tests/test_api.py index 2edfe10..048efbc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,6 +10,11 @@ import zigpy.config import zigpy.types as zigpy_t +if sys.version_info[:2] < (3, 11): + from async_timeout import timeout as asyncio_timeout +else: + from asyncio import timeout as asyncio_timeout + from zigpy_deconz import api as deconz_api, types as t, uart import zigpy_deconz.exception import zigpy_deconz.zigbee.application @@ -86,7 +91,7 @@ async def mock_connect(config, api): @pytest.fixture async def mock_command_rsp(gateway): - def inner(command_id, params, rsp, *, replace=False): + def inner(command_id, params, rsp, *, rsp_command=None, replace=False): if ( getattr(getattr(gateway.send, "side_effect", None), "_handlers", None) is None @@ -107,15 +112,18 @@ def receiver(data): kwargs, rest = t.deserialize_dict(command.payload, schema) - for params, mock in receiver._handlers[command.command_id]: + for params, rsp_command, mock in receiver._handlers[command.command_id]: + if rsp_command is None: + rsp_command = command.command_id + if all(kwargs[k] == v for k, v in params.items()): - _, rx_schema = deconz_api.COMMAND_SCHEMAS[command.command_id] + _, rx_schema = deconz_api.COMMAND_SCHEMAS[rsp_command] ret = mock(**kwargs) asyncio.get_running_loop().call_soon( gateway._api.data_received, deconz_api.Command( - command_id=command.command_id, + command_id=rsp_command, seq=command.seq, payload=t.serialize_dict(ret, rx_schema), ).serialize(), @@ -128,7 +136,9 @@ def receiver(data): gateway.send.side_effect._handlers[command_id].clear() mock = MagicMock(return_value=rsp) - gateway.send.side_effect._handlers[command_id].append((params, mock)) + gateway.send.side_effect._handlers[command_id].append( + (params, rsp_command, mock) + ) return mock @@ -993,3 +1003,31 @@ async def test_cb3_device_state_callback_bug(api, mock_command_rsp): await asyncio.sleep(0.01) assert api._device_state == device_state + + +async def test_firmware_responding_with_wrong_type_with_correct_seq( + api, mock_command_rsp, caplog +): + await api.connect() + + mock_command_rsp( + command_id=deconz_api.CommandId.aps_data_confirm, + params={}, + # Completely different response + rsp_command=deconz_api.CommandId.version, + rsp={ + "status": deconz_api.Status.SUCCESS, + "frame_length": t.uint16_t(9), + "version": deconz_api.FirmwareVersion(0x26450900), + }, + ) + + with caplog.at_level(logging.DEBUG): + with pytest.raises(asyncio.TimeoutError): + async with asyncio_timeout(0.5): + await api.send_command(deconz_api.CommandId.aps_data_confirm) + + assert ( + "Firmware responded incorrectly (Response is mismatched! Sent" + " , received ), retrying" + ) in caplog.text diff --git a/zigpy_deconz/api.py b/zigpy_deconz/api.py index 226bfe1..ccd092c 100644 --- a/zigpy_deconz/api.py +++ b/zigpy_deconz/api.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import collections import itertools import logging import sys @@ -25,7 +26,7 @@ ) from zigpy.zdo.types import SimpleDescriptor -from zigpy_deconz.exception import APIException, CommandError +from zigpy_deconz.exception import APIException, CommandError, MismatchedResponseError import zigpy_deconz.types as t import zigpy_deconz.uart from zigpy_deconz.utils import restart_forever @@ -415,7 +416,9 @@ class Deconz: def __init__(self, app: Callable, device_config: dict[str, Any]): """Init instance.""" self._app = app - self._awaiting = {} + + # [seq][cmd_id] = [fut1, fut2, ...] + self._awaiting = collections.defaultdict(lambda: collections.defaultdict(list)) self._command_lock = asyncio.Lock() self._config = device_config self._device_state = DeviceState( @@ -459,7 +462,7 @@ async def connect(self) -> None: await self.version() - device_state_rsp = await self._command(CommandId.device_state) + device_state_rsp = await self.send_command(CommandId.device_state) self._device_state = device_state_rsp["device_state"] self._data_poller_task = asyncio.create_task(self._data_poller()) @@ -486,6 +489,13 @@ def close(self): self._uart.close() self._uart = None + async def send_command(self, cmd, **kwargs) -> Any: + while True: + try: + return await self._command(cmd, **kwargs) + except MismatchedResponseError as exc: + LOGGER.debug("Firmware responded incorrectly (%s), retrying", exc) + async def _command(self, cmd, **kwargs): payload = [] tx_schema, _ = COMMAND_SCHEMAS[cmd] @@ -556,17 +566,16 @@ async def _command(self, cmd, **kwargs): self._seq = (self._seq % 255) + 1 fut = asyncio.Future() - self._awaiting[seq, cmd] = fut + self._awaiting[seq][cmd].append(fut) try: async with asyncio_timeout(COMMAND_TIMEOUT): return await fut except asyncio.TimeoutError: - LOGGER.warning( - "No response to '%s' command with seq id '0x%02x'", cmd, seq - ) - self._awaiting.pop((seq, cmd), None) + LOGGER.debug("No response to '%s' command with seq %d", cmd, seq) raise + finally: + self._awaiting[seq][cmd].remove(fut) def data_received(self, data: bytes) -> None: command, _ = Command.deserialize(data) @@ -577,7 +586,19 @@ def data_received(self, data: bytes) -> None: _, rx_schema = COMMAND_SCHEMAS[command.command_id] - fut = self._awaiting.pop((command.seq, command.command_id), None) + fut = None + wrong_fut_cmd_id = None + + try: + fut = self._awaiting[command.seq][command.command_id][0] + except IndexError: + # XXX: The firmware can sometimes respond with the wrong response. Find the + # future associated with it so we can throw an appropriate error. + for cmd_id, futs in self._awaiting[command.seq].items(): + if futs: + fut = futs[0] + wrong_fut_cmd_id = cmd_id + break try: params, rest = t.deserialize_dict(command.payload, rx_schema) @@ -614,7 +635,16 @@ def data_received(self, data: bytes) -> None: exc = None - if status != Status.SUCCESS: + if wrong_fut_cmd_id is not None: + exc = MismatchedResponseError( + command.command_id, + params, + ( + f"Response is mismatched! Sent {wrong_fut_cmd_id}," + f" received {command.command_id}" + ), + ) + elif status != Status.SUCCESS: exc = CommandError(status, f"{command.command_id}, status: {status}") if fut is not None: @@ -665,7 +695,9 @@ async def _data_poller(self): else: flags = t.DataIndicationFlags.Always_Use_NWK_Source_Addr - rsp = await self._command(CommandId.aps_data_indication, flags=flags) + rsp = await self.send_command( + CommandId.aps_data_indication, flags=flags + ) self._handle_device_state_changed( status=rsp["status"], device_state=rsp["device_state"] ) @@ -687,7 +719,7 @@ async def _data_poller(self): # Poll data confirm if DeviceStateFlags.APSDE_DATA_CONFIRM in self._device_state.device_state: - rsp = await self._command(CommandId.aps_data_confirm) + rsp = await self.send_command(CommandId.aps_data_confirm) self._app.handle_tx_confirm(rsp["request_id"], rsp["confirm_status"]) self._handle_device_state_changed( @@ -738,7 +770,7 @@ async def version(self): NetworkParameter.protocol_version ) - version_rsp = await self._command(CommandId.version, reserved=0) + version_rsp = await self.send_command(CommandId.version, reserved=0) self._firmware_version = version_rsp["version"] return self.firmware_version @@ -753,7 +785,7 @@ async def read_parameter( else: value = read_param_type(parameter).serialize() - rsp = await self._command( + rsp = await self.send_command( CommandId.read_parameter, parameter_id=parameter_id, parameter=value, @@ -770,7 +802,7 @@ async def write_parameter( self, parameter_id: NetworkParameter, parameter: Any ) -> None: read_param_type, write_param_type = NETWORK_PARAMETER_TYPES[parameter_id] - await self._command( + await self.send_command( CommandId.write_parameter, parameter_id=parameter_id, parameter=write_param_type(parameter).serialize(), @@ -803,7 +835,7 @@ async def aps_data_request( await self._free_slots_available_event.wait() try: - rsp = await self._command( + rsp = await self.send_command( CommandId.aps_data_request, request_id=req_id, flags=flags, @@ -830,17 +862,17 @@ async def aps_data_request( return async def get_device_state(self) -> DeviceState: - rsp = await self._command(CommandId.device_state) + rsp = await self.send_command(CommandId.device_state) return rsp["device_state"] async def change_network_state(self, new_state: NetworkState) -> None: - await self._command(CommandId.change_network_state, network_state=new_state) + await self.send_command(CommandId.change_network_state, network_state=new_state) async def add_neighbour( self, nwk: t.NWK, ieee: t.EUI64, mac_capability_flags: t.uint8_t ) -> None: - await self._command( + await self.send_command( CommandId.update_neighbor, action=UpdateNeighborAction.ADD, nwk=nwk, diff --git a/zigpy_deconz/exception.py b/zigpy_deconz/exception.py index db2d18d..82ab450 100644 --- a/zigpy_deconz/exception.py +++ b/zigpy_deconz/exception.py @@ -1,7 +1,14 @@ """Zigpy-deconz exceptions.""" +from __future__ import annotations + +import typing + from zigpy.exceptions import APIException +if typing.TYPE_CHECKING: + from zigpy_deconz.api import CommandId + class CommandError(APIException): def __init__(self, status=1, *args, **kwargs): @@ -12,3 +19,13 @@ def __init__(self, status=1, *args, **kwargs): @property def status(self): return self._status + + +class MismatchedResponseError(APIException): + def __init__( + self, command_id: CommandId, params: dict[str, typing.Any], *args, **kwargs + ) -> None: + """Initialize instance.""" + super().__init__(*args, **kwargs) + self.command_id = command_id + self.params = params