diff --git a/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCall.java b/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCall.java index 325218dc3..0436bac62 100644 --- a/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCall.java +++ b/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCall.java @@ -20,19 +20,17 @@ import com.baidu.bifromq.basekv.store.proto.RWCoProcInput; import com.baidu.bifromq.basekv.store.proto.RWCoProcOutput; import com.baidu.bifromq.basescheduler.ICallTask; -import com.baidu.bifromq.retain.rpc.proto.BatchRetainRequest; -import com.baidu.bifromq.retain.rpc.proto.RetainMessage; -import com.baidu.bifromq.retain.rpc.proto.RetainParam; import com.baidu.bifromq.retain.rpc.proto.RetainReply; import com.baidu.bifromq.retain.rpc.proto.RetainRequest; import com.baidu.bifromq.retain.rpc.proto.RetainResult; import com.baidu.bifromq.retain.rpc.proto.RetainServiceRWCoProcInput; import java.time.Duration; -import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Queue; +import lombok.extern.slf4j.Slf4j; +@Slf4j public class BatchRetainCall extends BatchMutationCall { protected BatchRetainCall(KVRangeId rangeId, @@ -43,21 +41,9 @@ protected BatchRetainCall(KVRangeId rangeId, @Override protected RWCoProcInput makeBatch(Iterator retainRequestIterator) { - Map retainMsgPackBuilders = new HashMap<>(128); - retainRequestIterator.forEachRemaining(request -> - retainMsgPackBuilders.computeIfAbsent(request.getPublisher().getTenantId(), k -> RetainParam.newBuilder() - .putTopicMessages(request.getTopic(), RetainMessage.newBuilder() - .setMessage(request.getMessage().toBuilder().setIsRetained(true).build()) - .setPublisher(request.getPublisher()) - .build()))); - long reqId = System.nanoTime(); - BatchRetainRequest.Builder reqBuilder = BatchRetainRequest.newBuilder().setReqId(reqId); - retainMsgPackBuilders.forEach((tenantId, retainMsgPackBuilder) -> - reqBuilder.putParams(tenantId, retainMsgPackBuilder.build())); - return RWCoProcInput.newBuilder() .setRetainService(RetainServiceRWCoProcInput.newBuilder() - .setBatchRetain(reqBuilder.build()) + .setBatchRetain(BatchRetainCallHelper.makeBatch(retainRequestIterator)) .build()).build(); } @@ -68,12 +54,22 @@ protected void handleOutput(Queue resultMap = output.getRetainService() .getBatchRetain() - .getResultsMap() - .getOrDefault(task.call().getPublisher().getTenantId(), - RetainResult.getDefaultInstance()) - .getResultsOrDefault(task.call().getTopic(), RetainResult.Code.ERROR); + .getResultsMap(); + RetainResult topicMap = resultMap.get(task.call().getPublisher().getTenantId()); + if (topicMap == null) { + log.error("tenantId not found in result map, tenantId: {}", task.call().getPublisher().getTenantId()); + task.resultPromise().complete(replyBuilder.setResult(RetainReply.Result.ERROR).build()); + continue; + } + RetainResult.Code result = topicMap.getResultsMap().get(task.call().getTopic()); + if (result == null) { + log.error("topic not found in result map, tenantId: {}, topic: {}", + task.call().getPublisher().getTenantId(), task.call().getTopic()); + task.resultPromise().complete(replyBuilder.setResult(RetainReply.Result.ERROR).build()); + continue; + } switch (result) { case RETAINED -> replyBuilder.setResult(RetainReply.Result.RETAINED); case CLEARED -> replyBuilder.setResult(RetainReply.Result.CLEARED); diff --git a/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelper.java b/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelper.java new file mode 100644 index 000000000..3ea523f64 --- /dev/null +++ b/bifromq-retain/bifromq-retain-server/src/main/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelper.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ + +package com.baidu.bifromq.retain.server.scheduler; + +import com.baidu.bifromq.retain.rpc.proto.BatchRetainRequest; +import com.baidu.bifromq.retain.rpc.proto.RetainMessage; +import com.baidu.bifromq.retain.rpc.proto.RetainParam; +import com.baidu.bifromq.retain.rpc.proto.RetainRequest; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +class BatchRetainCallHelper { + static BatchRetainRequest makeBatch(Iterator retainRequestIterator) { + Map retainMsgPackBuilders = new HashMap<>(128); + retainRequestIterator.forEachRemaining(request -> + retainMsgPackBuilders.computeIfAbsent(request.getPublisher().getTenantId(), k -> RetainParam.newBuilder()) + .putTopicMessages(request.getTopic(), RetainMessage.newBuilder() + .setMessage(request.getMessage().toBuilder().setIsRetained(true).build()) + .setPublisher(request.getPublisher()) + .build())); + long reqId = System.nanoTime(); + BatchRetainRequest.Builder reqBuilder = BatchRetainRequest.newBuilder().setReqId(reqId); + retainMsgPackBuilders.forEach((tenantId, retainMsgPackBuilder) -> + reqBuilder.putParams(tenantId, retainMsgPackBuilder.build())); + return reqBuilder.build(); + } +} diff --git a/bifromq-retain/bifromq-retain-server/src/test/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelperTest.java b/bifromq-retain/bifromq-retain-server/src/test/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelperTest.java new file mode 100644 index 000000000..1464a8c34 --- /dev/null +++ b/bifromq-retain/bifromq-retain-server/src/test/java/com/baidu/bifromq/retain/server/scheduler/BatchRetainCallHelperTest.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ + +package com.baidu.bifromq.retain.server.scheduler; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import com.baidu.bifromq.retain.rpc.proto.BatchRetainRequest; +import com.baidu.bifromq.retain.rpc.proto.RetainMessage; +import com.baidu.bifromq.retain.rpc.proto.RetainParam; +import com.baidu.bifromq.retain.rpc.proto.RetainRequest; +import com.baidu.bifromq.type.ClientInfo; +import com.baidu.bifromq.type.Message; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class BatchRetainCallHelperTest { + + private RetainRequest retainRequest1; + private RetainRequest retainRequest11; + private RetainRequest retainRequest2; + private Iterator retainRequestIterator; + private Message message1; + private Message message11; + private Message message2; + private ClientInfo publisher1; + private ClientInfo publisher11; + private ClientInfo publisher2; + + @BeforeMethod + public void setUp() { + message1 = Message.newBuilder().build(); + message11 = Message.newBuilder().build(); + message2 = Message.newBuilder().build(); + publisher1 = ClientInfo.newBuilder().setTenantId("tenant1").build(); + publisher11 = ClientInfo.newBuilder().setTenantId("tenant1").build(); + publisher2 = ClientInfo.newBuilder().setTenantId("tenant2").build(); + retainRequest1 = RetainRequest.newBuilder() + .setTopic("topic1") + .setMessage(message1) + .setPublisher(publisher1) + .build(); + retainRequest11 = RetainRequest.newBuilder() + .setTopic("topic11") + .setMessage(message11) + .setPublisher(publisher11) + .build(); + retainRequest2 = RetainRequest.newBuilder() + .setTopic("topic2") + .setMessage(message2) + .setPublisher(publisher2) + .build(); + + // Create the iterator + retainRequestIterator = Arrays.asList(retainRequest1, retainRequest11, retainRequest2).iterator(); + } + + @Test + public void testMakeBatch() { + BatchRetainRequest result = BatchRetainCallHelper.makeBatch(retainRequestIterator); + + assertNotNull(result); + assertEquals(result.getParamsCount(), 2); + + Map params = new HashMap<>(result.getParamsMap()); + assertTrue(params.containsKey("tenant1")); + assertTrue(params.containsKey("tenant2")); + + RetainParam tenant1Param = params.get("tenant1"); + RetainParam tenant2Param = params.get("tenant2"); + + assertTrue(tenant1Param.containsTopicMessages("topic1")); + assertTrue(tenant1Param.containsTopicMessages("topic11")); + assertTrue(tenant2Param.containsTopicMessages("topic2")); + + // Verify that the message was marked as retained + RetainMessage tenant1Message = tenant1Param.getTopicMessagesOrThrow("topic1"); + RetainMessage tenant11Message = tenant1Param.getTopicMessagesOrThrow("topic11"); + RetainMessage tenant2Message = tenant2Param.getTopicMessagesOrThrow("topic2"); + + assertTrue(tenant1Message.getMessage().getIsRetained()); + assertTrue(tenant11Message.getMessage().getIsRetained()); + assertTrue(tenant2Message.getMessage().getIsRetained()); + } +} \ No newline at end of file