diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index a7923d109e4..89c138bdae3 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -67,6 +67,9 @@ async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch): await mock_server.do_heartbeat() await client.send("stop", retries=1) + # the client should be disconnected + assert len(mock_server.dealers) == 0 + # when reconnection happens CONNECT message is sent again assert mock_server.messages.count("CONNECT") == 2 assert mock_server.messages.count("DISCONNECT") == 1 diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 7bdf868c1d6..2138adcd64d 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -81,6 +81,8 @@ def __init__(self, port, signal=0): self.server_task = None self.handler_task = None self.dealers = set() + self.no_dealers = asyncio.Event() + self.no_dealers.set() def start_event_loop(self): asyncio.set_event_loop(self.loop) @@ -104,6 +106,8 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): if not self.server_task.done(): + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.no_dealers.wait(), timeout=2.0) self.server_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self.server_task @@ -132,18 +136,21 @@ async def _handler(self): while True: try: dealer, __, frame = await self.router_socket.recv_multipart() + if ( + self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG} + ) or self.value == 3: + self.messages.append(frame.decode("utf-8")) if frame == CONNECT_MSG: - await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) self.dealers.add(dealer) - elif frame == DISCONNECT_MSG: + self.no_dealers.clear() await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) + elif frame == DISCONNECT_MSG: self.dealers.discard(dealer) + if not self.dealers: + self.no_dealers.set() + await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) elif self.value in {0, 3}: await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) - if ( - self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG} - ) or self.value == 3: - self.messages.append(frame.decode("utf-8")) except asyncio.CancelledError: break