diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp index a1252edfee..49971cba8c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp @@ -54,8 +54,14 @@ VoiceActivityDetection::preprocess(std::span waveform) const { return frameBuffer; } +void VoiceActivityDetection::unload() noexcept { + std::scoped_lock lock(inference_mutex_); + BaseModel::unload(); +} + std::vector VoiceActivityDetection::generate(std::span waveform) const { + std::scoped_lock lock(inference_mutex_); auto windowedInput = preprocess(waveform); auto [chunksNumber, remainder] = std::div( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h index e692889305..d2f2b1f9e6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" @@ -23,7 +24,19 @@ class VoiceActivityDetection : public BaseModel { [[nodiscard("Registered non-void function")]] std::vector generate(std::span waveform) const; + /** + * @brief Thread-safe unload that waits for any in-flight inference to + * complete. + * + * Mirrors VisionModel::unload(). Without this, BaseModel::unload() can + * destroy module_ while generate() is still calling forward() on a worker + * thread, causing SIGILL / SIGSEGV crashes. + */ + void unload() noexcept; + private: + mutable std::mutex inference_mutex_; + std::vector> preprocess(std::span waveform) const; std::vector postprocess(const std::vector &scores,