From 8e05a1f7a90827ee2e1924724df859fe88955b4e Mon Sep 17 00:00:00 2001 From: Jian You Date: Mon, 20 Nov 2023 21:29:02 +0800 Subject: [PATCH] for blank skipping in sherpa-onnx test --- CMakeLists.txt | 2 +- sherpa-onnx/csrc/offline-model-config.h | 4 +- .../csrc/offline-recognizer-transducer-impl.h | 9 +- ...ffline-transducer-greedy-search-decoder.cc | 10 +- .../csrc/offline-transducer-model-config.cc | 36 ++- .../csrc/offline-transducer-model-config.h | 16 +- sherpa-onnx/csrc/offline-transducer-model.cc | 248 ++++++++++++++++++ sherpa-onnx/csrc/offline-transducer-model.h | 10 + sherpa-onnx/csrc/online-model-config.h | 2 +- .../csrc/online-recognizer-transducer-impl.h | 3 + ...online-transducer-greedy-search-decoder.cc | 4 +- .../csrc/online-transducer-model-config.cc | 19 +- .../csrc/online-transducer-model-config.h | 8 +- .../csrc/online-zipformer-transducer-model.cc | 51 +++- .../online-zipformer2-transducer-model.cc | 150 ++++++++++- .../csrc/online-zipformer2-transducer-model.h | 16 ++ sherpa-onnx/csrc/sherpa-onnx-microphone.cc | 1 + 17 files changed, 569 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ca1af9e78..968e15446 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ endif() option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF) -option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) +option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON) option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 55a063f98..326e479e4 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -25,7 +25,7 @@ struct OfflineModelConfig { std::string tokens; int32_t num_threads = 2; - bool debug = false; + bool debug = true; std::string provider = "cpu"; // With the help of this field, we only need to load the model once @@ -37,7 +37,7 @@ struct OfflineModelConfig { // - nemo_ctc. It is a NeMo CTC model. // // All other values are invalid and lead to loading the model twice. - std::string model_type; + std::string model_type = "transducer"; OfflineModelConfig() = default; OfflineModelConfig(const OfflineTransducerModelConfig &transducer, diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 3f4e2b05e..0e9d42a0a 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -173,8 +173,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { -23.025850929940457f); auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + + // Ort::Value encoder_out = t.first; + auto ctc_out = model_->RunCTC(Clone(model_->Allocator(), &(t.first))); + auto frame_reducer_out = model_->RunFrameReducer(std::move(t.first), std::move(t.second), std::move(ctc_out)); + + // auto results = + // decoder_->Decode(std::move(t.first), std::move(t.second), ss, n); auto results = - decoder_->Decode(std::move(t.first), std::move(t.second), ss, n); + decoder_->Decode(std::move(frame_reducer_out.first), std::move(frame_reducer_out.second), ss, n); int32_t frame_shift_ms = 10; for (int32_t i = 0; i != n; ++i) { diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc index 99ac33388..cc4d5b369 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -22,6 +22,8 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, PackedSequence packed_encoder_out = PackPaddedSequence( model_->Allocator(), &encoder_out, &encoder_out_length); + auto projected_encoder_out = model_->RunEncoderProj(Clone(model_->Allocator(), &(packed_encoder_out.data))); + int32_t batch_size = static_cast(packed_encoder_out.sorted_indexes.size()); @@ -38,11 +40,15 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + auto projected_decoder_out = model_->RunDecoderProj(std::move(decoder_out)); + int32_t start = 0; int32_t t = 0; for (auto n : packed_encoder_out.batch_sizes) { - Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); - Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n); + // Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); + Ort::Value cur_encoder_out = Slice(model_->Allocator(), &projected_encoder_out, start, start + n); + // Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n); + Ort::Value cur_decoder_out = Slice(model_->Allocator(), &projected_decoder_out, 0, n); start += n; Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), std::move(cur_decoder_out)); diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc index 05fcc9092..5bb14093d 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.cc +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -14,6 +14,10 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); + po->Register("ctc", &ctc, "Path to ctc.onnx"); + po->Register("frame_reducer", &frame_reducer, "Path to frame_reducer.onnx"); + po->Register("encoder_proj", &encoder_proj, "Path to encoder_proj.onnx"); + po->Register("decoder_proj", &decoder_proj, "Path to decoder_proj.onnx"); } bool OfflineTransducerModelConfig::Validate() const { @@ -35,6 +39,32 @@ bool OfflineTransducerModelConfig::Validate() const { return false; } + if (!ctc.empty()) + { + if (!FileExists(ctc)) { + SHERPA_ONNX_LOGE("ctc: %s does not exist", ctc.c_str()); + return false; + } + + if (!FileExists(frame_reducer)) { + SHERPA_ONNX_LOGE("frame_reducer: %s does not exist", frame_reducer.c_str()); + return false; + } + } + + if (!encoder_proj.empty()) + { + if (!FileExists(encoder_proj)) { + SHERPA_ONNX_LOGE("encoder_proj: %s does not exist", encoder_proj.c_str()); + return false; + } + + if (!FileExists(decoder_proj)) { + SHERPA_ONNX_LOGE("decoder_proj: %s does not exist", decoder_proj.c_str()); + return false; + } + } + return true; } @@ -44,7 +74,11 @@ std::string OfflineTransducerModelConfig::ToString() const { os << "OfflineTransducerModelConfig("; os << "encoder_filename=\"" << encoder_filename << "\", "; os << "decoder_filename=\"" << decoder_filename << "\", "; - os << "joiner_filename=\"" << joiner_filename << "\")"; + os << "joiner_filename=\"" << joiner_filename << "\", "; + os << "ctc=\"" << ctc << "\", "; + os << "frame_reducer=\"" << frame_reducer << "\", "; + os << "encoder_proj=\"" << encoder_proj << "\", "; + os << "decoder_proj=\"" << decoder_proj << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.h b/sherpa-onnx/csrc/offline-transducer-model-config.h index 1b51f104e..55920da07 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.h +++ b/sherpa-onnx/csrc/offline-transducer-model-config.h @@ -14,14 +14,26 @@ struct OfflineTransducerModelConfig { std::string encoder_filename; std::string decoder_filename; std::string joiner_filename; + std::string ctc; + std::string frame_reducer; + std::string encoder_proj; + std::string decoder_proj; OfflineTransducerModelConfig() = default; OfflineTransducerModelConfig(const std::string &encoder_filename, const std::string &decoder_filename, - const std::string &joiner_filename) + const std::string &joiner_filename, + const std::string &ctc="", + const std::string &frame_reducer="", + const std::string &encoder_proj="", + const std::string &decoder_proj="") : encoder_filename(encoder_filename), decoder_filename(decoder_filename), - joiner_filename(joiner_filename) {} + joiner_filename(joiner_filename), + ctc(ctc), + frame_reducer(frame_reducer), + encoder_proj(encoder_proj), + decoder_proj(decoder_proj) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index fe32388fa..ccc63aa5d 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -36,6 +36,32 @@ class OfflineTransducerModel::Impl { auto buf = ReadFile(config.transducer.joiner_filename); InitJoiner(buf.data(), buf.size()); } + + if (!config.transducer.ctc.empty()) + { + { + auto buf = ReadFile(config.transducer.ctc); + InitCTC(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.frame_reducer); + InitFrameReducer(buf.data(), buf.size()); + } + } + + if (!config.transducer.encoder_proj.empty()) + { + { + auto buf = ReadFile(config.transducer.encoder_proj); + InitEncoderProj(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder_proj); + InitDecoderProj(buf.data(), buf.size()); + } + } } #if __ANDROID_API__ >= 9 @@ -92,6 +118,39 @@ class OfflineTransducerModel::Impl { return std::move(logit[0]); } + Ort::Value RunCTC(Ort::Value encoder_out) { + auto ctc_out = ctc_sess_->Run( + {}, ctc_input_names_ptr_.data(), &encoder_out, 1, + ctc_output_names_ptr_.data(), ctc_output_names_ptr_.size()); + return std::move(ctc_out[0]); + } + + std::pair RunFrameReducer(Ort::Value encoder_out, + Ort::Value encoder_out_lens, + Ort::Value ctc_out) { + std::array frame_reducer_inputs = {std::move(encoder_out), + std::move(encoder_out_lens), + std::move(ctc_out)}; + auto frame_reducer_out = frame_reducer_sess_->Run( + {}, frame_reducer_input_names_ptr_.data(), frame_reducer_inputs.data(), frame_reducer_inputs.size(), + frame_reducer_output_names_ptr_.data(), frame_reducer_output_names_ptr_.size()); + return {std::move(frame_reducer_out[0]), std::move(frame_reducer_out[1])}; + } + + Ort::Value RunEncoderProj(Ort::Value encoder_proj_input) { + auto encoder_proj_out = encoder_proj_sess_->Run( + {}, encoder_proj_input_names_ptr_.data(), &encoder_proj_input, 1, + encoder_proj_output_names_ptr_.data(), encoder_proj_output_names_ptr_.size()); + return std::move(encoder_proj_out[0]); + } + + Ort::Value RunDecoderProj(Ort::Value decoder_proj_input) { + auto decoder_proj_out = decoder_proj_sess_->Run( + {}, decoder_proj_input_names_ptr_.data(), &decoder_proj_input, 1, + decoder_proj_output_names_ptr_.data(), decoder_proj_output_names_ptr_.size()); + return std::move(decoder_proj_out[0]); + } + int32_t VocabSize() const { return vocab_size_; } int32_t ContextSize() const { return context_size_; } int32_t SubsamplingFactor() const { return 4; } @@ -209,6 +268,150 @@ class OfflineTransducerModel::Impl { } } + void InitCTC(void *model_data, size_t model_data_length) + { + ctc_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(ctc_sess_.get(), &ctc_input_names_, + &ctc_input_names_ptr_); + + GetOutputNames(ctc_sess_.get(), &ctc_output_names_, + &ctc_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = ctc_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---ctc---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "ctc input names:\n"); + for (const auto& n : ctc_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "ctc output names:\n"); + for (const auto& n : ctc_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } + } + + void InitFrameReducer(void *model_data, size_t model_data_length) + { + frame_reducer_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(frame_reducer_sess_.get(), &frame_reducer_input_names_, + &frame_reducer_input_names_ptr_); + + GetOutputNames(frame_reducer_sess_.get(), &frame_reducer_output_names_, + &frame_reducer_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = frame_reducer_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---frame_reducer---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "frame reducer input names:\n"); + for (const auto& n : frame_reducer_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "frame reducer output names:\n"); + for (const auto& n : frame_reducer_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } + } + + void InitEncoderProj(void *model_data, size_t model_data_length) + { + encoder_proj_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(encoder_proj_sess_.get(), &encoder_proj_input_names_, + &encoder_proj_input_names_ptr_); + + GetOutputNames(encoder_proj_sess_.get(), &encoder_proj_output_names_, + &encoder_proj_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_proj_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---encoder_proj---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "encoder_proj input names:\n"); + for (const auto& n : encoder_proj_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "encoder_proj output names:\n"); + for (const auto& n : encoder_proj_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } + } + + void InitDecoderProj(void *model_data, size_t model_data_length) + { + decoder_proj_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(decoder_proj_sess_.get(), &decoder_proj_input_names_, + &decoder_proj_input_names_ptr_); + + GetOutputNames(decoder_proj_sess_.get(), &decoder_proj_output_names_, + &decoder_proj_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_proj_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---decoder_proj---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "decoder_proj input names:\n"); + for (const auto& n : decoder_proj_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "decoder_proj output names:\n"); + for (const auto& n : decoder_proj_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } + } + private: OfflineModelConfig config_; Ort::Env env_; @@ -218,6 +421,10 @@ class OfflineTransducerModel::Impl { std::unique_ptr encoder_sess_; std::unique_ptr decoder_sess_; std::unique_ptr joiner_sess_; + std::unique_ptr ctc_sess_; + std::unique_ptr frame_reducer_sess_; + std::unique_ptr encoder_proj_sess_; + std::unique_ptr decoder_proj_sess_; std::vector encoder_input_names_; std::vector encoder_input_names_ptr_; @@ -237,6 +444,30 @@ class OfflineTransducerModel::Impl { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; + std::vector ctc_input_names_; + std::vector ctc_input_names_ptr_; + + std::vector ctc_output_names_; + std::vector ctc_output_names_ptr_; + + std::vector frame_reducer_input_names_; + std::vector frame_reducer_input_names_ptr_; + + std::vector frame_reducer_output_names_; + std::vector frame_reducer_output_names_ptr_; + + std::vector encoder_proj_input_names_; + std::vector encoder_proj_input_names_ptr_; + + std::vector encoder_proj_output_names_; + std::vector encoder_proj_output_names_ptr_; + + std::vector decoder_proj_input_names_; + std::vector decoder_proj_input_names_ptr_; + + std::vector decoder_proj_output_names_; + std::vector decoder_proj_output_names_ptr_; + int32_t vocab_size_ = 0; // initialized in InitDecoder int32_t context_size_ = 0; // initialized in InitDecoder }; @@ -266,6 +497,23 @@ Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out, return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); } +Ort::Value OfflineTransducerModel::RunCTC(Ort::Value encoder_out) { + return impl_->RunCTC(std::move(encoder_out)); +} + +std::pair OfflineTransducerModel::RunFrameReducer( + Ort::Value encoder_out, Ort::Value encoder_out_lens, Ort::Value ctc_out) { + return impl_->RunFrameReducer(std::move(encoder_out), std::move(encoder_out_lens), std::move(ctc_out)); +} + +Ort::Value OfflineTransducerModel::RunEncoderProj(Ort::Value encoder_proj_input) { + return impl_->RunEncoderProj(std::move(encoder_proj_input)); +} + +Ort::Value OfflineTransducerModel::RunDecoderProj(Ort::Value decoder_proj_input) { + return impl_->RunDecoderProj(std::move(decoder_proj_input)); +} + int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); } int32_t OfflineTransducerModel::ContextSize() const { diff --git a/sherpa-onnx/csrc/offline-transducer-model.h b/sherpa-onnx/csrc/offline-transducer-model.h index 31a238cb7..e2edfcacc 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.h +++ b/sherpa-onnx/csrc/offline-transducer-model.h @@ -70,6 +70,16 @@ class OfflineTransducerModel { */ Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out); + Ort::Value RunCTC(Ort::Value encoder_out); + + std::pair RunFrameReducer(Ort::Value encoder_out, + Ort::Value encoder_out_lens, + Ort::Value ctc_out); + + Ort::Value RunEncoderProj(Ort::Value encoder_proj_input); + + Ort::Value RunDecoderProj(Ort::Value decoder_proj_input); + /** Return the vocabulary size of the model */ int32_t VocabSize() const; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 2afd66176..910fd145e 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -16,7 +16,7 @@ struct OnlineModelConfig { OnlineParaformerModelConfig paraformer; std::string tokens; int32_t num_threads = 1; - bool debug = false; + bool debug = true; std::string provider = "cpu"; // Valid values: diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index a9ba0a95e..d1dcb800b 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -198,6 +198,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); std::array x_shape{n, chunk_size, feature_dim}; + // std::array x_shape{chunk_size, n, feature_dim}; Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), features_vec.size(), x_shape.data(), @@ -212,6 +213,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { auto states = model_->StackStates(states_vec); + SHERPA_ONNX_LOGE("streams num: %d", n); + auto pair = model_->RunEncoder(std::move(x), std::move(states), std::move(processed_frames)); diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index e90426bdc..01ab420bd 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -79,8 +79,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( if (encoder_out_shape[0] != result->size()) { SHERPA_ONNX_LOGE( - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", + "Size mismatch! encoder_out.size(0) %d, encoder_out.size(1) %d, encoder_out.size(2) %d, result.size(0): %d", static_cast(encoder_out_shape[0]), + static_cast(encoder_out_shape[1]), + static_cast(encoder_out_shape[2]), static_cast(result->size())); exit(-1); } diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index f7015f98d..3986713ee 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -14,6 +14,8 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { po->Register("encoder", &encoder, "Path to encoder.onnx"); po->Register("decoder", &decoder, "Path to decoder.onnx"); po->Register("joiner", &joiner, "Path to joiner.onnx"); + po->Register("ctc", &ctc, "Path to ctc.onnx"); + po->Register("frame_reducer", &frame_reducer, "Path to frame_reducer.onnx"); } bool OnlineTransducerModelConfig::Validate() const { @@ -32,6 +34,19 @@ bool OnlineTransducerModelConfig::Validate() const { return false; } + if (!ctc.empty()) + { + if (!FileExists(ctc)) { + SHERPA_ONNX_LOGE("ctc: %s does not exist", ctc.c_str()); + return false; + } + + if (!FileExists(frame_reducer)) { + SHERPA_ONNX_LOGE("frame_reducer: %s does not exist", frame_reducer.c_str()); + return false; + } + } + return true; } @@ -41,7 +56,9 @@ std::string OnlineTransducerModelConfig::ToString() const { os << "OnlineTransducerModelConfig("; os << "encoder=\"" << encoder << "\", "; os << "decoder=\"" << decoder << "\", "; - os << "joiner=\"" << joiner << "\")"; + os << "joiner=\"" << joiner << "\", "; + os << "ctc=\"" << ctc << "\", "; + os << "frame_reducer=\"" << frame_reducer << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index 5d79e25bf..b795571c9 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -14,12 +14,16 @@ struct OnlineTransducerModelConfig { std::string encoder; std::string decoder; std::string joiner; + std::string ctc; + std::string frame_reducer; OnlineTransducerModelConfig() = default; OnlineTransducerModelConfig(const std::string &encoder, const std::string &decoder, - const std::string &joiner) - : encoder(encoder), decoder(decoder), joiner(joiner) {} + const std::string &joiner, + const std::string &ctc="", + const std::string &frame_reducer="") + : encoder(encoder), decoder(decoder), joiner(joiner), ctc(ctc), frame_reducer(frame_reducer) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 31234ae74..b287b97a2 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -90,9 +90,24 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---encoder---\n"; + os << "\n---encoder---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "encoder input names:\n"); + for (const auto& n : encoder_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "encoder output names:\n"); + for (const auto& n : encoder_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -138,9 +153,24 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data, Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---decoder---\n"; + os << "\n---decoder---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "decoder input names:\n"); + for (const auto& n : decoder_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "decoder output names:\n"); + for (const auto& n : decoder_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -163,9 +193,24 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data, Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---joiner---\n"; + os << "\n---joiner---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "joiner input names:\n"); + for (const auto& n : joiner_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "joiner output names:\n"); + for (const auto& n : joiner_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } } diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index e818b0bc9..c00124e33 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -51,6 +51,19 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( auto buf = ReadFile(config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } + + if (!config.transducer.ctc.empty()) + { + { + auto buf = ReadFile(config.transducer.ctc); + InitCTC(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.frame_reducer); + InitFrameReducer(buf.data(), buf.size()); + } + } } #if __ANDROID_API__ >= 9 @@ -92,9 +105,24 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---encoder---\n"; + os << "\n---encoder---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "encoder input names:\n"); + for (const auto& n : encoder_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "encoder output names:\n"); + for (const auto& n : encoder_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -144,9 +172,24 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---decoder---\n"; + os << "\n---decoder---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "decoder input names:\n"); + for (const auto& n : decoder_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "decoder output names:\n"); + for (const auto& n : decoder_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -169,9 +212,96 @@ void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); if (config_.debug) { std::ostringstream os; - os << "---joiner---\n"; + os << "\n---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "joiner input names:\n"); + for (const auto& n : joiner_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "joiner output names:\n"); + for (const auto& n : joiner_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } +} + +void OnlineZipformer2TransducerModel::InitCTC(void *model_data, size_t model_data_length) +{ + ctc_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(ctc_sess_.get(), &ctc_input_names_, + &ctc_input_names_ptr_); + + GetOutputNames(ctc_sess_.get(), &ctc_output_names_, + &ctc_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = ctc_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---ctc---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "ctc input names:\n"); + for (const auto& n : ctc_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "ctc output names:\n"); + for (const auto& n : ctc_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); + } +} + +void OnlineZipformer2TransducerModel::InitFrameReducer(void *model_data, size_t model_data_length) +{ + frame_reducer_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(frame_reducer_sess_.get(), &frame_reducer_input_names_, + &frame_reducer_input_names_ptr_); + + GetOutputNames(frame_reducer_sess_.get(), &frame_reducer_output_names_, + &frame_reducer_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = frame_reducer_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "\n---frame_reducer---\n"; PrintModelMetadata(os, meta_data); SHERPA_ONNX_LOGE("%s", os.str().c_str()); + fprintf(stderr, "\033[1;33m"); + fprintf(stderr, "frame reducer input names:\n"); + for (const auto& n : frame_reducer_input_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "\033[1;34m"); + fprintf(stderr, "frame reducer output names:\n"); + for (const auto& n : frame_reducer_output_names_) + { + fprintf(stderr, "-- %s\n", n.c_str()); + } + fprintf(stderr, "\033[0m"); + fprintf(stderr, "--------------------------------------\n"); } } @@ -463,4 +593,18 @@ Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } +// Ort::Value OnlineZipformer2TransducerModel::RunCTC(Ort::Value encoder_out) { +// auto ctc_out = ctc_sess_->Run( +// {}, ctc_input_names_ptr_.data(), &encoder_out, 1, +// ctc_output_names_ptr_.data(), ctc_output_names_ptr_.size()); +// return std::move(ctc_out[0]); +// } + +// Ort::Value OnlineZipformer2TransducerModel::RunFrameReducer(Ort::Value encoder_out) { +// auto ctc_out = ctc_sess_->Run( +// {}, ctc_input_names_ptr_.data(), &encoder_out, 1, +// ctc_output_names_ptr_.data(), ctc_output_names_ptr_.size()); +// return std::move(ctc_out[0]); +// } + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index 666ad1989..6e3ee3223 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -58,6 +58,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { void InitEncoder(void *model_data, size_t model_data_length); void InitDecoder(void *model_data, size_t model_data_length); void InitJoiner(void *model_data, size_t model_data_length); + void InitCTC(void *model_data, size_t model_data_length); + void InitFrameReducer(void *model_data, size_t model_data_length); private: Ort::Env env_; @@ -67,6 +69,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { std::unique_ptr encoder_sess_; std::unique_ptr decoder_sess_; std::unique_ptr joiner_sess_; + std::unique_ptr ctc_sess_; + std::unique_ptr frame_reducer_sess_; std::vector encoder_input_names_; std::vector encoder_input_names_ptr_; @@ -86,6 +90,18 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; + std::vector ctc_input_names_; + std::vector ctc_input_names_ptr_; + + std::vector ctc_output_names_; + std::vector ctc_output_names_ptr_; + + std::vector frame_reducer_input_names_; + std::vector frame_reducer_input_names_ptr_; + + std::vector frame_reducer_output_names_; + std::vector frame_reducer_output_names_ptr_; + OnlineModelConfig config_; std::vector encoder_dims_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc index bdb43a204..73cca681d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc @@ -66,6 +66,7 @@ for a list of pre-trained models to download. exit(EXIT_FAILURE); } + config.model_config.debug = true; fprintf(stderr, "%s\n", config.ToString().c_str()); if (!config.Validate()) {