Skip to content

Commit

Permalink
Add no_dealers Event to mockzmqserver
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Jan 24, 2025
1 parent 3acf8e7 commit 23e000e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch):
await mock_server.do_heartbeat()
await client.send("stop", retries=1)

# the client should be disconnected
assert len(mock_server.dealers) == 0

# when reconnection happens CONNECT message is sent again
assert mock_server.messages.count("CONNECT") == 2
assert mock_server.messages.count("DISCONNECT") == 1
Expand Down
19 changes: 13 additions & 6 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self, port, signal=0):
self.server_task = None
self.handler_task = None
self.dealers = set()
self.no_dealers = asyncio.Event()
self.no_dealers.set()

def start_event_loop(self):
asyncio.set_event_loop(self.loop)
Expand All @@ -104,6 +106,8 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
if not self.server_task.done():
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.no_dealers.wait(), timeout=2.0)
self.server_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.server_task
Expand Down Expand Up @@ -132,18 +136,21 @@ async def _handler(self):
while True:
try:
dealer, __, frame = await self.router_socket.recv_multipart()
if (
self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG}
) or self.value == 3:
self.messages.append(frame.decode("utf-8"))
if frame == CONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.add(dealer)
elif frame == DISCONNECT_MSG:
self.no_dealers.clear()
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
elif frame == DISCONNECT_MSG:
self.dealers.discard(dealer)
if not self.dealers:
self.no_dealers.set()
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
elif self.value in {0, 3}:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if (
self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG}
) or self.value == 3:
self.messages.append(frame.decode("utf-8"))
except asyncio.CancelledError:
break

Expand Down

0 comments on commit 23e000e

Please sign in to comment.