diff --git a/README.md b/README.md index 62b5ef1..0fbaa0f 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Example of `alex_asr.conf` that should reside in `model_dir`: ``` --model_type=nnet2 # Supported model types are nnet2 (nnet2::AmNnet) and gmm (AmDiagGmm) +--feature_type=mfcc # Supported feature types are mfcc and fbank --model=final.mdl # Filename of the mdl file for the decoder. --hclg=HCLG.fst # Filename of the fst file with decoding HCLG fst. --words=words.txt # Filename with a list of words (each line contains: "" "). @@ -73,6 +74,7 @@ Example of `alex_asr.conf` that should reside in `model_dir`: --cfg_decoder=decoder.cfg --cfg_decodable=decodable.cfg --cfg_mfcc=mfcc.cfg +--cfg_fbank=fbank.cfg --cfg_cmvn=cmvn.cfg --cfg_splice=splice.cfg --cfg_endpoint=endpoint.cfg @@ -112,6 +114,16 @@ Example ``mfcc.cfg``: Details: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-mfcc.h#L63 +## FBANK configuration + +Example ``fbank.cfg``: +``` +--low-freq=128 +--high-freq=3800 +``` + +Details: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.h#L62 + ## Online CMVN configuration Online CMVN configuration is needed when you set ``--use_cmvn=true``. diff --git a/src/decoder.cc b/src/decoder.cc index e9d10b2..20cabd3 100644 --- a/src/decoder.cc +++ b/src/decoder.cc @@ -124,12 +124,12 @@ namespace alex_asr { bool Decoder::EndpointDetected() { return kaldi::EndpointDetected(config_->endpoint_config, *trans_model_, - config_->mfcc_opts.frame_opts.frame_shift_ms * 1.0e-03f, + config_->FrameShiftInSeconds(), *decoder_); } void Decoder::FrameIn(VectorBase *waveform_in) { - feature_pipeline_->AcceptWaveform(config_->mfcc_opts.frame_opts.samp_freq, *waveform_in); + feature_pipeline_->AcceptWaveform(config_->SamplingFrequency(), *waveform_in); } void Decoder::FrameIn(unsigned char *buffer, int32 buffer_length) { diff --git a/src/decoder_config.cc b/src/decoder_config.cc index fe4c94f..5be0ace 100644 --- a/src/decoder_config.cc +++ b/src/decoder_config.cc @@ -14,6 +14,7 @@ namespace alex_asr { cfg_decoder(""), cfg_decodable(""), cfg_mfcc(""), + cfg_fbank(""), cfg_cmvn(""), cfg_splice(""), cfg_endpoint(""), @@ -33,6 +34,7 @@ namespace alex_asr { void DecoderConfig::Register(ParseOptions *po) { po->Register("model_type", &model_type_str, "Type of model. GMM/NNET2"); + po->Register("feature_type", &feature_type_str, "Type of features. MFCC/FBANK"); po->Register("model", &model_rxfilename, "Accoustic model filename."); po->Register("hclg", &fst_rxfilename, "HCLG FST filename."); po->Register("words", &words_rxfilename, "Word to ID mapping filename."); @@ -47,6 +49,7 @@ namespace alex_asr { po->Register("cfg_decoder", &cfg_decoder, ""); po->Register("cfg_decodable", &cfg_decodable, ""); po->Register("cfg_mfcc", &cfg_mfcc, ""); + po->Register("cfg_fbank", &cfg_fbank, ""); po->Register("cfg_cmvn", &cfg_cmvn, ""); po->Register("cfg_splice", &cfg_splice, ""); po->Register("cfg_endpoint", &cfg_endpoint, ""); @@ -66,6 +69,7 @@ namespace alex_asr { LoadConfig(cfg_decoder, &decoder_opts); LoadConfig(cfg_decodable, &decodable_opts); LoadConfig(cfg_mfcc, &mfcc_opts); + LoadConfig(cfg_fbank, &fbank_opts); LoadConfig(cfg_cmvn, &cmvn_opts); LoadConfig(cfg_splice, &splice_opts); LoadConfig(cfg_endpoint, &endpoint_config); @@ -160,9 +164,17 @@ namespace alex_asr { res = false; KALDI_ERR << "You have to specify a valid model_type."; - } + if(feature_type_str == "mfcc" || feature_type_str == "") { + feature_type = MFCC; + } else if(feature_type_str == "fbank") { + feature_type = FBANK; + } else { + res = false; + + KALDI_ERR << "You have to specify a valid feature_type."; + } res &= OptionCheck(use_ivectors && cfg_ivector == "", "You have to specify --cfg_ivector if you want to use ivectors."); @@ -193,4 +205,26 @@ namespace alex_asr { } return true; } -} \ No newline at end of file + + BaseFloat DecoderConfig::FrameShiftInSeconds() const { + if(feature_type == DecoderConfig::MFCC) { + return mfcc_opts.frame_opts.frame_shift_ms * 1.0e-03; + } else if(feature_type == DecoderConfig::FBANK) { + return fbank_opts.frame_opts.frame_shift_ms * 1.0e-03; + } else { + KALDI_ERR << "You have to specify a valid feature_type."; + return 0.0; + } + } + + BaseFloat DecoderConfig::SamplingFrequency() const { + if(feature_type == DecoderConfig::MFCC) { + return mfcc_opts.frame_opts.samp_freq; + } else if(feature_type == DecoderConfig::FBANK) { + return fbank_opts.frame_opts.samp_freq; + } else { + KALDI_ERR << "You have to specify a valid feature_type."; + return 0.0; + } + } +} diff --git a/src/decoder_config.h b/src/decoder_config.h index fab647f..6698cfe 100644 --- a/src/decoder_config.h +++ b/src/decoder_config.h @@ -19,17 +19,21 @@ using namespace kaldi; namespace alex_asr { class DecoderConfig { public: - enum ModelType { None, GMM, NNET2 }; + enum ModelType { NoneModelType, GMM, NNET2 }; + enum FeatureType { NoneFeatureType, MFCC, FBANK }; DecoderConfig(); ~DecoderConfig(); void Register(ParseOptions *po); void LoadConfigs(const string cfg_file); bool InitAndCheck(); + BaseFloat FrameShiftInSeconds() const; + BaseFloat SamplingFrequency() const; LatticeFasterDecoderConfig decoder_opts; nnet2::DecodableNnet2OnlineOptions decodable_opts; MfccOptions mfcc_opts; + FbankOptions fbank_opts; OnlineCmvnOptions cmvn_opts; OnlineSpliceOptions splice_opts; OnlineEndpointConfig endpoint_config; @@ -42,6 +46,7 @@ namespace alex_asr { OnlineIvectorExtractionInfo *ivector_extraction_info; ModelType model_type; + FeatureType feature_type; int32 bits_per_sample; bool use_lda; @@ -53,6 +58,7 @@ namespace alex_asr { std::string cfg_decoder; std::string cfg_decodable; std::string cfg_mfcc; + std::string cfg_fbank; std::string cfg_cmvn; std::string cfg_splice; std::string cfg_endpoint; @@ -74,6 +80,7 @@ namespace alex_asr { bool OptionCheck(bool cond, std::string fail_text); string model_type_str; + string feature_type_str; }; } diff --git a/src/feature_pipeline.cc b/src/feature_pipeline.cc index a8bf588..14ce777 100644 --- a/src/feature_pipeline.cc +++ b/src/feature_pipeline.cc @@ -1,10 +1,12 @@ +#include "src/decoder_config.h" + #include "feature_pipeline.h" using namespace kaldi; namespace alex_asr { FeaturePipeline::FeaturePipeline(DecoderConfig &config) : - mfcc_(NULL), + base_feature_(NULL), cmvn_(NULL), cmvn_state_(NULL), splice_(NULL), @@ -19,11 +21,21 @@ namespace alex_asr { { OnlineFeatureInterface *prev_feature; - KALDI_VLOG(3) << "Feature MFCC " - << config.mfcc_opts.mel_opts.low_freq - << " " << config.mfcc_opts.mel_opts.high_freq; - prev_feature = mfcc_ = new OnlineMfcc(config.mfcc_opts); - KALDI_VLOG(3) << " -> dims: " << mfcc_->Dim(); + if(config.feature_type == DecoderConfig::MFCC) { + KALDI_VLOG(3) << "Feature MFCC " + << config.mfcc_opts.mel_opts.low_freq + << " " << config.mfcc_opts.mel_opts.high_freq; + prev_feature = base_feature_ = new OnlineMfcc(config.mfcc_opts); + KALDI_VLOG(3) << " -> dims: " << base_feature_->Dim(); + } else if(config.feature_type == DecoderConfig::FBANK) { + KALDI_VLOG(3) << "Feature FBANK " + << config.fbank_opts.mel_opts.low_freq + << " " << config.fbank_opts.mel_opts.high_freq; + prev_feature = base_feature_ = new OnlineFbank(config.fbank_opts); + KALDI_VLOG(3) << " -> dims: " << base_feature_->Dim(); + } else { + KALDI_ERR << "You have to specify a valid feature_type."; + } if (config.use_cmvn) { KALDI_VLOG(3) << "Feature CMVN"; @@ -50,7 +62,7 @@ namespace alex_asr { if (config.use_ivectors) { KALDI_VLOG(3) << "Feature IVectors"; - ivector_ = new OnlineIvectorFeature(*config.ivector_extraction_info, mfcc_); + ivector_ = new OnlineIvectorFeature(*config.ivector_extraction_info, base_feature_); prev_feature = ivector_append_ = new OnlineAppendFeature(prev_feature, ivector_); KALDI_VLOG(3) << " -> dims: " << prev_feature->Dim(); } @@ -59,7 +71,7 @@ namespace alex_asr { } FeaturePipeline::~FeaturePipeline() { - delete mfcc_; + delete base_feature_; delete cmvn_; delete cmvn_state_; delete splice_; @@ -77,14 +89,14 @@ namespace alex_asr { void FeaturePipeline::AcceptWaveform(BaseFloat sampling_rate, const VectorBase &waveform) { - mfcc_->AcceptWaveform(sampling_rate, waveform); + base_feature_->AcceptWaveform(sampling_rate, waveform); if(pitch_) { pitch_->AcceptWaveform(sampling_rate, waveform); } } void FeaturePipeline::InputFinished() { - mfcc_->InputFinished(); + base_feature_->InputFinished(); if(pitch_) { pitch_->InputFinished(); } diff --git a/src/feature_pipeline.h b/src/feature_pipeline.h index 81b527c..4174322 100644 --- a/src/feature_pipeline.h +++ b/src/feature_pipeline.h @@ -16,7 +16,7 @@ namespace alex_asr { void InputFinished(); OnlineIvectorFeature* GetIvectorFeature(); private: - OnlineMfcc *mfcc_; + OnlineBaseFeature *base_feature_; OnlineCmvn *cmvn_; OnlineCmvnState *cmvn_state_; OnlineSpliceFrames *splice_; diff --git a/test/asr_model_digits/alex_asr.conf b/test/asr_model_digits/alex_asr.conf index 9514107..ca09aa4 100644 --- a/test/asr_model_digits/alex_asr.conf +++ b/test/asr_model_digits/alex_asr.conf @@ -1,4 +1,5 @@ --model_type=gmm +--feature_type=mfcc --model=final.mdl --hclg=HCLG.fst --words=words.txt