Skip to content

Commit

Permalink
Cleanly shut down the serial port on disconnect (#259)
Browse files Browse the repository at this point in the history
* Cleanly shut down the serial port on disconnect

* Send `connection_lost` even if we do not have an open serial connection

* Call `super().close()` in `SerialProtocol`

* Use `self._transport.write` instead of `send_data`

* Let zigpy handle flow control

* Bump minimum zigpy version

* Fix unit tests

* Make `api` an async fixture to grab reference to loop early

* Set default pytest-asyncio fixture loop scope

* Fix unit test failing due to event loop caching issue in pytest-asyncio

* Bring test coverage up
  • Loading branch information
puddly authored Oct 27, 2024
1 parent aa26bbd commit 70910bc
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 66 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"}
requires-python = ">=3.8"
dependencies = [
"voluptuous",
"zigpy>=0.68.0",
"zigpy>=0.70.0",
'async-timeout; python_version<"3.11"',
]

Expand Down Expand Up @@ -47,6 +47,7 @@ ignore_errors = true

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.flake8]
exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"]
Expand Down
44 changes: 34 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@


@pytest.fixture
def gateway():
async def gateway():
return uart.Gateway(api=None)


@pytest.fixture
def api(gateway, mock_command_rsp):
async def api(gateway, mock_command_rsp):
loop = asyncio.get_running_loop()

async def mock_connect(config, api):
transport = MagicMock()
transport.close = MagicMock(
side_effect=lambda: loop.call_soon(gateway.connection_lost, None)
)

gateway._api = api
gateway.connection_made(MagicMock())
gateway.connection_made(transport)

return gateway

with patch("zigpy_deconz.uart.connect", side_effect=mock_connect):
Expand Down Expand Up @@ -178,15 +186,33 @@ async def test_connect(api, mock_command_rsp):
await api.connect()


async def test_connect_failure(api, mock_command_rsp):
transport = None

def mock_version(*args, **kwargs):
nonlocal transport
transport = api._uart._transport

raise asyncio.TimeoutError()

with patch.object(api, "version", side_effect=mock_version):
# We connect but fail to probe
with pytest.raises(asyncio.TimeoutError):
await api.connect()

assert api._uart is None
assert len(transport.close.mock_calls) == 1


async def test_close(api):
await api.connect()

uart = api._uart
uart.close = MagicMock(wraps=uart.close)
uart.disconnect = AsyncMock()

api.close()
await api.disconnect()
assert api._uart is None
assert uart.close.call_count == 1
assert uart.disconnect.call_count == 1


def test_commands():
Expand Down Expand Up @@ -898,11 +924,9 @@ async def test_data_poller(api, mock_command_rsp):

# The task is cancelled on close
task = api._data_poller_task
api.close()
await api.disconnect()
assert api._data_poller_task is None

if sys.version_info >= (3, 11):
assert task.cancelling()
assert task.done()


async def test_get_device_state(api, mock_command_rsp):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ async def test_connect_failure(app):
with patch.object(application, "Deconz") as api_mock:
api = api_mock.return_value = MagicMock()
api.connect = AsyncMock(side_effect=RuntimeError("Broken"))
api.disconnect = AsyncMock()

app._api = None

Expand All @@ -195,16 +196,16 @@ async def test_connect_failure(app):

assert app._api is None
api.connect.assert_called_once()
api.close.assert_called_once()
api.disconnect.assert_called_once()


async def test_disconnect(app):
api_close = app._api.close = MagicMock()
api_disconnect = app._api.disconnect = AsyncMock()

await app.disconnect()

assert app._api is None
assert api_close.call_count == 1
assert api_disconnect.call_count == 1


async def test_disconnect_no_api(app):
Expand Down
13 changes: 11 additions & 2 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from unittest import mock

import pytest
from zigpy.config import CONF_DEVICE_BAUDRATE, CONF_DEVICE_PATH
from zigpy.config import (
CONF_DEVICE_BAUDRATE,
CONF_DEVICE_FLOW_CONTROL,
CONF_DEVICE_PATH,
)
import zigpy.serial

from zigpy_deconz import uart
Expand All @@ -28,7 +32,12 @@ async def mock_conn(loop, protocol_factory, **kwargs):
monkeypatch.setattr(zigpy.serial, "create_serial_connection", mock_conn)

await uart.connect(
{CONF_DEVICE_PATH: "/dev/null", CONF_DEVICE_BAUDRATE: 115200}, api
{
CONF_DEVICE_PATH: "/dev/null",
CONF_DEVICE_BAUDRATE: 115200,
CONF_DEVICE_FLOW_CONTROL: None,
},
api,
)


Expand Down
27 changes: 13 additions & 14 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover

from zigpy.config import CONF_DEVICE_PATH
from zigpy.datastructures import PriorityLock
from zigpy.types import (
APSStatus,
Expand Down Expand Up @@ -461,37 +460,37 @@ def protocol_version(self) -> int:

async def connect(self) -> None:
assert self._uart is None

self._uart = await zigpy_deconz.uart.connect(self._config, self)

await self.version()
try:
await self.version()
device_state_rsp = await self.send_command(CommandId.device_state)
except Exception:
await self.disconnect()
self._uart = None
raise

device_state_rsp = await self.send_command(CommandId.device_state)
self._device_state = device_state_rsp["device_state"]

self._data_poller_task = asyncio.create_task(self._data_poller())

def connection_lost(self, exc: Exception) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Lost serial connection."""
LOGGER.debug(
"Serial %r connection lost unexpectedly: %r",
self._config[CONF_DEVICE_PATH],
exc,
)

if self._app is not None:
self._app.connection_lost(exc)

def close(self):
self._app = None

async def disconnect(self):
if self._data_poller_task is not None:
self._data_poller_task.cancel()
self._data_poller_task = None

if self._uart is not None:
self._uart.close()
await self._uart.disconnect()
self._uart = None

self._app = None

def _get_command_priority(self, command: Command) -> int:
return {
# The watchdog is fed using `write_parameter` and `get_device_state` so they
Expand Down
55 changes: 21 additions & 34 deletions zigpy_deconz/uart.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,50 @@
"""Uart module."""

from __future__ import annotations

import asyncio
import binascii
import logging
from typing import Callable, Dict
from typing import Any, Callable

import zigpy.config
import zigpy.serial

LOGGER = logging.getLogger(__name__)


class Gateway(asyncio.Protocol):
class Gateway(zigpy.serial.SerialProtocol):
END = b"\xC0"
ESC = b"\xDB"
ESC_END = b"\xDC"
ESC_ESC = b"\xDD"

def __init__(self, api, connected_future=None):
def __init__(self, api):
"""Initialize instance of the UART gateway."""

super().__init__()
self._api = api
self._buffer = b""
self._connected_future = connected_future
self._transport = None

def connection_lost(self, exc) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Port was closed expectedly or unexpectedly."""
super().connection_lost(exc)

if exc is not None:
LOGGER.warning("Lost connection: %r", exc, exc_info=exc)

self._api.connection_lost(exc)

def connection_made(self, transport):
"""Call this when the uart connection is established."""

LOGGER.debug("Connection made")
self._transport = transport
if self._connected_future and not self._connected_future.done():
self._connected_future.set_result(True)
if self._api is not None:
self._api.connection_lost(exc)

def close(self):
self._transport.close()
super().close()
self._api = None

def send(self, data):
def send(self, data: bytes) -> None:
"""Send data, taking care of escaping and framing."""
LOGGER.debug("Send: %s", binascii.hexlify(data).decode())
checksum = bytes(self._checksum(data))
frame = self._escape(data + checksum)
self._transport.write(self.END + frame + self.END)

def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Handle data received from the uart."""
self._buffer += data
super().data_received(data)

while self._buffer:
end = self._buffer.find(self.END)
if end < 0:
Expand Down Expand Up @@ -121,23 +112,19 @@ def _checksum(self, data):
return bytes(ret)


async def connect(config: Dict[str, any], api: Callable) -> Gateway:
loop = asyncio.get_running_loop()
connected_future = loop.create_future()
protocol = Gateway(api, connected_future)
async def connect(config: dict[str, Any], api: Callable) -> Gateway:
protocol = Gateway(api)

LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH])

_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
loop=asyncio.get_running_loop(),
protocol_factory=lambda: protocol,
url=config[zigpy.config.CONF_DEVICE_PATH],
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
xonxoff=False,
flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)

await connected_future

LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH])
await protocol.wait_until_connected()

return protocol
4 changes: 2 additions & 2 deletions zigpy_deconz/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def connect(self):
try:
await api.connect()
except Exception:
api.close()
await api.disconnect()
raise

self._api = api
Expand All @@ -109,7 +109,7 @@ async def disconnect(self):
self._delayed_neighbor_scan_task = None

if self._api is not None:
self._api.close()
await self._api.disconnect()
self._api = None

async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):
Expand Down

0 comments on commit 70910bc

Please sign in to comment.