Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcui committed Jan 19, 2025
1 parent 335dd91 commit 6eb8b73
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 44 deletions.
3 changes: 2 additions & 1 deletion include/lgraph/lgraph_exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ X(ReachMaximumEid, "Edge eid exceeds the limit.") \
X(ReachMaximumCompositeIndexField, "The size of composite index fields exceeds the limit.") \
X(PluginDisabled, "Plugin disabled!") \
X(BoltDataException, "Bolt data exception") \
X(VectorIndexException, "Vector index exception")
X(VectorIndexException, "Vector index exception") \
X(ReplicateTimeout, "Raft replication Timeout")

enum class ErrorCode {
#define X(code, msg) code,
Expand Down
2 changes: 2 additions & 0 deletions src/BuildBoltLib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ set(BOLT_SRC
bolt_ha/logger.cpp
bolt_ha/raft_log_store.cpp
bolt_ha/raft_driver.cpp
bolt_ha/bolt_ha.pb.cc
${LGRAPH_ROOT_DIR}/deps/etcd-raft-cpp/raftpb/raft.pb.cc
lgraph_api/lgraph_exceptions.cpp)

add_library(${TARGET_BOLT_LIB} STATIC ${BOLT_SRC})
Expand Down
3 changes: 2 additions & 1 deletion src/bolt/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class BoltConnection
public:
BoltConnection(boost::asio::io_service& io_service,
std::function<void(BoltConnection& conn, BoltMsg msg,
const std::vector<std::any>& fields)> handle)
const std::vector<std::any>& fields,
std::vector<uint8_t > raw_data)> handle)
: Connection(io_service),
handle_(std::move(handle)) {}
void Start() override;
Expand Down
22 changes: 22 additions & 0 deletions src/bolt_ha/bolt_ha.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* Copyright 2024 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

syntax = "proto3";
package bolt_ha;

message RaftRequest {
uint64 id = 1;
string user = 2;
bytes raw_data = 3;
}
23 changes: 6 additions & 17 deletions src/bolt_ha/raft_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
using boost::asio::async_write;
using boost::asio::ip::tcp;

extern std::shared_mutex promise_mutex;
extern std::unordered_map<uint64_t, std::promise<std::string>> pending_promise;
namespace bolt_ha {
void NodeClient::reconnect() {
if (has_closed_) {
Expand Down Expand Up @@ -305,7 +303,7 @@ eraft::Error RaftDriver::ProposeConfChange(const raftpb::ConfChange& cc) {
return PostMessage(std::move(msg));
}

eraft::Error RaftDriver::Proposal(std::string data) {
eraft::Error RaftDriver::Propose(std::string data) {
raftpb::Message msg;
auto entry = msg.add_entries();
entry->set_type(raftpb::EntryType::EntryNormal);
Expand Down Expand Up @@ -375,17 +373,7 @@ void RaftDriver::on_ready(eraft::Ready ready) {
if (entry.data().empty()) {
continue;
}
auto propose = nlohmann::json::parse(entry.data());
auto uid = propose["uid"].get<uint64_t>();
auto data = propose["data"].get<std::string>();
apply_(entry.index(), data);
{
std::shared_lock lock(promise_mutex);
auto iter = pending_promise.find(uid);
if (iter != pending_promise.end()) {
iter->second.set_value(data);
}
}
apply_(entry.index(), entry.data());
break;
}
case raftpb::EntryConfChange: {
Expand Down Expand Up @@ -437,13 +425,13 @@ void RaftDriver::on_ready(eraft::Ready ready) {
storage_->SetNodesInfo(nodes_info(), wb);
storage_->SetApplyIndex(entry.index(), wb);
storage_->WriteBatch(wb);
if (cc.id() > 0) {
/*if (cc.id() > 0) {
std::shared_lock lock(promise_mutex);
auto iter = pending_promise.find(cc.id());
if (iter != pending_promise.end()) {
iter->second.set_value(cc.context());
}
}
}*/
break;
}
default: {
Expand All @@ -452,5 +440,6 @@ void RaftDriver::on_ready(eraft::Ready ready) {
}
}
}
std::shared_ptr<RaftDriver> raft_driver;
std::shared_ptr<RaftDriver> g_raft_driver;
std::shared_ptr<Generator> g_id_generator;
}
33 changes: 31 additions & 2 deletions src/bolt_ha/raft_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class RaftDriver {
std::string log_path);
eraft::Error Run();
void Message(raftpb::Message msg);
eraft::Error Proposal(std::string data);
eraft::Error Propose(std::string data);
eraft::Error ProposeConfChange(const raftpb::ConfChange& cc);

private:
Expand Down Expand Up @@ -87,5 +87,34 @@ class RaftDriver {
bool advance_ = false;
std::unordered_map<uint64_t, std::shared_ptr<NodeClient>> node_clients_;
};
extern std::shared_ptr<RaftDriver> raft_driver;
extern std::shared_ptr<RaftDriver> g_raft_driver;

struct Generator {
static const int tsLen = 5 * 8;
static const int cntLen = 8;
static const int suffixLen = tsLen + cntLen;

Generator(uint64_t id, uint64_t time) {
prefix = id << suffixLen;
suffix = lowbit(time, tsLen) << cntLen;
}
uint64_t prefix = 0;
std::atomic<uint64_t> suffix{0};
uint64_t lowbit(uint64_t x, int n) {
return x & (std::numeric_limits<uint64_t>::max() >> (64 - n));
}
uint64_t Next() {
auto suf = suffix.fetch_add(1) + 1;
auto id = prefix | lowbit(suf, suffixLen);
return id;
}
};
extern std::shared_ptr<Generator> g_id_generator;

struct ApplyContext {
uint64_t index = 0;
std::promise<void> start;
std::promise<void> end;
};

}
6 changes: 6 additions & 0 deletions src/db/galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,12 @@ void lgraph::Galaxy::CheckTuGraphVersion(KvTransaction& txn) {
}
}

void lgraph::Galaxy::UpdateBoltRaftApplyIndex(uint64_t index) {
auto txn = store_->CreateWriteTxn(false);
db_info_table_->SetValue(*txn, Value::ConstRef("bolt_raft_apply_index"), Value::ConstRef(index));
txn->Commit();
}

void lgraph::Galaxy::BootstrapRaftLogIndex(int64_t log_id) {
SetRaftLogIndexBeforeWrite(log_id);
// need to write something so that raft log id can be written to db
Expand Down
2 changes: 2 additions & 0 deletions src/db/galaxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class Galaxy {
bool AddUserRoles(const std::string& current_user, const std::string& user,
const std::vector<std::string>& roles);

void UpdateBoltRaftApplyIndex(uint64_t index);

private:
// load config from db
// overwrites content of global_config_ if it is not null
Expand Down
109 changes: 108 additions & 1 deletion src/server/bolt_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "server/bolt_server.h"
#include "server/bolt_session.h"
#include "db/galaxy.h"

#include "bolt_ha/raft_driver.h"
#include "bolt_ha/bolt_ha.pb.h"

using namespace lgraph_api;
namespace bolt {
extern boost::asio::io_service workers;
Expand Down Expand Up @@ -107,6 +111,76 @@ parser::Expression ConvertParameters(std::any data) {
return ret;
}

using namespace std::chrono;
std::shared_mutex promise_mutex;
std::unordered_map<uint64_t, std::shared_ptr<bolt_ha::ApplyContext>> pending_promise;

void g_apply(uint64_t index, const std::string& log) {
bolt_ha::RaftRequest request;
auto ret = request.ParseFromString(log);
assert(ret);
std::shared_ptr<bolt_ha::ApplyContext> context;
{
std::unique_lock lock(promise_mutex);
auto iter = pending_promise.find(request.id());
if (iter != pending_promise.end()) {
context = iter->second;
pending_promise.erase(iter);
}
}
if (context) {
context->index = index;
context->start.set_value();
context->end.get_future().get();
} else {
Unpacker unpacker;
unpacker.Reset(std::string_view(request.raw_data().data(), request.raw_data().size()));
unpacker.Next();
auto len = unpacker.Len();
auto tag = static_cast<BoltMsg>(unpacker.StructTag());
FMA_ASSERT(tag == BoltMsg::Run);
std::vector<std::any> fields;
for (uint32_t i = 0; i < len; i++) {
unpacker.Next();
fields.push_back(bolt::ServerHydrator(unpacker));
}

auto& cypher = std::any_cast<const std::string&>(fields[0]);
auto& extra = std::any_cast<
const std::unordered_map<std::string, std::any>&>(fields[2]);
std::string graph;
auto db_iter = extra.find("db");
if (db_iter != extra.end()) {
graph = std::any_cast<const std::string&>(db_iter->second);
}
auto& field1 = std::any_cast<
std::unordered_map<std::string, std::any>&>(fields[1]);
auto sm = BoltServer::Instance().StateMachine();
cypher::RTContext ctx(sm, sm->GetGalaxy(), request.user(), graph,
sm->IsCypherV2());
if (ctx.is_cypher_v2_) {
ctx.bolt_parameters_v2_ = std::make_shared<std::unordered_map<
std::string, geax::frontend::Expr*>>();
} else {
ctx.bolt_parameters_ = std::make_shared<std::unordered_map<
std::string, parser::Expression>>();
}
for (auto& pair : field1) {
if (ctx.is_cypher_v2_) {
ctx.bolt_parameters_v2_->emplace("$" + pair.first,
ConvertParameters(ctx.obj_alloc_, std::move(pair.second)));
} else {
ctx.bolt_parameters_->emplace("$" + pair.first,
ConvertParameters(std::move(pair.second)));
}
}
cypher::ElapsedTime elapsed;
sm->GetCypherScheduler()->Eval(&ctx, lgraph_api::GraphQueryType::CYPHER,
cypher, elapsed);
sm->GetGalaxy()->UpdateBoltRaftApplyIndex(index);
}
}

void BoltFSM(std::shared_ptr<BoltConnection> conn) {
pthread_setname_np(pthread_self(), "bolt_fsm");
auto conn_id = conn->conn_id();
Expand Down Expand Up @@ -223,11 +297,42 @@ void BoltFSM(std::shared_ptr<BoltConnection> conn) {
}
}
session->streaming_msg.reset();
std::shared_ptr<bolt_ha::ApplyContext> apply_context;
{
std::string plugin_name, plugin_type;
auto ret = cypher::Scheduler::DetermineReadOnly(&ctx, GraphQueryType::CYPHER, cypher, plugin_name, plugin_type);
if (!ret) {
auto uid = bolt_ha::g_id_generator->Next();
apply_context = std::make_shared<bolt_ha::ApplyContext>();
auto future = apply_context->start.get_future();
{
std::unique_lock lock(promise_mutex);
pending_promise.emplace(uid, apply_context);
}
bolt_ha::RaftRequest request;
request.set_id(uid);
request.set_user(session->user);
request.set_raw_data((const char*)msg.value().raw_data.data(), msg.value().raw_data.size());
auto err = bolt_ha::g_raft_driver->Propose(request.SerializeAsString());
if (err != nullptr) {
LOG_ERROR() << FMA_FMT("Failed to propose, err: {}", err.String());
}
if (future.wait_for(std::chrono::milliseconds(1000)) == std::future_status::ready) {
future.get();
} else {
THROW_CODE(ReplicateTimeout);
}
}
}
cypher::ElapsedTime elapsed;
LOG_DEBUG() << "Bolt run " << cypher;
sm->GetCypherScheduler()->Eval(&ctx, lgraph_api::GraphQueryType::CYPHER,
cypher, elapsed);
LOG_DEBUG() << "Cypher execution completed";
if (apply_context) {
sm->GetGalaxy()->UpdateBoltRaftApplyIndex(apply_context->index);
apply_context->end.set_value();
}
} catch (const lgraph_api::LgraphException& e) {
LOG_ERROR() << e.what();
RespondFailure(e.code(), e.msg());
Expand All @@ -241,7 +346,9 @@ void BoltFSM(std::shared_ptr<BoltConnection> conn) {
LOG_DEBUG() << FMA_FMT("bolt fsm thread[conn_id:{}] exit.", conn_id);
}

std::function BoltHandler =
std::function<void(bolt::BoltConnection &conn, bolt::BoltMsg msg,
std::vector<std::any> fields, std::vector<uint8_t> raw_data)>
BoltHandler =
[](BoltConnection& conn, BoltMsg msg, std::vector<std::any> fields, std::vector<uint8_t> raw_data) {
if (msg == BoltMsg::Hello) {
if (fields.size() != 1) {
Expand Down
Loading

0 comments on commit 6eb8b73

Please sign in to comment.