diff --git a/sherpa-onnx/csrc/offline-ctc-decoder.h b/sherpa-onnx/csrc/offline-ctc-decoder.h index c9d1b36ff..575ea884c 100644 --- a/sherpa-onnx/csrc/offline-ctc-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-decoder.h @@ -15,17 +15,26 @@ struct OfflineCtcDecoderResult { /// The decoded token IDs std::vector tokens; + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + /// + /// tokens.size() == timestamps.size() + std::vector timestamps; + + /// stop_timestamps[i] - timestamps[i] is the duration of the i-th token + /// in terms of number of output frames + /// + /// tokens.size() == stop_timestamps.size() + std::vector stop_timestamps; + /// The decoded word IDs /// Note: tokens.size() is usually not equal to words.size() /// words is empty for greedy search decoding. /// it is not empty when an HLG graph or an HLG graph is used. std::vector words; - /// timestamps[i] contains the output frame index where tokens[i] is decoded. - /// Note: The index is after subsampling - /// - /// tokens.size() == timestamps.size() - std::vector timestamps; + /// word_start_timestamps.size() == words.size() + std::vector word_start_timestamps; }; class OfflineCtcDecoder { diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc index 6c9df3fd3..95af8b8d0 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -88,13 +88,19 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, auto cur_state = decoded.Start(); int32_t blank_id = 0; + int32_t prev = -1; + int32_t t = 0; - for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) { + for (; decoded.NumArcs(cur_state) == 1; ++t) { fst::ArcIterator> iter(decoded, cur_state); const auto &arc = iter.Value(); cur_state = arc.nextstate; + if (prev != -1 && prev != 0 && prev != blank_id + 1 && arc.ilabel != prev) { + r.stop_timestamps.push_back(t); + } + if (arc.ilabel == prev) { continue; } @@ -110,12 +116,21 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, r.tokens.push_back(arc.ilabel - 1); if (arc.olabel != 0) { r.words.push_back(arc.olabel); + r.word_start_timestamps.push_back(t); } r.timestamps.push_back(t); prev = arc.ilabel; } + if (r.timestamps.size() != r.stop_timestamps.size()) { + r.stop_timestamps.push_back(t); + } + + if (r.timestamps.size() != r.stop_timestamps.size()) { + SHERPA_ONNX_LOGE("something bad happened"); + } + return r; } diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc index 59d16f5d3..07f2ff457 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -29,9 +29,10 @@ std::vector OfflineCtcGreedySearchDecoder::Decode( log_probs.GetTensorData() + b * num_frames * vocab_size; OfflineCtcDecoderResult r; - int64_t prev_id = -1; + int32_t prev_id = -1; + int32_t t = 0; - for (int32_t t = 0; t != static_cast(p_log_probs_length[b]); ++t) { + for (; t != static_cast(p_log_probs_length[b]); ++t) { auto y = static_cast(std::distance( static_cast(p_log_probs), std::max_element( @@ -39,6 +40,10 @@ std::vector OfflineCtcGreedySearchDecoder::Decode( static_cast(p_log_probs) + vocab_size))); p_log_probs += vocab_size; + if (prev_id != -1 && prev_id != blank_id_ && y != prev_id) { + r.stop_timestamps.push_back(t); + } + if (y != blank_id_ && y != prev_id) { r.tokens.push_back(y); r.timestamps.push_back(t); @@ -46,6 +51,14 @@ std::vector OfflineCtcGreedySearchDecoder::Decode( prev_id = y; } // for (int32_t t = 0; ...) + if (r.timestamps.size() != r.stop_timestamps.size()) { + r.stop_timestamps.push_back(t); + } + + if (r.timestamps.size() != r.stop_timestamps.size()) { + SHERPA_ONNX_LOGE("something bad happened"); + } + ans.push_back(std::move(r)); } return ans; diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index c64da12af..93d9ae41d 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -34,6 +34,9 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.timestamps.size()); + r.stop_timestamps.reserve(src.stop_timestamps.size()); + + r.word_start_timestamps.reserve(src.word_start_timestamps.size()); std::string text; @@ -65,6 +68,16 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, r.timestamps.push_back(time); } + for (auto t : src.stop_timestamps) { + float time = frame_shift_s * t; + r.stop_timestamps.push_back(time); + } + + for (auto t : src.word_start_timestamps) { + float time = frame_shift_s * t; + r.word_start_timestamps.push_back(time); + } + r.words = std::move(src.words); return r; diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 6e72a4a1f..5cb217033 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -14,6 +14,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/resample.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -305,18 +306,8 @@ std::string OfflineRecognitionResult::AsJsonString() const { os << "\"" << text << "\"" << ", "; - os << "\"" - << "timestamps" - << "\"" - << ": "; - os << "["; - - std::string sep = ""; - for (auto t : timestamps) { - os << sep << std::fixed << std::setprecision(2) << t; - sep = ", "; - } - os << "], "; + os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; + os << "\"stop_timestamps\": " << VecToString(stop_timestamps, 2) << ", "; os << "\"" << "tokens" @@ -324,7 +315,7 @@ std::string OfflineRecognitionResult::AsJsonString() const { << ":"; os << "["; - sep = ""; + std::string sep = ""; auto oldFlags = os.flags(); for (const auto &t : tokens) { if (t.size() == 1 && static_cast(t[0]) > 0x7f) { @@ -341,19 +332,10 @@ std::string OfflineRecognitionResult::AsJsonString() const { } os << "], "; - sep = ""; + os << "\"words\": " << VecToString(words, 0) << ", "; - os << "\"" - << "words" - << "\"" - << ": "; - os << "["; - for (int32_t w : words) { - os << sep << w; - sep = ", "; - } + os << "\"word_start_timestamps\": " << VecToString(word_start_timestamps, 2); - os << "]"; os << "}"; return os.str(); diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 9df46d04e..81d44cb51 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -28,10 +28,29 @@ struct OfflineRecognitionResult { /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. + /// + /// Note: It is the start time stamp of a token. + /// + /// It is empty if the model does not support time stamp information. std::vector timestamps; + /// It is not empty for CTC models. + /// It is empty for non-CTC models. + /// If it is not empty, then stop_timestamps.size() == timestamps.size() + std::vector stop_timestamps; + + /// It is not empty for CTC models with a HL or HLG decoding graph + /// It is empty for non-CTC models. + /// + /// If not empty, it contains word IDs. You have to use words.txt + /// to map word IDs to word symbols. std::vector words; + /// If not empty, word_start_timestamps[i] is the start time of words[i]. + /// + /// words.size() == word_start_timestamps.size() + std::vector word_start_timestamps; + std::string AsJsonString() const; }; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index fcb9169ef..ff263f5d7 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -15,41 +15,10 @@ #include #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { -/// Helper for `OnlineRecognizerResult::AsJsonString()` -template -std::string VecToString(const std::vector &vec, int32_t precision = 6) { - std::ostringstream oss; - if (precision != 0) { - oss << std::fixed << std::setprecision(precision); - } - oss << "["; - std::string sep = ""; - for (const auto &item : vec) { - oss << sep << item; - sep = ", "; - } - oss << "]"; - return oss.str(); -} - -/// Helper for `OnlineRecognizerResult::AsJsonString()` -template <> // explicit specialization for T = std::string -std::string VecToString(const std::vector &vec, - int32_t) { // ignore 2nd arg - std::ostringstream oss; - oss << "["; - std::string sep = ""; - for (const auto &item : vec) { - oss << sep << "\"" << item << "\""; - sep = ", "; - } - oss << "]"; - return oss.str(); -} - std::string OnlineRecognizerResult::AsJsonString() const { std::ostringstream os; os << "{ "; diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc index 04586dd8c..42fbf1557 100644 --- a/sherpa-onnx/csrc/text-utils.cc +++ b/sherpa-onnx/csrc/text-utils.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -397,4 +398,38 @@ void ToLowerCase(std::string *in_out) { [](unsigned char c) { return std::tolower(c); }); } +template +std::string VecToString(const std::vector &vec, int32_t precision /*= 6*/) { + std::ostringstream os; + if (precision != 0) { + os << std::fixed << std::setprecision(precision); + } + os << "["; + std::string sep = ""; + for (const auto &item : vec) { + os << sep << item; + sep = ", "; + } + os << "]"; + return os.str(); +} + +template std::string VecToString(const std::vector &vec, + int32_t precision /*= 6*/); + +template std::string VecToString(const std::vector &vec, + int32_t precision /*= 6*/); + +std::string VecToString(const std::vector &vec) { + std::ostringstream os; + os << "["; + std::string sep = ""; + for (const auto &item : vec) { + os << sep << "\"" << item << "\""; + sep = ", "; + } + os << "]"; + return os.str(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h index a0b968d8a..2f1d0653a 100644 --- a/sherpa-onnx/csrc/text-utils.h +++ b/sherpa-onnx/csrc/text-utils.h @@ -124,6 +124,11 @@ std::vector SplitUtf8(const std::string &text); std::string ToLowerCase(const std::string &s); void ToLowerCase(std::string *in_out); +template +std::string VecToString(const std::vector &vec, int32_t precision = 6); + +std::string VecToString(const std::vector &vec); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ diff --git a/sherpa-onnx/python/csrc/offline-stream.cc b/sherpa-onnx/python/csrc/offline-stream.cc index 3c1cf3486..28fe5e3f8 100644 --- a/sherpa-onnx/python/csrc/offline-stream.cc +++ b/sherpa-onnx/python/csrc/offline-stream.cc @@ -37,7 +37,13 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT .def_property_readonly("words", [](const PyClass &self) { return self.words; }) .def_property_readonly( - "timestamps", [](const PyClass &self) { return self.timestamps; }); + "word_start_timestamps", + [](const PyClass &self) { return self.word_start_timestamps; }) + .def_property_readonly( + "timestamps", [](const PyClass &self) { return self.timestamps; }) + .def_property_readonly("stop_timestamps", [](const PyClass &self) { + return self.stop_timestamps; + }); } void PybindOfflineStream(py::module *m) {