Skip to content

Commit 52cfb0b

Browse files
authored
Paho Client Update (#350)
* Update Paho Client to v1.2
1 parent 057fda5 commit 52cfb0b

File tree

1 file changed

+46
-38
lines changed

1 file changed

+46
-38
lines changed

AWSIoTPythonSDK/core/protocol/paho/client.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@
4949
from AWSIoTPythonSDK.core.protocol.connection.cores import SecuredWebSocketCore
5050
from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder
5151

52-
VERSION_MAJOR=1
53-
VERSION_MINOR=0
54-
VERSION_REVISION=0
55-
VERSION_NUMBER=(VERSION_MAJOR*1000000+VERSION_MINOR*1000+VERSION_REVISION)
56-
5752
MQTTv31 = 3
5853
MQTTv311 = 4
5954

@@ -497,6 +492,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT
497492
self._msgtime_mutex = threading.Lock()
498493
self._out_message_mutex = threading.Lock()
499494
self._in_message_mutex = threading.Lock()
495+
self._mid_generate_mutex = threading.Lock()
500496
self._thread = None
501497
self._thread_terminate = False
502498
self._ssl = None
@@ -515,7 +511,8 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT
515511
self._alpn_protocols = None
516512

517513
def __del__(self):
518-
pass
514+
# Closes socket in client destructor to avoid FD leak.
515+
self._reset_sockets()
519516

520517

521518
def setBackoffTiming(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond):
@@ -547,7 +544,8 @@ def config_alpn_protocols(self, alpn_protocols):
547544
"""
548545
self._alpn_protocols = alpn_protocols
549546

550-
def reinitialise(self, client_id="", clean_session=True, userdata=None):
547+
# Closes socket in client destructor to avoid FD leak.
548+
def _reset_sockets(self):
551549
if self._ssl:
552550
self._ssl.close()
553551
self._ssl = None
@@ -562,6 +560,9 @@ def reinitialise(self, client_id="", clean_session=True, userdata=None):
562560
self._sockpairW.close()
563561
self._sockpairW = None
564562

563+
# Closes socket in client destructor to avoid FD leak.
564+
def reinitialise(self, client_id="", clean_session=True, userdata=None):
565+
self._reset_sockets()
565566
self.__init__(client_id, clean_session, userdata)
566567

567568
def tls_set(self, ca_certs, certfile=None, keyfile=None, cert_reqs=cert_reqs, tls_version=tls_version, ciphers=None):
@@ -831,24 +832,14 @@ def reconnect(self):
831832
verify_hostname = False # Since check_hostname in SSLContext is already set to True, no need to verify it again
832833
self._ssl.do_handshake()
833834
else:
834-
if force_ssl_context:
835-
ssl_context = ssl.SSLContext(self._tls_version)
836-
ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile)
837-
ssl_context.load_verify_locations(self._tls_ca_certs)
838-
ssl_context.verify_mode = self._tls_cert_reqs
839-
if self._tls_ciphers is not None:
840-
ssl_context.set_ciphers(self._tls_ciphers)
841-
842-
self._ssl = ssl_context.wrap_socket(sock)
843-
else:
844-
self._ssl = ssl.wrap_socket(
845-
sock,
846-
certfile=self._tls_certfile,
847-
keyfile=self._tls_keyfile,
848-
ca_certs=self._tls_ca_certs,
849-
cert_reqs=self._tls_cert_reqs,
850-
ssl_version=self._tls_version,
851-
ciphers=self._tls_ciphers)
835+
# ssl.wrap_socket is deprecated in Python 3.7+. Use SSLContext instead.
836+
ssl_context = ssl.SSLContext(self._tls_version)
837+
ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile)
838+
ssl_context.load_verify_locations(self._tls_ca_certs)
839+
ssl_context.verify_mode = self._tls_cert_reqs
840+
if self._tls_ciphers is not None:
841+
ssl_context.set_ciphers(self._tls_ciphers)
842+
self._ssl = ssl_context.wrap_socket(sock)
852843

853844
if verify_hostname:
854845
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 5): # No IP host match before 3.5.x
@@ -924,6 +915,9 @@ def loop(self, timeout=1.0, max_packets=1):
924915
# Can occur if we just reconnected but rlist/wlist contain a -1 for
925916
# some reason.
926917
return MQTT_ERR_CONN_LOST
918+
except KeyboardInterrupt:
919+
# Allow ^C to interrupt
920+
raise
927921
except:
928922
return MQTT_ERR_UNKNOWN
929923

@@ -981,8 +975,9 @@ def publish(self, topic, payload=None, qos=0, retain=False):
981975
raise ValueError('Invalid QoS level.')
982976
if isinstance(payload, str) or isinstance(payload, bytearray):
983977
local_payload = payload
984-
elif sys.version_info[0] < 3 and isinstance(payload, unicode):
985-
local_payload = payload
978+
# Client.publish() now accepts bytes() payloads on Python 3.
979+
elif sys.version_info[0] == 3 and isinstance(payload, bytes):
980+
local_payload = bytearray(payload)
986981
elif isinstance(payload, int) or isinstance(payload, float):
987982
local_payload = str(payload)
988983
elif payload is None:
@@ -1047,10 +1042,15 @@ def username_pw_set(self, username, password=None):
10471042
Requires a broker that supports MQTT v3.1.
10481043
10491044
username: The username to authenticate with. Need have no relationship to the client id.
1045+
[MQTT-3.1.3-11].
1046+
Set to None to reset client back to not using username/password for broker authentication.
10501047
password: The password to authenticate with. Optional, set to None if not required.
10511048
"""
1052-
self._username = username.encode('utf-8')
1049+
# [MQTT-3.1.3-11] User name must be UTF-8 encoded string
1050+
self._username = None if username is None else username.encode('utf-8')
10531051
self._password = password
1052+
if isinstance(self._password, str):
1053+
self._password = self._password.encode('utf-8')
10541054

10551055
def socket_factory_set(self, socket_factory):
10561056
"""Set a socket factory to custom configure a different socket type for
@@ -1117,7 +1117,7 @@ def subscribe(self, topic, qos=0):
11171117
zero string length, or if topic is not a string, tuple or list.
11181118
"""
11191119
topic_qos_list = None
1120-
if isinstance(topic, str):
1120+
if isinstance(topic, str) :
11211121
if qos<0 or qos>2:
11221122
raise ValueError('Invalid QoS level.')
11231123
if topic is None or len(topic) == 0:
@@ -1165,7 +1165,7 @@ def unsubscribe(self, topic):
11651165
topic_list = None
11661166
if topic is None:
11671167
raise ValueError('Invalid topic.')
1168-
if isinstance(topic, str):
1168+
if isinstance(topic, str) :
11691169
if len(topic) == 0:
11701170
raise ValueError('Invalid topic.')
11711171
topic_list = [topic.encode('utf-8')]
@@ -1453,8 +1453,10 @@ def loop_stop(self, force=False):
14531453
return MQTT_ERR_INVAL
14541454

14551455
self._thread_terminate = True
1456-
self._thread.join()
1457-
self._thread = None
1456+
# Don't attempt to join() own thread.
1457+
if threading.current_thread() != self._thread:
1458+
self._thread.join()
1459+
self._thread = None
14581460

14591461
def message_callback_add(self, sub, callback):
14601462
"""Register a message callback for a specific topic.
@@ -1704,6 +1706,10 @@ def _easy_log(self, level, buf):
17041706
self.on_log(self, self._userdata, level, buf)
17051707

17061708
def _check_keepalive(self):
1709+
# Fix for keepalive=0 causing an infinite disconnect/reconnect loop.
1710+
if self._keepalive == 0:
1711+
return MQTT_ERR_SUCCESS
1712+
17071713
now = time.time()
17081714
self._msgtime_mutex.acquire()
17091715
last_msg_out = self._last_msg_out
@@ -1736,10 +1742,12 @@ def _check_keepalive(self):
17361742
self._callback_mutex.release()
17371743

17381744
def _mid_generate(self):
1739-
self._last_mid = self._last_mid + 1
1740-
if self._last_mid == 65536:
1741-
self._last_mid = 1
1742-
return self._last_mid
1745+
# Make sure mid generation that was thread-safe.
1746+
with self._mid_generate_mutex:
1747+
self._last_mid += 1
1748+
if self._last_mid == 65536:
1749+
self._last_mid = 1
1750+
return self._last_mid
17431751

17441752
def _topic_wildcard_len_check(self, topic):
17451753
# Search for + or # in a topic. Return MQTT_ERR_INVAL if found.
@@ -1903,11 +1911,11 @@ def _send_connect(self, keepalive, clean_session):
19031911
connect_flags = connect_flags | 0x04 | ((self._will_qos&0x03) << 3) | ((self._will_retain&0x01) << 5)
19041912

19051913
if self._username:
1906-
remaining_length = remaining_length + 2+len(self._username)
1914+
remaining_length += 2+len(self._username)
19071915
connect_flags = connect_flags | 0x80
19081916
if self._password:
19091917
connect_flags = connect_flags | 0x40
1910-
remaining_length = remaining_length + 2+len(self._password)
1918+
remaining_length += 2+len(self._password)
19111919

19121920
command = CONNECT
19131921
packet = bytearray()

0 commit comments

Comments
 (0)