Skip to content

Commit

Permalink
fix: better stream reconstruction logic to handle backpressure and co…
Browse files Browse the repository at this point in the history
…nnectionreseterrors
  • Loading branch information
joshuagruenstein committed Jul 5, 2024
1 parent 38476d4 commit 1830b13
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 99 deletions.
21 changes: 7 additions & 14 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio

from umodbus.functions import WriteSingleCoil

from tcp_modbus_aio.client import TCPModbusClient
from tcp_modbus_aio.exceptions import ModbusCommunicationTimeoutError
from tcp_modbus_aio.exceptions import (
ModbusCommunicationFailureError,
ModbusCommunicationTimeoutError,
)
from tcp_modbus_aio.typed_functions import ReadCoils

DIGITAL_IN_COILS = list(range(8))
Expand All @@ -12,7 +13,7 @@

async def example() -> None:

async with TCPModbusClient("192.168.250.207") as conn:
async with TCPModbusClient("192.168.250.207", enforce_pingable=False) as conn:
for _ in range(1000):
for digital_in_coil in DIGITAL_IN_COILS:
example_message = ReadCoils()
Expand All @@ -21,21 +22,13 @@ async def example() -> None:

try:
response = await conn.send_modbus_message(
example_message, retries=0
example_message, timeout=0.02
)
assert response is not None, "we expect a response from ReadCoils"
print(response.data) # noqa: T201
except ModbusCommunicationTimeoutError as e:
print(f"{type(e).__name__}({e})")

for digital_out_coil in DIGITAL_OUT_COILS:
example_message = WriteSingleCoil()
example_message.address = digital_out_coil
example_message.value = False

try:
await conn.send_modbus_message(example_message, retries=0)
except ModbusCommunicationTimeoutError as e:
except ModbusCommunicationFailureError as e:
print(f"{type(e).__name__}({e})")


Expand Down
187 changes: 102 additions & 85 deletions tcp_modbus_aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TCPModbusClient:
KEEPALIVE_MAX_FAILS: ClassVar = 5

PING_LOOP_PERIOD: ClassVar = 1
CONSECUTIVE_TIMEOUTS_TO_RECONNECT: ClassVar = 5

def __init__(
self,
Expand All @@ -66,11 +67,13 @@ def __init__(
*,
logger: logging.Logger | None = None,
enforce_pingable: bool = True,
ping_timeout: float = 0.5,
) -> None:
self.host = host
self.port = port
self.slave_id = slave_id
self.logger = logger
self.ping_timeout = ping_timeout

# If True, will throw an exception if attempting to send a request and the device is not pingable
self.enforce_pingable = enforce_pingable
Expand All @@ -82,6 +85,9 @@ def __init__(
self._reader: asyncio.StreamReader | None = None
self._writer: asyncio.StreamWriter | None = None

# Number of current consecutive modbus calls that resulted in a timeout
self._consecutive_timeouts = 0

# Last ping time in seconds from ping loop, or None if the last ping failed
self._last_ping: float | None = None

Expand Down Expand Up @@ -132,7 +138,7 @@ def __repr__(self) -> str:

async def _ping_loop_task(self) -> None:
while True:
self._last_ping = await ping_ip(self.host)
self._last_ping = await ping_ip(self.host, timeout=self.ping_timeout)

if self.logger is not None:
self.logger.debug(f"[{self}][_ping_loop_task] ping ping ping")
Expand All @@ -143,67 +149,74 @@ async def _ping_loop_task(self) -> None:
async def _get_tcp_connection(
self, timeout: float | None = DEFAULT_MODBUS_TIMEOUT_SEC
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
if self._reader is None or self._writer is None:
self._lifetime_tcp_connection_num += 1
if self._reader is not None and self._writer is not None:
return self._reader, self._writer

if self.logger is not None:
self.logger.info(
f"[{self}][_get_tcp_connection] creating new TCP connection (#{self._lifetime_tcp_connection_num})"
)
self._lifetime_tcp_connection_num += 1

try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host=self.host, port=self.port), timeout
)
if self.logger is not None:
self.logger.info(
f"[{self}][_get_tcp_connection] creating new TCP connection (#{self._lifetime_tcp_connection_num})"
)

sock: socket.socket = writer.get_extra_info("socket")

# Receive and send buffers set to 900 bytes (recommended by MODBUS implementation guide: this is
# becuase the max request size is 256 bytes + the header size of 7 bytes = 263 bytes, and the
# max response size is 256 bytes + the header size of 7 bytes = 263 bytes, so a 900 byte buffer
# can store 3 frames of buffering, which is apparently the suggestion).
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 900)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 900)

# Reuse address (perf optimization, recommended by MODBUS implementation guide)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

# Enable TCP_NODELAY (prevent small packet buffering, recommended by MODBUS implementation guide)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# Enable TCP keepalive (otherwise the Adam connection will terminate after 720 (1000?) seconds
# with an open idle connection: this is also recommended by the MODBUS implementation guide)
#
# In most cases this is not necessary because Adam commands are short lived and we
# close the connection after each command. However, if we want to keep a connection
# open for a long time we would need to enable keepalive.

sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
# Only available on Linux so this makes typing work cross platform
sock.setsockopt(
socket.IPPROTO_TCP,
socket.TCP_KEEPIDLE,
self.KEEPALIVE_AFTER_IDLE_SEC,
)
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host=self.host, port=self.port), timeout
)
except asyncio.TimeoutError:
msg = (
f"Timed out connecting to TCP modbus device at {self.host}:{self.port}"
)
if self.logger is not None:
self.logger.warning(f"[{self}][_get_tcp_connection] {msg}")
raise ModbusCommunicationTimeoutError(msg)
except OSError:
msg = f"Cannot connect to TCP modbus device at {self.host}:{self.port}"
if self.logger is not None:
self.logger.warning(f"[{self}][_get_tcp_connection] {msg}")
raise ModbusNotConnectedError(msg)

sock: socket.socket = writer.get_extra_info("socket")

# Receive and send buffers set to 900 bytes (recommended by MODBUS implementation guide: this is
# becuase the max request size is 256 bytes + the header size of 7 bytes = 263 bytes, and the
# max response size is 256 bytes + the header size of 7 bytes = 263 bytes, so a 900 byte buffer
# can store 3 frames of buffering, which is apparently the suggestion).
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 900)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 900)

# Reuse address (perf optimization, recommended by MODBUS implementation guide)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

# Enable TCP_NODELAY (prevent small packet buffering, recommended by MODBUS implementation guide)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# Enable TCP keepalive (otherwise the Adam connection will terminate after 720 (1000?) seconds
# with an open idle connection: this is also recommended by the MODBUS implementation guide)
#
# In most cases this is not necessary because Adam commands are short lived and we
# close the connection after each command. However, if we want to keep a connection
# open for a long time we would need to enable keepalive.

sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
# Only available on Linux so this makes typing work cross platform
sock.setsockopt(
socket.IPPROTO_TCP,
socket.TCP_KEEPIDLE,
self.KEEPALIVE_AFTER_IDLE_SEC,
)

sock.setsockopt(
socket.IPPROTO_TCP,
socket.TCP_KEEPINTVL,
self.KEEPALIVE_INTERVAL_SEC,
)
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_KEEPCNT, self.KEEPALIVE_MAX_FAILS
)
sock.setsockopt(
socket.IPPROTO_TCP,
socket.TCP_KEEPINTVL,
self.KEEPALIVE_INTERVAL_SEC,
)
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_KEEPCNT, self.KEEPALIVE_MAX_FAILS
)

self._reader, self._writer = reader, writer
except (asyncio.TimeoutError, OSError):
msg = f"Cannot connect to TCP modbus device at {self.host}:{self.port}"
if self.logger is not None:
self.logger.warning(f"[{self}][_get_tcp_connection] {msg}")
raise ModbusNotConnectedError(msg)
else:
reader, writer = self._reader, self._writer
self._reader, self._writer = reader, writer

return reader, writer

Expand All @@ -226,7 +239,7 @@ async def close(self) -> None:
if self._ping_loop is None:
return

await self.clear_tcp_connection()
self.clear_tcp_connection()

if self._ping_loop is not None:
if self.logger is not None:
Expand Down Expand Up @@ -302,12 +315,14 @@ async def _watch_loop() -> None:
_watch_loop(), name=f"TCPModbusClient{log_prefix}"
)

async def clear_tcp_connection(self) -> None:
def clear_tcp_connection(self) -> None:
"""
Closes the current TCP connection and clears the reader and writer objects.
On the next send_modbus_message call, a new connection will be created.
"""

self._consecutive_timeouts = 0

if self._ping_loop is None:
raise RuntimeError("Cannot clear TCP connection on closed TCPModbusClient")

Expand All @@ -319,17 +334,6 @@ async def clear_tcp_connection(self) -> None:

self._writer.close()

try:
await self._writer.wait_closed()
except (TimeoutError, ConnectionResetError, OSError) as e:
if self.logger is not None:
self.logger.warning(
f"[{self}][clear_tcp_connection] {type(e).__name__}({e}) error on connection close, "
"continuing anyway"
)

pass

self._reader = None
self._writer = None

Expand Down Expand Up @@ -434,6 +438,7 @@ async def send_modbus_message(
reader, writer = await self._get_tcp_connection(
timeout=time_budget_remaining
)

time_budget_remaining -= conn_t()

# STEP THREE: WRITE OUR REQUEST
Expand All @@ -447,9 +452,9 @@ async def send_modbus_message(
if self.logger is not None:
self.logger.debug(f"[{self}][send_modbus_message] wrote {msg_str}")

except (asyncio.TimeoutError, OSError, ConnectionResetError):
except OSError: # this includes timeout errors
# Clear connection no matter what if we fail on the write
# TODO: consider revisiting this to only do it on OSError and ConnectionResetError
# TODO: consider revisiting this to not do it on a timeouterror
# (but Gru is scared about partial writes)

if self.logger is not None:
Expand All @@ -458,7 +463,7 @@ async def send_modbus_message(
f"request {msg_str}, clearing connection"
)

await self.clear_tcp_connection()
self.clear_tcp_connection()

if retries > 0:
if self.logger is not None:
Expand Down Expand Up @@ -516,25 +521,37 @@ async def send_modbus_message(
return None

raise
except asyncio.TimeoutError as e:
self._consecutive_timeouts += 1
if self._consecutive_timeouts >= self.CONSECUTIVE_TIMEOUTS_TO_RECONNECT:
if self.logger is not None:
self.logger.warning(
f"[{self}][send_modbus_message] {self._consecutive_timeouts} consecutive timeouts, "
"clearing connection"
)
self.clear_tcp_connection()

except (asyncio.TimeoutError, OSError, ConnectionResetError) as e:
# We clear the connection if the connection was reset by peer or was an OS error
if isinstance(e, (OSError, ConnectionResetError)):
print("CLEARING TCP ON GENERAL FAIL")
await self.clear_tcp_connection()

raise (
ModbusCommunicationTimeoutError
if isinstance(e, asyncio.TimeoutError)
else ModbusCommunicationFailureError
)(
f"Request {msg_str} failed to {self.host}:{self.port} ({type(e).__name__}({e}))"
raise ModbusCommunicationTimeoutError(
f"Request {msg_str} timed out to {self.host}:{self.port}"
) from e
except OSError as e:
if self.logger is not None:
self.logger.warning(
f"[{self}][send_modbus_message] OSError{type(e).__name__}({e}) while sending request {msg_str}, "
"clearing connection"
)

self.clear_tcp_connection()

raise ModbusCommunicationFailureError(
f"Request {msg_str} failed to {self.host}:{self.port} ({type(e).__name__}({e}))"
) from e
finally:
if self._comms_lock.locked():
self._comms_lock.release()

self._consecutive_timeouts = 0

if self.logger is not None:
self.logger.debug(
f"[{self}][send_modbus_message] executed request/response with timing "
Expand Down

0 comments on commit 1830b13

Please sign in to comment.