Skip to content

Commit 77fd330

Browse files
Support message router (#265)
1 parent 4253d63 commit 77fd330

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

pulsar/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"""
4444

4545
import logging
46-
from typing import List, Tuple, Optional, Union
46+
from typing import Callable, List, Tuple, Optional, Union
4747

4848
import _pulsar
4949

@@ -54,6 +54,7 @@
5454
from pulsar.__about__ import __version__
5555

5656
from pulsar.exceptions import *
57+
from pulsar.schema.schema import BytesSchema
5758
from pulsar.tableview import TableView
5859

5960
from pulsar.functions.function import Function
@@ -246,6 +247,7 @@ def schema_version(self):
246247
@staticmethod
247248
def _wrap(_message):
248249
self = Message()
250+
self._schema = BytesSchema()
249251
self._message = _message
250252
return self
251253

@@ -696,6 +698,7 @@ def create_producer(self, topic,
696698
encryption_key=None,
697699
crypto_key_reader: Union[None, CryptoKeyReader] = None,
698700
access_mode: ProducerAccessMode = ProducerAccessMode.Shared,
701+
message_router: Callable[[Message, int], int]=None,
699702
):
700703
"""
701704
Create a new producer on a given topic.
@@ -811,6 +814,10 @@ def create_producer(self, topic,
811814
* WaitForExclusive: Producer creation is pending until it can acquire exclusive access.
812815
* ExclusiveWithFencing: Acquire exclusive access for the producer.
813816
Any existing producer will be removed and invalidated immediately.
817+
message_router: optional
818+
A custom message router function that takes a `Message` and the number of partitions
819+
and returns the partition index to which the message should be routed. If not provided,
820+
the default routing policy defined by `message_routing_mode` will be used.
814821
"""
815822
_check_type(str, topic, 'topic')
816823
_check_type_or_none(str, producer_name, 'producer_name')
@@ -848,6 +855,10 @@ def create_producer(self, topic,
848855
conf.chunking_enabled(chunking_enabled)
849856
conf.lazy_start_partitioned_producers(lazy_start_partitioned_producers)
850857
conf.access_mode(access_mode)
858+
if message_router is not None:
859+
underlying_router = lambda msg, num_partitions: int(message_router(Message._wrap(msg), num_partitions))
860+
conf.message_router(underlying_router)
861+
851862
if producer_name:
852863
conf.producer_name(producer_name)
853864
if initial_sequence_id:

src/config.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <pybind11/functional.h>
2727
#include <pybind11/pybind11.h>
2828
#include <pybind11/stl.h>
29+
#include <functional>
2930
#include <memory>
3031

3132
namespace py = pybind11;
@@ -104,6 +105,19 @@ class HIDDEN LoggerWrapperFactory : public LoggerFactory, public CaptivePythonOb
104105
}
105106
};
106107

108+
using MessageRouterFunc = std::function<int(const Message&, int)>;
109+
class HIDDEN MessageRouter : public pulsar::MessageRoutingPolicy {
110+
public:
111+
explicit MessageRouter(MessageRouterFunc func) : func_(std::move(func)) {}
112+
113+
int getPartition(const Message& msg, const TopicMetadata& topicMetadata) final {
114+
return func_(msg, topicMetadata.getNumPartitions());
115+
}
116+
117+
private:
118+
MessageRouterFunc func_;
119+
};
120+
107121
static ClientConfiguration& ClientConfiguration_setLogger(ClientConfiguration& conf, py::object logger) {
108122
conf.setLogger(new LoggerWrapperFactory(logger));
109123
return conf;
@@ -235,7 +249,10 @@ void export_config(py::module_& m) {
235249
.def("encryption_key", &ProducerConfiguration::addEncryptionKey, return_value_policy::reference)
236250
.def("crypto_key_reader", &ProducerConfiguration::setCryptoKeyReader, return_value_policy::reference)
237251
.def("access_mode", &ProducerConfiguration::setAccessMode, return_value_policy::reference)
238-
.def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy);
252+
.def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy)
253+
.def("message_router", [](ProducerConfiguration& config, MessageRouterFunc func) {
254+
config.setMessageRouter(std::make_shared<MessageRouter>(std::move(func)));
255+
});
239256

240257
class_<BatchReceivePolicy>(m, "BatchReceivePolicy")
241258
.def(init<int, int, long>())

tests/pulsar_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,5 +2019,37 @@ def test_deserialize_msg_id_with_topic(self):
20192019
self.assertEqual(msg.value(), b'msg-3')
20202020
client.close()
20212021

2022+
def test_message_router(self):
2023+
topic_name = "public/default/test_message_router" + str(time.time())
2024+
url1 = self.adminUrl + "/admin/v2/persistent/" + topic_name + "/partitions"
2025+
doHttpPut(url1, "5")
2026+
client = Client(self.serviceUrl)
2027+
def router(msg: pulsar.Message, num_partitions: int):
2028+
s = msg.value().decode('utf-8')
2029+
if s.startswith("hello-"):
2030+
return 10 % num_partitions
2031+
else:
2032+
return 11 % num_partitions
2033+
producer = client.create_producer(topic_name, message_router=router)
2034+
producer.send(b"hello-0")
2035+
producer.send(b"hello-1")
2036+
producer.send(b"world-0")
2037+
producer.send(b"world-1")
2038+
consumer = client.subscribe(topic_name, 'sub',
2039+
initial_position=InitialPosition.Earliest)
2040+
partition_to_values = dict()
2041+
for _ in range(4):
2042+
msg = consumer.receive(TM)
2043+
partition = msg.message_id().partition()
2044+
if partition in partition_to_values:
2045+
partition_to_values[partition].append(msg.value().decode('utf-8'))
2046+
else:
2047+
partition_to_values[partition] = [msg.value().decode('utf-8')]
2048+
self.assertEqual(partition_to_values[0], ["hello-0", "hello-1"])
2049+
self.assertEqual(partition_to_values[1], ["world-0", "world-1"])
2050+
2051+
client.close()
2052+
2053+
20222054
if __name__ == "__main__":
20232055
main()

0 commit comments

Comments
 (0)