4949from AWSIoTPythonSDK .core .protocol .connection .cores import SecuredWebSocketCore
5050from AWSIoTPythonSDK .core .protocol .connection .alpn import SSLContextBuilder
5151
52- VERSION_MAJOR = 1
53- VERSION_MINOR = 0
54- VERSION_REVISION = 0
55- VERSION_NUMBER = (VERSION_MAJOR * 1000000 + VERSION_MINOR * 1000 + VERSION_REVISION )
56-
5752MQTTv31 = 3
5853MQTTv311 = 4
5954
@@ -497,6 +492,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT
497492 self ._msgtime_mutex = threading .Lock ()
498493 self ._out_message_mutex = threading .Lock ()
499494 self ._in_message_mutex = threading .Lock ()
495+ self ._mid_generate_mutex = threading .Lock ()
500496 self ._thread = None
501497 self ._thread_terminate = False
502498 self ._ssl = None
@@ -515,7 +511,8 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT
515511 self ._alpn_protocols = None
516512
517513 def __del__ (self ):
518- pass
514+ # Closes socket in client destructor to avoid FD leak.
515+ self ._reset_sockets ()
519516
520517
521518 def setBackoffTiming (self , srcBaseReconnectTimeSecond , srcMaximumReconnectTimeSecond , srcMinimumConnectTimeSecond ):
@@ -547,7 +544,8 @@ def config_alpn_protocols(self, alpn_protocols):
547544 """
548545 self ._alpn_protocols = alpn_protocols
549546
550- def reinitialise (self , client_id = "" , clean_session = True , userdata = None ):
547+ # Closes socket in client destructor to avoid FD leak.
548+ def _reset_sockets (self ):
551549 if self ._ssl :
552550 self ._ssl .close ()
553551 self ._ssl = None
@@ -562,6 +560,9 @@ def reinitialise(self, client_id="", clean_session=True, userdata=None):
562560 self ._sockpairW .close ()
563561 self ._sockpairW = None
564562
563+ # Closes socket in client destructor to avoid FD leak.
564+ def reinitialise (self , client_id = "" , clean_session = True , userdata = None ):
565+ self ._reset_sockets ()
565566 self .__init__ (client_id , clean_session , userdata )
566567
567568 def tls_set (self , ca_certs , certfile = None , keyfile = None , cert_reqs = cert_reqs , tls_version = tls_version , ciphers = None ):
@@ -831,24 +832,14 @@ def reconnect(self):
831832 verify_hostname = False # Since check_hostname in SSLContext is already set to True, no need to verify it again
832833 self ._ssl .do_handshake ()
833834 else :
834- if force_ssl_context :
835- ssl_context = ssl .SSLContext (self ._tls_version )
836- ssl_context .load_cert_chain (self ._tls_certfile , self ._tls_keyfile )
837- ssl_context .load_verify_locations (self ._tls_ca_certs )
838- ssl_context .verify_mode = self ._tls_cert_reqs
839- if self ._tls_ciphers is not None :
840- ssl_context .set_ciphers (self ._tls_ciphers )
841-
842- self ._ssl = ssl_context .wrap_socket (sock )
843- else :
844- self ._ssl = ssl .wrap_socket (
845- sock ,
846- certfile = self ._tls_certfile ,
847- keyfile = self ._tls_keyfile ,
848- ca_certs = self ._tls_ca_certs ,
849- cert_reqs = self ._tls_cert_reqs ,
850- ssl_version = self ._tls_version ,
851- ciphers = self ._tls_ciphers )
835+ # ssl.wrap_socket is deprecated in Python 3.7+. Use SSLContext instead.
836+ ssl_context = ssl .SSLContext (self ._tls_version )
837+ ssl_context .load_cert_chain (self ._tls_certfile , self ._tls_keyfile )
838+ ssl_context .load_verify_locations (self ._tls_ca_certs )
839+ ssl_context .verify_mode = self ._tls_cert_reqs
840+ if self ._tls_ciphers is not None :
841+ ssl_context .set_ciphers (self ._tls_ciphers )
842+ self ._ssl = ssl_context .wrap_socket (sock )
852843
853844 if verify_hostname :
854845 if sys .version_info [0 ] < 3 or (sys .version_info [0 ] == 3 and sys .version_info [1 ] < 5 ): # No IP host match before 3.5.x
@@ -924,6 +915,9 @@ def loop(self, timeout=1.0, max_packets=1):
924915 # Can occur if we just reconnected but rlist/wlist contain a -1 for
925916 # some reason.
926917 return MQTT_ERR_CONN_LOST
918+ except KeyboardInterrupt :
919+ # Allow ^C to interrupt
920+ raise
927921 except :
928922 return MQTT_ERR_UNKNOWN
929923
@@ -981,8 +975,9 @@ def publish(self, topic, payload=None, qos=0, retain=False):
981975 raise ValueError ('Invalid QoS level.' )
982976 if isinstance (payload , str ) or isinstance (payload , bytearray ):
983977 local_payload = payload
984- elif sys .version_info [0 ] < 3 and isinstance (payload , unicode ):
985- local_payload = payload
978+ # Client.publish() now accepts bytes() payloads on Python 3.
979+ elif sys .version_info [0 ] == 3 and isinstance (payload , bytes ):
980+ local_payload = bytearray (payload )
986981 elif isinstance (payload , int ) or isinstance (payload , float ):
987982 local_payload = str (payload )
988983 elif payload is None :
@@ -1047,10 +1042,15 @@ def username_pw_set(self, username, password=None):
10471042 Requires a broker that supports MQTT v3.1.
10481043
10491044 username: The username to authenticate with. Need have no relationship to the client id.
1045+ [MQTT-3.1.3-11].
1046+ Set to None to reset client back to not using username/password for broker authentication.
10501047 password: The password to authenticate with. Optional, set to None if not required.
10511048 """
1052- self ._username = username .encode ('utf-8' )
1049+ # [MQTT-3.1.3-11] User name must be UTF-8 encoded string
1050+ self ._username = None if username is None else username .encode ('utf-8' )
10531051 self ._password = password
1052+ if isinstance (self ._password , str ):
1053+ self ._password = self ._password .encode ('utf-8' )
10541054
10551055 def socket_factory_set (self , socket_factory ):
10561056 """Set a socket factory to custom configure a different socket type for
@@ -1117,7 +1117,7 @@ def subscribe(self, topic, qos=0):
11171117 zero string length, or if topic is not a string, tuple or list.
11181118 """
11191119 topic_qos_list = None
1120- if isinstance (topic , str ):
1120+ if isinstance (topic , str ) :
11211121 if qos < 0 or qos > 2 :
11221122 raise ValueError ('Invalid QoS level.' )
11231123 if topic is None or len (topic ) == 0 :
@@ -1165,7 +1165,7 @@ def unsubscribe(self, topic):
11651165 topic_list = None
11661166 if topic is None :
11671167 raise ValueError ('Invalid topic.' )
1168- if isinstance (topic , str ):
1168+ if isinstance (topic , str ) :
11691169 if len (topic ) == 0 :
11701170 raise ValueError ('Invalid topic.' )
11711171 topic_list = [topic .encode ('utf-8' )]
@@ -1453,8 +1453,10 @@ def loop_stop(self, force=False):
14531453 return MQTT_ERR_INVAL
14541454
14551455 self ._thread_terminate = True
1456- self ._thread .join ()
1457- self ._thread = None
1456+ # Don't attempt to join() own thread.
1457+ if threading .current_thread () != self ._thread :
1458+ self ._thread .join ()
1459+ self ._thread = None
14581460
14591461 def message_callback_add (self , sub , callback ):
14601462 """Register a message callback for a specific topic.
@@ -1704,6 +1706,10 @@ def _easy_log(self, level, buf):
17041706 self .on_log (self , self ._userdata , level , buf )
17051707
17061708 def _check_keepalive (self ):
1709+ # Fix for keepalive=0 causing an infinite disconnect/reconnect loop.
1710+ if self ._keepalive == 0 :
1711+ return MQTT_ERR_SUCCESS
1712+
17071713 now = time .time ()
17081714 self ._msgtime_mutex .acquire ()
17091715 last_msg_out = self ._last_msg_out
@@ -1736,10 +1742,12 @@ def _check_keepalive(self):
17361742 self ._callback_mutex .release ()
17371743
17381744 def _mid_generate (self ):
1739- self ._last_mid = self ._last_mid + 1
1740- if self ._last_mid == 65536 :
1741- self ._last_mid = 1
1742- return self ._last_mid
1745+ # Make sure mid generation that was thread-safe.
1746+ with self ._mid_generate_mutex :
1747+ self ._last_mid += 1
1748+ if self ._last_mid == 65536 :
1749+ self ._last_mid = 1
1750+ return self ._last_mid
17431751
17441752 def _topic_wildcard_len_check (self , topic ):
17451753 # Search for + or # in a topic. Return MQTT_ERR_INVAL if found.
@@ -1903,11 +1911,11 @@ def _send_connect(self, keepalive, clean_session):
19031911 connect_flags = connect_flags | 0x04 | ((self ._will_qos & 0x03 ) << 3 ) | ((self ._will_retain & 0x01 ) << 5 )
19041912
19051913 if self ._username :
1906- remaining_length = remaining_length + 2 + len (self ._username )
1914+ remaining_length += 2 + len (self ._username )
19071915 connect_flags = connect_flags | 0x80
19081916 if self ._password :
19091917 connect_flags = connect_flags | 0x40
1910- remaining_length = remaining_length + 2 + len (self ._password )
1918+ remaining_length += 2 + len (self ._password )
19111919
19121920 command = CONNECT
19131921 packet = bytearray ()
0 commit comments