diff --git a/bellows/exception.py b/bellows/exception.py index bbc66203..aa43fddd 100644 --- a/bellows/exception.py +++ b/bellows/exception.py @@ -11,3 +11,7 @@ class InvalidCommandError(EzspError): class ControllerError(ControllerException): pass + + +class StackAlreadyRunning(EzspError): + pass diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index bff3d6e8..80021bf7 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -32,7 +32,7 @@ CONFIG_SCHEMA, SCHEMA_DEVICE, ) -from bellows.exception import ControllerError, EzspError +from bellows.exception import ControllerError, EzspError, StackAlreadyRunning import bellows.ezsp from bellows.ezsp.v8.types.named import EmberDeviceUpdate import bellows.multicast @@ -104,7 +104,7 @@ def multicast(self): async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: """Add endpoint.""" - res = await self._ezsp.addEndpoint( + (status,) = await self._ezsp.addEndpoint( descriptor.endpoint, descriptor.profile, descriptor.device_type, @@ -114,7 +114,8 @@ async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: descriptor.input_clusters, descriptor.output_clusters, ) - LOGGER.debug("Ezsp adding endpoint: %s", res) + if status != t.EmberStatus.SUCCESS: + raise StackAlreadyRunning() async def cleanup_tc_link_key(self, ieee: t.EmberEUI64) -> None: """Remove tc link_key for the given device.""" @@ -176,14 +177,20 @@ async def _ensure_network_running(self) -> bool: async def start_network(self): ezsp = self._ezsp + try: + await self.register_endpoints() + except StackAlreadyRunning: + # Endpoints can only be registered before the network is up + await self._reset() + await self.register_endpoints() + await self._ensure_network_running() if await repairs.fix_invalid_tclk_partner_ieee(ezsp): await self._reset() + await self.register_endpoints() await self._ensure_network_running() - await self.register_endpoints() - if self.config[zigpy.config.CONF_SOURCE_ROUTING]: await ezsp.set_source_routing() diff --git a/tests/test_application.py b/tests/test_application.py index 27150b4f..c8641bfc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -115,7 +115,7 @@ def aps(): @patch("zigpy.device.Device._initialize", new=AsyncMock()) @patch("bellows.zigbee.application.ControllerApplication._watchdog", new=AsyncMock()) -async def _test_startup( +def _create_app_for_startup( app, nwk_type, ieee, @@ -150,8 +150,8 @@ async def mock_leave(*args, **kwargs): ezsp_mock.setConcentrator = AsyncMock() ezsp_mock.getTokenData = AsyncMock(return_value=[t.EmberStatus.ERR_FATAL, b""]) ezsp_mock._command = AsyncMock(return_value=t.EmberStatus.SUCCESS) - ezsp_mock.addEndpoint = AsyncMock(return_value=t.EmberStatus.SUCCESS) - ezsp_mock.setConfigurationValue = AsyncMock(return_value=t.EmberStatus.SUCCESS) + ezsp_mock.addEndpoint = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) + ezsp_mock.setConfigurationValue = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) ezsp_mock.networkInit = AsyncMock(return_value=[init]) ezsp_mock.networkInitExtended = AsyncMock(return_value=[init]) ezsp_mock.getNetworkParameters = AsyncMock(return_value=[0, nwk_type, nwk_params]) @@ -217,6 +217,23 @@ def form_network(): app.form_network = AsyncMock(side_effect=form_network) + return ezsp_mock + + +async def _test_startup( + app, + nwk_type, + ieee, + auto_form=False, + init=0, + ezsp_version=4, + board_info=True, + network_state=t.EmberNetworkStatus.JOINED_NETWORK, +): + ezsp_mock = _create_app_for_startup( + app, nwk_type, ieee, auto_form, init, ezsp_version, board_info, network_state + ) + p1 = patch("bellows.ezsp.EZSP", return_value=ezsp_mock) p2 = patch.object(bellows.multicast.Multicast, "startup") @@ -593,9 +610,9 @@ def test_sequence(app): assert seq < 256 -def test_permit_ncp(app): +async def test_permit_ncp(app): app._ezsp.permitJoining = AsyncMock() - app.permit_ncp(60) + await app.permit_ncp(60) assert app._ezsp.permitJoining.call_count == 1 @@ -1818,12 +1835,47 @@ async def test_connect_failure( async def test_repair_tclk_partner_ieee(app: ControllerApplication) -> None: """Test that EZSP is reset after repairing TCLK.""" app._ensure_network_running = AsyncMock() + app._reset = AsyncMock() app.load_network_info = AsyncMock() + with patch( + "bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee", + AsyncMock(return_value=False), + ): + await app.start_network() + + assert len(app._reset.mock_calls) == 0 + app._reset.reset_mock() + with patch( "bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee", AsyncMock(return_value=True), ): await app.start_network() - assert len(app._ensure_network_running.mock_calls) == 2 + assert len(app._reset.mock_calls) == 1 + + +async def test_startup_endpoint_register_already_running( + app: ControllerApplication, ieee: t.EmberEUI64 +) -> None: + """Test that the host is reset before endpoint registration if it is running.""" + + app._ezsp = _create_app_for_startup(app, t.EmberNodeType.COORDINATOR, ieee) + app._ezsp.addEndpoint = AsyncMock( + side_effect=[ + [t.EmberStatus.INVALID_CALL], # Fail the first time + [t.EmberStatus.SUCCESS], + [t.EmberStatus.SUCCESS], + [t.EmberStatus.SUCCESS], + ] + ) + + app._reset = AsyncMock() + + with patch.object(bellows.multicast.Multicast, "startup"): + await app.start_network() + + assert len(app._reset.mock_calls) == 1 + + assert app._ezsp.addEndpoint.call_count >= 3