Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cxx_test(range_tree_test dfly_search_core absl::random_random LABELS DFLY)
cxx_test(rax_tree_test redis_test_lib LABELS DFLY)
cxx_test(search_parser_test dfly_search_core LABELS DFLY)
cxx_test(search_test redis_test_lib dfly_search_core LABELS DFLY)
cxx_test(mrmw_mutex_test redis_test_lib dfly_search_core fibers2 LABELS DFLY)

if(WITH_SIMSIMD)
target_link_libraries(search_test TRDP::simsimd)
Expand Down
9 changes: 7 additions & 2 deletions src/core/search/hnsw_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
#include "core/search/hnsw_index.h"

#include <absl/strings/match.h>
#include <absl/synchronization/mutex.h>
#include <hnswlib/hnswlib.h>
#include <hnswlib/space_ip.h>
#include <hnswlib/space_l2.h>

#include "base/logging.h"
#include "core/search/hnsw_alg.h"
#include "core/search/mrmw_mutex.h"
#include "core/search/vector_utils.h"

namespace dfly::search {
Expand Down Expand Up @@ -70,7 +70,8 @@ struct HnswlibAdapter {
void Add(const float* data, GlobalDocId id) {
while (true) {
try {
absl::ReaderMutexLock lock(&resize_mutex_);
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
absl::ReaderMutexLock resize_lock(&resize_mutex_);
world_.addPoint(data, id);
return;
} catch (const std::exception& e) {
Expand All @@ -86,6 +87,7 @@ struct HnswlibAdapter {

void Remove(GlobalDocId id) {
try {
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
world_.markDelete(id);
} catch (const std::exception& e) {
LOG(WARNING) << "HnswlibAdapter::Remove exception: " << e.what();
Expand All @@ -94,6 +96,7 @@ struct HnswlibAdapter {

vector<pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
world_.setEf(ef.value_or(kDefaultEfRuntime));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving the world_.setEf(...) call inside the MRMWMutexLock scope; setEf mutates shared state and, as written, can race with concurrent reads/writes. (Also applies to the filtered Knn overload below.)

🤖 Was this useful? React with 👍 or 👎

MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
return QueueToVec(world_.searchKnn(target, k));
}

Expand All @@ -111,6 +114,7 @@ struct HnswlibAdapter {

world_.setEf(ef.value_or(kDefaultEfRuntime));
BinsearchFilter filter{&allowed};
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
return QueueToVec(world_.searchKnn(target, k, &filter));
}

Expand Down Expand Up @@ -153,6 +157,7 @@ struct HnswlibAdapter {
HnswSpace space_;
HierarchicalNSW<float> world_;
absl::Mutex resize_mutex_;
mutable MRMWMutex mrmw_mutex_;
};

HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
Expand Down
89 changes: 89 additions & 0 deletions src/core/search/mrmw_mutex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2025, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//

#include "base/logging.h"
#include "util/fibers/synchronization.h"

namespace dfly::search {

// Simple implementation of multi-Reader multi-Writer Mutex
// MRMWMutex supports concurrent reads or concurrent writes but not a mix of
// concurrent reads and writes at the same time.

class MRMWMutex {
public:
enum class LockMode : uint8_t { kReadLock, kWriteLock };

MRMWMutex() : lock_mode_(LockMode::kReadLock) {
}

void Lock(LockMode mode) {
std::unique_lock lk(mutex_);

// If we have any active_runners we need to check lock mode
if (active_runners_) {
auto& waiters = GetWaiters(mode);
waiters++;
GetCondVar(mode).wait(lk, [&] { return lock_mode_ == mode; });
waiters--;
} else {
// No active runners so just update to requested lock mode
lock_mode_ = mode;
}
active_runners_++;
}

void Unlock(LockMode mode) {
std::unique_lock lk(mutex_);
LockMode inverse_mode = GetInverseMode(mode);
active_runners_--;
// If this was last runner and there are waiters on inverse mode
if (!active_runners_ && GetWaiters(inverse_mode) > 0) {
lock_mode_ = inverse_mode;
GetCondVar(inverse_mode).notify_all();
}
}

private:
inline size_t& GetWaiters(LockMode target_mode) {
return target_mode == LockMode::kReadLock ? reader_waiters_ : writer_waiters_;
};

inline util::fb2::CondVar& GetCondVar(LockMode target_mode) {
return target_mode == LockMode::kReadLock ? reader_cond_var_ : writer_cond_var_;
};

static inline LockMode GetInverseMode(LockMode mode) {
return mode == LockMode::kReadLock ? LockMode::kWriteLock : LockMode::kReadLock;
}

util::fb2::Mutex mutex_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you also contend on mutex_, though here maybe absl::SpinLock would fit better as you take it for a small period of time. you could use CondVarAny to be able to use locks with a different type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the most important thing though.

util::fb2::CondVar reader_cond_var_, writer_cond_var_;
size_t writer_waiters_ = 0, reader_waiters_ = 0;
size_t active_runners_ = 0;
LockMode lock_mode_;
};

class MRMWMutexLock {
public:
explicit MRMWMutexLock(MRMWMutex* mutex, MRMWMutex::LockMode mode)
: mutex_(mutex), lock_mode_(mode) {
mutex->Lock(lock_mode_);
}

~MRMWMutexLock() {
mutex_->Unlock(lock_mode_);
}

MRMWMutexLock(const MRMWMutexLock&) = delete;
MRMWMutexLock(MRMWMutexLock&&) = delete;
MRMWMutexLock& operator=(const MRMWMutexLock&) = delete;
MRMWMutexLock& operator=(MRMWMutexLock&&) = delete;

private:
MRMWMutex* const mutex_;
MRMWMutex::LockMode lock_mode_;
};

} // namespace dfly::search
224 changes: 224 additions & 0 deletions src/core/search/mrmw_mutex_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
// Copyright 2025, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//

#include "core/search/mrmw_mutex.h"

#include <random>
#include <thread>

#include "absl/flags/flag.h"
#include "base/gtest.h"
#include "base/logging.h"
#include "util/fibers/pool.h"

ABSL_FLAG(bool, force_epoll, false, "If true, uses epoll api instead iouring to run tests");

namespace dfly::search {

namespace {

// Helper function to simulate reading operation
void ReadTask(MRMWMutex* mutex, std::atomic<size_t>& read_count, size_t sleep_time) {
read_count.fetch_add(1, std::memory_order_relaxed);
MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kReadLock);
util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time));
read_count.fetch_sub(1, std::memory_order_relaxed);
}

// Helper function to simulate writing operation
void WriteTask(MRMWMutex* mutex, std::atomic<size_t>& write_count, size_t sleep_time) {
write_count.fetch_add(1, std::memory_order_relaxed);
MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kWriteLock);
util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time));
write_count.fetch_sub(1, std::memory_order_relaxed);
}

constexpr size_t kReadTaskSleepTime = 50;
constexpr size_t kWriteTaskSleepTime = 100;

} // namespace

class MRMWMutexTest : public ::testing::Test {
protected:
MRMWMutex mutex_;
std::mt19937 generator_;
void SetUp() override {
#ifdef __linux__
if (absl::GetFlag(FLAGS_force_epoll)) {
pp_.reset(util::fb2::Pool::Epoll(2));
} else {
pp_.reset(util::fb2::Pool::IOUring(16, 2));
}
#else
pp_.reset(fb2::Pool::Epoll(2));
#endif
pp_->Run();
}
void TearDown() override {
pp_->Stop();
pp_.reset();
}
std::unique_ptr<util::ProactorPool> pp_;
};

// Test 1: Multiple readers can lock concurrently
TEST_F(MRMWMutexTest, MultipleReadersConcurrently) {
std::atomic<size_t> read_count(0);
const int num_readers = 5;

std::vector<util::fb2::Fiber> readers;
readers.reserve(num_readers);

for (int i = 0; i < num_readers; ++i) {
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
}));
}

// Wait for all reader threads to finish
for (auto& t : readers) {
t.Join();
}

// All readers should have been able to lock the mutex concurrently
EXPECT_EQ(read_count.load(), 0);
}

// Test 2: Writer blocks readers and writer should get the lock exclusively
TEST_F(MRMWMutexTest, ReadersBlockWriters) {
std::atomic<size_t> read_count(0);
std::atomic<size_t> write_count(0);

const int num_readers = 10;

// Start multiple readers
std::vector<util::fb2::Fiber> readers;
readers.reserve(num_readers);

for (int i = 0; i < num_readers; ++i) {
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
}));
}

// Give readers time to acquire the lock
util::ThisFiber::SleepFor(std::chrono::milliseconds(10));

pp_->at(1)
->LaunchFiber(util::fb2::Launch::post,
[&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); })
.Join();

// Wait for all reader threads to finish
for (auto& t : readers) {
t.Join();
}

EXPECT_EQ(read_count.load(), 0);
EXPECT_EQ(write_count.load(), 0);
}

// Test 3: Unlock transitions correctly and wakes up waiting threads
TEST_F(MRMWMutexTest, ReaderAfterWriter) {
std::atomic<size_t> write_count(0);
std::atomic<size_t> read_count(0);

// Start a writer thread
auto writer = pp_->at(1)->LaunchFiber(util::fb2::Launch::post, [&] {
WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime);
});

// Give writer time to acquire the lock
util::ThisFiber::SleepFor(std::chrono::milliseconds(10));

// Now start a reader task that will block until the writer is done
pp_->at(0)
->LaunchFiber(util::fb2::Launch::post,
[&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })
.Join();

// Ensure that writer has completed
writer.Join();

EXPECT_EQ(read_count.load(), 0);
EXPECT_EQ(write_count.load(), 0);
}

// Test 4: Ensure writer gets the lock after readers finish
TEST_F(MRMWMutexTest, WriterAfterReaders) {
std::atomic<size_t> read_count(0);
std::atomic<size_t> write_count(0);

// Start multiple readers
const int num_readers = 10;
std::vector<util::fb2::Fiber> readers;
readers.reserve(num_readers);

for (int i = 0; i < num_readers; ++i) {
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
}));
}

// Wait for all readers to acquire and release the lock
for (auto& t : readers) {
t.Join();
}

// Start the writer after all readers are done
pp_->at(1)
->LaunchFiber(util::fb2::Launch::post,
[&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); })
.Join();

EXPECT_EQ(read_count.load(), 0);
EXPECT_EQ(write_count.load(), 0);
}

// Test 5: Mix of readers and writes
TEST_F(MRMWMutexTest, MixWritersReaders) {
std::atomic<size_t> read_count(0);
std::atomic<size_t> write_count(0);

// Start multiple readers and writers
const int num_threads = 100;
std::vector<util::fb2::Fiber> threads;
threads.reserve(num_threads + 1);

// Add long read task that will block all write tasks
threads.emplace_back(
pp_->at(0)->LaunchFiber([&] { ReadTask(&mutex_, std::ref(read_count), 2000); }));

// Give long writer time to acquire the lock
util::ThisFiber::SleepFor(std::chrono::milliseconds(100));

size_t write_threads = 0;
for (int i = 0; i < num_threads; ++i) {
size_t fiber_id = rand() % 2;
if (rand() % 3) {
threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] {
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
}));
} else {
write_threads++;
threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] {
WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime);
}));
}
}

// All shorter threads should be done and only long one remains
util::ThisFiber::SleepFor(std::chrono::milliseconds(500));

EXPECT_EQ(read_count.load(), 1);

EXPECT_EQ(write_count.load(), write_threads);

// Wait for all readers to acquire and release the lock
for (auto& t : threads) {
t.Join();
}
}

} // namespace dfly::search
Loading