diff --git a/tcp_modbus_aio/client.py b/tcp_modbus_aio/client.py index 7e46eef..1eac4cc 100644 --- a/tcp_modbus_aio/client.py +++ b/tcp_modbus_aio/client.py @@ -61,13 +61,18 @@ def __init__( host: str, port: int = 502, slave_id: int = 1, + *, logger: logging.Logger | None = None, + enforce_pingable: bool = True, ) -> None: self.host = host self.port = port self.slave_id = slave_id self.logger = logger + # If True, will throw an exception if attempting to send a request and the device is not pingable + self.enforce_pingable = enforce_pingable + # Unique identifier for this client (used only for logging) self._id = uuid.uuid4() @@ -377,6 +382,11 @@ async def send_modbus_message( if self._ping_loop is None: raise RuntimeError("Cannot send modbus message on closed TCPModbusClient") + if self.enforce_pingable and not await self.is_pingable(): + raise ModbusNotConnectedError( + f"Cannot send modbus message to {self.host} because it is not pingable" + ) + request_transaction_id = self._next_transaction_id self._next_transaction_id = (self._next_transaction_id + 1) % MAX_TRANSACTION_ID @@ -397,17 +407,35 @@ async def send_modbus_message( f"[{self}][send_modbus_message] sending request {msg_str}: {request_adu=}" ) - async with self._comms_lock: + time_budget_remaining = timeout if timeout is not None else float("inf") + + last_time = time.perf_counter() + try: + await asyncio.wait_for(self._comms_lock.acquire(), time_budget_remaining) + except asyncio.TimeoutError: + raise ModbusCommunicationFailureError( + f"Failed to acquire lock to send request {msg_str} to modbus device {self.host}" + ) + time_budget_remaining -= time.perf_counter() - last_time + + try: if self.logger is not None: self.logger.debug( f"[{self}][send_modbus_message] acquired lock to send {msg_str}" ) - reader, writer = await self._get_tcp_connection(timeout=timeout) + last_time = time.perf_counter() + reader, writer = await self._get_tcp_connection( + timeout=time_budget_remaining + ) + time_budget_remaining -= time.perf_counter() - last_time try: writer.write(request_adu) - await asyncio.wait_for(writer.drain(), timeout) + + last_time = time.perf_counter() + await asyncio.wait_for(writer.drain(), time_budget_remaining) + time_budget_remaining -= time.perf_counter() - last_time if self.logger is not None: self.logger.debug(f"[{self}][send_modbus_message] wrote {msg_str}") @@ -422,6 +450,9 @@ async def send_modbus_message( await self.clear_tcp_connection() + # release the lock before retrying (so we can re-get it) + self._comms_lock.release() + return await self.send_modbus_message( request_function, timeout=timeout, @@ -439,9 +470,12 @@ async def send_modbus_message( try: seen_response_transaction_ids = [] while True: + last_time = time.perf_counter() response_adu = await asyncio.wait_for( - reader.read(expected_response_size), timeout=timeout + reader.read(expected_response_size), + timeout=time_budget_remaining, ) + time_budget_remaining -= time.perf_counter() - last_time response_pdu = response_adu[MODBUS_MBAP_SIZE:] response_mbap_header = response_adu[:MODBUS_MBAP_SIZE] @@ -507,6 +541,8 @@ async def send_modbus_message( ) return None + finally: + self._comms_lock.release() mismatch = response_transaction_id != request_transaction_id