diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index abfdb40dfc..04b0accd34 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -375,7 +375,9 @@ template class ModelHostObject : public JsiHostObject { // We need to dispatch a thread if we want the function to be // asynchronous. In this thread all accesses to jsi::Runtime need to // be done via the callInvoker. - threads::GlobalThreadPool::detach([this, promise, + threads::GlobalThreadPool::detach([model = this->model, + callInvoker = this->callInvoker, + promise, argsConverted = std::move(argsConverted)]() { try { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp index d645d6afa3..f0f4108543 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp @@ -35,8 +35,14 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) { return {.inputIds = inputIds64, .attentionMask = attentionMask}; } +void TextEmbeddings::unload() noexcept { + std::scoped_lock lock(inference_mutex_); + BaseModel::unload(); +} + std::shared_ptr TextEmbeddings::generate(const std::string input) { + std::scoped_lock lock(inference_mutex_); auto preprocessed = preprocess(input); std::vector tokenIdsShape = { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h index 28dacca365..93d0988c04 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h @@ -1,6 +1,7 @@ #pragma once #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include #include @@ -20,8 +21,10 @@ class TextEmbeddings final : public BaseEmbeddings { [[nodiscard( "Registered non-void function")]] std::shared_ptr generate(const std::string input); + void unload() noexcept; private: + mutable std::mutex inference_mutex_; std::vector> inputShapes; TokenIdsWithAttentionMask preprocess(const std::string &input); std::unique_ptr tokenizer; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp index e8de58b708..22ad6f2ad8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp @@ -58,6 +58,7 @@ std::shared_ptr TextToImage::generate(std::string input, int32_t imageSize, size_t numInferenceSteps, int32_t seed, std::shared_ptr callback) { + std::scoped_lock lock(inference_mutex_); setImageSize(imageSize); setSeed(seed); @@ -137,6 +138,7 @@ size_t TextToImage::getMemoryLowerBound() const noexcept { } void TextToImage::unload() noexcept { + std::scoped_lock lock(inference_mutex_); encoder->unload(); unet->unload(); decoder->unload(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h index 18316217cd..e071a0c2ee 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -49,6 +50,7 @@ class TextToImage final { static constexpr float guidanceScale = 7.5f; static constexpr float latentsScale = 0.18215f; bool interrupted = false; + mutable std::mutex inference_mutex_; std::shared_ptr callInvoker; std::unique_ptr scheduler;