diff --git a/src/core/search/CMakeLists.txt b/src/core/search/CMakeLists.txt index 8f47c08e2718..a3ec5e6389fc 100644 --- a/src/core/search/CMakeLists.txt +++ b/src/core/search/CMakeLists.txt @@ -27,6 +27,7 @@ cxx_test(range_tree_test dfly_search_core absl::random_random LABELS DFLY) cxx_test(rax_tree_test redis_test_lib LABELS DFLY) cxx_test(search_parser_test dfly_search_core LABELS DFLY) cxx_test(search_test redis_test_lib dfly_search_core LABELS DFLY) +cxx_test(mrmw_mutex_test redis_test_lib dfly_search_core fibers2 LABELS DFLY) if(WITH_SIMSIMD) target_link_libraries(search_test TRDP::simsimd) diff --git a/src/core/search/hnsw_index.cc b/src/core/search/hnsw_index.cc index aea9ba6b18d7..795d18abfea9 100644 --- a/src/core/search/hnsw_index.cc +++ b/src/core/search/hnsw_index.cc @@ -5,13 +5,13 @@ #include "core/search/hnsw_index.h" #include -#include #include #include #include #include "base/logging.h" #include "core/search/hnsw_alg.h" +#include "core/search/mrmw_mutex.h" #include "core/search/vector_utils.h" namespace dfly::search { @@ -70,7 +70,8 @@ struct HnswlibAdapter { void Add(const float* data, GlobalDocId id) { while (true) { try { - absl::ReaderMutexLock lock(&resize_mutex_); + MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); + absl::ReaderMutexLock resize_lock(&resize_mutex_); world_.addPoint(data, id); return; } catch (const std::exception& e) { @@ -86,6 +87,7 @@ struct HnswlibAdapter { void Remove(GlobalDocId id) { try { + MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); world_.markDelete(id); } catch (const std::exception& e) { LOG(WARNING) << "HnswlibAdapter::Remove exception: " << e.what(); @@ -94,6 +96,7 @@ struct HnswlibAdapter { vector> Knn(float* target, size_t k, std::optional ef) { world_.setEf(ef.value_or(kDefaultEfRuntime)); + MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return QueueToVec(world_.searchKnn(target, k)); } @@ -111,6 +114,7 @@ struct HnswlibAdapter { world_.setEf(ef.value_or(kDefaultEfRuntime)); BinsearchFilter filter{&allowed}; + MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return QueueToVec(world_.searchKnn(target, k, &filter)); } @@ -153,6 +157,7 @@ struct HnswlibAdapter { HnswSpace space_; HierarchicalNSW world_; absl::Mutex resize_mutex_; + mutable MRMWMutex mrmw_mutex_; }; HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*) diff --git a/src/core/search/mrmw_mutex.h b/src/core/search/mrmw_mutex.h new file mode 100644 index 000000000000..5f144c0373a4 --- /dev/null +++ b/src/core/search/mrmw_mutex.h @@ -0,0 +1,89 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "base/logging.h" +#include "util/fibers/synchronization.h" + +namespace dfly::search { + +// Simple implementation of multi-Reader multi-Writer Mutex +// MRMWMutex supports concurrent reads or concurrent writes but not a mix of +// concurrent reads and writes at the same time. + +class MRMWMutex { + public: + enum class LockMode : uint8_t { kReadLock, kWriteLock }; + + MRMWMutex() : lock_mode_(LockMode::kReadLock) { + } + + void Lock(LockMode mode) { + std::unique_lock lk(mutex_); + + // If we have any active_runners we need to check lock mode + if (active_runners_) { + auto& waiters = GetWaiters(mode); + waiters++; + GetCondVar(mode).wait(lk, [&] { return lock_mode_ == mode; }); + waiters--; + } else { + // No active runners so just update to requested lock mode + lock_mode_ = mode; + } + active_runners_++; + } + + void Unlock(LockMode mode) { + std::unique_lock lk(mutex_); + LockMode inverse_mode = GetInverseMode(mode); + active_runners_--; + // If this was last runner and there are waiters on inverse mode + if (!active_runners_ && GetWaiters(inverse_mode) > 0) { + lock_mode_ = inverse_mode; + GetCondVar(inverse_mode).notify_all(); + } + } + + private: + inline size_t& GetWaiters(LockMode target_mode) { + return target_mode == LockMode::kReadLock ? reader_waiters_ : writer_waiters_; + }; + + inline util::fb2::CondVar& GetCondVar(LockMode target_mode) { + return target_mode == LockMode::kReadLock ? reader_cond_var_ : writer_cond_var_; + }; + + static inline LockMode GetInverseMode(LockMode mode) { + return mode == LockMode::kReadLock ? LockMode::kWriteLock : LockMode::kReadLock; + } + + util::fb2::Mutex mutex_; + util::fb2::CondVar reader_cond_var_, writer_cond_var_; + size_t writer_waiters_ = 0, reader_waiters_ = 0; + size_t active_runners_ = 0; + LockMode lock_mode_; +}; + +class MRMWMutexLock { + public: + explicit MRMWMutexLock(MRMWMutex* mutex, MRMWMutex::LockMode mode) + : mutex_(mutex), lock_mode_(mode) { + mutex->Lock(lock_mode_); + } + + ~MRMWMutexLock() { + mutex_->Unlock(lock_mode_); + } + + MRMWMutexLock(const MRMWMutexLock&) = delete; + MRMWMutexLock(MRMWMutexLock&&) = delete; + MRMWMutexLock& operator=(const MRMWMutexLock&) = delete; + MRMWMutexLock& operator=(MRMWMutexLock&&) = delete; + + private: + MRMWMutex* const mutex_; + MRMWMutex::LockMode lock_mode_; +}; + +} // namespace dfly::search diff --git a/src/core/search/mrmw_mutex_test.cc b/src/core/search/mrmw_mutex_test.cc new file mode 100644 index 000000000000..32de87f701b7 --- /dev/null +++ b/src/core/search/mrmw_mutex_test.cc @@ -0,0 +1,224 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/search/mrmw_mutex.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "base/gtest.h" +#include "base/logging.h" +#include "util/fibers/pool.h" + +ABSL_FLAG(bool, force_epoll, false, "If true, uses epoll api instead iouring to run tests"); + +namespace dfly::search { + +namespace { + +// Helper function to simulate reading operation +void ReadTask(MRMWMutex* mutex, std::atomic& read_count, size_t sleep_time) { + read_count.fetch_add(1, std::memory_order_relaxed); + MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kReadLock); + util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time)); + read_count.fetch_sub(1, std::memory_order_relaxed); +} + +// Helper function to simulate writing operation +void WriteTask(MRMWMutex* mutex, std::atomic& write_count, size_t sleep_time) { + write_count.fetch_add(1, std::memory_order_relaxed); + MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kWriteLock); + util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time)); + write_count.fetch_sub(1, std::memory_order_relaxed); +} + +constexpr size_t kReadTaskSleepTime = 50; +constexpr size_t kWriteTaskSleepTime = 100; + +} // namespace + +class MRMWMutexTest : public ::testing::Test { + protected: + MRMWMutex mutex_; + std::mt19937 generator_; + void SetUp() override { +#ifdef __linux__ + if (absl::GetFlag(FLAGS_force_epoll)) { + pp_.reset(util::fb2::Pool::Epoll(2)); + } else { + pp_.reset(util::fb2::Pool::IOUring(16, 2)); + } +#else + pp_.reset(fb2::Pool::Epoll(2)); +#endif + pp_->Run(); + } + void TearDown() override { + pp_->Stop(); + pp_.reset(); + } + std::unique_ptr pp_; +}; + +// Test 1: Multiple readers can lock concurrently +TEST_F(MRMWMutexTest, MultipleReadersConcurrently) { + std::atomic read_count(0); + const int num_readers = 5; + + std::vector readers; + readers.reserve(num_readers); + + for (int i = 0; i < num_readers; ++i) { + readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { + ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); + })); + } + + // Wait for all reader threads to finish + for (auto& t : readers) { + t.Join(); + } + + // All readers should have been able to lock the mutex concurrently + EXPECT_EQ(read_count.load(), 0); +} + +// Test 2: Writer blocks readers and writer should get the lock exclusively +TEST_F(MRMWMutexTest, ReadersBlockWriters) { + std::atomic read_count(0); + std::atomic write_count(0); + + const int num_readers = 10; + + // Start multiple readers + std::vector readers; + readers.reserve(num_readers); + + for (int i = 0; i < num_readers; ++i) { + readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { + ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); + })); + } + + // Give readers time to acquire the lock + util::ThisFiber::SleepFor(std::chrono::milliseconds(10)); + + pp_->at(1) + ->LaunchFiber(util::fb2::Launch::post, + [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); }) + .Join(); + + // Wait for all reader threads to finish + for (auto& t : readers) { + t.Join(); + } + + EXPECT_EQ(read_count.load(), 0); + EXPECT_EQ(write_count.load(), 0); +} + +// Test 3: Unlock transitions correctly and wakes up waiting threads +TEST_F(MRMWMutexTest, ReaderAfterWriter) { + std::atomic write_count(0); + std::atomic read_count(0); + + // Start a writer thread + auto writer = pp_->at(1)->LaunchFiber(util::fb2::Launch::post, [&] { + WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); + }); + + // Give writer time to acquire the lock + util::ThisFiber::SleepFor(std::chrono::milliseconds(10)); + + // Now start a reader task that will block until the writer is done + pp_->at(0) + ->LaunchFiber(util::fb2::Launch::post, + [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); }) + .Join(); + + // Ensure that writer has completed + writer.Join(); + + EXPECT_EQ(read_count.load(), 0); + EXPECT_EQ(write_count.load(), 0); +} + +// Test 4: Ensure writer gets the lock after readers finish +TEST_F(MRMWMutexTest, WriterAfterReaders) { + std::atomic read_count(0); + std::atomic write_count(0); + + // Start multiple readers + const int num_readers = 10; + std::vector readers; + readers.reserve(num_readers); + + for (int i = 0; i < num_readers; ++i) { + readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { + ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); + })); + } + + // Wait for all readers to acquire and release the lock + for (auto& t : readers) { + t.Join(); + } + + // Start the writer after all readers are done + pp_->at(1) + ->LaunchFiber(util::fb2::Launch::post, + [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); }) + .Join(); + + EXPECT_EQ(read_count.load(), 0); + EXPECT_EQ(write_count.load(), 0); +} + +// Test 5: Mix of readers and writes +TEST_F(MRMWMutexTest, MixWritersReaders) { + std::atomic read_count(0); + std::atomic write_count(0); + + // Start multiple readers and writers + const int num_threads = 100; + std::vector threads; + threads.reserve(num_threads + 1); + + // Add long read task that will block all write tasks + threads.emplace_back( + pp_->at(0)->LaunchFiber([&] { ReadTask(&mutex_, std::ref(read_count), 2000); })); + + // Give long writer time to acquire the lock + util::ThisFiber::SleepFor(std::chrono::milliseconds(100)); + + size_t write_threads = 0; + for (int i = 0; i < num_threads; ++i) { + size_t fiber_id = rand() % 2; + if (rand() % 3) { + threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] { + ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); + })); + } else { + write_threads++; + threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] { + WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); + })); + } + } + + // All shorter threads should be done and only long one remains + util::ThisFiber::SleepFor(std::chrono::milliseconds(500)); + + EXPECT_EQ(read_count.load(), 1); + + EXPECT_EQ(write_count.load(), write_threads); + + // Wait for all readers to acquire and release the lock + for (auto& t : threads) { + t.Join(); + } +} + +} // namespace dfly::search