Skip to content

Commit bb17e32

Browse files
committed
fix: rework thread safety (2)
1 parent 306dcd7 commit bb17e32

File tree

1 file changed

+74
-99
lines changed

1 file changed

+74
-99
lines changed

src/callback.cpp

Lines changed: 74 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
using namespace asmjit;
77

8-
static JitRuntime rt;
9-
108
struct SimpleErrorHandler : ErrorHandler {
119
Error error{kErrorOk};
1210
const char* code{};
@@ -17,6 +15,11 @@ struct SimpleErrorHandler : ErrorHandler {
1715
}
1816
};
1917

18+
using namespace PLH;
19+
20+
static JitRuntime rt;
21+
static thread_local std::map<const Callback*, std::array<plg::any, Globals::kMaxFuncArgs>> storage;
22+
2023
struct ArgRegSlot {
2124
explicit ArgRegSlot(uint32_t idx) {
2225
argIdx = idx;
@@ -29,7 +32,7 @@ struct ArgRegSlot {
2932
bool useHighReg;
3033
};
3134

32-
bool PLH::Callback::hasHiArgSlot(const x86::Compiler& compiler, const TypeId typeId) noexcept {
35+
bool Callback::hasHiArgSlot(const x86::Compiler& compiler, const TypeId typeId) noexcept {
3336
// 64bit width regs can fit wider args
3437
if (compiler.is64Bit()) {
3538
return false;
@@ -49,7 +52,7 @@ constexpr TypeId getTypeIdx() noexcept {
4952
return static_cast<TypeId>(TypeUtils::TypeIdOfT<T>::kTypeId);
5053
}
5154

52-
TypeId PLH::Callback::getTypeId(const DataType type) noexcept {
55+
TypeId Callback::getTypeId(const DataType type) noexcept {
5356
switch (type) {
5457
case DataType::Void:
5558
return getTypeIdx<void>();
@@ -82,7 +85,7 @@ TypeId PLH::Callback::getTypeId(const DataType type) noexcept {
8285
return TypeId::kVoid;
8386
}
8487

85-
uint64_t PLH::Callback::getJitFunc(const FuncSignature& sig, const CallbackEntry pre, const CallbackEntry post) {
88+
uint64_t Callback::getJitFunc(const FuncSignature& sig, const CallbackEntry pre, const CallbackEntry post) {
8689
if (m_functionPtr) {
8790
return m_functionPtr;
8891
}
@@ -116,8 +119,7 @@ uint64_t PLH::Callback::getJitFunc(const FuncSignature& sig, const CallbackEntry
116119
Label noPost = cc.newLabel();
117120

118121
// map argument slots to registers, following abi.
119-
std::vector<ArgRegSlot> argRegSlots;
120-
argRegSlots.reserve(sig.argCount());
122+
std::inplace_vector<ArgRegSlot, Globals::kMaxFuncArgs> argRegSlots;
121123

122124
for (uint32_t argIdx = 0; argIdx < sig.argCount(); ++argIdx) {
123125
const auto& argType = sig.args()[argIdx];
@@ -143,7 +145,7 @@ uint64_t PLH::Callback::getJitFunc(const FuncSignature& sig, const CallbackEntry
143145
func->setArg(argSlot.argIdx, 1, argSlot.high);
144146
}
145147

146-
argRegSlots.emplace_back(std::move(argSlot));
148+
argRegSlots.push_back(std::move(argSlot));
147149
}
148150

149151
const uint32_t alignment = 16;
@@ -379,156 +381,129 @@ uint64_t PLH::Callback::getJitFunc(const FuncSignature& sig, const CallbackEntry
379381
return m_functionPtr;
380382
}
381383

382-
uint64_t PLH::Callback::getJitFunc(const DataType retType, std::span<const DataType> paramTypes, const CallbackEntry pre, const CallbackEntry post, uint8_t vaIndex) {
384+
uint64_t Callback::getJitFunc(const DataType retType, std::span<const DataType> paramTypes, const CallbackEntry pre, const CallbackEntry post, uint8_t vaIndex) {
383385
FuncSignature sig(CallConvId::kCDecl, vaIndex, getTypeId(retType));
384386
for (const DataType& type : paramTypes) {
385387
sig.addArg(getTypeId(type));
386388
}
387389
return getJitFunc(sig, pre, post);
388390
}
389391

390-
bool PLH::Callback::addCallback(const CallbackType type, const CallbackHandler callback, int priority) {
391-
if (!callback)
392+
bool Callback::addCallback(const CallbackType type, const CallbackHandler handler, int priority) {
393+
if (!handler)
392394
return false;
393395

394-
std::scoped_lock lock(m_mutex);
396+
std::unique_lock lock(m_mutex);
395397

396-
auto& callbacks = m_callbacks[static_cast<size_t>(type)];
398+
auto& [handlers, priorities] = m_callbacks[static_cast<size_t>(type)];
397399

398-
if (std::any_of(callbacks.begin(), callbacks.end(),
399-
[&](auto& x){ return x.first == callback; }))
400+
if (std::any_of(handlers.begin(), handlers.end(),
401+
[&](const auto& h){ return h == handler; }))
400402
return false;
401403

402-
auto& used = m_used[static_cast<size_t>(type)];
403-
if (used.load(std::memory_order_relaxed) == 0) {
404-
callbacks.emplace_back(callback, priority);
405-
} else {
406-
auto& appends = m_appends[static_cast<size_t>(type)];
407-
appends.emplace_back(callback, priority);
408-
}
404+
auto it = std::upper_bound(priorities.begin(), priorities.end(), priority,
405+
[](int p, int cur){ return p > cur; }); // descending order
406+
407+
auto index = std::distance(priorities.begin(), it);
408+
409+
handlers.insert(handlers.begin() + index, handler);
410+
priorities.insert(priorities.begin() + index, priority);
409411

410412
return true;
411413
}
412414

413-
bool PLH::Callback::removeCallback(const CallbackType type, const CallbackHandler callback) {
414-
if (!callback)
415+
bool Callback::removeCallback(const CallbackType type, const CallbackHandler handler) {
416+
if (!handler)
415417
return false;
416418

417-
std::scoped_lock lock(m_mutex);
419+
std::unique_lock lock(m_mutex);
418420

419-
auto& callbacks = m_callbacks[static_cast<size_t>(type)];
421+
auto& [handlers, priorities] = m_callbacks[static_cast<size_t>(type)];
420422

421-
auto it = std::find_if(callbacks.begin(), callbacks.end(),
422-
[&](auto& x){ return x.first == callback; });
423-
if (it == callbacks.end())
424-
return false;
423+
auto it = std::find(handlers.begin(), handlers.end(), handler);
424+
if (it == handlers.end())
425+
return false;
425426

426-
auto& used = m_used[static_cast<size_t>(type)];
427-
if (used.load(std::memory_order_relaxed) == 0) {
428-
callbacks.erase(it);
429-
} else {
430-
auto& removals = m_removals[static_cast<size_t>(type)];
431-
removals.emplace_back(callback, -1);
432-
}
427+
auto index = std::distance(handlers.begin(), it);
428+
handlers.erase(handlers.begin() + index);
429+
priorities.erase(priorities.begin() + index);
433430

434431
return true;
435432
}
436433

437-
bool PLH::Callback::isCallbackRegistered(const CallbackType type, const CallbackHandler callback) const noexcept {
438-
if (!callback)
434+
bool Callback::isCallbackRegistered(const CallbackType type, const CallbackHandler handler) const noexcept {
435+
if (!handler)
439436
return false;
440437

441-
auto& callbacks = m_callbacks[static_cast<size_t>(type)];
442-
return std::any_of(callbacks.begin(), callbacks.end(), [&](auto& x){ return x.first == callback; });
438+
std::shared_lock lock(m_mutex);
439+
440+
const auto& [handlers, priorities] = m_callbacks[static_cast<size_t>(type)];
441+
442+
return std::any_of(handlers.begin(), handlers.end(), [&](const auto& x){ return x == handler; });
443443
}
444444

445-
bool PLH::Callback::areCallbacksRegistered(const CallbackType type) const noexcept {
446-
return !m_callbacks[static_cast<size_t>(type)].empty();
445+
bool Callback::areCallbacksRegistered(const CallbackType type) const noexcept {
446+
std::shared_lock lock(m_mutex);
447+
448+
const auto& [handlers, priorities] = m_callbacks[static_cast<size_t>(type)];
449+
450+
return !handlers.empty();
447451
}
448452

449-
bool PLH::Callback::areCallbacksRegistered() const noexcept {
453+
bool Callback::areCallbacksRegistered() const noexcept {
450454
return areCallbacksRegistered(CallbackType::Pre) || areCallbacksRegistered(CallbackType::Post);
451455
}
452456

453-
PLH::Callback::View PLH::Callback::getCallbacks(const CallbackType type) noexcept {
454-
return { *this, type };
457+
plg::hybrid_vector<Callback::CallbackHandler, Callback::kMaxFuncStack> Callback::getCallbacks(const CallbackType type) noexcept {
458+
std::shared_lock lock(m_mutex);
459+
460+
const auto& [handlers, priorities] = m_callbacks[static_cast<size_t>(type)];
461+
462+
return handlers;
455463
}
456464

457-
uint64_t* PLH::Callback::getTrampolineHolder() noexcept {
465+
uint64_t* Callback::getTrampolineHolder() noexcept {
458466
return &m_trampolinePtr;
459467
}
460468

461-
uint64_t* PLH::Callback::getFunctionHolder() noexcept {
469+
uint64_t* Callback::getFunctionHolder() noexcept {
462470
return &m_functionPtr;
463471
}
464472

465-
std::string_view PLH::Callback::getError() const noexcept {
473+
std::string_view Callback::getError() const noexcept {
466474
return !m_functionPtr && m_errorCode ? m_errorCode : "";
467475
}
468476

469-
plg::any& PLH::Callback::createStorage(const size_t idx, const plg::any& any) {
470-
std::scoped_lock lock(m_mutex);
471-
return m_storage[std::this_thread::get_id()][idx] = any;
477+
plg::any& Callback::setStorage(size_t idx, const plg::any& any) const {
478+
if (idx == -1) {
479+
idx = 32;
480+
}
481+
storage[this][idx] = any;
482+
return storage[this][idx];
472483
}
473484

474-
plg::any& PLH::Callback::getStorage(const size_t idx) {
475-
return m_storage[std::this_thread::get_id()][idx];
485+
plg::any& Callback::getStorage(size_t idx) const {
486+
if (idx == -1) {
487+
idx = 32;
488+
}
489+
return storage[this][idx];
476490
}
477491

478-
PLH::DataType PLH::Callback::getReturnType() const {
492+
DataType Callback::getReturnType() const {
479493
return m_returnType;
480494
}
481495

482-
PLH::DataType PLH::Callback::getArgumentType(size_t idx) const {
496+
DataType Callback::getArgumentType(size_t idx) const {
483497
return m_arguments[idx];
484498
}
485499

486-
void PLH::Callback::cleanupStorage() {
487-
if (m_storage.empty()) {
488-
return;
489-
}
490-
491-
std::scoped_lock lock(m_mutex);
492-
m_storage.erase(std::this_thread::get_id());
493-
}
494-
495-
void PLH::Callback::processDelayed(CallbackType type) {
496-
auto& removals = m_appends[static_cast<size_t>(type)];
497-
auto& appends = m_appends[static_cast<size_t>(type)];
498-
if (removals.empty() && appends.empty()) {
499-
return;
500-
}
501-
502-
std::scoped_lock lock(m_mutex);
503-
504-
auto& callbacks = m_callbacks[static_cast<size_t>(type)];
505-
506-
for (const auto& [callback, priority] : appends) {
507-
if (std::any_of(callbacks.begin(), callbacks.end(),
508-
[&](auto& x){ return x.first == callback; })) {
509-
continue;
510-
}
511-
callbacks.emplace_back(callback, priority);
512-
}
513-
514-
for (const auto& [callback, priority] : removals) {
515-
auto it = std::find_if(callbacks.begin(), callbacks.end(),
516-
[&](auto& x){ return x.first == callback; });
517-
if (it == callbacks.end()) {
518-
continue;
519-
}
520-
callbacks.erase(it);
521-
}
522-
523-
appends.clear();
524-
removals.clear();
525-
}
526-
527-
PLH::Callback::Callback(DataType returnType, std::span<const DataType> arguments) : m_returnType(returnType), m_arguments(arguments.begin(), arguments.end()) {
500+
Callback::Callback(DataType returnType, std::span<const DataType> arguments) : m_returnType(returnType), m_arguments(arguments.begin(), arguments.end()) {
501+
storage[this] = {};
528502
}
529503

530-
PLH::Callback::~Callback() {
504+
Callback::~Callback() {
531505
if (m_functionPtr) {
532506
rt.release(m_functionPtr);
533507
}
508+
storage.erase(this);
534509
}

0 commit comments

Comments
 (0)