Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ void disconnect(boolean error, MqttMessage disconnect) {
if (disconnect != null && disconnect.variableHeader() instanceof MqttReasonCodeAndPropertiesVariableHeader) {
Integer sessionExpiryInterval = MQTTUtil.getProperty(Integer.class, ((MqttReasonCodeAndPropertiesVariableHeader)disconnect.variableHeader()).properties(), SESSION_EXPIRY_INTERVAL, null);
if (sessionExpiryInterval != null) {
session.getState().setClientSessionExpiryInterval(sessionExpiryInterval);
try {
session.getState().setClientSessionExpiryInterval(sessionExpiryInterval);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
session.getConnectionManager().disconnect(error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ public MQTTPublishManager(MQTTSession session, boolean closeMqttConnectionOnPubl
this.closeMqttConnectionOnPublishAuthorizationFailure = closeMqttConnectionOnPublishAuthorizationFailure;
}

public static SimpleString getQoS2ManagementAddressName(SimpleString clientId) {
return SimpleString.of(MQTTUtil.QOS2_MANAGEMENT_QUEUE_PREFIX + clientId);
}

synchronized void start() {
this.state = session.getState();
this.outboundStore = state.getOutboundStore();
Expand Down Expand Up @@ -315,7 +319,7 @@ void handlePubRec(int messageId) throws Exception {
*/
private void initQos2Resources() throws Exception {
if (qos2ManagementAddress == null) {
qos2ManagementAddress = SimpleString.of(MQTTUtil.QOS2_MANAGEMENT_QUEUE_PREFIX + session.getState().getClientId());
qos2ManagementAddress = MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(session.getState().getClientId()));
}
if (qos2ManagementQueue == null) {
qos2ManagementQueue = session.getServer().createQueue(QueueConfiguration.of(qos2ManagementAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public MQTTSessionState(String clientId) {
* <li>byte: version
* <li>int: subscription count
* </ul>
* There may be 0 or more subscriptions. The subscription format is as follows.
* There may be 0 or more subscriptions. The subscription format is as follows.
* <ul>
* <li>String: topic name
* <li>int: QoS
Expand All @@ -116,6 +116,10 @@ public MQTTSessionState(String clientId) {
* <li>int: retain handling
* <li>int (nullable): subscription identifier
* </ul>
* After the subscriptions there is:
* <ul>
* <li>int (nullable): session expiry interval
* </ul>
*
* @param message the message holding the MQTT session data
*/
Expand All @@ -139,8 +143,17 @@ public MQTTSessionState(CoreMessage message) {

subscriptions.put(topicName, new SubscriptionItem(new MqttTopicSubscription(topicName, new MqttSubscriptionOption(qos, nolocal, retainAsPublished, retainedHandlingPolicy)), subscriptionId));
}
if (buf.readable()) {
clientSessionExpiryInterval = buf.readNullableInt();
} else {
// this is for old records where we don't know the session expiry interval and can't risk removing a subscription illegitimately
clientSessionExpiryInterval = session.getProtocolManager().getDefaultMqttSessionExpiryInterval();
}
disconnectedTime = System.currentTimeMillis();
}

// TODO: create a way to send a message in the old format in order to test upgrade functionality

public MQTTSession getSession() {
return session;
}
Expand Down Expand Up @@ -266,7 +279,10 @@ public int getClientSessionExpiryInterval() {
return clientSessionExpiryInterval;
}

public void setClientSessionExpiryInterval(int sessionExpiryInterval) {
public void setClientSessionExpiryInterval(int sessionExpiryInterval) throws Exception {
if (session != null && (sessionExpiryInterval != 0 || this.clientSessionExpiryInterval != 0)) {
session.getStateManager().storeDurableSessionState(this);
}
this.clientSessionExpiryInterval = sessionExpiryInterval;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.activemq.artemis.api.core.Message;
import org.apache.activemq.artemis.api.core.QueueConfiguration;
import org.apache.activemq.artemis.api.core.RoutingType;
import org.apache.activemq.artemis.api.core.SimpleString;
import org.apache.activemq.artemis.core.filter.impl.FilterImpl;
import org.apache.activemq.artemis.core.message.impl.CoreMessage;
import org.apache.activemq.artemis.core.persistence.StorageManager;
Expand Down Expand Up @@ -112,7 +113,7 @@ public void scanSessions() {
MQTTSessionState state = entry.getValue();
logger.debug("Inspecting session: {}", state);
int sessionExpiryInterval = state.getClientSessionExpiryInterval();
if (!state.isAttached() && sessionExpiryInterval > 0 && state.getDisconnectedTime() + (sessionExpiryInterval * 1000) < System.currentTimeMillis()) {
if (!state.isAttached() && (sessionExpiryInterval == 0 || (sessionExpiryInterval > 0 && state.getDisconnectedTime() + (sessionExpiryInterval * 1000) < System.currentTimeMillis()))) {
toRemove.add(entry.getKey());
}
if (state.isWill() && !state.isAttached() && state.isFailed() && state.getWillDelayInterval() > 0 && state.getDisconnectedTime() + (state.getWillDelayInterval() * 1000) < System.currentTimeMillis()) {
Expand All @@ -121,13 +122,26 @@ public void scanSessions() {
}

for (String key : toRemove) {
logger.info("Removing expired session: {}", key);
try {
MQTTSessionState state = removeSessionState(key);
if (state != null) {
if (state.isWill() && !state.isAttached() && state.isFailed()) {
state.getSession().sendWillMessage();
}
state.getSession().clean(false);
MQTTSession session = state.getSession();
if (session != null) {
session.clean(false);
} else {
// if the in-memory session doesn't exist, then we need to ensure that any other state is cleaned up
for (MqttTopicSubscription mqttTopicSubscription : state.getSubscriptions()) {
MQTTSubscriptionManager.cleanSubscriptionQueue(mqttTopicSubscription.topicFilter(), state.getClientId(), server, (q) -> server.destroyQueue(q, null, true, false, true));
}
Queue qos2ManagementQueue = server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(state.getClientId())));
if (qos2ManagementQueue != null) {
qos2ManagementQueue.deleteQueue();
}
}
}
} catch (Exception e) {
MQTTLogger.LOGGER.failedToRemoveSessionState(key, e);
Expand All @@ -154,15 +168,15 @@ public MQTTSessionState removeSessionState(String clientId) throws Exception {
}
MQTTSessionState removed = sessionStates.remove(clientId);
if (removed != null && removed.getSubscriptions().size() > 0) {
removeDurableSubscriptionState(clientId);
removeDurableSessionState(clientId);
}
return removed;
}

public void removeDurableSubscriptionState(String clientId) throws Exception {
public void removeDurableSessionState(String clientId) throws Exception {
if (subscriptionPersistenceEnabled) {
int deletedCount = sessionStore.deleteMatchingReferences(FilterImpl.createFilter(new StringBuilder(Message.HDR_LAST_VALUE_NAME).append(" = '").append(clientId).append("'").toString()));
logger.debug("Removed {} durable MQTT subscription record(s) for: {}", deletedCount, clientId);
logger.debug("Removed {} durable MQTT session record(s) for: {}", deletedCount, clientId);
}
}

Expand All @@ -175,14 +189,22 @@ public String toString() {
return "MQTTSessionStateManager@" + Integer.toHexString(System.identityHashCode(this));
}

public void storeDurableSubscriptionState(MQTTSessionState state) throws Exception {
public void storeDurableSessionState(MQTTSessionState state) throws Exception {
if (subscriptionPersistenceEnabled) {
logger.debug("Adding durable MQTT subscription record for: {}", state.getClientId());
logger.debug("Adding durable MQTT session record for: {}", state.getClientId());
StorageManager storageManager = server.getStorageManager();
MQTTUtil.sendMessageDirectlyToQueue(storageManager, server.getPostOffice(), serializeState(state, storageManager.generateID()), sessionStore, null);
}
}

public long getDurableSessionStateCount() {
if (subscriptionPersistenceEnabled) {
return sessionStore.getMessageCount();
} else {
return 0;
}
}

public static CoreMessage serializeState(MQTTSessionState state, long messageID) {
CoreMessage message = new CoreMessage().initBuffer(50).setMessageID(messageID);
message.setAddress(MQTTUtil.MQTT_SESSION_STORE);
Expand All @@ -209,6 +231,8 @@ public static CoreMessage serializeState(MQTTSessionState state, long messageID)
buf.writeNullableInt(item.getId());
}

buf.writeNullableInt(state.getClientSessionExpiryInterval());

return message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.activemq.artemis.api.core.RoutingType;
import org.apache.activemq.artemis.api.core.SimpleString;
import org.apache.activemq.artemis.core.server.ActiveMQMessageBundle;
import org.apache.activemq.artemis.core.server.ActiveMQServer;
import org.apache.activemq.artemis.core.server.BindingQueryResult;
import org.apache.activemq.artemis.core.server.Queue;
import org.apache.activemq.artemis.core.server.ServerConsumer;
Expand Down Expand Up @@ -265,15 +266,7 @@ short[] removeSubscriptions(List<String> topics, boolean enforceSecurity) throws
consumerQoSLevels.remove(removed.getID());
}

SimpleString internalQueueName = SimpleString.of(MQTTUtil.getCoreQueueFromMqttTopic(topics.get(i), state.getClientId(), session.getServer().getConfiguration().getWildcardConfiguration()));
Queue queue = session.getServer().locateQueue(internalQueueName);
if (queue != null) {
if (queue.isConfigurationManaged()) {
queue.deleteAllReferences();
} else if (!MQTTUtil.isSharedSubscription(topics.get(i)) || (MQTTUtil.isSharedSubscription(topics.get(i)) && queue.getConsumerCount() == 0)) {
session.getServerSession().deleteQueue(internalQueueName, enforceSecurity);
}
}
cleanSubscriptionQueue(topics.get(i), state.getClientId(), session.getServer(), (q) -> session.getServerSession().deleteQueue(q, enforceSecurity));
} catch (Exception e) {
MQTTLogger.LOGGER.errorRemovingSubscription(e);
reasonCode = MQTTReasonCodes.UNSPECIFIED_ERROR;
Expand All @@ -285,16 +278,28 @@ short[] removeSubscriptions(List<String> topics, boolean enforceSecurity) throws
// deal with durable state after *all* requested subscriptions have been removed in memory
if (state.getSubscriptions().size() > 0) {
// if there are some subscriptions left then update the state
stateManager.storeDurableSubscriptionState(state);
stateManager.storeDurableSessionState(state);
} else {
// if there are no subscriptions left then remove the state entirely
stateManager.removeDurableSubscriptionState(state.getClientId());
stateManager.removeDurableSessionState(state.getClientId());
}
}

return reasonCodes;
}

public static void cleanSubscriptionQueue(String topic, String clientId, ActiveMQServer server, SubscriptionQueueDeleter<SimpleString> deleter) throws Exception {
SimpleString internalQueueName = SimpleString.of(MQTTUtil.getCoreQueueFromMqttTopic(topic, clientId, server.getConfiguration().getWildcardConfiguration()));
Queue queue = server.locateQueue(internalQueueName);
if (queue != null) {
if (queue.isConfigurationManaged()) {
queue.deleteAllReferences();
} else if (!MQTTUtil.isSharedSubscription(topic) || (MQTTUtil.isSharedSubscription(topic) && queue.getConsumerCount() == 0)) {
deleter.delete(internalQueueName);
}
}
}

/**
* As per MQTT Spec. Subscribes this client to a number of MQTT topics.
*
Expand Down Expand Up @@ -338,7 +343,7 @@ int[] addSubscriptions(List<MqttTopicSubscription> subscriptions, Integer subscr
}

// store state after *all* requested subscriptions have been created in memory
stateManager.storeDurableSubscriptionState(state);
stateManager.storeDurableSessionState(state);

return qos;
}
Expand All @@ -355,4 +360,9 @@ void clean(boolean enforceSecurity) throws Exception {
}
removeSubscriptions(topics, enforceSecurity);
}

@FunctionalInterface
public interface SubscriptionQueueDeleter<SimpleString> {
void delete(SimpleString q) throws Exception;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public class StateSerDeTest {
@Timeout(30)
public void testSerDe() throws Exception {
for (int i = 0; i < 500; i++) {
String clientId = RandomUtil.randomUUIDString();
MQTTSessionState unserialized = new MQTTSessionState(clientId);
MQTTSessionState unserialized = new MQTTSessionState(RandomUtil.randomUUIDString());
Integer subscriptionIdentifier = RandomUtil.randomPositiveIntOrNull();
for (int j = 0; j < RandomUtil.randomInterval(1, 50); j++) {
MqttTopicSubscription sub = new MqttTopicSubscription(RandomUtil.randomUUIDString(),
Expand All @@ -47,6 +46,8 @@ public void testSerDe() throws Exception {
unserialized.addSubscription(sub, MQTTUtil.MQTT_WILDCARD, subscriptionIdentifier);
}

unserialized.setClientSessionExpiryInterval(RandomUtil.randomInt());

CoreMessage serializedState = MQTTStateManager.serializeState(unserialized, 0);
MQTTSessionState deserialized = new MQTTSessionState(serializedState);

Expand All @@ -61,6 +62,7 @@ public void testSerDe() throws Exception {
assertTrue(compareSubs(unserializedSub, deserializedSub));
assertEquals(unserializedSubId, deserializedSubId);
}
assertEquals(unserialized.getClientSessionExpiryInterval(), deserialized.getClientSessionExpiryInterval());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.activemq.artemis.core.postoffice.impl.PostOfficeTestAccessor;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTInterceptor;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTProtocolManager;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTPublishManager;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTReasonCodes;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSessionAccessor;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSessionState;
Expand Down Expand Up @@ -364,20 +365,53 @@ public void testWillFlagFalseWithSessionExpiryDelay() throws Exception {

@Test
@Timeout(DEFAULT_TIMEOUT_SEC)
public void testQueueCleanOnRestart() throws Exception {
public void testResourceCleanUpOnRestartWithNonZeroSessionExpiryInterval() throws Exception {
testResourceCleanUpOnRestartWithSessionExpiryInterval(2);
}

@Test
@Timeout(DEFAULT_TIMEOUT_SEC)
public void testResourceCleanUpOnRestartWithZeroSessionExpiryInterval() throws Exception {
testResourceCleanUpOnRestartWithSessionExpiryInterval(0);
}

private void testResourceCleanUpOnRestartWithSessionExpiryInterval(long sessionExpiryInterval) throws Exception {
String topic = RandomUtil.randomUUIDString();
String clientId = RandomUtil.randomUUIDString();
CountDownLatch latch = new CountDownLatch(1);

MqttClient client = createPahoClient(clientId);
client.setCallback(new LatchedMqttCallback(latch));
MqttConnectionOptions options = new MqttConnectionOptionsBuilder()
.sessionExpiryInterval(999L)
.sessionExpiryInterval(sessionExpiryInterval)
.cleanStart(true)
.build();
client.connect(options);
client.subscribe(topic, AT_LEAST_ONCE);
client.subscribe(topic, EXACTLY_ONCE);
client.publish(topic, new byte[0], EXACTLY_ONCE, true);
assertTrue(latch.await(2, TimeUnit.SECONDS));
assertNotNull(server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(clientId))));
assertEquals(1, getProtocolManager().getStateManager().getDurableSessionStateCount());
server.stop();
try {
client.disconnect();
} catch (MqttException e) {
// ignore
}
client.close();
server.start();
org.apache.activemq.artemis.tests.util.Wait.assertTrue(() -> getSubscriptionQueue(topic, clientId) != null, 3000, 10);
scanSessions();
if (sessionExpiryInterval > 0) {
assertNotNull(getSubscriptionQueue(topic, clientId));
Wait.assertNull(() -> {
scanSessions();
return getSubscriptionQueue(topic, clientId);
}, sessionExpiryInterval * 2 * 1000, 25);
} else {
assertNull(getSubscriptionQueue(topic, clientId));
}
assertNull(server.locateQueue(MQTTPublishManager.getQoS2ManagementAddressName(SimpleString.of(clientId))));
assertEquals(0, getProtocolManager().getStateManager().getDurableSessionStateCount());
}

@Test
Expand Down