diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java index 071df5c5fb5..f28329bcc76 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java @@ -276,7 +276,11 @@ void disconnect(boolean error, MqttMessage disconnect) { if (disconnect != null && disconnect.variableHeader() instanceof MqttReasonCodeAndPropertiesVariableHeader) { Integer sessionExpiryInterval = MQTTUtil.getProperty(Integer.class, ((MqttReasonCodeAndPropertiesVariableHeader)disconnect.variableHeader()).properties(), SESSION_EXPIRY_INTERVAL, null); if (sessionExpiryInterval != null) { - session.getState().setClientSessionExpiryInterval(sessionExpiryInterval); + try { + session.getState().setClientSessionExpiryInterval(sessionExpiryInterval); + } catch (Exception e) { + throw new RuntimeException(e); + } } } session.getConnectionManager().disconnect(error); diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java index 4f5fec62a56..33dfbd8807e 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java @@ -102,6 +102,10 @@ public MQTTPublishManager(MQTTSession session, boolean closeMqttConnectionOnPubl this.closeMqttConnectionOnPublishAuthorizationFailure = closeMqttConnectionOnPublishAuthorizationFailure; } + public static SimpleString getQoS2ManagementAddressName(SimpleString clientId) { + return SimpleString.of(MQTTUtil.QOS2_MANAGEMENT_QUEUE_PREFIX + clientId); + } + synchronized void start() { this.state = session.getState(); this.outboundStore = state.getOutboundStore(); @@ -315,7 +319,7 @@ void handlePubRec(int messageId) throws Exception { */ private void initQos2Resources() throws Exception { if (qos2ManagementAddress == null) { - qos2ManagementAddress = SimpleString.of(MQTTUtil.QOS2_MANAGEMENT_QUEUE_PREFIX + session.getState().getClientId()); + qos2ManagementAddress = MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(session.getState().getClientId())); } if (qos2ManagementQueue == null) { qos2ManagementQueue = session.getServer().createQueue(QueueConfiguration.of(qos2ManagementAddress) diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java index 9ec877b51f9..a089d26b1e3 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java @@ -107,7 +107,7 @@ public MQTTSessionState(String clientId) { *
  • byte: version *
  • int: subscription count * - * There may be 0 or more subscriptions. The subscription format is as follows. + * There may be 0 or more subscriptions. The subscription format is as follows. * + * After the subscriptions there is: + * * * @param message the message holding the MQTT session data */ @@ -139,8 +143,17 @@ public MQTTSessionState(CoreMessage message) { subscriptions.put(topicName, new SubscriptionItem(new MqttTopicSubscription(topicName, new MqttSubscriptionOption(qos, nolocal, retainAsPublished, retainedHandlingPolicy)), subscriptionId)); } + if (buf.readable()) { + clientSessionExpiryInterval = buf.readNullableInt(); + } else { + // this is for old records where we don't know the session expiry interval and can't risk removing a subscription illegitimately + clientSessionExpiryInterval = session.getProtocolManager().getDefaultMqttSessionExpiryInterval(); + } + disconnectedTime = System.currentTimeMillis(); } + // TODO: create a way to send a message in the old format in order to test upgrade functionality + public MQTTSession getSession() { return session; } @@ -266,7 +279,10 @@ public int getClientSessionExpiryInterval() { return clientSessionExpiryInterval; } - public void setClientSessionExpiryInterval(int sessionExpiryInterval) { + public void setClientSessionExpiryInterval(int sessionExpiryInterval) throws Exception { + if (session != null && (sessionExpiryInterval != 0 || this.clientSessionExpiryInterval != 0)) { + session.getStateManager().storeDurableSessionState(this); + } this.clientSessionExpiryInterval = sessionExpiryInterval; } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java index 700d6f49782..55b50a25f35 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java @@ -31,6 +31,7 @@ import org.apache.activemq.artemis.api.core.Message; import org.apache.activemq.artemis.api.core.QueueConfiguration; import org.apache.activemq.artemis.api.core.RoutingType; +import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.filter.impl.FilterImpl; import org.apache.activemq.artemis.core.message.impl.CoreMessage; import org.apache.activemq.artemis.core.persistence.StorageManager; @@ -112,7 +113,7 @@ public void scanSessions() { MQTTSessionState state = entry.getValue(); logger.debug("Inspecting session: {}", state); int sessionExpiryInterval = state.getClientSessionExpiryInterval(); - if (!state.isAttached() && sessionExpiryInterval > 0 && state.getDisconnectedTime() + (sessionExpiryInterval * 1000) < System.currentTimeMillis()) { + if (!state.isAttached() && (sessionExpiryInterval == 0 || (sessionExpiryInterval > 0 && state.getDisconnectedTime() + (sessionExpiryInterval * 1000) < System.currentTimeMillis()))) { toRemove.add(entry.getKey()); } if (state.isWill() && !state.isAttached() && state.isFailed() && state.getWillDelayInterval() > 0 && state.getDisconnectedTime() + (state.getWillDelayInterval() * 1000) < System.currentTimeMillis()) { @@ -121,13 +122,26 @@ public void scanSessions() { } for (String key : toRemove) { + logger.info("Removing expired session: {}", key); try { MQTTSessionState state = removeSessionState(key); if (state != null) { if (state.isWill() && !state.isAttached() && state.isFailed()) { state.getSession().sendWillMessage(); } - state.getSession().clean(false); + MQTTSession session = state.getSession(); + if (session != null) { + session.clean(false); + } else { + // if the in-memory session doesn't exist, then we need to ensure that any other state is cleaned up + for (MqttTopicSubscription mqttTopicSubscription : state.getSubscriptions()) { + MQTTSubscriptionManager.cleanSubscriptionQueue(mqttTopicSubscription.topicFilter(), state.getClientId(), server, (q) -> server.destroyQueue(q, null, true, false, true)); + } + Queue qos2ManagementQueue = server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(state.getClientId()))); + if (qos2ManagementQueue != null) { + qos2ManagementQueue.deleteQueue(); + } + } } } catch (Exception e) { MQTTLogger.LOGGER.failedToRemoveSessionState(key, e); @@ -154,15 +168,15 @@ public MQTTSessionState removeSessionState(String clientId) throws Exception { } MQTTSessionState removed = sessionStates.remove(clientId); if (removed != null && removed.getSubscriptions().size() > 0) { - removeDurableSubscriptionState(clientId); + removeDurableSessionState(clientId); } return removed; } - public void removeDurableSubscriptionState(String clientId) throws Exception { + public void removeDurableSessionState(String clientId) throws Exception { if (subscriptionPersistenceEnabled) { int deletedCount = sessionStore.deleteMatchingReferences(FilterImpl.createFilter(new StringBuilder(Message.HDR_LAST_VALUE_NAME).append(" = '").append(clientId).append("'").toString())); - logger.debug("Removed {} durable MQTT subscription record(s) for: {}", deletedCount, clientId); + logger.debug("Removed {} durable MQTT session record(s) for: {}", deletedCount, clientId); } } @@ -175,14 +189,22 @@ public String toString() { return "MQTTSessionStateManager@" + Integer.toHexString(System.identityHashCode(this)); } - public void storeDurableSubscriptionState(MQTTSessionState state) throws Exception { + public void storeDurableSessionState(MQTTSessionState state) throws Exception { if (subscriptionPersistenceEnabled) { - logger.debug("Adding durable MQTT subscription record for: {}", state.getClientId()); + logger.debug("Adding durable MQTT session record for: {}", state.getClientId()); StorageManager storageManager = server.getStorageManager(); MQTTUtil.sendMessageDirectlyToQueue(storageManager, server.getPostOffice(), serializeState(state, storageManager.generateID()), sessionStore, null); } } + public long getDurableSessionStateCount() { + if (subscriptionPersistenceEnabled) { + return sessionStore.getMessageCount(); + } else { + return 0; + } + } + public static CoreMessage serializeState(MQTTSessionState state, long messageID) { CoreMessage message = new CoreMessage().initBuffer(50).setMessageID(messageID); message.setAddress(MQTTUtil.MQTT_SESSION_STORE); @@ -209,6 +231,8 @@ public static CoreMessage serializeState(MQTTSessionState state, long messageID) buf.writeNullableInt(item.getId()); } + buf.writeNullableInt(state.getClientSessionExpiryInterval()); + return message; } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java index 16a188d47bd..6b133dfd9e5 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java @@ -33,6 +33,7 @@ import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.server.ActiveMQMessageBundle; +import org.apache.activemq.artemis.core.server.ActiveMQServer; import org.apache.activemq.artemis.core.server.BindingQueryResult; import org.apache.activemq.artemis.core.server.Queue; import org.apache.activemq.artemis.core.server.ServerConsumer; @@ -265,15 +266,7 @@ short[] removeSubscriptions(List topics, boolean enforceSecurity) throws consumerQoSLevels.remove(removed.getID()); } - SimpleString internalQueueName = SimpleString.of(MQTTUtil.getCoreQueueFromMqttTopic(topics.get(i), state.getClientId(), session.getServer().getConfiguration().getWildcardConfiguration())); - Queue queue = session.getServer().locateQueue(internalQueueName); - if (queue != null) { - if (queue.isConfigurationManaged()) { - queue.deleteAllReferences(); - } else if (!MQTTUtil.isSharedSubscription(topics.get(i)) || (MQTTUtil.isSharedSubscription(topics.get(i)) && queue.getConsumerCount() == 0)) { - session.getServerSession().deleteQueue(internalQueueName, enforceSecurity); - } - } + cleanSubscriptionQueue(topics.get(i), state.getClientId(), session.getServer(), (q) -> session.getServerSession().deleteQueue(q, enforceSecurity)); } catch (Exception e) { MQTTLogger.LOGGER.errorRemovingSubscription(e); reasonCode = MQTTReasonCodes.UNSPECIFIED_ERROR; @@ -285,16 +278,28 @@ short[] removeSubscriptions(List topics, boolean enforceSecurity) throws // deal with durable state after *all* requested subscriptions have been removed in memory if (state.getSubscriptions().size() > 0) { // if there are some subscriptions left then update the state - stateManager.storeDurableSubscriptionState(state); + stateManager.storeDurableSessionState(state); } else { // if there are no subscriptions left then remove the state entirely - stateManager.removeDurableSubscriptionState(state.getClientId()); + stateManager.removeDurableSessionState(state.getClientId()); } } return reasonCodes; } + public static void cleanSubscriptionQueue(String topic, String clientId, ActiveMQServer server, SubscriptionQueueDeleter deleter) throws Exception { + SimpleString internalQueueName = SimpleString.of(MQTTUtil.getCoreQueueFromMqttTopic(topic, clientId, server.getConfiguration().getWildcardConfiguration())); + Queue queue = server.locateQueue(internalQueueName); + if (queue != null) { + if (queue.isConfigurationManaged()) { + queue.deleteAllReferences(); + } else if (!MQTTUtil.isSharedSubscription(topic) || (MQTTUtil.isSharedSubscription(topic) && queue.getConsumerCount() == 0)) { + deleter.delete(internalQueueName); + } + } + } + /** * As per MQTT Spec. Subscribes this client to a number of MQTT topics. * @@ -338,7 +343,7 @@ int[] addSubscriptions(List subscriptions, Integer subscr } // store state after *all* requested subscriptions have been created in memory - stateManager.storeDurableSubscriptionState(state); + stateManager.storeDurableSessionState(state); return qos; } @@ -355,4 +360,9 @@ void clean(boolean enforceSecurity) throws Exception { } removeSubscriptions(topics, enforceSecurity); } + + @FunctionalInterface + public interface SubscriptionQueueDeleter { + void delete(SimpleString q) throws Exception; + } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/StateSerDeTest.java b/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/StateSerDeTest.java index 1d5630e99a2..7b4c7763749 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/StateSerDeTest.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/StateSerDeTest.java @@ -35,8 +35,7 @@ public class StateSerDeTest { @Timeout(30) public void testSerDe() throws Exception { for (int i = 0; i < 500; i++) { - String clientId = RandomUtil.randomUUIDString(); - MQTTSessionState unserialized = new MQTTSessionState(clientId); + MQTTSessionState unserialized = new MQTTSessionState(RandomUtil.randomUUIDString()); Integer subscriptionIdentifier = RandomUtil.randomPositiveIntOrNull(); for (int j = 0; j < RandomUtil.randomInterval(1, 50); j++) { MqttTopicSubscription sub = new MqttTopicSubscription(RandomUtil.randomUUIDString(), @@ -47,6 +46,8 @@ public void testSerDe() throws Exception { unserialized.addSubscription(sub, MQTTUtil.MQTT_WILDCARD, subscriptionIdentifier); } + unserialized.setClientSessionExpiryInterval(RandomUtil.randomInt()); + CoreMessage serializedState = MQTTStateManager.serializeState(unserialized, 0); MQTTSessionState deserialized = new MQTTSessionState(serializedState); @@ -61,6 +62,7 @@ public void testSerDe() throws Exception { assertTrue(compareSubs(unserializedSub, deserializedSub)); assertEquals(unserializedSubId, deserializedSubId); } + assertEquals(unserialized.getClientSessionExpiryInterval(), deserialized.getClientSessionExpiryInterval()); } } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java index 83e70929578..3afcbd17c58 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java @@ -41,6 +41,7 @@ import org.apache.activemq.artemis.core.postoffice.impl.PostOfficeTestAccessor; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTInterceptor; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTProtocolManager; +import org.apache.activemq.artemis.core.protocol.mqtt.MQTTPublishManager; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTReasonCodes; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSessionAccessor; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSessionState; @@ -364,20 +365,53 @@ public void testWillFlagFalseWithSessionExpiryDelay() throws Exception { @Test @Timeout(DEFAULT_TIMEOUT_SEC) - public void testQueueCleanOnRestart() throws Exception { + public void testResourceCleanUpOnRestartWithNonZeroSessionExpiryInterval() throws Exception { + testResourceCleanUpOnRestartWithSessionExpiryInterval(2); + } + + @Test + @Timeout(DEFAULT_TIMEOUT_SEC) + public void testResourceCleanUpOnRestartWithZeroSessionExpiryInterval() throws Exception { + testResourceCleanUpOnRestartWithSessionExpiryInterval(0); + } + + private void testResourceCleanUpOnRestartWithSessionExpiryInterval(long sessionExpiryInterval) throws Exception { String topic = RandomUtil.randomUUIDString(); String clientId = RandomUtil.randomUUIDString(); + CountDownLatch latch = new CountDownLatch(1); MqttClient client = createPahoClient(clientId); + client.setCallback(new LatchedMqttCallback(latch)); MqttConnectionOptions options = new MqttConnectionOptionsBuilder() - .sessionExpiryInterval(999L) + .sessionExpiryInterval(sessionExpiryInterval) .cleanStart(true) .build(); client.connect(options); - client.subscribe(topic, AT_LEAST_ONCE); + client.subscribe(topic, EXACTLY_ONCE); + client.publish(topic, new byte[0], EXACTLY_ONCE, true); + assertTrue(latch.await(2, TimeUnit.SECONDS)); + assertNotNull(server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(clientId)))); + assertEquals(1, getProtocolManager().getStateManager().getDurableSessionStateCount()); server.stop(); + try { + client.disconnect(); + } catch (MqttException e) { + // ignore + } + client.close(); server.start(); - org.apache.activemq.artemis.tests.util.Wait.assertTrue(() -> getSubscriptionQueue(topic, clientId) != null, 3000, 10); + scanSessions(); + if (sessionExpiryInterval > 0) { + assertNotNull(getSubscriptionQueue(topic, clientId)); + Wait.assertNull(() -> { + scanSessions(); + return getSubscriptionQueue(topic, clientId); + }, sessionExpiryInterval * 2 * 1000, 25); + } else { + assertNull(getSubscriptionQueue(topic, clientId)); + } + assertNull(server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(clientId)))); + assertEquals(0, getProtocolManager().getStateManager().getDurableSessionStateCount()); } @Test