diff --git a/protos/kaldi_serve.proto b/protos/kaldi_serve.proto index a6b24db..ea4721d 100644 --- a/protos/kaldi_serve.proto +++ b/protos/kaldi_serve.proto @@ -40,11 +40,11 @@ message RecognitionConfig { bool punctuation = 5; repeated SpeechContext speech_contexts = 6; int32 audio_channel_count = 7; - bool enable_word_time_offsets = 8; // RecognitionMetadata metadata = 9; string model = 10; bool raw = 11; int32 data_bytes = 12; + bool word_level = 13; } // Either `content` or `uri` must be supplied. @@ -64,6 +64,14 @@ message SpeechRecognitionAlternative { float confidence = 2; float am_score = 3; float lm_score = 4; + repeated Word words = 5; +} + +message Word { + float start_time = 1; + float end_time = 2; + string word = 3; + float confidence = 4; } message SpeechContext { diff --git a/python/kaldi_serve/kaldi_serve_pb2.py b/python/kaldi_serve/kaldi_serve_pb2.py index eea0f0b..35993c5 100644 --- a/python/kaldi_serve/kaldi_serve_pb2.py +++ b/python/kaldi_serve/kaldi_serve_pb2.py @@ -20,7 +20,7 @@ package='kaldi_serve', syntax='proto3', serialized_options=None, - serialized_pb=_b('\n\x11kaldi_serve.proto\x12\x0bkaldi_serve\"~\n\x10RecognizeRequest\x12.\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x1e.kaldi_serve.RecognitionConfig\x12,\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x1d.kaldi_serve.RecognitionAudio\x12\x0c\n\x04uuid\x18\x03 \x01(\t\"J\n\x11RecognizeResponse\x12\x35\n\x07results\x18\x01 \x03(\x0b\x32$.kaldi_serve.SpeechRecognitionResult\"\x9b\x03\n\x11RecognitionConfig\x12>\n\x08\x65ncoding\x18\x01 \x01(\x0e\x32,.kaldi_serve.RecognitionConfig.AudioEncoding\x12\x19\n\x11sample_rate_hertz\x18\x02 \x01(\x05\x12\x15\n\rlanguage_code\x18\x03 \x01(\t\x12\x18\n\x10max_alternatives\x18\x04 \x01(\x05\x12\x13\n\x0bpunctuation\x18\x05 \x01(\x08\x12\x33\n\x0fspeech_contexts\x18\x06 \x03(\x0b\x32\x1a.kaldi_serve.SpeechContext\x12\x1b\n\x13\x61udio_channel_count\x18\x07 \x01(\x05\x12 \n\x18\x65nable_word_time_offsets\x18\x08 \x01(\x08\x12\r\n\x05model\x18\n \x01(\t\x12\x0b\n\x03raw\x18\x0b \x01(\x08\x12\x12\n\ndata_bytes\x18\x0c \x01(\x05\"A\n\rAudioEncoding\x12\x18\n\x14\x45NCODING_UNSPECIFIED\x10\x00\x12\x0c\n\x08LINEAR16\x10\x01\x12\x08\n\x04\x46LAC\x10\x02\"D\n\x10RecognitionAudio\x12\x11\n\x07\x63ontent\x18\x01 \x01(\x0cH\x00\x12\r\n\x03uri\x18\x02 \x01(\tH\x00\x42\x0e\n\x0c\x61udio_source\"Z\n\x17SpeechRecognitionResult\x12?\n\x0c\x61lternatives\x18\x01 \x03(\x0b\x32).kaldi_serve.SpeechRecognitionAlternative\"\x8c\x01\n\x1cSpeechRecognitionAlternative\x12\x12\n\ntranscript\x18\x01 \x01(\t\x12\x12\n\nconfidence\x18\x02 \x01(\x02\x12\x10\n\x08\x61m_score\x18\x03 \x01(\x02\x12\x10\n\x08lm_score\x18\x04 \x01(\x02\x12 \n\x05words\x18\x05 \x03(\x0b\x32\x11.kaldi_serve.Word\"L\n\x04Word\x12\x11\n\tstartTime\x18\x01 \x01(\x02\x12\x0f\n\x07\x65ndTime\x18\x02 \x01(\x02\x12\x0c\n\x04word\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x02\".\n\rSpeechContext\x12\x0f\n\x07phrases\x18\x01 \x03(\t\x12\x0c\n\x04type\x18\x02 \x01(\t2\xb3\x01\n\nKaldiServe\x12L\n\tRecognize\x12\x1d.kaldi_serve.RecognizeRequest\x1a\x1e.kaldi_serve.RecognizeResponse\"\x00\x12W\n\x12StreamingRecognize\x12\x1d.kaldi_serve.RecognizeRequest\x1a\x1e.kaldi_serve.RecognizeResponse\"\x00(\x01\x62\x06proto3') + serialized_pb=_b('\n\x11kaldi_serve.proto\x12\x0bkaldi_serve\"~\n\x10RecognizeRequest\x12.\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x1e.kaldi_serve.RecognitionConfig\x12,\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x1d.kaldi_serve.RecognitionAudio\x12\x0c\n\x04uuid\x18\x03 \x01(\t\"J\n\x11RecognizeResponse\x12\x35\n\x07results\x18\x01 \x03(\x0b\x32$.kaldi_serve.SpeechRecognitionResult\"\x8d\x03\n\x11RecognitionConfig\x12>\n\x08\x65ncoding\x18\x01 \x01(\x0e\x32,.kaldi_serve.RecognitionConfig.AudioEncoding\x12\x19\n\x11sample_rate_hertz\x18\x02 \x01(\x05\x12\x15\n\rlanguage_code\x18\x03 \x01(\t\x12\x18\n\x10max_alternatives\x18\x04 \x01(\x05\x12\x13\n\x0bpunctuation\x18\x05 \x01(\x08\x12\x33\n\x0fspeech_contexts\x18\x06 \x03(\x0b\x32\x1a.kaldi_serve.SpeechContext\x12\x1b\n\x13\x61udio_channel_count\x18\x07 \x01(\x05\x12\r\n\x05model\x18\n \x01(\t\x12\x0b\n\x03raw\x18\x0b \x01(\x08\x12\x12\n\ndata_bytes\x18\x0c \x01(\x05\x12\x12\n\nword_level\x18\r \x01(\x08\"A\n\rAudioEncoding\x12\x18\n\x14\x45NCODING_UNSPECIFIED\x10\x00\x12\x0c\n\x08LINEAR16\x10\x01\x12\x08\n\x04\x46LAC\x10\x02\"D\n\x10RecognitionAudio\x12\x11\n\x07\x63ontent\x18\x01 \x01(\x0cH\x00\x12\r\n\x03uri\x18\x02 \x01(\tH\x00\x42\x0e\n\x0c\x61udio_source\"Z\n\x17SpeechRecognitionResult\x12?\n\x0c\x61lternatives\x18\x01 \x03(\x0b\x32).kaldi_serve.SpeechRecognitionAlternative\"\x8c\x01\n\x1cSpeechRecognitionAlternative\x12\x12\n\ntranscript\x18\x01 \x01(\t\x12\x12\n\nconfidence\x18\x02 \x01(\x02\x12\x10\n\x08\x61m_score\x18\x03 \x01(\x02\x12\x10\n\x08lm_score\x18\x04 \x01(\x02\x12 \n\x05words\x18\x05 \x03(\x0b\x32\x11.kaldi_serve.Word\"N\n\x04Word\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x12\x10\n\x08\x65nd_time\x18\x02 \x01(\x02\x12\x0c\n\x04word\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x02\".\n\rSpeechContext\x12\x0f\n\x07phrases\x18\x01 \x03(\t\x12\x0c\n\x04type\x18\x02 \x01(\t2\xb3\x01\n\nKaldiServe\x12L\n\tRecognize\x12\x1d.kaldi_serve.RecognizeRequest\x1a\x1e.kaldi_serve.RecognizeResponse\"\x00\x12W\n\x12StreamingRecognize\x12\x1d.kaldi_serve.RecognizeRequest\x1a\x1e.kaldi_serve.RecognizeResponse\"\x00(\x01\x62\x06proto3') ) @@ -46,8 +46,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=585, - serialized_end=650, + serialized_start=571, + serialized_end=636, ) _sym_db.RegisterEnumDescriptor(_RECOGNITIONCONFIG_AUDIOENCODING) @@ -185,33 +185,33 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='enable_word_time_offsets', full_name='kaldi_serve.RecognitionConfig.enable_word_time_offsets', index=7, - number=8, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='model', full_name='kaldi_serve.RecognitionConfig.model', index=8, + name='model', full_name='kaldi_serve.RecognitionConfig.model', index=7, number=10, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='raw', full_name='kaldi_serve.RecognitionConfig.raw', index=9, + name='raw', full_name='kaldi_serve.RecognitionConfig.raw', index=8, number=11, type=8, cpp_type=7, label=1, has_default_value=False, default_value=False, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='data_bytes', full_name='kaldi_serve.RecognitionConfig.data_bytes', index=10, + name='data_bytes', full_name='kaldi_serve.RecognitionConfig.data_bytes', index=9, number=12, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='word_level', full_name='kaldi_serve.RecognitionConfig.word_level', index=10, + number=13, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -226,7 +226,7 @@ oneofs=[ ], serialized_start=239, - serialized_end=650, + serialized_end=636, ) @@ -266,8 +266,8 @@ name='audio_source', full_name='kaldi_serve.RecognitionAudio.audio_source', index=0, containing_type=None, fields=[]), ], - serialized_start=652, - serialized_end=720, + serialized_start=638, + serialized_end=706, ) @@ -297,8 +297,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=722, - serialized_end=812, + serialized_start=708, + serialized_end=798, ) @@ -356,8 +356,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=815, - serialized_end=955, + serialized_start=801, + serialized_end=941, ) @@ -369,14 +369,14 @@ containing_type=None, fields=[ _descriptor.FieldDescriptor( - name='startTime', full_name='kaldi_serve.Word.startTime', index=0, + name='start_time', full_name='kaldi_serve.Word.start_time', index=0, number=1, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='endTime', full_name='kaldi_serve.Word.endTime', index=1, + name='end_time', full_name='kaldi_serve.Word.end_time', index=1, number=2, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, @@ -408,8 +408,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=957, - serialized_end=1033, + serialized_start=943, + serialized_end=1021, ) @@ -446,8 +446,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1035, - serialized_end=1081, + serialized_start=1023, + serialized_end=1069, ) _RECOGNIZEREQUEST.fields_by_name['config'].message_type = _RECOGNITIONCONFIG @@ -538,8 +538,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=1084, - serialized_end=1263, + serialized_start=1072, + serialized_end=1251, methods=[ _descriptor.MethodDescriptor( name='Recognize', diff --git a/python/pyproject.toml b/python/pyproject.toml index 1cc778a..6da9aa5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "kaldi_serve" -version = "0.1.0" +version = "0.2.0" description = "Python bindings for kaldi streaming ASR" authors = [] diff --git a/python/scripts/example_client.py b/python/scripts/example_client.py index 665241e..e25e7ff 100644 --- a/python/scripts/example_client.py +++ b/python/scripts/example_client.py @@ -2,8 +2,8 @@ Script for testing out ASR server. Usage: - example_client.py mic [--n-secs=] [--model=] [--lang=] [--raw] [--pcm] - example_client.py ... [--model=] [--lang=] [--raw] [--pcm] + example_client.py mic [--n-secs=] [--model=] [--lang=] [--raw] [--pcm] [--word-level] + example_client.py ... [--model=] [--lang=] [--raw] [--pcm] [--word-level] Options: --n-secs= Number of seconds to records, ideally there should be a VAD here. [default: 3] @@ -11,6 +11,7 @@ --lang= Language code of the model [default: hi] --raw Flag that specifies whether to stream raw audio bytes to server. --pcm Flag for sending raw pcm bytes + --word-level Whether to get word level features from server. """ import random @@ -19,8 +20,8 @@ from pprint import pprint from typing import List -from pydub import AudioSegment from docopt import docopt +from pydub import AudioSegment from kaldi_serve import KaldiServeClient, RecognitionAudio, RecognitionConfig from kaldi_serve.utils import (chunks_from_file, chunks_from_mic, @@ -39,14 +40,23 @@ def parse_response(response): "transcript": alt.transcript, "confidence": alt.confidence, "am_score": alt.am_score, - "lm_score": alt.lm_score + "lm_score": alt.lm_score, + "words": [ + { + "start_time": word.start_time, + "end_time": word.end_time, + "word": word.word, + "confidence": word.confidence + } + for word in alt.words + ] } for alt in res.alternatives ]) return output -def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: bool=False): +def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: bool=False, word_level: bool=False): """ Transcribe the given audio chunks """ @@ -64,6 +74,7 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: max_alternatives=10, model=model, raw=True, + word_level=word_level, data_bytes=chunk_len ) audio_params = [(config(len(chunk)), RecognitionAudio(content=chunk)) for chunk in audio_chunks] @@ -77,6 +88,7 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: language_code=language_code, max_alternatives=10, model=model, + word_level=word_level ) response = client.streaming_recognize(config, audio, uuid="") except Exception as e: @@ -86,14 +98,14 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: pprint(parse_response(response)) -def decode_files(client, audio_paths: List[str], model: str, language_code: str, raw: bool=False, pcm: bool=False): +def decode_files(client, audio_paths: List[str], model: str, language_code: str, raw: bool=False, pcm: bool=False, word_level: bool=False): """ Decode files using threaded requests """ chunked_audios = [chunks_from_file(x, chunk_size=random.randint(1, 3), raw=raw, pcm=pcm) for x in audio_paths] threads = [ - threading.Thread(target=transcribe_chunks, args=(client, chunks, model, language_code, raw)) + threading.Thread(target=transcribe_chunks, args=(client, chunks, model, language_code, raw, word_level)) for chunks in chunked_audios ] @@ -111,8 +123,9 @@ def decode_files(client, audio_paths: List[str], model: str, language_code: str, language_code = args["--lang"] raw = args['--raw'] pcm = args['--pcm'] + word_level = args["--word-level"] if args["mic"]: - transcribe_chunks(client, chunks_from_mic(int(args["--n-secs"]), SR, 1), model, language_code, raw) + transcribe_chunks(client, chunks_from_mic(int(args["--n-secs"]), SR, 1), model, language_code, raw, word_level) else: - decode_files(client, args[""], model, language_code, raw, pcm) + decode_files(client, args[""], model, language_code, raw, pcm, word_level) diff --git a/resources/model-spec.toml b/resources/model-spec.toml index 9bb2eb9..6799f29 100644 --- a/resources/model-spec.toml +++ b/resources/model-spec.toml @@ -39,6 +39,7 @@ frame_subsampling_factor = 3 # 3 # │   ├── final.ie # │   ├── final.mat # │   └── global_cmvn.stats +# ├── word_boundary.int (optional; needed only for word level confidence and timing information) # └── words.txt # The files above have the default kaldi chain model interpretation (with diff --git a/src/decoder.hpp b/src/decoder.hpp index 72fe118..80c3f89 100644 --- a/src/decoder.hpp +++ b/src/decoder.hpp @@ -21,6 +21,8 @@ #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" #include "lat/lattice-functions.h" +#include "lat/word-align-lattice.h" +#include "lat/sausages.h" #include "nnet3/nnet-utils.h" #include "online2/online-endpoint.h" #include "online2/online-nnet2-feature-pipeline.h" @@ -31,12 +33,23 @@ // local includes #include "utils.hpp" +struct Word { + float start_time, end_time, confidence; + std::string word; +}; + // An alternative defines a single hypothesis and certain details about the // parse (only scores for now). struct Alternative { std::string transcript; double confidence; float am_score, lm_score; + std::vector words; +}; + +// Options for decoder +struct DecoderOptions { + bool enable_word_level; }; // Result for one continuous utterance @@ -50,58 +63,6 @@ inline double calculate_confidence(const float &lm_score, const float &am_score, return std::max(0.0, std::min(1.0, -0.0001466488 * (2.388449 * lm_score + am_score) / (n_words + 1) + 0.956)); } -// Computes n-best alternative from lattice. Output symbols are converted to words -// based on word-syms. -void find_alternatives(const fst::SymbolTable *word_syms, - const kaldi::CompactLattice &clat, - const std::size_t &n_best, - utterance_results_t &results) noexcept { - if (clat.NumStates() == 0) { - KALDI_LOG << "Empty lattice."; - } - - kaldi::Lattice *lat = new kaldi::Lattice(); - fst::ConvertLattice(clat, lat); - - kaldi::Lattice nbest_lat; - std::vector nbest_lats; - fst::ShortestPath(*lat, &nbest_lat, n_best); - fst::ConvertNbestToVector(nbest_lat, &nbest_lats); - - if (nbest_lats.empty()) { - KALDI_WARN << "no N-best entries"; - return; - } - - // NOTE: Check why int32s specifically are used here - std::vector input_ids; - std::vector word_ids; - std::vector words; - std::string sentence; - - for (auto const &l : nbest_lats) { - kaldi::LatticeWeight weight; - fst::GetLinearSymbolSequence(l, &input_ids, &word_ids, &weight); - - for (auto const &wid : word_ids) { - words.push_back(word_syms->Find(wid)); - } - string_join(words, " ", sentence); - - Alternative alt; - alt.transcript = sentence; - alt.lm_score = float(weight.Value1()); - alt.am_score = float(weight.Value2()); - alt.confidence = calculate_confidence(alt.lm_score, alt.am_score, word_ids.size()); - results.push_back(alt); - - input_ids.clear(); - word_ids.clear(); - words.clear(); - sentence.clear(); - } -} - inline void print_wav_info(const kaldi::WaveInfo &wave_info) noexcept { std::cout << "sample freq: " << wave_info.SampFreq() << ENDL << "sample count: " << wave_info.SampleCount() << ENDL @@ -150,11 +111,13 @@ class Decoder final { private: std::unique_ptr word_syms_; + kaldi::WordBoundaryInfo* wb_info_; public: fst::Fst *const decode_fst_; mutable kaldi::nnet3::AmNnetSimple am_nnet_; // TODO: check why kaldi decodable_info needs a non-const ref of am_net model kaldi::TransitionModel trans_model_; + DecoderOptions options; std::unique_ptr feature_info_; @@ -166,6 +129,11 @@ class Decoder final { const kaldi::BaseFloat &, const std::size_t &, const std::string &, fst::Fst *const) noexcept; + void _find_alternatives(const kaldi::CompactLattice &clat, + const std::size_t &n_best, + utterance_results_t &results, + const bool &word_level) const noexcept; + // Decoding processes void _decode_wave(kaldi::OnlineNnet2FeaturePipeline &, kaldi::OnlineSilenceWeighting &, @@ -192,6 +160,7 @@ class Decoder final { void decode_wav_audio(std::istream &, const size_t &, utterance_results_t &, + const bool &, const kaldi::BaseFloat & = 1) const; // decodes an (independent) raw headerless wav audio stream @@ -200,13 +169,15 @@ class Decoder final { const size_t &, const size_t &, utterance_results_t &, + const bool &, const kaldi::BaseFloat & = 1) const; // get the final utterances based on the compact lattice void decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &, kaldi::SingleUtteranceNnet3Decoder &, const std::size_t &, - utterance_results_t &) const; + utterance_results_t &, + const bool &) const; }; Decoder::Decoder(const kaldi::BaseFloat &beam, @@ -228,6 +199,7 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, decodable_opts_.frame_subsampling_factor = frame_subsampling_factor; std::string word_syms_filepath = join_path(model_dir, "words.txt"); + std::string word_boundary_filepath = join_path(model_dir, "word_boundary.int"); std::string model_filepath = join_path(model_dir, "final.mdl"); std::string conf_dir = join_path(model_dir, "conf"); std::string mfcc_conf_filepath = join_path(conf_dir, "mfcc.conf"); @@ -249,6 +221,16 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, KALDI_ERR << "Could not read symbol table from file " << word_syms_filepath; } + if (exists(word_boundary_filepath)) { + kaldi::WordBoundaryInfoNewOpts word_boundary_opts; + wb_info_ = new kaldi::WordBoundaryInfo(word_boundary_opts, word_boundary_filepath); + options.enable_word_level = true; + } else { + KALDI_WARN << "Word boundary file" << word_boundary_filepath + << " not found. Disabling word level features."; + options.enable_word_level = false; + } + feature_info_ = std::make_unique(); feature_info_->feature_type = "mfcc"; kaldi::ReadConfigFromFile(mfcc_conf_filepath, &(feature_info_->mfcc_opts)); @@ -273,6 +255,125 @@ Decoder::Decoder(const kaldi::BaseFloat &beam, } } + +// Computes n-best alternative from lattice. Output symbols are converted to words +// based on word-syms. +void Decoder::_find_alternatives(const kaldi::CompactLattice &clat, + const std::size_t &n_best, + utterance_results_t &results, + const bool &word_level) const noexcept { + if (clat.NumStates() == 0) { + KALDI_LOG << "Empty lattice."; + } + + kaldi::Lattice *lat = new kaldi::Lattice(); + fst::ConvertLattice(clat, lat); + + kaldi::Lattice nbest_lat; + std::vector nbest_lats; + fst::ShortestPath(*lat, &nbest_lat, n_best); + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + + if (nbest_lats.empty()) { + KALDI_WARN << "no N-best entries"; + return; + } + + // NOTE: Check why int32s specifically are used here + std::vector input_ids; + std::vector word_ids; + std::vector word_strings; + std::string sentence; + + for (auto const &l : nbest_lats) { + kaldi::LatticeWeight weight; + fst::GetLinearSymbolSequence(l, &input_ids, &word_ids, &weight); + + for (auto const &wid : word_ids) { + word_strings.push_back(word_syms_->Find(wid)); + } + string_join(word_strings, " ", sentence); + + Alternative alt; + alt.transcript = sentence; + alt.lm_score = float(weight.Value1()); + alt.am_score = float(weight.Value2()); + alt.confidence = calculate_confidence(alt.lm_score, alt.am_score, word_ids.size()); + results.push_back(alt); + + input_ids.clear(); + word_ids.clear(); + word_strings.clear(); + sentence.clear(); + } + + if (!(options.enable_word_level && word_level)) + return; + + kaldi::CompactLattice aligned_clat; + kaldi::BaseFloat max_expand = 0.0; + int32 max_states; + + if (max_expand > 0) + max_states = 1000 + max_expand * clat.NumStates(); + else + max_states = 0; + + bool ok = kaldi::WordAlignLattice(clat, trans_model_, *wb_info_, max_states, &aligned_clat); + + if (!ok) { + if (aligned_clat.Start() != fst::kNoStateId) { + KALDI_WARN << "Outputting partial lattice"; + kaldi::TopSortCompactLatticeIfNeeded(&aligned_clat); + ok = true; + } else { + KALDI_WARN << "Empty aligned lattice, producing no output."; + } + } else { + if (aligned_clat.Start() == fst::kNoStateId) { + KALDI_WARN << "Lattice was empty"; + ok = false; + } else { + kaldi::TopSortCompactLatticeIfNeeded(&aligned_clat); + } + } + + std::vector words; + + // compute confidences and times only if alignment was ok + if (ok) { + kaldi::BaseFloat frame_shift = 0.01; + kaldi::BaseFloat lm_scale = 1.0; + kaldi::MinimumBayesRiskOptions mbr_opts; + mbr_opts.decode_mbr = false; + + fst::ScaleLattice(fst::LatticeScale(lm_scale, decodable_opts_.acoustic_scale), &aligned_clat); + kaldi::MinimumBayesRisk *mbr = new kaldi::MinimumBayesRisk(aligned_clat, mbr_opts); + + const std::vector &conf = mbr->GetOneBestConfidences(); + const std::vector &best_words = mbr->GetOneBest(); + const std::vector> × = mbr->GetOneBestTimes(); + + KALDI_ASSERT(conf.size() == best_words.size() && best_words.size() == times.size()); + + for (size_t i = 0; i < best_words.size(); i++) { + KALDI_ASSERT(best_words[i] != 0 || mbr_opts.print_silence); // Should not have epsilons. + + Word word; + word.start_time = frame_shift * times[i].first; + word.end_time = frame_shift * times[i].second; + word.word = word_syms_->Find(best_words[i]); // lookup word in SymbolTable + word.confidence = conf[i]; + + words.push_back(word); + } + } + + if (!results.empty() and !words.empty()) { + results[0].words = words; + } +} + void Decoder::_decode_wave(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, kaldi::OnlineSilenceWeighting &silence_weighting, kaldi::SingleUtteranceNnet3Decoder &decoder, @@ -329,6 +430,7 @@ void Decoder::decode_stream_raw_wav_chunk(kaldi::OnlineNnet2FeaturePipeline &fea void Decoder::decode_wav_audio(std::istream &wav_stream, const size_t &n_best, utterance_results_t &results, + const bool &word_level, const kaldi::BaseFloat &chunk_size) const { // decoder state variables need to be statically initialized kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(feature_info_->ivector_extractor_info); @@ -373,13 +475,14 @@ void Decoder::decode_wav_audio(std::istream &wav_stream, samp_offset += num_samp; } - decode_stream_final(feature_pipeline, decoder, n_best, results); + decode_stream_final(feature_pipeline, decoder, n_best, results, word_level); } void Decoder::decode_raw_wav_audio(std::istream &wav_stream, const size_t &data_bytes, const size_t &n_best, utterance_results_t &results, + const bool &word_level, const kaldi::BaseFloat &chunk_size) const { // decoder state variables need to be statically initialized kaldi::OnlineIvectorExtractorAdaptationState adaptation_state(feature_info_->ivector_extractor_info); @@ -424,13 +527,14 @@ void Decoder::decode_raw_wav_audio(std::istream &wav_stream, samp_offset += num_samp; } - decode_stream_final(feature_pipeline, decoder, n_best, results); + decode_stream_final(feature_pipeline, decoder, n_best, results, word_level); } void Decoder::decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &feature_pipeline, kaldi::SingleUtteranceNnet3Decoder &decoder, const std::size_t &n_best, - utterance_results_t &results) const { + utterance_results_t &results, + const bool &word_level) const { feature_pipeline.InputFinished(); decoder.FinalizeDecoding(); @@ -442,7 +546,7 @@ void Decoder::decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &feature_pip kaldi::CompactLattice clat; try { decoder.GetLattice(true, &clat); - find_alternatives(word_syms_.get(), clat, n_best, results); + _find_alternatives(clat, n_best, results, word_level); } catch (std::exception &e) { KALDI_ERR << "unexpected error during decoding lattice :: " << e.what(); } diff --git a/src/server.hpp b/src/server.hpp index 62b8a26..ccbbd48 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -104,9 +104,9 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, // TODO: take chunk length (secs) as parameter in request config try { if (config.raw()) { - decoder_->decode_raw_wav_audio(input_stream, config.data_bytes(), n_best, k_results_); + decoder_->decode_raw_wav_audio(input_stream, config.data_bytes(), n_best, k_results_, config.word_level()); } else { - decoder_->decode_wav_audio(input_stream, n_best, k_results_); + decoder_->decode_wav_audio(input_stream, n_best, k_results_, config.word_level()); } } catch (kaldi::KaldiFatalError &e) { decoder_queue_map_[model_id]->release(decoder_); @@ -119,6 +119,7 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, kaldi_serve::SpeechRecognitionResult *sr_result = response->add_results(); kaldi_serve::SpeechRecognitionAlternative *alternative; + kaldi_serve::Word *word; // find alternatives on final `lattice` after all chunks have been processed for (auto const &res : k_results_) { @@ -128,6 +129,15 @@ grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, alternative->set_confidence(res.confidence); alternative->set_am_score(res.am_score); alternative->set_lm_score(res.lm_score); + if (config.word_level()) { + for (auto const &w: res.words) { + word = alternative->add_words(); + word->set_start_time(w.start_time); + word->set_end_time(w.end_time); + word->set_word(w.word); + word->set_confidence(w.confidence); + } + } } } @@ -217,9 +227,10 @@ grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const conte kaldi_serve::SpeechRecognitionResult *sr_result = response->add_results(); kaldi_serve::SpeechRecognitionAlternative *alternative; + kaldi_serve::Word *word; utterance_results_t k_results_; - decoder_->decode_stream_final(feature_pipeline, decoder, n_best, k_results_); + decoder_->decode_stream_final(feature_pipeline, decoder, n_best, k_results_, config.word_level()); // find alternatives on final `lattice` after all chunks have been processed for (auto const &res : k_results_) { @@ -229,6 +240,15 @@ grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const conte alternative->set_confidence(res.confidence); alternative->set_am_score(res.am_score); alternative->set_lm_score(res.lm_score); + if (config.word_level()) { + for (auto const &w: res.words) { + word = alternative->add_words(); + word->set_start_time(w.start_time); + word->set_end_time(w.end_time); + word->set_word(w.word); + word->set_confidence(w.confidence); + } + } } } diff --git a/src/utils.hpp b/src/utils.hpp index a6b879d..6d02c4f 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -27,6 +27,11 @@ std::string join_path(std::string a, std::string b) { return (fs_a / fs_b).string(); } +bool exists(std::string path) { + boost::filesystem::path fs_path(path); + return boost::filesystem::exists(fs_path); +} + // Fills a list of model specifications from the config void parse_model_specs(const std::string &toml_path, std::vector &model_specs) { auto config = cpptoml::parse_file(toml_path);