diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImpl.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImpl.java index bee90b17fd9c..034ee74892f2 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImpl.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImpl.java @@ -492,7 +492,7 @@ ShareFetchCollector build( public Set subscription() { acquireAndEnsureOpen(); try { - return subscriptions.subscription(); + return Collections.unmodifiableSet(subscriptions.subscription()); } finally { release(); } @@ -594,7 +594,6 @@ public synchronized ConsumerRecords poll(final Duration timeout) { return ConsumerRecords.empty(); } finally { kafkaShareConsumerMetrics.recordPollEnd(timer.currentTimeMs()); - wakeupTrigger.clearTask(); release(); } } @@ -612,6 +611,8 @@ private ShareFetch pollForFetches(final Timer timer) { // Wait a bit - this is where we will fetch records Timer pollTimer = time.timer(pollTimeout); + wakeupTrigger.setShareFetchAction(fetchBuffer); + try { fetchBuffer.awaitNotEmpty(pollTimer); } catch (InterruptException e) { @@ -619,6 +620,7 @@ private ShareFetch pollForFetches(final Timer timer) { throw e; } finally { timer.update(pollTimer.currentTimeMs()); + wakeupTrigger.clearTask(); } return collect(Collections.emptyMap()); diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/WakeupTrigger.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/WakeupTrigger.java index 4c70797dbf84..7893cf29f23b 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/WakeupTrigger.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/WakeupTrigger.java @@ -53,6 +53,10 @@ public void wakeup() { FetchAction fetchAction = (FetchAction) task; fetchAction.fetchBuffer().wakeup(); return new WakeupFuture(); + } else if (task instanceof ShareFetchAction) { + ShareFetchAction shareFetchAction = (ShareFetchAction) task; + shareFetchAction.fetchBuffer().wakeup(); + return new WakeupFuture(); } else { return task; } @@ -60,11 +64,10 @@ public void wakeup() { } /** - * If there is no pending task, set the pending task active. - * If wakeup was called before setting an active task, the current task will complete exceptionally with - * WakeupException right - * away. - * if there is an active task, throw exception. + * If there is no pending task, set the pending task active. + * If wakeup was called before setting an active task, the current task will complete exceptionally with + * WakeupException right away. + * If there is an active task, throw exception. * @param currentTask * @param * @return @@ -105,6 +108,25 @@ public void setFetchAction(final FetchBuffer fetchBuffer) { } } + public void setShareFetchAction(final ShareFetchBuffer fetchBuffer) { + final AtomicBoolean throwWakeupException = new AtomicBoolean(false); + pendingTask.getAndUpdate(task -> { + if (task == null) { + return new ShareFetchAction(fetchBuffer); + } else if (task instanceof WakeupFuture) { + throwWakeupException.set(true); + return null; + } else if (task instanceof DisabledWakeups) { + return task; + } + // last active state is still active + throw new IllegalStateException("Last active task is still active"); + }); + if (throwWakeupException.get()) { + throw new WakeupException(); + } + } + public void disableWakeups() { pendingTask.set(new DisabledWakeups()); } @@ -113,7 +135,7 @@ public void clearTask() { pendingTask.getAndUpdate(task -> { if (task == null) { return null; - } else if (task instanceof ActiveFuture || task instanceof FetchAction) { + } else if (task instanceof ActiveFuture || task instanceof FetchAction || task instanceof ShareFetchAction) { return null; } return task; @@ -172,4 +194,17 @@ public FetchBuffer fetchBuffer() { return fetchBuffer; } } + + static class ShareFetchAction implements Wakeupable { + + private final ShareFetchBuffer fetchBuffer; + + public ShareFetchAction(ShareFetchBuffer fetchBuffer) { + this.fetchBuffer = fetchBuffer; + } + + public ShareFetchBuffer fetchBuffer() { + return fetchBuffer; + } + } } diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImplTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImplTest.java index fc42b3fe5bd9..b8cd5f88b1aa 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImplTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumerImplTest.java @@ -191,6 +191,45 @@ public void testWakeupBeforeCallingPoll() { assertDoesNotThrow(() -> consumer.poll(Duration.ZERO)); } + @Test + public void testWakeupAfterEmptyFetch() { + consumer = newConsumer(); + final String topicName = "foo"; + final int partition = 3; + doAnswer(invocation -> { + consumer.wakeup(); + return ShareFetch.empty(); + }).doAnswer(invocation -> ShareFetch.empty()).when(fetchCollector).collect(any(ShareFetchBuffer.class)); + + consumer.subscribe(singletonList(topicName)); + + assertThrows(WakeupException.class, () -> consumer.poll(Duration.ofMinutes(1))); + assertDoesNotThrow(() -> consumer.poll(Duration.ZERO)); + } + + @Test + public void testWakeupAfterNonEmptyFetch() { + consumer = newConsumer(); + final String topicName = "foo"; + final int partition = 3; + final TopicIdPartition tip = new TopicIdPartition(Uuid.randomUuid(), partition, topicName); + final ShareInFlightBatch batch = new ShareInFlightBatch<>(tip); + batch.addRecord(new ConsumerRecord<>(topicName, partition, 2, "key1", "value1")); + doAnswer(invocation -> { + consumer.wakeup(); + final ShareFetch fetch = ShareFetch.empty(); + fetch.add(tip, batch); + return fetch; + }).when(fetchCollector).collect(Mockito.any(ShareFetchBuffer.class)); + + consumer.subscribe(singletonList(topicName)); + + // since wakeup() is called when the non-empty fetch is returned the wakeup should be ignored + assertDoesNotThrow(() -> consumer.poll(Duration.ofMinutes(1))); + // the previously ignored wake-up should not be ignored in the next call + assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO)); + } + @Test public void testFailOnClosedConsumer() { consumer = newConsumer(); diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/WakeupTriggerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/WakeupTriggerTest.java index 5b63badba73b..518f1cc6978d 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/WakeupTriggerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/WakeupTriggerTest.java @@ -110,6 +110,19 @@ public void testWakeupFromFetchAction() { } } + @Test + public void testWakeupFromShareFetchAction() { + try (final ShareFetchBuffer fetchBuffer = mock(ShareFetchBuffer.class)) { + wakeupTrigger.setShareFetchAction(fetchBuffer); + + wakeupTrigger.wakeup(); + + verify(fetchBuffer).wakeup(); + final WakeupTrigger.Wakeupable wakeupable = wakeupTrigger.getPendingTask(); + assertInstanceOf(WakeupTrigger.WakeupFuture.class, wakeupable); + } + } + @Test public void testManualTriggerWhenWakeupCalled() { wakeupTrigger.wakeup();