Skip to content

Support message router #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pulsar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"""

import logging
from typing import List, Tuple, Optional, Union
from typing import Callable, List, Tuple, Optional, Union

import _pulsar

Expand All @@ -54,6 +54,7 @@
from pulsar.__about__ import __version__

from pulsar.exceptions import *
from pulsar.schema.schema import BytesSchema
from pulsar.tableview import TableView

from pulsar.functions.function import Function
Expand Down Expand Up @@ -246,6 +247,7 @@ def schema_version(self):
@staticmethod
def _wrap(_message):
self = Message()
self._schema = BytesSchema()
self._message = _message
return self

Expand Down Expand Up @@ -696,6 +698,7 @@ def create_producer(self, topic,
encryption_key=None,
crypto_key_reader: Union[None, CryptoKeyReader] = None,
access_mode: ProducerAccessMode = ProducerAccessMode.Shared,
message_router: Callable[[Message, int], int]=None,
):
"""
Create a new producer on a given topic.
Expand Down Expand Up @@ -811,6 +814,10 @@ def create_producer(self, topic,
* WaitForExclusive: Producer creation is pending until it can acquire exclusive access.
* ExclusiveWithFencing: Acquire exclusive access for the producer.
Any existing producer will be removed and invalidated immediately.
message_router: optional
A custom message router function that takes a `Message` and the number of partitions
and returns the partition index to which the message should be routed. If not provided,
the default routing policy defined by `message_routing_mode` will be used.
"""
_check_type(str, topic, 'topic')
_check_type_or_none(str, producer_name, 'producer_name')
Expand Down Expand Up @@ -848,6 +855,10 @@ def create_producer(self, topic,
conf.chunking_enabled(chunking_enabled)
conf.lazy_start_partitioned_producers(lazy_start_partitioned_producers)
conf.access_mode(access_mode)
if message_router is not None:
underlying_router = lambda msg, num_partitions: int(message_router(Message._wrap(msg), num_partitions))
conf.message_router(underlying_router)

if producer_name:
conf.producer_name(producer_name)
if initial_sequence_id:
Expand Down
19 changes: 18 additions & 1 deletion src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <functional>
#include <memory>

namespace py = pybind11;
Expand Down Expand Up @@ -104,6 +105,19 @@ class HIDDEN LoggerWrapperFactory : public LoggerFactory, public CaptivePythonOb
}
};

using MessageRouterFunc = std::function<int(const Message&, int)>;
class HIDDEN MessageRouter : public pulsar::MessageRoutingPolicy {
public:
explicit MessageRouter(MessageRouterFunc func) : func_(std::move(func)) {}

int getPartition(const Message& msg, const TopicMetadata& topicMetadata) final {
return func_(msg, topicMetadata.getNumPartitions());
}

private:
MessageRouterFunc func_;
};

static ClientConfiguration& ClientConfiguration_setLogger(ClientConfiguration& conf, py::object logger) {
conf.setLogger(new LoggerWrapperFactory(logger));
return conf;
Expand Down Expand Up @@ -235,7 +249,10 @@ void export_config(py::module_& m) {
.def("encryption_key", &ProducerConfiguration::addEncryptionKey, return_value_policy::reference)
.def("crypto_key_reader", &ProducerConfiguration::setCryptoKeyReader, return_value_policy::reference)
.def("access_mode", &ProducerConfiguration::setAccessMode, return_value_policy::reference)
.def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy);
.def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy)
.def("message_router", [](ProducerConfiguration& config, MessageRouterFunc func) {
config.setMessageRouter(std::make_shared<MessageRouter>(std::move(func)));
});

class_<BatchReceivePolicy>(m, "BatchReceivePolicy")
.def(init<int, int, long>())
Expand Down
32 changes: 32 additions & 0 deletions tests/pulsar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,5 +2019,37 @@ def test_deserialize_msg_id_with_topic(self):
self.assertEqual(msg.value(), b'msg-3')
client.close()

def test_message_router(self):
topic_name = "public/default/test_message_router" + str(time.time())
url1 = self.adminUrl + "/admin/v2/persistent/" + topic_name + "/partitions"
doHttpPut(url1, "5")
client = Client(self.serviceUrl)
def router(msg: pulsar.Message, num_partitions: int):
s = msg.value().decode('utf-8')
if s.startswith("hello-"):
return 10 % num_partitions
else:
return 11 % num_partitions
producer = client.create_producer(topic_name, message_router=router)
producer.send(b"hello-0")
producer.send(b"hello-1")
producer.send(b"world-0")
producer.send(b"world-1")
consumer = client.subscribe(topic_name, 'sub',
initial_position=InitialPosition.Earliest)
partition_to_values = dict()
for _ in range(4):
msg = consumer.receive(TM)
partition = msg.message_id().partition()
if partition in partition_to_values:
partition_to_values[partition].append(msg.value().decode('utf-8'))
else:
partition_to_values[partition] = [msg.value().decode('utf-8')]
self.assertEqual(partition_to_values[0], ["hello-0", "hello-1"])
self.assertEqual(partition_to_values[1], ["world-0", "world-1"])

client.close()


if __name__ == "__main__":
main()