diff --git a/src/include/qmdnsengine/cache.h b/src/include/qmdnsengine/cache.h index 0e6a969..6e8b983 100644 --- a/src/include/qmdnsengine/cache.h +++ b/src/include/qmdnsengine/cache.h @@ -78,8 +78,8 @@ class QMDNSENGINE_EXPORT Cache : public QObject * @param record add this record to the cache * * The TTL for the record will be added to the current time to calculate - * when the record expires. Existing records of the same name and type - * will be replaced, resetting their expiration. + * when the record expires. Call invalidateRecord() before addRecord() + * to ensure any superseded records are removed. */ void addRecord(const Record &record); @@ -105,6 +105,15 @@ class QMDNSENGINE_EXPORT Cache : public QObject */ bool lookupRecords(const QByteArray &name, quint16 type, QList &records) const; + /** + * @brief Invalidates the specified record in the cache + * @param record invalidate this record in the cache + * + * This must be called for all records in a message prior to adding + * any records to the cache to ensure the case of multiple records with + * the same type and 'flush cache' set is handled properly. + */ + void invalidateRecord(const Record &record); Q_SIGNALS: /** diff --git a/src/src/browser.cpp b/src/src/browser.cpp index 2f12012..1748b62 100644 --- a/src/src/browser.cpp +++ b/src/src/browser.cpp @@ -117,6 +117,13 @@ void BrowserPrivate::onMessageReceived(const Message &message) return; } + // Invalidate each record in the cache first. This ensures + // that we properly handle the case where we have multiple + // records of the same type with 'flush cache' set. + foreach (Record record, message.records()) { + cache->invalidateRecord(record); + } + // Use a set to track all services that are updated in the message to // prevent unnecessary queries for SRV and TXT records QSet updateNames; diff --git a/src/src/cache.cpp b/src/src/cache.cpp index 18d3128..c8d4d6a 100644 --- a/src/src/cache.cpp +++ b/src/src/cache.cpp @@ -92,28 +92,9 @@ Cache::Cache(QObject *parent) void Cache::addRecord(const Record &record) { - // If a record exists that matches, remove it from the cache; if the TTL - // is nonzero, it will be added back to the cache with updated times - for (auto i = d->entries.begin(); i != d->entries.end();) { - if ((record.flushCache() && - (*i).record.name() == record.name() && - (*i).record.type() == record.type()) || - (*i).record == record) { - - // If the TTL is set to 0, indicate that the record was removed - if (record.ttl() == 0) { - emit recordExpired((*i).record); - } - - i = d->entries.erase(i); - - // No need to continue further if the TTL was set to 0 - if (record.ttl() == 0) { - return; - } - } else { - ++i; - } + // No need to add anything if the TTL was set to 0 + if (record.ttl() == 0) { + return; } // Use the current time to calculate the triggers and add a random offset @@ -139,6 +120,27 @@ void Cache::addRecord(const Record &record) } } +void Cache::invalidateRecord(const Record &record) +{ + // If a record exists that matches, remove it from the cache + for (auto i = d->entries.begin(); i != d->entries.end();) { + if ((record.flushCache() && + (*i).record.name() == record.name() && + (*i).record.type() == record.type()) || + (*i).record == record) { + + // If the TTL is set to 0, indicate that the record was removed + if (record.ttl() == 0) { + emit recordExpired((*i).record); + } + + i = d->entries.erase(i); + } else { + ++i; + } + } +} + bool Cache::lookupRecord(const QByteArray &name, quint16 type, Record &record) const { QList records; diff --git a/src/src/resolver.cpp b/src/src/resolver.cpp index fe8ce0a..13e2b5e 100644 --- a/src/src/resolver.cpp +++ b/src/src/resolver.cpp @@ -88,6 +88,14 @@ void ResolverPrivate::onMessageReceived(const Message &message) if (!message.isResponse()) { return; } + + // Invalidate each record in the cache first. This ensures + // that we properly handle the case where we have multiple + // records of the same type with 'flush cache' set. + foreach (Record record, message.records()) { + cache->invalidateRecord(record); + } + foreach (Record record, message.records()) { if (record.name() == name && (record.type() == A || record.type() == AAAA)) { cache->addRecord(record); diff --git a/tests/TestCache.cpp b/tests/TestCache.cpp index 5f63a3c..7db53fa 100644 --- a/tests/TestCache.cpp +++ b/tests/TestCache.cpp @@ -94,6 +94,7 @@ void TestCache::testRemoval() // Purge the record from the cache by setting its TTL to 0 record.setTtl(0); + cache.invalidateRecord(record); cache.addRecord(record); // Verify that the record is gone @@ -115,6 +116,7 @@ void TestCache::testCacheFlush() // Insert a new record with the cache clear bit set QMdnsEngine::Record record = createRecord(); record.setFlushCache(true); + cache.invalidateRecord(record); cache.addRecord(record); // Confirm that only a single record exists diff --git a/tests/TestResolver.cpp b/tests/TestResolver.cpp index a396878..eec4e9b 100644 --- a/tests/TestResolver.cpp +++ b/tests/TestResolver.cpp @@ -38,6 +38,7 @@ Q_DECLARE_METATYPE(QHostAddress) const QByteArray Name = "test.localhost."; const QHostAddress Address("127.0.0.1"); +const QHostAddress Address2("127.0.0.2"); class TestResolver : public QObject { @@ -47,6 +48,7 @@ private Q_SLOTS: void initTestCase(); void testResolver(); + void testResolverCacheFlush(); }; void TestResolver::initTestCase() @@ -80,5 +82,37 @@ void TestResolver::testResolver() QCOMPARE(resolvedSpy.at(0).at(0).value(), Address); } +void TestResolver::testResolverCacheFlush() +{ + TestServer server; + QMdnsEngine::Resolver resolver(&server, Name); + QSignalSpy resolvedSpy(&resolver, SIGNAL(resolved(QHostAddress))); + + // Ensure two queries were dispatched + QTRY_VERIFY(queryReceived(&server, Name, QMdnsEngine::A)); + QVERIFY(queryReceived(&server, Name, QMdnsEngine::AAAA)); + + // Send a response with 2 flush cache records + QMdnsEngine::Message message; + message.setResponse(true); + + QMdnsEngine::Record record; + record.setName(Name); + record.setType(QMdnsEngine::A); + record.setFlushCache(true); + record.setAddress(Address); + message.addRecord(record); + + record.setAddress(Address2); + message.addRecord(record); + + server.deliverMessage(message); + + // Ensure resolved() was emitted with both addresses + QTRY_COMPARE(resolvedSpy.count(), 2); + QCOMPARE(resolvedSpy.at(0).at(0).value(), Address); + QCOMPARE(resolvedSpy.at(1).at(0).value(), Address2); +} + QTEST_MAIN(TestResolver) #include "TestResolver.moc"