Skip to content

Commit

Permalink
fix: refactor connection to better handle dropped data in stream
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuagruenstein committed Jul 5, 2024
1 parent 6b49820 commit db08254
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 109 deletions.
40 changes: 31 additions & 9 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
import asyncio

from umodbus.functions import ReadCoils
from umodbus.functions import WriteSingleCoil

from tcp_modbus_aio.client import TCPModbusClient
from tcp_modbus_aio.exceptions import ModbusCommunicationTimeoutError
from tcp_modbus_aio.typed_functions import ReadCoils

DIGITAL_IN_COILS = list(range(8))
DIGITAL_OUT_COILS = list(range(32, 32 + 12))

async def example() -> None:
example_message = ReadCoils()
example_message.starting_address = 0
example_message.quantity = 1

async with TCPModbusClient("192.168.250.204") as conn:
response = await conn.send_modbus_message(example_message)
async def example() -> None:

assert response is not None, "we expect a response from ReadCoils"
print(response.data) # noqa: T201
async with TCPModbusClient("192.168.250.207") as conn:
for _ in range(1000):
for digital_in_coil in DIGITAL_IN_COILS:
example_message = ReadCoils()
example_message.starting_address = digital_in_coil
example_message.quantity = 1

try:
response = await conn.send_modbus_message(
example_message, retries=0
)
assert response is not None, "we expect a response from ReadCoils"
print(response.data) # noqa: T201
except ModbusCommunicationTimeoutError as e:
print(f"{type(e).__name__}({e})")

for digital_out_coil in DIGITAL_OUT_COILS:
example_message = WriteSingleCoil()
example_message.address = digital_out_coil
example_message.value = False

try:
await conn.send_modbus_message(example_message, retries=0)
except ModbusCommunicationTimeoutError as e:
print(f"{type(e).__name__}({e})")


if __name__ == "__main__":
Expand Down
154 changes: 54 additions & 100 deletions tcp_modbus_aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from tcp_modbus_aio.exceptions import (
ModbusCommunicationFailureError,
ModbusCommunicationTimeoutError,
ModbusNotConnectedError,
)
from tcp_modbus_aio.ping import ping_ip
Expand All @@ -20,6 +21,7 @@
ReadCoils,
create_function_from_response_pdu,
)
from tcp_modbus_aio.utils import catchtime

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -409,38 +411,40 @@ async def send_modbus_message(

time_budget_remaining = timeout if timeout is not None else float("inf")

last_time = time.perf_counter()
try:
await asyncio.wait_for(self._comms_lock.acquire(), time_budget_remaining)
except asyncio.TimeoutError:
raise ModbusCommunicationFailureError(
f"Failed to acquire lock to send request {msg_str} to modbus device {self.host}"
)
time_budget_remaining -= time.perf_counter() - last_time
with catchtime() as t:
try:
await asyncio.wait_for(
self._comms_lock.acquire(), time_budget_remaining
)
except asyncio.TimeoutError:
raise ModbusCommunicationTimeoutError(
f"Failed to acquire lock to send request {msg_str} to modbus device {self.host}"
)
time_budget_remaining -= t()

try:
if self.logger is not None:
self.logger.debug(
f"[{self}][send_modbus_message] acquired lock to send {msg_str}"
)

last_time = time.perf_counter()
reader, writer = await self._get_tcp_connection(
timeout=time_budget_remaining
)
time_budget_remaining -= time.perf_counter() - last_time
with catchtime() as t:
reader, writer = await self._get_tcp_connection(
timeout=time_budget_remaining
)
time_budget_remaining -= t()

try:
writer.write(request_adu)

last_time = time.perf_counter()
await asyncio.wait_for(writer.drain(), time_budget_remaining)
time_budget_remaining -= time.perf_counter() - last_time
with catchtime() as t:
await asyncio.wait_for(writer.drain(), time_budget_remaining)
time_budget_remaining -= t()

if self.logger is not None:
self.logger.debug(f"[{self}][send_modbus_message] wrote {msg_str}")

except (asyncio.TimeoutError, OSError):
except (asyncio.TimeoutError, OSError) as e:
if retries > 0:
if self.logger is not None:
self.logger.warning(
Expand All @@ -459,79 +463,47 @@ async def send_modbus_message(
retries=retries - 1,
)

raise ModbusCommunicationFailureError(
f"Failed to write request {msg_str} to modbus device {self.host}"
)

expected_response_size = (
request_function.expected_response_pdu_size + MODBUS_MBAP_SIZE
)
raise (
ModbusCommunicationTimeoutError
if isinstance(e, asyncio.TimeoutError)
else ModbusCommunicationFailureError
)(f"Failed to write request {msg_str} to modbus device {self.host}")

try:
seen_response_transaction_ids = []
while True:
last_time = time.perf_counter()
response_adu = await asyncio.wait_for(
reader.read(expected_response_size),
expected_response_mbap_header = struct.pack(
MBAP_HEADER_STRUCT_FORMAT,
request_transaction_id,
0,
request_function.expected_response_pdu_size + 1,
self.slave_id,
)

with catchtime() as t:
response_up_to_mbap_header = await asyncio.wait_for(
reader.readuntil(expected_response_mbap_header),
timeout=time_budget_remaining,
)
time_budget_remaining -= time.perf_counter() - last_time

response_pdu = response_adu[MODBUS_MBAP_SIZE:]
response_mbap_header = response_adu[:MODBUS_MBAP_SIZE]

(
response_transaction_id,
_,
mbap_asserted_pdu_length_plus_one,
response_asserted_slave_id,
) = struct.unpack(MBAP_HEADER_STRUCT_FORMAT, response_mbap_header)

seen_response_transaction_ids.append(response_transaction_id)

if response_transaction_id in self._lost_transaction_ids:
self._lost_transaction_ids.pop(response_transaction_id)
if self.logger is not None:
self.logger.warning(
f"[{self}][send_modbus_message] Received response {response_transaction_id} for "
f"request {msg_str} that was previously lost, skipping"
)

continue

elif len(response_adu) != expected_response_size:
msg = (
f"[{self}][send_modbus_message] Received response {response_transaction_id} for "
f"request {msg_str} with unexpected size {len(response_adu)}, expected "
f"{expected_response_size}"
)

if self.logger is not None:
self.logger.error(msg)
time_budget_remaining -= t()

raise ModbusCommunicationFailureError(msg)

elif response_asserted_slave_id != self.slave_id:
raise ModbusCommunicationFailureError(
f"Response slave ID {response_asserted_slave_id} does not match expected "
f"{self.slave_id} on {self.host}"
)

elif mbap_asserted_pdu_length_plus_one != len(response_pdu) + 1:
raise ModbusCommunicationFailureError(
f"Response PDU length {len(response_pdu)} does not match expected "
f"{mbap_asserted_pdu_length_plus_one-1} on {self.host}"
if len(response_up_to_mbap_header) > MODBUS_MBAP_SIZE:
# TODO: consider introspecting the discarded traffic here for better introspection
if self.logger is not None:
self.logger.warning(
f"[{self}][send_modbus_message] got {response_up_to_mbap_header[:MODBUS_MBAP_SIZE]!r} "
"before mbap header, likely catching up stream after timeouts"
)

break
with catchtime() as t:
response_pdu = await asyncio.wait_for(
reader.readexactly(request_function.expected_response_pdu_size),
timeout=time_budget_remaining,
)
time_budget_remaining -= t()

except asyncio.TimeoutError:
self._lost_transaction_ids[request_transaction_id] = True

if error_on_no_response:
raise ModbusCommunicationFailureError(
f"Failed to read response to {msg_str} from modbus device {self.host} "
f"({seen_response_transaction_ids=})"
raise ModbusCommunicationTimeoutError(
f"Failed to read response to {msg_str} from modbus device {self.host}"
)

else:
Expand All @@ -544,36 +516,18 @@ async def send_modbus_message(
finally:
self._comms_lock.release()

mismatch = response_transaction_id != request_transaction_id

response_function = create_function_from_response_pdu(
response_pdu, request_function
)

response_msg = (
f"{response_function.__class__.__name__}[{response_transaction_id}]"
f"{response_function.__class__.__name__}[{request_transaction_id}]"
)

if self.logger is not None:
self.logger.debug(
f"[{self}][send_modbus_message] received response {response_msg} for request "
f"{msg_str} ({mismatch=} {response_function.data=} {response_pdu=} {len(response_pdu)=})"
)

if mismatch:
msg = (
f"Response transaction ID {response_transaction_id} does not match request "
f"{msg_str} on {self.host}"
f"{msg_str} {response_function.data=} {response_pdu=} {len(response_pdu)=})"
)

if self.logger is not None:
self.logger.error(
f"[{self}][send_modbus_message] {msg} {response_adu=}"
)

# THIS IS IMPORTANT SO MISMATCH ERRORS SELF CORRECT
self._lost_transaction_ids[request_transaction_id] = True

raise ModbusCommunicationFailureError(msg)

return response_function
6 changes: 6 additions & 0 deletions tcp_modbus_aio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class ModbusCommunicationFailureError(ModbusError):
pass


class ModbusCommunicationTimeoutError(ModbusCommunicationFailureError):
"""Timeout in communicating with modbus device."""

pass


class ModbusNotConnectedError(ModbusCommunicationFailureError):
"""Modbus not connected error."""

Expand Down
10 changes: 10 additions & 0 deletions tcp_modbus_aio/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from contextlib import contextmanager
from time import perf_counter
from typing import Callable, Iterator


@contextmanager
def catchtime() -> Iterator[Callable[[], float]]:
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()

0 comments on commit db08254

Please sign in to comment.