Skip to content

Commit 52a4c65

Browse files
committed
Support more types in publish payload
1 parent e3d8f58 commit 52a4c65

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

ohmqtt/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def shutdown(self) -> None:
122122
def publish(
123123
self,
124124
topic: str,
125-
payload: bytes,
125+
payload: bytes | bytearray | str,
126126
*,
127127
qos: int | MQTTQoS = MQTTQoS.Q0,
128128
retain: bool = False,
@@ -132,12 +132,16 @@ def publish(
132132
"""Publish a message to a topic.
133133
134134
:param topic: The topic to publish to.
135-
:param payload: The payload of the message.
135+
:param payload: The payload of the message. If a string is provided, it will be encoded as UTF-8.
136136
:param qos: The QoS level for the message (0, 1, or 2).
137137
:param retain: If True, the message will be retained by the broker.
138138
:param properties: Properties for the PUBLISH packet.
139139
:param alias_policy: The policy for using automatic topic aliases.
140140
"""
141+
if isinstance(payload, str):
142+
payload = payload.encode("utf-8")
143+
elif not isinstance(payload, bytes):
144+
payload = bytes(payload)
141145
if not isinstance(qos, MQTTQoS):
142146
qos = MQTTQoS(qos)
143147
properties = properties if properties is not None else None

tests/test_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,16 @@ def test_client_shutdown(mocker: MockerFixture, mock_connection: Mock, mock_hand
7777
mock_connection.shutdown.assert_called_once()
7878

7979

80+
@pytest.mark.parametrize("payload", [b"test_payload", bytearray(b"test_payload"), "test_payload"])
8081
@pytest.mark.parametrize("qos", [0, 1, 2, MQTTQoS.Q0, MQTTQoS.Q1, MQTTQoS.Q2])
81-
def test_client_publish(qos: int | MQTTQoS, mocker: MockerFixture, mock_connection: Mock, mock_handlers: MagicMock,
82-
mock_session: Mock, mock_subscriptions: Mock) -> None:
82+
def test_client_publish(payload: bytes | bytearray | str, qos: int | MQTTQoS, mocker: MockerFixture,
83+
mock_connection: Mock, mock_handlers: MagicMock, mock_session: Mock, mock_subscriptions: Mock) -> None:
8384
client = Client()
8485

8586
mock_session.publish.return_value = mocker.Mock()
8687
publish_handle = client.publish(
8788
"test/topic",
88-
b"test_payload",
89+
payload,
8990
qos=qos,
9091
retain=True,
9192
properties=MQTTPublishProps(
@@ -98,9 +99,10 @@ def test_client_publish(qos: int | MQTTQoS, mocker: MockerFixture, mock_connecti
9899
)
99100
assert publish_handle == mock_session.publish.return_value
100101
expected_qos = MQTTQoS(qos) if not isinstance(qos, MQTTQoS) else qos
102+
expected_payload = payload.encode("utf-8") if isinstance(payload, str) else bytes(payload)
101103
mock_session.publish.assert_called_once_with(
102104
"test/topic",
103-
b"test_payload",
105+
expected_payload,
104106
qos=expected_qos,
105107
retain=True,
106108
properties=MQTTPublishProps(

0 commit comments

Comments
 (0)