From fc70a41b79fafb07e1101efc9ec77a2f8180d591 Mon Sep 17 00:00:00 2001 From: Andrej Krpic Date: Mon, 25 Mar 2024 11:45:21 +0100 Subject: [PATCH 1/2] add support for "unix" transport where socket module contains AF_UNIX --- src/paho/mqtt/client.py | 25 ++++++++++++++++++++----- tests/test_client.py | 29 ++++++++++++++++++++--------- tests/testsupport/broker.py | 36 ++++++++++++++++++++++++++---------- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index cd607938..af7675ee 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -682,6 +682,10 @@ class Client: :param transport: use "websockets" to use WebSockets as the transport mechanism. Set to "tcp" to use raw TCP, which is the default. + Use "unix" to use Unix sockets as the transport mechanism; note that + this option is only available on platforms that support Unix sockets, + and the "host" argument is interpreted as the path to the Unix socket + file in this case. :param bool manual_ack: normally, when a message is received, the library automatically acknowledges after on_message callback returns. manual_ack=True allows the application to @@ -733,14 +737,16 @@ def __init__( clean_session: bool | None = None, userdata: Any = None, protocol: MQTTProtocolVersion = MQTTv311, - transport: Literal["tcp", "websockets"] = "tcp", + transport: Literal["tcp", "websockets", "unix"] = "tcp", reconnect_on_failure: bool = True, manual_ack: bool = False, ) -> None: transport = transport.lower() # type: ignore - if transport not in ("websockets", "tcp"): + if transport == "unix" and not hasattr(socket, "AF_UNIX"): + raise ValueError('"unix" transport not supported') + elif transport not in ("websockets", "tcp", "unix"): raise ValueError( - f'transport must be "websockets" or "tcp", not {transport}') + f'transport must be "websockets", "tcp" or "unix", not {transport}') self._manual_ack = manual_ack self._transport = transport @@ -931,7 +937,7 @@ def keepalive(self, value: int) -> None: self._keepalive = value @property - def transport(self) -> Literal["tcp", "websockets"]: + def transport(self) -> Literal["tcp", "websockets", "unix"]: """ Transport method used for the connection ("tcp" or "websockets"). @@ -4595,7 +4601,11 @@ def _get_proxy(self) -> dict[str, Any] | None: return None def _create_socket(self) -> SocketLike: - sock = self._create_socket_connection() + if self._transport == "unix": + sock = self._create_unix_socket_connection() + else: + sock = self._create_socket_connection() + if self._ssl: sock = self._ssl_wrap_socket(sock) @@ -4612,6 +4622,11 @@ def _create_socket(self) -> SocketLike: return sock + def _create_unix_socket_connection(self) -> _socket.socket: + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + unix_socket.connect(self._host) + return unix_socket + def _create_socket_connection(self) -> _socket.socket: proxy = self._get_proxy() addr = (self._host, self._port) diff --git a/tests/test_client.py b/tests/test_client.py index 33e22f56..4868c51a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,6 +31,7 @@ def test_01_con_discon_success(self, proto_ver, callback_version, fake_broker): callback_version, "01-con-discon-success", protocol=proto_ver, + transport=fake_broker.transport, ) def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None): @@ -70,7 +71,8 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None): def test_01_con_failure_rc(self, proto_ver, callback_version, fake_broker): mqttc = client.Client( - callback_version, "01-con-failure-rc", protocol=proto_ver) + callback_version, "01-con-failure-rc", + protocol=proto_ver, transport=fake_broker.transport) def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None): assert rc_or_reason_code > 0 @@ -107,7 +109,9 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None): mqttc.loop_stop() def test_connection_properties(self, proto_ver, callback_version, fake_broker): - mqttc = client.Client(CallbackAPIVersion.VERSION2, "client-id", protocol=proto_ver) + mqttc = client.Client( + CallbackAPIVersion.VERSION2, "client-id", + protocol=proto_ver, transport=fake_broker.transport) mqttc.enable_logger() is_connected = threading.Event() @@ -131,7 +135,7 @@ def on_disconnect(*args): mqttc.keepalive = 7 mqttc.max_inflight_messages = 7 mqttc.max_queued_messages = 7 - mqttc.transport = "tcp" + mqttc.transport = fake_broker.transport mqttc.username = "username" mqttc.password = "password" @@ -184,7 +188,7 @@ def on_disconnect(*args): mqttc.max_queued_messages = 7 with pytest.raises(RuntimeError): - mqttc.transport = "tcp" + mqttc.transport = fake_broker.transport with pytest.raises(RuntimeError): mqttc.username = "username" @@ -217,7 +221,9 @@ class Test_connect_v5: """ def test_01_broker_no_support(self, fake_broker): - mqttc = client.Client(CallbackAPIVersion.VERSION2, "01-broker-no-support", protocol=MQTTProtocolVersion.MQTTv5) + mqttc = client.Client( + CallbackAPIVersion.VERSION2, "01-broker-no-support", + protocol=MQTTProtocolVersion.MQTTv5, transport=fake_broker.transport) def on_connect(mqttc, obj, flags, reason, properties): assert reason == 132 @@ -261,6 +267,7 @@ def test_with_loop_start(self, fake_broker: FakeBroker): "test_with_loop_start", protocol=MQTTProtocolVersion.MQTTv311, reconnect_on_failure=False, + transport=fake_broker.transport ) on_connect_reached = threading.Event() @@ -311,6 +318,7 @@ def test_with_loop(self, fake_broker: FakeBroker): CallbackAPIVersion.VERSION1, "test_with_loop", clean_session=True, + transport=fake_broker.transport, ) on_connect_reached = threading.Event() @@ -367,6 +375,7 @@ def test_publish_before_connect(self, fake_broker: FakeBroker) -> None: mqttc = client.Client( CallbackAPIVersion.VERSION1, "test_publish_before_connect", + transport=fake_broker.transport, ) def on_connect(mqttc, obj, flags, rc): @@ -424,7 +433,7 @@ def on_connect(mqttc, obj, flags, rc): ]) class TestPublishBroker2Client: def test_invalid_utf8_topic(self, callback_version, fake_broker): - mqttc = client.Client(callback_version, "client-id") + mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport) def on_message(client, userdata, msg): with pytest.raises(UnicodeDecodeError): @@ -466,7 +475,7 @@ def on_message(client, userdata, msg): assert not packet_in # Check connection is closed def test_valid_utf8_topic_recv(self, callback_version, fake_broker): - mqttc = client.Client(callback_version, "client-id") + mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport) # It should be non-ascii multi-bytes character topic = unicodedata.lookup('SNOWMAN') @@ -512,7 +521,7 @@ def on_message(client, userdata, msg): assert not packet_in # Check connection is closed def test_valid_utf8_topic_publish(self, callback_version, fake_broker): - mqttc = client.Client(callback_version, "client-id") + mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport) # It should be non-ascii multi-bytes character topic = unicodedata.lookup('SNOWMAN') @@ -558,7 +567,7 @@ def test_valid_utf8_topic_publish(self, callback_version, fake_broker): assert not packet_in # Check connection is closed def test_message_callback(self, callback_version, fake_broker): - mqttc = client.Client(callback_version, "client-id") + mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport) userdata = { 'on_message': 0, 'callback1': 0, @@ -698,6 +707,7 @@ def test_callback_v1_mqtt3(self, fake_broker): CallbackAPIVersion.VERSION1, "client-id", userdata=callback_called, + transport=fake_broker.transport, ) def on_connect(cl, userdata, flags, rc): @@ -823,6 +833,7 @@ def test_callback_v2_mqtt3(self, fake_broker): CallbackAPIVersion.VERSION2, "client-id", userdata=callback_called, + transport=fake_broker.transport, ) def on_connect(cl, userdata, flags, reason, properties): diff --git a/tests/testsupport/broker.py b/tests/testsupport/broker.py index fb25f918..d81ddfb5 100644 --- a/tests/testsupport/broker.py +++ b/tests/testsupport/broker.py @@ -2,6 +2,7 @@ import socket import socketserver import threading +import os import pytest @@ -9,18 +10,27 @@ class FakeBroker: - def __init__(self): - # Bind to "localhost" for maximum performance, as described in: - # http://docs.python.org/howto/sockets.html#ipc - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + def __init__(self, transport): + if transport == "tcp": + # Bind to "localhost" for maximum performance, as described in: + # http://docs.python.org/howto/sockets.html#ipc + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", 0)) + self.port = sock.getsockname()[1] + elif transport == "unix": + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind("localhost") + self.port = 1883 + else: + raise ValueError(f"unsupported transport {transport}") + sock.settimeout(5) - sock.bind(("localhost", 0)) - self.port = sock.getsockname()[1] sock.listen(1) self._sock = sock self._conn = None + self.transport = transport def start(self): if self._sock is None: @@ -39,6 +49,12 @@ def finish(self): self._sock.close() self._sock = None + if self.transport == 'unix': + try: + os.unlink('localhost') + except OSError: + pass + def receive_packet(self, num_bytes): if self._conn is None: raise ValueError('Connection is not open') @@ -60,10 +76,10 @@ def expect_packet(self, name, packet): paho_test.expect_packet(self._conn, name, packet) -@pytest.fixture -def fake_broker(): +@pytest.fixture(params=["tcp"] + (["unix"] if hasattr(socket, 'AF_UNIX') else [])) +def fake_broker(request): # print('Setup broker') - broker = FakeBroker() + broker = FakeBroker(request.param) yield broker From a0554dd9cae15429f2cbb9eca3453040035df912 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Mon, 29 Apr 2024 19:56:17 +0200 Subject: [PATCH 2/2] Fix linter report --- tests/testsupport/broker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testsupport/broker.py b/tests/testsupport/broker.py index d81ddfb5..e08cf73e 100644 --- a/tests/testsupport/broker.py +++ b/tests/testsupport/broker.py @@ -1,8 +1,8 @@ import contextlib +import os import socket import socketserver import threading -import os import pytest