From f5f02d1316e814db3c5c1e15f5b48b14918dde4f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Mar 2019 21:12:20 -0500 Subject: [PATCH 001/163] [src] Lots of changes: first stab at kaldi10 (non-compatible version of Kaldi) --- src/Makefile | 34 +- src/bin/Makefile | 2 +- src/bin/acc-lda.cc | 4 +- src/bin/acc-tree-stats.cc | 4 +- src/bin/add-self-loops.cc | 4 +- src/bin/ali-to-pdf.cc | 4 +- src/bin/ali-to-phones.cc | 4 +- src/bin/ali-to-post.cc | 2 +- src/bin/align-compiled-mapped.cc | 4 +- src/bin/align-equal-compiled.cc | 2 +- src/bin/align-equal.cc | 4 +- src/bin/align-mapped.cc | 4 +- src/bin/am-info.cc | 4 +- src/bin/build-pfile-from-ali.cc | 4 +- src/bin/build-tree-two-level.cc | 4 +- src/bin/build-tree.cc | 4 +- src/bin/compile-graph.cc | 4 +- src/bin/compile-questions.cc | 6 +- src/bin/compile-train-graphs-fsts.cc | 4 +- src/bin/compile-train-graphs.cc | 4 +- src/bin/convert-ali.cc | 8 +- src/bin/copy-gselect.cc | 2 +- src/bin/copy-transition-model.cc | 4 +- src/bin/copy-tree.cc | 2 +- src/{nnetbin => bin}/cuda-gpu-available.cc | 0 src/bin/decode-faster-mapped.cc | 4 +- src/bin/decode-faster.cc | 2 +- src/bin/est-mllt.cc | 2 +- src/bin/get-post-on-ali.cc | 2 +- src/bin/hmm-info.cc | 4 +- src/bin/latgen-faster-mapped-parallel.cc | 4 +- src/bin/latgen-faster-mapped.cc | 4 +- src/bin/logprob-to-post.cc | 2 +- src/bin/make-h-transducer.cc | 4 +- src/bin/make-ilabel-transducer.cc | 4 +- src/bin/make-pdf-to-tid-transducer.cc | 4 +- src/bin/phones-to-prons.cc | 2 +- src/bin/post-to-pdf-post.cc | 4 +- src/bin/post-to-phone-post.cc | 4 +- src/bin/post-to-tacc.cc | 4 +- src/bin/prob-to-post.cc | 2 +- src/bin/prons-to-wordali.cc | 2 +- src/bin/show-alignments.cc | 4 +- src/bin/show-transitions.cc | 4 +- src/bin/tree-info.cc | 2 +- src/bin/weight-silence-post.cc | 4 +- src/chain/chain-den-graph.cc | 4 +- src/chain/chain-den-graph.h | 6 +- src/chain/chain-denominator.h | 2 +- src/chain/chain-generic-numerator.h | 2 +- src/chain/chain-numerator.h | 2 +- src/chain/chain-supervision-test.cc | 10 +- src/chain/chain-supervision.cc | 8 +- src/chain/chain-supervision.h | 18 +- src/chain/chain-training.h | 2 +- src/chainbin/chain-get-supervision.cc | 4 +- src/chainbin/chain-make-den-fst.cc | 2 +- src/chainbin/nnet3-chain-acc-lda-stats.cc | 2 +- src/chainbin/nnet3-chain-copy-egs.cc | 2 +- src/chainbin/nnet3-chain-e2e-get-egs.cc | 6 +- src/chainbin/nnet3-chain-get-egs.cc | 8 +- src/chainbin/nnet3-chain-merge-egs.cc | 2 +- src/chainbin/nnet3-chain-normalize-egs.cc | 2 +- src/chainbin/nnet3-chain-shuffle-egs.cc | 2 +- src/decoder/decodable-matrix.cc | 8 +- src/decoder/decodable-matrix.h | 22 +- src/decoder/decoder-wrappers.cc | 10 +- src/decoder/decoder-wrappers.h | 8 +- src/decoder/training-graph-compiler.cc | 2 +- src/decoder/training-graph-compiler.h | 6 +- src/feat/Makefile | 8 +- src/feat/feature-common-inl.h | 7 +- src/feat/feature-common.h | 8 +- src/feat/feature-fbank.cc | 36 +- src/feat/feature-fbank.h | 48 +- src/feat/feature-mfcc-test.cc | 40 +- src/feat/feature-mfcc.cc | 43 +- src/feat/feature-mfcc.h | 45 +- src/feat/feature-plp-test.cc | 177 - src/feat/feature-plp.cc | 191 - src/feat/feature-plp.h | 176 - src/feat/feature-window.cc | 147 +- src/feat/feature-window.h | 61 +- src/feat/mel-computations.cc | 95 +- src/feat/mel-computations.h | 43 +- src/feat/online-feature-test.cc | 81 +- src/feat/online-feature.cc | 18 +- src/feat/online-feature.h | 12 +- src/feat/pitch-functions-test.cc | 1 - src/feat/wave-reader.cc | 14 +- src/feat/wave-reader.h | 8 +- src/featbin/Makefile | 4 +- src/featbin/compute-plp-feats.cc | 184 - src/featbin/compute-spectrogram-feats.cc | 158 - src/fgmmbin/fgmm-global-info.cc | 2 +- src/fgmmbin/fgmm-gselect.cc | 2 +- src/gmm/Makefile | 6 +- src/gmm/decodable-am-diag-gmm.h | 36 +- src/gmmbin/Makefile | 17 +- src/gmmbin/gmm-acc-mllt-global.cc | 2 +- src/gmmbin/gmm-acc-mllt.cc | 4 +- src/gmmbin/gmm-acc-stats-ali.cc | 4 +- src/gmmbin/gmm-acc-stats-twofeats.cc | 4 +- src/gmmbin/gmm-acc-stats.cc | 4 +- src/gmmbin/gmm-acc-stats2.cc | 4 +- src/gmmbin/gmm-adapt-map.cc | 4 +- src/gmmbin/gmm-align-compiled.cc | 4 +- src/gmmbin/gmm-align.cc | 4 +- src/gmmbin/gmm-basis-fmllr-accs-gpost.cc | 6 +- src/gmmbin/gmm-basis-fmllr-accs.cc | 6 +- src/gmmbin/gmm-basis-fmllr-training.cc | 4 +- src/gmmbin/gmm-boost-silence.cc | 4 +- src/gmmbin/gmm-compute-likes.cc | 4 +- src/gmmbin/gmm-copy.cc | 4 +- src/gmmbin/gmm-decode-biglm-faster.cc | 4 +- src/gmmbin/gmm-decode-faster-regtree-fmllr.cc | 290 -- src/gmmbin/gmm-decode-faster-regtree-mllr.cc | 267 - src/gmmbin/gmm-decode-faster.cc | 4 +- src/gmmbin/gmm-decode-simple.cc | 4 +- src/gmmbin/gmm-est-basis-fmllr-gpost.cc | 6 +- src/gmmbin/gmm-est-basis-fmllr.cc | 6 +- src/gmmbin/gmm-est-fmllr-global.cc | 2 +- src/gmmbin/gmm-est-fmllr-gpost.cc | 6 +- src/gmmbin/gmm-est-fmllr-raw-gpost.cc | 198 - src/gmmbin/gmm-est-fmllr-raw.cc | 199 - src/gmmbin/gmm-est-fmllr.cc | 6 +- src/gmmbin/gmm-est-gaussians-ebw.cc | 4 +- src/gmmbin/gmm-est-lvtln-trans.cc | 4 +- src/gmmbin/gmm-est-map.cc | 4 +- src/gmmbin/gmm-est-regtree-fmllr-ali.cc | 202 - src/gmmbin/gmm-est-regtree-fmllr.cc | 216 - src/gmmbin/gmm-est-regtree-mllr.cc | 215 - src/gmmbin/gmm-est-rescale.cc | 4 +- src/gmmbin/gmm-est-weights-ebw.cc | 4 +- src/gmmbin/gmm-est.cc | 4 +- src/gmmbin/gmm-fmpe-acc-stats.cc | 4 +- src/gmmbin/gmm-get-stats-deriv.cc | 4 +- src/gmmbin/gmm-global-est-fmllr.cc | 2 +- src/gmmbin/gmm-global-est-lvtln-trans.cc | 2 +- src/gmmbin/gmm-global-info.cc | 2 +- src/gmmbin/gmm-gselect.cc | 2 +- src/gmmbin/gmm-info.cc | 4 +- src/gmmbin/gmm-init-biphone.cc | 8 +- src/gmmbin/gmm-init-model-flat.cc | 6 +- src/gmmbin/gmm-init-model.cc | 10 +- src/gmmbin/gmm-init-mono.cc | 8 +- src/gmmbin/gmm-ismooth-stats.cc | 4 +- src/gmmbin/gmm-latgen-biglm-faster.cc | 6 +- src/gmmbin/gmm-latgen-faster-parallel.cc | 4 +- src/gmmbin/gmm-latgen-faster-regtree-fmllr.cc | 218 - src/gmmbin/gmm-latgen-faster.cc | 4 +- src/gmmbin/gmm-latgen-map.cc | 4 +- src/gmmbin/gmm-latgen-simple.cc | 4 +- src/gmmbin/gmm-make-regtree.cc | 107 - src/gmmbin/gmm-mixup.cc | 4 +- src/gmmbin/gmm-post-to-gpost.cc | 4 +- src/gmmbin/gmm-rescore-lattice.cc | 4 +- src/gmmbin/gmm-sum-accs.cc | 2 +- src/gmmbin/gmm-transform-means-global.cc | 2 +- src/gmmbin/gmm-transform-means.cc | 4 +- .../gst-online-gmm-decode-faster.cc | 2 +- src/gst-plugin/gst-online-gmm-decode-faster.h | 2 +- src/hmm/Makefile | 7 +- src/hmm/hmm-test-utils.cc | 26 +- src/hmm/hmm-test-utils.h | 24 +- src/hmm/hmm-topology-test.cc | 12 +- src/hmm/hmm-topology.h | 194 - src/hmm/hmm-utils-test.cc | 6 +- src/hmm/hmm-utils.cc | 60 +- src/hmm/hmm-utils.h | 34 +- src/hmm/posterior.cc | 20 +- src/hmm/posterior.h | 16 +- src/hmm/{hmm-topology.cc => topology.cc} | 50 +- src/hmm/topology.h | 138 + src/hmm/transition-model.h | 371 -- ...tion-model-test.cc => transitions-test.cc} | 10 +- .../{transition-model.cc => transitions.cc} | 139 +- src/hmm/transitions.h | 263 + src/hmm/tree-accu.cc | 2 +- src/hmm/tree-accu.h | 4 +- src/itf/context-dep-itf.h | 4 +- src/lat/determinize-lattice-pruned.cc | 56 +- src/lat/determinize-lattice-pruned.h | 10 +- src/lat/lattice-functions.cc | 68 +- src/lat/lattice-functions.h | 24 +- src/lat/minimize-lattice.cc | 2 +- src/lat/phone-align-lattice.cc | 122 +- src/lat/phone-align-lattice.h | 4 +- src/lat/push-lattice.cc | 2 +- src/lat/word-align-lattice-lexicon-test.cc | 2 +- src/lat/word-align-lattice-lexicon.cc | 29 +- src/lat/word-align-lattice-lexicon.h | 6 +- src/lat/word-align-lattice.cc | 563 +-- src/lat/word-align-lattice.h | 33 +- src/latbin/lattice-add-trans-probs.cc | 4 +- src/latbin/lattice-align-phones.cc | 2 +- src/latbin/lattice-align-words-lexicon.cc | 2 +- src/latbin/lattice-align-words.cc | 2 +- src/latbin/lattice-arc-post.cc | 6 +- src/latbin/lattice-boost-ali.cc | 2 +- ...ttice-determinize-phone-pruned-parallel.cc | 8 +- .../lattice-determinize-phone-pruned.cc | 4 +- src/latbin/lattice-rescore-mapped.cc | 6 +- src/latbin/lattice-to-mpe-post.cc | 4 +- src/latbin/lattice-to-phone-lattice.cc | 4 +- src/latbin/lattice-to-smbr-post.cc | 4 +- src/latbin/nbest-to-prons.cc | 2 +- src/nnet/Makefile | 22 - src/nnet/nnet-activation.h | 373 -- src/nnet/nnet-affine-transform.h | 247 - src/nnet/nnet-average-pooling-2d-component.h | 209 - src/nnet/nnet-average-pooling-component.h | 169 - src/nnet/nnet-blstm-projected.h | 1206 ----- src/nnet/nnet-component-test.cc | 451 -- src/nnet/nnet-component.cc | 288 -- src/nnet/nnet-component.h | 358 -- src/nnet/nnet-convolutional-2d-component.h | 495 -- src/nnet/nnet-convolutional-component.h | 482 -- src/nnet/nnet-frame-pooling-component.h | 290 -- src/nnet/nnet-kl-hmm.h | 155 - src/nnet/nnet-linear-transform.h | 212 - src/nnet/nnet-loss.cc | 460 -- src/nnet/nnet-loss.h | 251 - src/nnet/nnet-lstm-projected.h | 737 --- src/nnet/nnet-matrix-buffer.h | 233 - src/nnet/nnet-max-pooling-2d-component.h | 225 - src/nnet/nnet-max-pooling-component.h | 176 - src/nnet/nnet-multibasis-component.h | 456 -- src/nnet/nnet-nnet.cc | 520 -- src/nnet/nnet-nnet.h | 186 - src/nnet/nnet-parallel-component.h | 361 -- src/nnet/nnet-parametric-relu.h | 213 - src/nnet/nnet-pdf-prior.cc | 90 - src/nnet/nnet-pdf-prior.h | 77 - src/nnet/nnet-randomizer-test.cc | 240 - src/nnet/nnet-randomizer.cc | 234 - src/nnet/nnet-randomizer.h | 274 - src/nnet/nnet-rbm.h | 433 -- src/nnet/nnet-recurrent.h | 346 -- src/nnet/nnet-sentence-averaging-component.h | 314 -- src/nnet/nnet-trnopts.h | 118 - src/nnet/nnet-utils.h | 317 -- src/nnet/nnet-various.h | 518 -- src/nnet2/Makefile | 33 - src/nnet2/am-nnet-test.cc | 88 - src/nnet2/am-nnet.cc | 83 - src/nnet2/am-nnet.h | 86 - src/nnet2/combine-nnet-a.cc | 230 - src/nnet2/combine-nnet-a.h | 85 - src/nnet2/combine-nnet-fast.cc | 443 -- src/nnet2/combine-nnet-fast.h | 112 - src/nnet2/combine-nnet.cc | 253 - src/nnet2/combine-nnet.h | 74 - src/nnet2/decodable-am-nnet.h | 187 - src/nnet2/get-feature-transform.cc | 203 - src/nnet2/get-feature-transform.h | 180 - src/nnet2/mixup-nnet.cc | 222 - src/nnet2/mixup-nnet.h | 69 - src/nnet2/nnet-component-test.cc | 915 ---- src/nnet2/nnet-component.cc | 4390 ----------------- src/nnet2/nnet-component.h | 1816 ------- .../nnet-compute-discriminative-parallel.cc | 222 - .../nnet-compute-discriminative-parallel.h | 49 - src/nnet2/nnet-compute-discriminative.cc | 416 -- src/nnet2/nnet-compute-discriminative.h | 115 - src/nnet2/nnet-compute-online.cc | 215 - src/nnet2/nnet-compute-online.h | 110 - src/nnet2/nnet-compute-test.cc | 134 - src/nnet2/nnet-compute.cc | 224 - src/nnet2/nnet-compute.h | 85 - src/nnet2/nnet-example-functions-test.cc | 69 - src/nnet2/nnet-example-functions.cc | 997 ---- src/nnet2/nnet-example-functions.h | 300 -- src/nnet2/nnet-example.cc | 309 -- src/nnet2/nnet-example.h | 191 - src/nnet2/nnet-fix.cc | 111 - src/nnet2/nnet-fix.h | 74 - src/nnet2/nnet-functions.cc | 78 - src/nnet2/nnet-functions.h | 70 - src/nnet2/nnet-limit-rank.cc | 112 - src/nnet2/nnet-limit-rank.h | 62 - src/nnet2/nnet-nnet-test.cc | 57 - src/nnet2/nnet-nnet.cc | 846 ---- src/nnet2/nnet-nnet.h | 306 -- src/nnet2/nnet-precondition-online-test.cc | 342 -- src/nnet2/nnet-precondition-online.cc | 641 --- src/nnet2/nnet-precondition-online.h | 574 --- src/nnet2/nnet-precondition-test.cc | 67 - src/nnet2/nnet-precondition.cc | 352 -- src/nnet2/nnet-precondition.h | 88 - src/nnet2/nnet-stats.cc | 122 - src/nnet2/nnet-stats.h | 97 - src/nnet2/nnet-update-parallel.cc | 271 - src/nnet2/nnet-update-parallel.h | 88 - src/nnet2/nnet-update.cc | 361 -- src/nnet2/nnet-update.h | 191 - src/nnet2/online-nnet2-decodable-test.cc | 114 - src/nnet2/online-nnet2-decodable.cc | 145 - src/nnet2/online-nnet2-decodable.h | 122 - src/nnet2/rescale-nnet.cc | 227 - src/nnet2/rescale-nnet.h | 80 - src/nnet2/shrink-nnet.cc | 112 - src/nnet2/shrink-nnet.h | 59 - src/nnet2/train-nnet-ensemble.cc | 141 - src/nnet2/train-nnet-ensemble.h | 105 - src/nnet2/train-nnet.cc | 206 - src/nnet2/train-nnet.h | 64 - src/nnet2/widen-nnet.cc | 100 - src/nnet2/widen-nnet.h | 65 - src/nnet2bin/Makefile | 44 - src/nnet2bin/cuda-compiled.cc | 36 - src/nnet2bin/nnet-adjust-priors.cc | 144 - src/nnet2bin/nnet-align-compiled.cc | 159 - src/nnet2bin/nnet-am-average.cc | 259 - src/nnet2bin/nnet-am-compute.cc | 163 - src/nnet2bin/nnet-am-copy.cc | 214 - src/nnet2bin/nnet-am-fix.cc | 88 - src/nnet2bin/nnet-am-info.cc | 87 - src/nnet2bin/nnet-am-init.cc | 110 - src/nnet2bin/nnet-am-mixup.cc | 81 - src/nnet2bin/nnet-am-reinitialize.cc | 87 - .../nnet-am-switch-preconditioning.cc | 97 - src/nnet2bin/nnet-am-widen.cc | 83 - .../nnet-combine-egs-discriminative.cc | 115 - src/nnet2bin/nnet-combine-fast.cc | 133 - src/nnet2bin/nnet-combine.cc | 124 - .../nnet-compare-hash-discriminative.cc | 138 - src/nnet2bin/nnet-compute-from-egs.cc | 99 - src/nnet2bin/nnet-compute-prob.cc | 104 - src/nnet2bin/nnet-compute.cc | 105 - src/nnet2bin/nnet-copy-egs-discriminative.cc | 158 - src/nnet2bin/nnet-copy-egs.cc | 179 - src/nnet2bin/nnet-get-egs-discriminative.cc | 151 - src/nnet2bin/nnet-get-egs.cc | 184 - .../nnet-get-feature-transform-multi.cc | 94 - src/nnet2bin/nnet-get-feature-transform.cc | 87 - src/nnet2bin/nnet-get-weighted-egs.cc | 232 - src/nnet2bin/nnet-init.cc | 76 - src/nnet2bin/nnet-insert.cc | 138 - src/nnet2bin/nnet-latgen-faster-parallel.cc | 207 - src/nnet2bin/nnet-latgen-faster.cc | 196 - src/nnet2bin/nnet-modify-learning-rates.cc | 211 - src/nnet2bin/nnet-normalize-stddev.cc | 174 - src/nnet2bin/nnet-relabel-egs.cc | 168 - src/nnet2bin/nnet-replace-last-layers.cc | 97 - src/nnet2bin/nnet-show-progress.cc | 164 - .../nnet-shuffle-egs-discriminative.cc | 114 - src/nnet2bin/nnet-shuffle-egs.cc | 113 - src/nnet2bin/nnet-subset-egs.cc | 102 - src/nnet2bin/nnet-to-raw-nnet.cc | 83 - .../nnet-train-discriminative-parallel.cc | 95 - .../nnet-train-discriminative-simple.cc | 116 - src/nnet2bin/nnet-train-ensemble.cc | 145 - src/nnet2bin/nnet-train-parallel.cc | 112 - src/nnet2bin/nnet-train-simple.cc | 117 - src/nnet2bin/nnet-train-transitions.cc | 147 - src/nnet2bin/nnet1-to-raw-nnet.cc | 222 - src/nnet2bin/raw-nnet-concat.cc | 75 - src/nnet2bin/raw-nnet-copy.cc | 107 - src/nnet2bin/raw-nnet-info.cc | 63 - src/nnet2bin/raw-nnet-init | 1 - src/nnet3/decodable-online-looped.h | 8 +- src/nnet3/decodable-simple-looped.cc | 2 +- src/nnet3/decodable-simple-looped.h | 6 +- src/nnet3/discriminative-supervision.cc | 2 +- src/nnet3/discriminative-supervision.h | 6 +- src/nnet3/discriminative-training.cc | 8 +- src/nnet3/discriminative-training.h | 4 +- src/nnet3/nnet-am-decodable-simple.cc | 4 +- src/nnet3/nnet-am-decodable-simple.h | 10 +- src/nnet3/nnet-batch-compute.cc | 2 +- src/nnet3/nnet-batch-compute.h | 6 +- src/nnet3/nnet-discriminative-diagnostics.cc | 2 +- src/nnet3/nnet-discriminative-diagnostics.h | 4 +- src/nnet3/nnet-discriminative-example.h | 2 +- src/nnet3/nnet-discriminative-training.cc | 2 +- src/nnet3/nnet-discriminative-training.h | 4 +- src/nnet3/nnet-nnet.cc | 6 +- src/nnet3bin/Makefile | 4 +- src/nnet3bin/nnet3-acc-lda-stats.cc | 2 +- src/nnet3bin/nnet3-align-compiled.cc | 4 +- src/nnet3bin/nnet3-am-adjust-priors.cc | 4 +- src/nnet3bin/nnet3-am-copy.cc | 4 +- src/nnet3bin/nnet3-am-info.cc | 4 +- src/nnet3bin/nnet3-am-init.cc | 10 +- src/nnet3bin/nnet3-am-train-transitions.cc | 147 - src/nnet3bin/nnet3-average.cc | 2 +- src/nnet3bin/nnet3-compute-batch.cc | 2 +- src/nnet3bin/nnet3-compute-from-egs.cc | 2 +- src/nnet3bin/nnet3-compute.cc | 2 +- src/nnet3bin/nnet3-copy-egs.cc | 2 +- src/nnet3bin/nnet3-copy.cc | 2 +- .../nnet3-discriminative-compute-from-egs.cc | 2 +- .../nnet3-discriminative-compute-objf.cc | 2 +- src/nnet3bin/nnet3-discriminative-copy-egs.cc | 2 +- src/nnet3bin/nnet3-discriminative-get-egs.cc | 6 +- .../nnet3-discriminative-merge-egs.cc | 2 +- .../nnet3-discriminative-shuffle-egs.cc | 2 +- src/nnet3bin/nnet3-discriminative-train.cc | 2 +- src/nnet3bin/nnet3-egs-augment-image.cc | 2 +- src/nnet3bin/nnet3-get-egs-dense-targets.cc | 2 +- src/nnet3bin/nnet3-get-egs-simple.cc | 2 +- src/nnet3bin/nnet3-get-egs.cc | 2 +- src/nnet3bin/nnet3-init.cc | 2 +- src/nnet3bin/nnet3-latgen-faster-batch.cc | 4 +- src/nnet3bin/nnet3-latgen-faster-looped.cc | 4 +- src/nnet3bin/nnet3-latgen-faster-parallel.cc | 4 +- src/nnet3bin/nnet3-latgen-faster.cc | 4 +- src/nnet3bin/nnet3-latgen-grammar.cc | 4 +- src/nnet3bin/nnet3-merge-egs.cc | 2 +- src/nnet3bin/nnet3-show-progress.cc | 2 +- src/nnet3bin/nnet3-shuffle-egs.cc | 2 +- src/nnetbin/Makefile | 30 - src/nnetbin/cmvn-to-nnet.cc | 121 - src/nnetbin/feat-to-post.cc | 80 - src/nnetbin/nnet-concat.cc | 90 - src/nnetbin/nnet-copy.cc | 151 - src/nnetbin/nnet-forward.cc | 208 - src/nnetbin/nnet-info.cc | 65 - src/nnetbin/nnet-initialize.cc | 71 - src/nnetbin/nnet-set-learnrate.cc | 104 - src/nnetbin/nnet-train-frmshuff.cc | 424 -- src/nnetbin/nnet-train-mmi-sequential.cc | 481 -- src/nnetbin/nnet-train-mpe-sequential.cc | 412 -- src/nnetbin/nnet-train-multistream-perutt.cc | 363 -- src/nnetbin/nnet-train-multistream.cc | 460 -- src/nnetbin/nnet-train-perutt.cc | 310 -- src/nnetbin/paste-post.cc | 168 - src/nnetbin/rbm-convert-to-nnet.cc | 77 - src/nnetbin/rbm-train-cd1-frmshuff.cc | 287 -- src/nnetbin/train-transitions.cc | 101 - src/nnetbin/transf-to-nnet.cc | 79 - src/online/online-decodable.cc | 2 +- src/online/online-decodable.h | 4 +- src/online/online-faster-decoder.h | 6 +- src/online2/Makefile | 6 +- src/online2/online-endpoint.cc | 8 +- src/online2/online-endpoint.h | 6 +- src/online2/online-gmm-decodable.cc | 2 +- src/online2/online-gmm-decodable.h | 6 +- src/online2/online-gmm-decoding.cc | 16 +- src/online2/online-gmm-decoding.h | 6 +- src/online2/online-ivector-feature.cc | 2 +- src/online2/online-ivector-feature.h | 4 +- src/online2/online-nnet2-decoding-threaded.cc | 652 --- src/online2/online-nnet2-decoding-threaded.h | 6 +- src/online2/online-nnet2-decoding.cc | 81 - src/online2/online-nnet2-decoding.h | 6 +- src/online2/online-nnet3-decoding.cc | 2 +- src/online2/online-nnet3-decoding.h | 6 +- ...ipeline.cc => online2-feature-pipeline.cc} | 0 .../online2-wav-nnet2-am-compute.cc | 2 +- .../online2-wav-nnet2-latgen-faster.cc | 2 +- .../online2-wav-nnet2-latgen-threaded.cc | 2 +- .../online2-wav-nnet3-latgen-faster.cc | 2 +- .../online2-wav-nnet3-latgen-grammar.cc | 2 +- .../online-audio-server-decode-faster.cc | 2 +- src/onlinebin/online-gmm-decode-faster.cc | 2 +- .../online-server-gmm-decode-faster.cc | 2 +- src/onlinebin/online-wav-gmm-decode-faster.cc | 2 +- src/sgmm2/Makefile | 19 - src/sgmm2/am-sgmm2-project.cc | 265 - src/sgmm2/am-sgmm2-project.h | 86 - src/sgmm2/am-sgmm2-test.cc | 285 -- src/sgmm2/am-sgmm2.cc | 1493 ------ src/sgmm2/am-sgmm2.h | 586 --- src/sgmm2/decodable-am-sgmm2.cc | 54 - src/sgmm2/decodable-am-sgmm2.h | 138 - src/sgmm2/estimate-am-sgmm2-ebw.cc | 736 --- src/sgmm2/estimate-am-sgmm2-ebw.h | 242 - src/sgmm2/estimate-am-sgmm2-test.cc | 167 - src/sgmm2/estimate-am-sgmm2.cc | 1952 -------- src/sgmm2/estimate-am-sgmm2.h | 478 -- src/sgmm2/fmllr-sgmm2-test.cc | 243 - src/sgmm2/fmllr-sgmm2.cc | 555 --- src/sgmm2/fmllr-sgmm2.h | 193 - src/sgmm2bin/Makefile | 26 - src/sgmm2bin/init-ubm.cc | 95 - src/sgmm2bin/sgmm2-acc-stats-gpost.cc | 181 - src/sgmm2bin/sgmm2-acc-stats.cc | 223 - src/sgmm2bin/sgmm2-acc-stats2.cc | 240 - src/sgmm2bin/sgmm2-align-compiled.cc | 183 - src/sgmm2bin/sgmm2-comp-prexform.cc | 84 - src/sgmm2bin/sgmm2-copy.cc | 74 - src/sgmm2bin/sgmm2-est-ebw.cc | 118 - src/sgmm2bin/sgmm2-est-fmllr.cc | 302 -- src/sgmm2bin/sgmm2-est-spkvecs-gpost.cc | 218 - src/sgmm2bin/sgmm2-est-spkvecs.cc | 259 - src/sgmm2bin/sgmm2-est.cc | 166 - src/sgmm2bin/sgmm2-gselect.cc | 110 - src/sgmm2bin/sgmm2-info.cc | 115 - src/sgmm2bin/sgmm2-init.cc | 132 - src/sgmm2bin/sgmm2-latgen-faster-parallel.cc | 291 -- src/sgmm2bin/sgmm2-latgen-faster.cc | 268 - src/sgmm2bin/sgmm2-post-to-gpost.cc | 186 - src/sgmm2bin/sgmm2-project.cc | 116 - src/sgmm2bin/sgmm2-rescore-lattice.cc | 166 - src/sgmm2bin/sgmm2-sum-accs.cc | 94 - src/transform/Makefile | 14 +- .../decodable-am-diag-gmm-regtree.cc | 234 - src/transform/decodable-am-diag-gmm-regtree.h | 141 - src/transform/fmllr-raw-test.cc | 123 - src/transform/fmllr-raw.cc | 546 -- src/transform/fmllr-raw.h | 206 - src/transform/fmpe-test.cc | 177 - src/transform/fmpe.cc | 691 --- src/transform/fmpe.h | 271 - src/transform/regtree-fmllr-diag-gmm-test.cc | 320 -- src/transform/regtree-fmllr-diag-gmm.cc | 407 -- src/transform/regtree-fmllr-diag-gmm.h | 204 - src/transform/regtree-mllr-diag-gmm-test.cc | 194 - src/transform/regtree-mllr-diag-gmm.cc | 398 -- src/transform/regtree-mllr-diag-gmm.h | 164 - src/tree/build-tree.h | 4 +- src/tree/context-dep.h | 6 +- 515 files changed, 1652 insertions(+), 63007 deletions(-) rename src/{nnetbin => bin}/cuda-gpu-available.cc (100%) delete mode 100644 src/feat/feature-plp-test.cc delete mode 100644 src/feat/feature-plp.cc delete mode 100644 src/feat/feature-plp.h delete mode 100644 src/featbin/compute-plp-feats.cc delete mode 100644 src/featbin/compute-spectrogram-feats.cc delete mode 100644 src/gmmbin/gmm-decode-faster-regtree-fmllr.cc delete mode 100644 src/gmmbin/gmm-decode-faster-regtree-mllr.cc delete mode 100644 src/gmmbin/gmm-est-fmllr-raw-gpost.cc delete mode 100644 src/gmmbin/gmm-est-fmllr-raw.cc delete mode 100644 src/gmmbin/gmm-est-regtree-fmllr-ali.cc delete mode 100644 src/gmmbin/gmm-est-regtree-fmllr.cc delete mode 100644 src/gmmbin/gmm-est-regtree-mllr.cc delete mode 100644 src/gmmbin/gmm-latgen-faster-regtree-fmllr.cc delete mode 100644 src/gmmbin/gmm-make-regtree.cc delete mode 100644 src/hmm/hmm-topology.h rename src/hmm/{hmm-topology.cc => topology.cc} (89%) create mode 100644 src/hmm/topology.h delete mode 100644 src/hmm/transition-model.h rename src/hmm/{transition-model-test.cc => transitions-test.cc} (87%) rename src/hmm/{transition-model.cc => transitions.cc} (88%) create mode 100644 src/hmm/transitions.h delete mode 100644 src/nnet/Makefile delete mode 100644 src/nnet/nnet-activation.h delete mode 100644 src/nnet/nnet-affine-transform.h delete mode 100644 src/nnet/nnet-average-pooling-2d-component.h delete mode 100644 src/nnet/nnet-average-pooling-component.h delete mode 100644 src/nnet/nnet-blstm-projected.h delete mode 100644 src/nnet/nnet-component-test.cc delete mode 100644 src/nnet/nnet-component.cc delete mode 100644 src/nnet/nnet-component.h delete mode 100644 src/nnet/nnet-convolutional-2d-component.h delete mode 100644 src/nnet/nnet-convolutional-component.h delete mode 100644 src/nnet/nnet-frame-pooling-component.h delete mode 100644 src/nnet/nnet-kl-hmm.h delete mode 100644 src/nnet/nnet-linear-transform.h delete mode 100644 src/nnet/nnet-loss.cc delete mode 100644 src/nnet/nnet-loss.h delete mode 100644 src/nnet/nnet-lstm-projected.h delete mode 100644 src/nnet/nnet-matrix-buffer.h delete mode 100644 src/nnet/nnet-max-pooling-2d-component.h delete mode 100644 src/nnet/nnet-max-pooling-component.h delete mode 100644 src/nnet/nnet-multibasis-component.h delete mode 100644 src/nnet/nnet-nnet.cc delete mode 100644 src/nnet/nnet-nnet.h delete mode 100644 src/nnet/nnet-parallel-component.h delete mode 100644 src/nnet/nnet-parametric-relu.h delete mode 100644 src/nnet/nnet-pdf-prior.cc delete mode 100644 src/nnet/nnet-pdf-prior.h delete mode 100644 src/nnet/nnet-randomizer-test.cc delete mode 100644 src/nnet/nnet-randomizer.cc delete mode 100644 src/nnet/nnet-randomizer.h delete mode 100644 src/nnet/nnet-rbm.h delete mode 100644 src/nnet/nnet-recurrent.h delete mode 100644 src/nnet/nnet-sentence-averaging-component.h delete mode 100644 src/nnet/nnet-trnopts.h delete mode 100644 src/nnet/nnet-utils.h delete mode 100644 src/nnet/nnet-various.h delete mode 100644 src/nnet2/Makefile delete mode 100644 src/nnet2/am-nnet-test.cc delete mode 100644 src/nnet2/am-nnet.cc delete mode 100644 src/nnet2/am-nnet.h delete mode 100644 src/nnet2/combine-nnet-a.cc delete mode 100644 src/nnet2/combine-nnet-a.h delete mode 100644 src/nnet2/combine-nnet-fast.cc delete mode 100644 src/nnet2/combine-nnet-fast.h delete mode 100644 src/nnet2/combine-nnet.cc delete mode 100644 src/nnet2/combine-nnet.h delete mode 100644 src/nnet2/decodable-am-nnet.h delete mode 100644 src/nnet2/get-feature-transform.cc delete mode 100644 src/nnet2/get-feature-transform.h delete mode 100644 src/nnet2/mixup-nnet.cc delete mode 100644 src/nnet2/mixup-nnet.h delete mode 100644 src/nnet2/nnet-component-test.cc delete mode 100644 src/nnet2/nnet-component.cc delete mode 100644 src/nnet2/nnet-component.h delete mode 100644 src/nnet2/nnet-compute-discriminative-parallel.cc delete mode 100644 src/nnet2/nnet-compute-discriminative-parallel.h delete mode 100644 src/nnet2/nnet-compute-discriminative.cc delete mode 100644 src/nnet2/nnet-compute-discriminative.h delete mode 100644 src/nnet2/nnet-compute-online.cc delete mode 100644 src/nnet2/nnet-compute-online.h delete mode 100644 src/nnet2/nnet-compute-test.cc delete mode 100644 src/nnet2/nnet-compute.cc delete mode 100644 src/nnet2/nnet-compute.h delete mode 100644 src/nnet2/nnet-example-functions-test.cc delete mode 100644 src/nnet2/nnet-example-functions.cc delete mode 100644 src/nnet2/nnet-example-functions.h delete mode 100644 src/nnet2/nnet-example.cc delete mode 100644 src/nnet2/nnet-example.h delete mode 100644 src/nnet2/nnet-fix.cc delete mode 100644 src/nnet2/nnet-fix.h delete mode 100644 src/nnet2/nnet-functions.cc delete mode 100644 src/nnet2/nnet-functions.h delete mode 100644 src/nnet2/nnet-limit-rank.cc delete mode 100644 src/nnet2/nnet-limit-rank.h delete mode 100644 src/nnet2/nnet-nnet-test.cc delete mode 100644 src/nnet2/nnet-nnet.cc delete mode 100644 src/nnet2/nnet-nnet.h delete mode 100644 src/nnet2/nnet-precondition-online-test.cc delete mode 100644 src/nnet2/nnet-precondition-online.cc delete mode 100644 src/nnet2/nnet-precondition-online.h delete mode 100644 src/nnet2/nnet-precondition-test.cc delete mode 100644 src/nnet2/nnet-precondition.cc delete mode 100644 src/nnet2/nnet-precondition.h delete mode 100644 src/nnet2/nnet-stats.cc delete mode 100644 src/nnet2/nnet-stats.h delete mode 100644 src/nnet2/nnet-update-parallel.cc delete mode 100644 src/nnet2/nnet-update-parallel.h delete mode 100644 src/nnet2/nnet-update.cc delete mode 100644 src/nnet2/nnet-update.h delete mode 100644 src/nnet2/online-nnet2-decodable-test.cc delete mode 100644 src/nnet2/online-nnet2-decodable.cc delete mode 100644 src/nnet2/online-nnet2-decodable.h delete mode 100644 src/nnet2/rescale-nnet.cc delete mode 100644 src/nnet2/rescale-nnet.h delete mode 100644 src/nnet2/shrink-nnet.cc delete mode 100644 src/nnet2/shrink-nnet.h delete mode 100644 src/nnet2/train-nnet-ensemble.cc delete mode 100644 src/nnet2/train-nnet-ensemble.h delete mode 100644 src/nnet2/train-nnet.cc delete mode 100644 src/nnet2/train-nnet.h delete mode 100644 src/nnet2/widen-nnet.cc delete mode 100644 src/nnet2/widen-nnet.h delete mode 100644 src/nnet2bin/Makefile delete mode 100644 src/nnet2bin/cuda-compiled.cc delete mode 100644 src/nnet2bin/nnet-adjust-priors.cc delete mode 100644 src/nnet2bin/nnet-align-compiled.cc delete mode 100644 src/nnet2bin/nnet-am-average.cc delete mode 100644 src/nnet2bin/nnet-am-compute.cc delete mode 100644 src/nnet2bin/nnet-am-copy.cc delete mode 100644 src/nnet2bin/nnet-am-fix.cc delete mode 100644 src/nnet2bin/nnet-am-info.cc delete mode 100644 src/nnet2bin/nnet-am-init.cc delete mode 100644 src/nnet2bin/nnet-am-mixup.cc delete mode 100644 src/nnet2bin/nnet-am-reinitialize.cc delete mode 100644 src/nnet2bin/nnet-am-switch-preconditioning.cc delete mode 100644 src/nnet2bin/nnet-am-widen.cc delete mode 100644 src/nnet2bin/nnet-combine-egs-discriminative.cc delete mode 100644 src/nnet2bin/nnet-combine-fast.cc delete mode 100644 src/nnet2bin/nnet-combine.cc delete mode 100644 src/nnet2bin/nnet-compare-hash-discriminative.cc delete mode 100644 src/nnet2bin/nnet-compute-from-egs.cc delete mode 100644 src/nnet2bin/nnet-compute-prob.cc delete mode 100644 src/nnet2bin/nnet-compute.cc delete mode 100644 src/nnet2bin/nnet-copy-egs-discriminative.cc delete mode 100644 src/nnet2bin/nnet-copy-egs.cc delete mode 100644 src/nnet2bin/nnet-get-egs-discriminative.cc delete mode 100644 src/nnet2bin/nnet-get-egs.cc delete mode 100644 src/nnet2bin/nnet-get-feature-transform-multi.cc delete mode 100644 src/nnet2bin/nnet-get-feature-transform.cc delete mode 100644 src/nnet2bin/nnet-get-weighted-egs.cc delete mode 100644 src/nnet2bin/nnet-init.cc delete mode 100644 src/nnet2bin/nnet-insert.cc delete mode 100644 src/nnet2bin/nnet-latgen-faster-parallel.cc delete mode 100644 src/nnet2bin/nnet-latgen-faster.cc delete mode 100644 src/nnet2bin/nnet-modify-learning-rates.cc delete mode 100644 src/nnet2bin/nnet-normalize-stddev.cc delete mode 100644 src/nnet2bin/nnet-relabel-egs.cc delete mode 100644 src/nnet2bin/nnet-replace-last-layers.cc delete mode 100644 src/nnet2bin/nnet-show-progress.cc delete mode 100644 src/nnet2bin/nnet-shuffle-egs-discriminative.cc delete mode 100644 src/nnet2bin/nnet-shuffle-egs.cc delete mode 100644 src/nnet2bin/nnet-subset-egs.cc delete mode 100644 src/nnet2bin/nnet-to-raw-nnet.cc delete mode 100644 src/nnet2bin/nnet-train-discriminative-parallel.cc delete mode 100644 src/nnet2bin/nnet-train-discriminative-simple.cc delete mode 100644 src/nnet2bin/nnet-train-ensemble.cc delete mode 100644 src/nnet2bin/nnet-train-parallel.cc delete mode 100644 src/nnet2bin/nnet-train-simple.cc delete mode 100644 src/nnet2bin/nnet-train-transitions.cc delete mode 100644 src/nnet2bin/nnet1-to-raw-nnet.cc delete mode 100644 src/nnet2bin/raw-nnet-concat.cc delete mode 100644 src/nnet2bin/raw-nnet-copy.cc delete mode 100644 src/nnet2bin/raw-nnet-info.cc delete mode 120000 src/nnet2bin/raw-nnet-init delete mode 100644 src/nnet3bin/nnet3-am-train-transitions.cc delete mode 100644 src/nnetbin/Makefile delete mode 100644 src/nnetbin/cmvn-to-nnet.cc delete mode 100644 src/nnetbin/feat-to-post.cc delete mode 100644 src/nnetbin/nnet-concat.cc delete mode 100644 src/nnetbin/nnet-copy.cc delete mode 100644 src/nnetbin/nnet-forward.cc delete mode 100644 src/nnetbin/nnet-info.cc delete mode 100644 src/nnetbin/nnet-initialize.cc delete mode 100644 src/nnetbin/nnet-set-learnrate.cc delete mode 100644 src/nnetbin/nnet-train-frmshuff.cc delete mode 100644 src/nnetbin/nnet-train-mmi-sequential.cc delete mode 100644 src/nnetbin/nnet-train-mpe-sequential.cc delete mode 100644 src/nnetbin/nnet-train-multistream-perutt.cc delete mode 100644 src/nnetbin/nnet-train-multistream.cc delete mode 100644 src/nnetbin/nnet-train-perutt.cc delete mode 100644 src/nnetbin/paste-post.cc delete mode 100644 src/nnetbin/rbm-convert-to-nnet.cc delete mode 100644 src/nnetbin/rbm-train-cd1-frmshuff.cc delete mode 100644 src/nnetbin/train-transitions.cc delete mode 100644 src/nnetbin/transf-to-nnet.cc delete mode 100644 src/online2/online-nnet2-decoding-threaded.cc delete mode 100644 src/online2/online-nnet2-decoding.cc rename src/online2/{online-nnet2-feature-pipeline.cc => online2-feature-pipeline.cc} (100%) delete mode 100644 src/sgmm2/Makefile delete mode 100644 src/sgmm2/am-sgmm2-project.cc delete mode 100644 src/sgmm2/am-sgmm2-project.h delete mode 100644 src/sgmm2/am-sgmm2-test.cc delete mode 100644 src/sgmm2/am-sgmm2.cc delete mode 100644 src/sgmm2/am-sgmm2.h delete mode 100644 src/sgmm2/decodable-am-sgmm2.cc delete mode 100644 src/sgmm2/decodable-am-sgmm2.h delete mode 100644 src/sgmm2/estimate-am-sgmm2-ebw.cc delete mode 100644 src/sgmm2/estimate-am-sgmm2-ebw.h delete mode 100644 src/sgmm2/estimate-am-sgmm2-test.cc delete mode 100644 src/sgmm2/estimate-am-sgmm2.cc delete mode 100644 src/sgmm2/estimate-am-sgmm2.h delete mode 100644 src/sgmm2/fmllr-sgmm2-test.cc delete mode 100644 src/sgmm2/fmllr-sgmm2.cc delete mode 100644 src/sgmm2/fmllr-sgmm2.h delete mode 100644 src/sgmm2bin/Makefile delete mode 100644 src/sgmm2bin/init-ubm.cc delete mode 100644 src/sgmm2bin/sgmm2-acc-stats-gpost.cc delete mode 100644 src/sgmm2bin/sgmm2-acc-stats.cc delete mode 100644 src/sgmm2bin/sgmm2-acc-stats2.cc delete mode 100644 src/sgmm2bin/sgmm2-align-compiled.cc delete mode 100644 src/sgmm2bin/sgmm2-comp-prexform.cc delete mode 100644 src/sgmm2bin/sgmm2-copy.cc delete mode 100644 src/sgmm2bin/sgmm2-est-ebw.cc delete mode 100644 src/sgmm2bin/sgmm2-est-fmllr.cc delete mode 100644 src/sgmm2bin/sgmm2-est-spkvecs-gpost.cc delete mode 100644 src/sgmm2bin/sgmm2-est-spkvecs.cc delete mode 100644 src/sgmm2bin/sgmm2-est.cc delete mode 100644 src/sgmm2bin/sgmm2-gselect.cc delete mode 100644 src/sgmm2bin/sgmm2-info.cc delete mode 100644 src/sgmm2bin/sgmm2-init.cc delete mode 100644 src/sgmm2bin/sgmm2-latgen-faster-parallel.cc delete mode 100644 src/sgmm2bin/sgmm2-latgen-faster.cc delete mode 100644 src/sgmm2bin/sgmm2-post-to-gpost.cc delete mode 100644 src/sgmm2bin/sgmm2-project.cc delete mode 100644 src/sgmm2bin/sgmm2-rescore-lattice.cc delete mode 100644 src/sgmm2bin/sgmm2-sum-accs.cc delete mode 100644 src/transform/decodable-am-diag-gmm-regtree.cc delete mode 100644 src/transform/decodable-am-diag-gmm-regtree.h delete mode 100644 src/transform/fmllr-raw-test.cc delete mode 100644 src/transform/fmllr-raw.cc delete mode 100644 src/transform/fmllr-raw.h delete mode 100644 src/transform/fmpe-test.cc delete mode 100644 src/transform/fmpe.cc delete mode 100644 src/transform/fmpe.h delete mode 100644 src/transform/regtree-fmllr-diag-gmm-test.cc delete mode 100644 src/transform/regtree-fmllr-diag-gmm.cc delete mode 100644 src/transform/regtree-fmllr-diag-gmm.h delete mode 100644 src/transform/regtree-mllr-diag-gmm-test.cc delete mode 100644 src/transform/regtree-mllr-diag-gmm.cc delete mode 100644 src/transform/regtree-mllr-diag-gmm.h diff --git a/src/Makefile b/src/Makefile index a49c912c6ed..88da5ed1e55 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,15 +6,15 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree gmm transform \ - fstext hmm lm decoder lat kws cudamatrix nnet \ + fstext hmm lm decoder lat kws cudamatrix \ bin fstbin gmmbin fgmmbin featbin \ - nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 rnnlm chain nnet3bin nnet2bin kwsbin \ + latbin nnet3 rnnlm chain nnet3bin kwsbin \ ivector ivectorbin online2 online2bin lmbin chainbin rnnlmbin MEMTESTDIRS = base matrix util feat tree gmm transform \ - fstext hmm lm decoder lat nnet kws chain \ + fstext hmm lm decoder lat kws chain \ bin fstbin gmmbin fgmmbin featbin \ - nnetbin latbin sgmm2 nnet2 nnet3 rnnlm nnet2bin nnet3bin sgmm2bin kwsbin \ + latbin nnet3 rnnlm nnet3bin kwsbin \ ivector ivectorbin online2 online2bin lmbin CUDAMEMTESTDIR = cudamatrix @@ -23,9 +23,6 @@ SUBDIRS_LIB = $(filter-out %bin, $(SUBDIRS)) KALDI_SONAME ?= libkaldi.so -# Optional subdirectories -EXT_SUBDIRS = online onlinebin # python-kaldi-decoding -EXT_SUBDIRS_LIB = $(filter-out %bin, $(EXT_SUBDIRS)) include kaldi.mk @@ -72,19 +69,6 @@ endif endif endif -biglibext: $(EXT_SUBDIRS_LIB) -ifeq ($(KALDI_FLAVOR), dynamic) -ifeq ($(shell uname), Darwin) - $(CXX) -dynamiclib -o $(KALDILIBDIR)/libkaldi_ext.dylib -install_name @rpath/libkaldi_ext.dylib -framework Accelerate $(LDFLAGS) $(EXT_SUBDIRS_LIB:=/*.dylib) -else -ifeq ($(shell uname), Linux) - #$(warning The following command will probably fail, in that case add -fPIC to your CXXFLAGS and remake all.) - $(CXX) -shared -o $(KALDILIBDIR)/libkaldi_ext.so -Wl,-soname=libkaldi_ext.so,--whole-archive $(EXT_SUBDIRS_LIB:=/kaldi-*.a) -Wl,--no-whole-archive -else - $(error Dynamic libraries not supported on this platform. Run configure with --static flag. ) -endif -endif -endif kaldi.mk: @[ -f kaldi.mk ] || { echo "kaldi.mk does not exist; you have to run ./configure"; exit 1; } @@ -143,9 +127,9 @@ $(EXT_SUBDIRS) : checkversion kaldi.mk mklibdir ext_depend ### Dependency list ### # this is necessary for correct parallel compilation #1)The tools depend on all the libraries -bin fstbin gmmbin fgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin rnnlmbin: \ +bin fstbin gmmbin fgmmbin sgmm2bin featbin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin rnnlmbin: \ base matrix util feat tree gmm transform sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 rnnlm + lm decoder lat cudamatrix nnet3 ivector chain kws online2 rnnlm #2)The libraries have inter-dependencies base: base/.depend.mk @@ -162,15 +146,13 @@ lm: base util matrix fstext decoder: base util matrix gmm hmm tree transform lat lat: base util hmm tree matrix cudamatrix: base util matrix -nnet: base util hmm tree matrix cudamatrix -nnet2: base util matrix lat gmm hmm tree transform cudamatrix nnet3: base util matrix lat gmm hmm tree transform cudamatrix chain fstext rnnlm: base util matrix cudamatrix nnet3 lm hmm chain: lat hmm tree fstext matrix cudamatrix util base ivector: base util matrix transform tree gmm #3)Dependencies for optional parts of Kaldi -onlinebin: base matrix util feat tree gmm transform sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online +onlinebin: base matrix util feat tree gmm transform sgmm2 fstext hmm lm decoder lat cudamatrix online # python-kaldi-decoding: base matrix util feat tree gmm transform sgmm2 fstext hmm decoder lat online online: decoder gmm transform feat matrix util base lat hmm tree -online2: decoder gmm transform feat matrix util base lat hmm tree ivector cudamatrix nnet2 nnet3 chain +online2: decoder gmm transform feat matrix util base lat hmm tree ivector cudamatrix nnet3 chain kws: base util hmm tree matrix lat diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b50120..f8f0564743c 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector cuda-gpu-available OBJFILES = diff --git a/src/bin/acc-lda.cc b/src/bin/acc-lda.cc index b664135bdc7..a0451218513 100644 --- a/src/bin/acc-lda.cc +++ b/src/bin/acc-lda.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/posterior.h" #include "transform/lda-estimate.h" @@ -57,7 +57,7 @@ int main(int argc, char *argv[]) { std::string posteriors_rspecifier = po.GetArg(3); std::string acc_wxfilename = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_rxfilename, &binary_read); diff --git a/src/bin/acc-tree-stats.cc b/src/bin/acc-tree-stats.cc index 8b9ce9065b4..c0eb31f6064 100644 --- a/src/bin/acc-tree-stats.cc +++ b/src/bin/acc-tree-stats.cc @@ -22,7 +22,7 @@ #include "util/common-utils.h" #include "tree/context-dep.h" #include "tree/build-tree-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/tree-accu.h" /** @brief Accumulate tree statistics for decision tree training. The @@ -62,7 +62,7 @@ int main(int argc, char *argv[]) { AccumulateTreeStatsInfo acc_tree_stats_info(opts); - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/bin/add-self-loops.cc b/src/bin/add-self-loops.cc index b223dfe317d..562b0977a69 100644 --- a/src/bin/add-self-loops.cc +++ b/src/bin/add-self-loops.cc @@ -18,7 +18,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "tree/context-dep.h" #include "util/common-utils.h" @@ -88,7 +88,7 @@ int main(int argc, char *argv[]) { "standard input" : disambig_in_filename); } - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); diff --git a/src/bin/ali-to-pdf.cc b/src/bin/ali-to-pdf.cc index 61b5138cf31..1706f5aa371 100644 --- a/src/bin/ali-to-pdf.cc +++ b/src/bin/ali-to-pdf.cc @@ -21,7 +21,7 @@ */ #include "base/kaldi-common.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" @@ -48,7 +48,7 @@ int main(int argc, char *argv[]) { alignments_rspecifier = po.GetArg(2), pdfs_wspecifier = po.GetArg(3); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_filename, &trans_model); SequentialInt32VectorReader reader(alignments_rspecifier); diff --git a/src/bin/ali-to-phones.cc b/src/bin/ali-to-phones.cc index 602e32e9768..5def11ffc79 100644 --- a/src/bin/ali-to-phones.cc +++ b/src/bin/ali-to-phones.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" @@ -68,7 +68,7 @@ int main(int argc, char *argv[]) { std::string model_filename = po.GetArg(1), alignments_rspecifier = po.GetArg(2); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_filename, &trans_model); SequentialInt32VectorReader reader(alignments_rspecifier); diff --git a/src/bin/ali-to-post.cc b/src/bin/ali-to-post.cc index ac87d676c06..00c026c0692 100644 --- a/src/bin/ali-to-post.cc +++ b/src/bin/ali-to-post.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" diff --git a/src/bin/align-compiled-mapped.cc b/src/bin/align-compiled-mapped.cc index 98ffebd6eaa..ab7425c1a32 100644 --- a/src/bin/align-compiled-mapped.cc +++ b/src/bin/align-compiled-mapped.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" @@ -74,7 +74,7 @@ int main(int argc, char *argv[]) { std::string alignment_wspecifier = po.GetArg(4); std::string scores_wspecifier = po.GetOptArg(5); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); SequentialBaseFloatMatrixReader loglikes_reader(feature_rspecifier); diff --git a/src/bin/align-equal-compiled.cc b/src/bin/align-equal-compiled.cc index c4ab9d4205a..f5900727aef 100644 --- a/src/bin/align-equal-compiled.cc +++ b/src/bin/align-equal-compiled.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/training-graph-compiler.h" diff --git a/src/bin/align-equal.cc b/src/bin/align-equal.cc index a3bc40dc236..671c515f33e 100644 --- a/src/bin/align-equal.cc +++ b/src/bin/align-equal.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/training-graph-compiler.h" @@ -65,7 +65,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_in_filename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); // need VectorFst because we will change it by adding subseq symbol. diff --git a/src/bin/align-mapped.cc b/src/bin/align-mapped.cc index c78401fffdd..e8249c4a123 100644 --- a/src/bin/align-mapped.cc +++ b/src/bin/align-mapped.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "decoder/training-graph-compiler.h" @@ -72,7 +72,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_in_filename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); VectorFst *lex_fst = fst::ReadFstKaldi(lex_in_filename); diff --git a/src/bin/am-info.cc b/src/bin/am-info.cc index 6afb0c5014e..dd59047c35c 100644 --- a/src/bin/am-info.cc +++ b/src/bin/am-info.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { @@ -45,7 +45,7 @@ int main(int argc, char *argv[]) { std::string model_in_filename = po.GetArg(1); - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/bin/build-pfile-from-ali.cc b/src/bin/build-pfile-from-ali.cc index fadb873825f..fb82fe27eaa 100644 --- a/src/bin/build-pfile-from-ali.cc +++ b/src/bin/build-pfile-from-ali.cc @@ -25,7 +25,7 @@ using std::vector; #include "base/kaldi-common.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" @@ -64,7 +64,7 @@ int main(int argc, char *argv[]) { feature_rspecifier = po.GetArg(3), pfile_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/bin/build-tree-two-level.cc b/src/bin/build-tree-two-level.cc index c7cd553484e..005c5d80532 100644 --- a/src/bin/build-tree-two-level.cc +++ b/src/bin/build-tree-two-level.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "tree/context-dep.h" #include "tree/build-tree.h" #include "tree/build-tree-utils.h" @@ -112,7 +112,7 @@ int main(int argc, char *argv[]) { ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root); } - HmmTopology topo; + Topology topo; ReadKaldiObject(topo_filename, &topo); BuildTreeStatsType stats; diff --git a/src/bin/build-tree.cc b/src/bin/build-tree.cc index 72774900d61..b37c9c7d184 100644 --- a/src/bin/build-tree.cc +++ b/src/bin/build-tree.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "tree/context-dep.h" #include "tree/build-tree.h" #include "tree/build-tree-utils.h" @@ -91,7 +91,7 @@ int main(int argc, char *argv[]) { ReadRootsFile(ki.Stream(), &phone_sets, &is_shared_root, &is_split_root); } - HmmTopology topo; + Topology topo; ReadKaldiObject(topo_filename, &topo); BuildTreeStatsType stats; diff --git a/src/bin/compile-graph.cc b/src/bin/compile-graph.cc index 7174fdf8113..2dae81fa702 100644 --- a/src/bin/compile-graph.cc +++ b/src/bin/compile-graph.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "fstext/fstext-lib.h" #include "fstext/push-special.h" @@ -81,7 +81,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; // the tree. ReadKaldiObject(tree_rxfilename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_rxfilename, &trans_model); VectorFst *lex_fst = fst::ReadFstKaldi(lex_rxfilename), diff --git a/src/bin/compile-questions.cc b/src/bin/compile-questions.cc index f9694140ae8..1c8565e032d 100644 --- a/src/bin/compile-questions.cc +++ b/src/bin/compile-questions.cc @@ -19,12 +19,12 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "tree/build-tree-questions.h" namespace kaldi { -int32 ProcessTopo(const HmmTopology &topo, const std::vector > &questions) { +int32 ProcessTopo(const Topology &topo, const std::vector > &questions) { std::vector seen_phones; // ids of phones seen in questions. for (size_t i = 0; i < questions.size(); i++) for (size_t j= 0; j < questions[i].size(); j++) seen_phones.push_back(questions[i][j]); @@ -93,7 +93,7 @@ int main(int argc, char *argv[]) { questions_rxfilename = po.GetArg(2), questions_out_filename = po.GetArg(3); - HmmTopology topo; // just needed for checking, and to get the + Topology topo; // just needed for checking, and to get the // largest number of pdf-classes for any phone. ReadKaldiObject(topo_filename, &topo); diff --git a/src/bin/compile-train-graphs-fsts.cc b/src/bin/compile-train-graphs-fsts.cc index 00ec1038943..473887538ae 100644 --- a/src/bin/compile-train-graphs-fsts.cc +++ b/src/bin/compile-train-graphs-fsts.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/training-graph-compiler.h" @@ -80,7 +80,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; // the tree. ReadKaldiObject(tree_rxfilename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_rxfilename, &trans_model); // need VectorFst because we will change it by adding subseq symbol. diff --git a/src/bin/compile-train-graphs.cc b/src/bin/compile-train-graphs.cc index 874d079376e..a0722c920b4 100644 --- a/src/bin/compile-train-graphs.cc +++ b/src/bin/compile-train-graphs.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/training-graph-compiler.h" @@ -74,7 +74,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; // the tree. ReadKaldiObject(tree_rxfilename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_rxfilename, &trans_model); // need VectorFst because we will change it by adding subseq symbol. diff --git a/src/bin/convert-ali.cc b/src/bin/convert-ali.cc index 89fe838638c..7daeb40ca53 100644 --- a/src/bin/convert-ali.cc +++ b/src/bin/convert-ali.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/tree-accu.h" // for ReadPhoneMap @@ -48,7 +48,7 @@ int main(int argc, char *argv[]) { "old-integer-id new-integer-id)"); po.Register("reorder", &reorder, "True if you want the converted alignments to be 'reordered' " - "versus the way they appear in the HmmTopology object"); + "versus the way they appear in the Topology object"); po.Register("repeat-frames", &repeat_frames, "Only relevant when frame-subsampling-factor != 1. If true, " "repeat frames of alignment by 'frame-subsampling-factor' " @@ -79,10 +79,10 @@ int main(int argc, char *argv[]) { SequentialInt32VectorReader alignment_reader(old_alignments_rspecifier); Int32VectorWriter alignment_writer(new_alignments_wspecifier); - TransitionModel old_trans_model; + Transitions old_trans_model; ReadKaldiObject(old_model_filename, &old_trans_model); - TransitionModel new_trans_model; + Transitions new_trans_model; ReadKaldiObject(new_model_filename, &new_trans_model); if (!(old_trans_model.GetTopo() == new_trans_model.GetTopo())) diff --git a/src/bin/copy-gselect.cc b/src/bin/copy-gselect.cc index e6c92013b58..ee427d59b8e 100644 --- a/src/bin/copy-gselect.cc +++ b/src/bin/copy-gselect.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { diff --git a/src/bin/copy-transition-model.cc b/src/bin/copy-transition-model.cc index 62a5d0c51dd..b05c64d28bf 100644 --- a/src/bin/copy-transition-model.cc +++ b/src/bin/copy-transition-model.cc @@ -17,7 +17,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fst/fstlib.h" #include "util/common-utils.h" @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { transition_model_wxfilename = po.GetArg(2); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(transition_model_rxfilename, &trans_model); WriteKaldiObject(trans_model, transition_model_wxfilename, binary); diff --git a/src/bin/copy-tree.cc b/src/bin/copy-tree.cc index c412366b151..69ab0c309ad 100644 --- a/src/bin/copy-tree.cc +++ b/src/bin/copy-tree.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "tree/context-dep.h" #include "tree/clusterable-classes.h" #include "util/text-utils.h" diff --git a/src/nnetbin/cuda-gpu-available.cc b/src/bin/cuda-gpu-available.cc similarity index 100% rename from src/nnetbin/cuda-gpu-available.cc rename to src/bin/cuda-gpu-available.cc diff --git a/src/bin/decode-faster-mapped.cc b/src/bin/decode-faster-mapped.cc index c7411592504..4606933411f 100644 --- a/src/bin/decode-faster-mapped.cc +++ b/src/bin/decode-faster-mapped.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/faster-decoder.h" #include "decoder/decodable-matrix.h" @@ -67,7 +67,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetArg(4), alignment_wspecifier = po.GetOptArg(5); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); Int32VectorWriter words_writer(words_wspecifier); diff --git a/src/bin/decode-faster.cc b/src/bin/decode-faster.cc index cbcdb771d56..a1e112b129f 100644 --- a/src/bin/decode-faster.cc +++ b/src/bin/decode-faster.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/faster-decoder.h" #include "decoder/decodable-matrix.h" diff --git a/src/bin/est-mllt.cc b/src/bin/est-mllt.cc index 48021304b80..2a01f0dbb78 100644 --- a/src/bin/est-mllt.cc +++ b/src/bin/est-mllt.cc @@ -20,7 +20,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/mllt.h" int main(int argc, char *argv[]) { diff --git a/src/bin/get-post-on-ali.cc b/src/bin/get-post-on-ali.cc index 6d6dfd0d3df..471bbfbfff2 100644 --- a/src/bin/get-post-on-ali.cc +++ b/src/bin/get-post-on-ali.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" diff --git a/src/bin/hmm-info.cc b/src/bin/hmm-info.cc index 4ece5e88171..30d6f999c8e 100644 --- a/src/bin/hmm-info.cc +++ b/src/bin/hmm-info.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { @@ -43,7 +43,7 @@ int main(int argc, char *argv[]) { std::string model_in_filename = po.GetArg(1); - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/bin/latgen-faster-mapped-parallel.cc b/src/bin/latgen-faster-mapped-parallel.cc index 4479ec8b73e..415fd1a3584 100644 --- a/src/bin/latgen-faster-mapped-parallel.cc +++ b/src/bin/latgen-faster-mapped-parallel.cc @@ -24,7 +24,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" @@ -74,7 +74,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(5), alignment_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); bool determinize = config.determinize_lattice; diff --git a/src/bin/latgen-faster-mapped.cc b/src/bin/latgen-faster-mapped.cc index 610d9aa6d7d..3a65d78be04 100644 --- a/src/bin/latgen-faster-mapped.cc +++ b/src/bin/latgen-faster-mapped.cc @@ -23,7 +23,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(5), alignment_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_in_filename, &trans_model); bool determinize = config.determinize_lattice; diff --git a/src/bin/logprob-to-post.cc b/src/bin/logprob-to-post.cc index f221580a484..0edfba0189d 100644 --- a/src/bin/logprob-to-post.cc +++ b/src/bin/logprob-to-post.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" diff --git a/src/bin/make-h-transducer.cc b/src/bin/make-h-transducer.cc index c54b9250cf7..777cab0f94d 100644 --- a/src/bin/make-h-transducer.cc +++ b/src/bin/make-h-transducer.cc @@ -16,7 +16,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "tree/context-dep.h" #include "util/common-utils.h" @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_filename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_filename, &trans_model); std::vector disambig_syms_out; diff --git a/src/bin/make-ilabel-transducer.cc b/src/bin/make-ilabel-transducer.cc index a78cefafd3a..70a5d6d4e18 100644 --- a/src/bin/make-ilabel-transducer.cc +++ b/src/bin/make-ilabel-transducer.cc @@ -16,7 +16,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "tree/context-dep.h" #include "util/common-utils.h" @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_filename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_filename, &trans_model); diff --git a/src/bin/make-pdf-to-tid-transducer.cc b/src/bin/make-pdf-to-tid-transducer.cc index 907380a974d..ad9c627e558 100644 --- a/src/bin/make-pdf-to-tid-transducer.cc +++ b/src/bin/make-pdf-to-tid-transducer.cc @@ -16,7 +16,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" @@ -47,7 +47,7 @@ int main(int argc, char *argv[]) { std::string trans_model_filename = po.GetArg(1); std::string fst_out_filename = po.GetOptArg(2); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(trans_model_filename, &trans_model); fst::VectorFst *fst = GetPdfToTransitionIdTransducer(trans_model); diff --git a/src/bin/phones-to-prons.cc b/src/bin/phones-to-prons.cc index 0d7ab12c232..23c17a58385 100644 --- a/src/bin/phones-to-prons.cc +++ b/src/bin/phones-to-prons.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" diff --git a/src/bin/post-to-pdf-post.cc b/src/bin/post-to-pdf-post.cc index 99aa5770aa5..6c2227806b4 100644 --- a/src/bin/post-to-pdf-post.cc +++ b/src/bin/post-to-pdf-post.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" @@ -50,7 +50,7 @@ int main(int argc, char *argv[]) { posteriors_rspecifier = po.GetArg(2), posteriors_wspecifier = po.GetArg(3); - TransitionModel trans_model; + Transitions trans_model; { bool binary_in; Input ki(model_rxfilename, &binary_in); diff --git a/src/bin/post-to-phone-post.cc b/src/bin/post-to-phone-post.cc index 92f67514a0f..d6ba0991924 100644 --- a/src/bin/post-to-phone-post.cc +++ b/src/bin/post-to-phone-post.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/posterior.h" int main(int argc, char *argv[]) { @@ -51,7 +51,7 @@ int main(int argc, char *argv[]) { kaldi::SequentialPosteriorReader posterior_reader(post_rspecifier); kaldi::PosteriorWriter posterior_writer(phone_post_wspecifier); - TransitionModel trans_model; + Transitions trans_model; { bool binary_in; Input ki(model_rxfilename, &binary_in); diff --git a/src/bin/post-to-tacc.cc b/src/bin/post-to-tacc.cc index afa5315d6b4..7867e9f5697 100644 --- a/src/bin/post-to-tacc.cc +++ b/src/bin/post-to-tacc.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/posterior.h" int main(int argc, char *argv[]) { @@ -61,7 +61,7 @@ int main(int argc, char *argv[]) { bool binary_in; Input ki(model_rxfilename, &binary_in); - TransitionModel trans_model; + Transitions trans_model; trans_model.Read(ki.Stream(), binary_in); num_transition_ids = trans_model.NumTransitionIds(); diff --git a/src/bin/prob-to-post.cc b/src/bin/prob-to-post.cc index 4266d34ca47..7bdff6f1e78 100644 --- a/src/bin/prob-to-post.cc +++ b/src/bin/prob-to-post.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" diff --git a/src/bin/prons-to-wordali.cc b/src/bin/prons-to-wordali.cc index a6331043500..8579c79ea02 100644 --- a/src/bin/prons-to-wordali.cc +++ b/src/bin/prons-to-wordali.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" diff --git a/src/bin/show-alignments.cc b/src/bin/show-alignments.cc index 06bc907005f..beadf1b590c 100644 --- a/src/bin/show-alignments.cc +++ b/src/bin/show-alignments.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "util/common-utils.h" #include "fst/fstlib.h" @@ -47,7 +47,7 @@ int main(int argc, char *argv[]) { model_filename = po.GetArg(2), alignments_rspecifier = po.GetArg(3); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_filename, &trans_model); fst::SymbolTable *phones_symtab = NULL; diff --git a/src/bin/show-transitions.cc b/src/bin/show-transitions.cc index bdc780b060a..db72d47f988 100644 --- a/src/bin/show-transitions.cc +++ b/src/bin/show-transitions.cc @@ -18,7 +18,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fst/fstlib.h" #include "util/common-utils.h" @@ -59,7 +59,7 @@ int main(int argc, char *argv[]) { for (size_t i = 0; i < syms->NumSymbols(); i++) names[i] = syms->Find(i); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(transition_model_filename, &trans_model); Vector occs; diff --git a/src/bin/tree-info.cc b/src/bin/tree-info.cc index ce3c5c9cfc1..a1f4f21e983 100644 --- a/src/bin/tree-info.cc +++ b/src/bin/tree-info.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "tree/context-dep.h" int main(int argc, char *argv[]) { diff --git a/src/bin/weight-silence-post.cc b/src/bin/weight-silence-post.cc index dba935d1cd3..3c8478752c8 100644 --- a/src/bin/weight-silence-post.cc +++ b/src/bin/weight-silence-post.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "hmm/posterior.h" @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { KALDI_WARN <<"No silence phones, this will have no effect"; ConstIntegerSet silence_set(silence_phones); // faster lookup. - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(model_rxfilename, &trans_model); int32 num_posteriors = 0; diff --git a/src/chain/chain-den-graph.cc b/src/chain/chain-den-graph.cc index 11c851091bd..920ade49348 100644 --- a/src/chain/chain-den-graph.cc +++ b/src/chain/chain-den-graph.cc @@ -162,7 +162,7 @@ void DenominatorGraph::GetNormalizationFst(const fst::StdVectorFst &ifst, } -void MapFstToPdfIdsPlusOne(const TransitionModel &trans_model, +void MapFstToPdfIdsPlusOne(const Transitions &trans_model, fst::StdVectorFst *fst) { int32 num_states = fst->NumStates(); for (int32 s = 0; s < num_states; s++) { @@ -295,7 +295,7 @@ static void CheckDenominatorFst(int32 num_pdfs, } void CreateDenominatorFst(const ContextDependency &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::StdVectorFst &phone_lm_in, fst::StdVectorFst *den_fst) { using fst::StdVectorFst; diff --git a/src/chain/chain-den-graph.h b/src/chain/chain-den-graph.h index b2510651f39..baf5ac2c6f1 100644 --- a/src/chain/chain-den-graph.h +++ b/src/chain/chain-den-graph.h @@ -32,7 +32,7 @@ #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" #include "chain/chain-datastruct.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-vector.h" #include "cudamatrix/cu-array.h" @@ -149,7 +149,7 @@ void MinimizeAcceptorNoPush(fst::StdVectorFst *fst); // transition-ids to pdf-ids plus one. Assumes 'fst' // is an acceptor, but does not check this (only looks at its // ilabels). -void MapFstToPdfIdsPlusOne(const TransitionModel &trans_model, +void MapFstToPdfIdsPlusOne(const Transitions &trans_model, fst::StdVectorFst *fst); // Starting from an acceptor on phones that represents some kind of compiled @@ -157,7 +157,7 @@ void MapFstToPdfIdsPlusOne(const TransitionModel &trans_model, // denominator-graph. Note: there is similar code in chain-supervision.cc, when // creating the supervision graph. void CreateDenominatorFst(const ContextDependency &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::StdVectorFst &phone_lm, fst::StdVectorFst *den_graph); diff --git a/src/chain/chain-denominator.h b/src/chain/chain-denominator.h index d76e4244ae2..9960dfede0b 100644 --- a/src/chain/chain-denominator.h +++ b/src/chain/chain-denominator.h @@ -31,7 +31,7 @@ #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-array.h" #include "chain/chain-den-graph.h" diff --git a/src/chain/chain-generic-numerator.h b/src/chain/chain-generic-numerator.h index fc5e00b2c63..8c542d6049c 100644 --- a/src/chain/chain-generic-numerator.h +++ b/src/chain/chain-generic-numerator.h @@ -32,7 +32,7 @@ #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "chain/chain-supervision.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-array.h" diff --git a/src/chain/chain-numerator.h b/src/chain/chain-numerator.h index 15cb31e0571..c4ea4774b53 100644 --- a/src/chain/chain-numerator.h +++ b/src/chain/chain-numerator.h @@ -31,7 +31,7 @@ #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "chain/chain-supervision.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-array.h" diff --git a/src/chain/chain-supervision-test.cc b/src/chain/chain-supervision-test.cc index 7ee5ee117b0..10385e2c4f2 100644 --- a/src/chain/chain-supervision-test.cc +++ b/src/chain/chain-supervision-test.cc @@ -57,7 +57,7 @@ void ComputeExamplePhoneLanguageModel(const std::vector &phones, void ComputeExampleDenFst(const ContextDependency &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, fst::StdVectorFst *den_graph) { using fst::StdVectorFst; using fst::StdArc; @@ -151,7 +151,7 @@ void TestSupervisionNumerator(const Supervision &supervision) { } -void TestSupervisionAppend(const TransitionModel &trans_model, +void TestSupervisionAppend(const Transitions &trans_model, const Supervision &supervision) { int32 num_append = RandInt(1,5); std::vector input(num_append); @@ -180,7 +180,7 @@ void TestSupervisionAppend(const TransitionModel &trans_model, output.Check(trans_model); } -void TestSupervisionReattached(const TransitionModel &trans_model, +void TestSupervisionReattached(const Transitions &trans_model, const Supervision &supervision, const Supervision &reattached_supervision) { using namespace fst; @@ -333,7 +333,7 @@ void ChainTrainingTest(const DenominatorGraph &den_graph, } void TestSupervisionSplitting(const ContextDependency &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const Supervision &supervision) { fst::StdVectorFst den_fst, normalization_fst; ComputeExampleDenFst(ctx_dep, trans_model, &den_fst); @@ -456,7 +456,7 @@ void ChainDenominatorTest(const DenominatorGraph &den_graph) { void ChainSupervisionTest() { ContextDependency *ctx_dep; - TransitionModel *trans_model = GenRandTransitionModel(&ctx_dep); + Transitions *trans_model = GenRandTransitionModel(&ctx_dep); const std::vector &phones = trans_model->GetPhones(); int32 subsample_factor = RandInt(1, 3); diff --git a/src/chain/chain-supervision.cc b/src/chain/chain-supervision.cc index f8a2c1d11cc..af28ef85a33 100644 --- a/src/chain/chain-supervision.cc +++ b/src/chain/chain-supervision.cc @@ -255,7 +255,7 @@ bool TimeEnforcerFst::GetArc(StateId s, Label ilabel, fst::StdArc* oarc) { bool TrainingGraphToSupervisionE2e( const fst::StdVectorFst &training_graph, - const TransitionModel &trans_model, + const Transitions &trans_model, int32 num_frames, Supervision *supervision) { using fst::VectorFst; @@ -292,7 +292,7 @@ bool TrainingGraphToSupervisionE2e( bool ProtoSupervisionToSupervision( const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const ProtoSupervision &proto_supervision, bool convert_to_pdfs, Supervision *supervision) { @@ -906,7 +906,7 @@ bool Supervision::operator == (const Supervision &other) const { label_dim == other.label_dim && fst::Equal(fst, other.fst); } -void Supervision::Check(const TransitionModel &trans_mdl) const { +void Supervision::Check(const Transitions &trans_mdl) const { if (weight <= 0.0) KALDI_ERR << "Weight should be positive."; if (frames_per_sequence <= 0) @@ -970,7 +970,7 @@ void GetWeightsForRanges(int32 range_length, } bool ConvertSupervisionToUnconstrained( - const TransitionModel &trans_mdl, + const Transitions &trans_mdl, Supervision *supervision) { KALDI_ASSERT(supervision->label_dim == trans_mdl.NumTransitionIds() && supervision->fst.NumStates() > 0 && diff --git a/src/chain/chain-supervision.h b/src/chain/chain-supervision.h index f1a796dc2f8..0b8a760f1e6 100644 --- a/src/chain/chain-supervision.h +++ b/src/chain/chain-supervision.h @@ -29,7 +29,7 @@ #include "util/common-utils.h" #include "lat/kaldi-lattice.h" #include "fstext/deterministic-fst.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" namespace kaldi { namespace chain { @@ -181,7 +181,7 @@ class TimeEnforcerFst: typedef fst::StdArc::StateId StateId; typedef fst::StdArc::Label Label; - TimeEnforcerFst(const TransitionModel &trans_model, + TimeEnforcerFst(const Transitions &trans_model, bool convert_to_pdfs, const std::vector > &allowed_phones): trans_model_(trans_model), @@ -204,7 +204,7 @@ class TimeEnforcerFst: virtual bool GetArc(StateId s, Label ilabel, fst::StdArc* oarc); private: - const TransitionModel &trans_model_; + const Transitions &trans_model_; // if convert_to_pdfs_ is true, this FST will map from transition-id (on the // input side) to pdf-id plus one (on the output side); if false, both sides' // labels will be transition-id. @@ -234,10 +234,10 @@ struct Supervision { // the maximum possible value of the labels in 'fst' (which go from 1 to // label_dim). For fully-processed examples this will equal the NumPdfs() in the - // TransitionModel object, but for newer-style "unconstrained" examples + // Transitions object, but for newer-style "unconstrained" examples // that have been output by chain-get-supervision but not yet processed // by nnet3-chain-get-egs, it will be the NumTransitionIds() of the - // TransitionModel object. + // Transitions object. int32 label_dim; // This is an epsilon-free unweighted acceptor that is sorted in increasing @@ -297,7 +297,7 @@ struct Supervision { // This function checks that this supervision object satifsies some // of the properties we expect of it, and calls KALDI_ERR if not. - void Check(const TransitionModel &trans_model) const; + void Check(const Transitions &trans_model) const; void Write(std::ostream &os, bool binary) const; void Read(std::istream &is, bool binary); @@ -317,7 +317,7 @@ struct Supervision { */ bool ProtoSupervisionToSupervision( const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const ProtoSupervision &proto_supervision, bool convert_to_pdfs, Supervision *supervision); @@ -333,7 +333,7 @@ bool ProtoSupervisionToSupervision( */ bool TrainingGraphToSupervisionE2e( const fst::StdVectorFst& training_graph, - const TransitionModel& trans_model, + const Transitions& trans_model, int32 num_frames, Supervision *supervision); @@ -484,7 +484,7 @@ void GetWeightsForRanges(int32 range_length, /// It returns true on success, and false if some kind of error happened /// (this is not expected). bool ConvertSupervisionToUnconstrained( - const TransitionModel &trans_mdl, + const Transitions &trans_mdl, Supervision *supervision); diff --git a/src/chain/chain-training.h b/src/chain/chain-training.h index 6ea70b5ca41..7dbc1a058c2 100644 --- a/src/chain/chain-training.h +++ b/src/chain/chain-training.h @@ -31,7 +31,7 @@ #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "chain/chain-den-graph.h" #include "chain/chain-supervision.h" diff --git a/src/chainbin/chain-get-supervision.cc b/src/chainbin/chain-get-supervision.cc index 1ac89d4630b..8a4904843be 100644 --- a/src/chainbin/chain-get-supervision.cc +++ b/src/chainbin/chain-get-supervision.cc @@ -30,7 +30,7 @@ namespace chain { // This wrapper function does all the job of processing the features and // lattice into ChainSupervision objects, and writing them out. -static bool ProcessSupervision(const TransitionModel &trans_model, +static bool ProcessSupervision(const Transitions &trans_model, const ContextDependencyInterface &ctx_dep, const ProtoSupervision &proto_sup, const std::string &key, @@ -97,7 +97,7 @@ int main(int argc, char *argv[]) { phone_durs_or_lat_rspecifier = po.GetArg(3), supervision_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(trans_model_rxfilename, &trans_model); ContextDependency ctx_dep; diff --git a/src/chainbin/chain-make-den-fst.cc b/src/chainbin/chain-make-den-fst.cc index 0d8d249242b..dc2b41a369d 100644 --- a/src/chainbin/chain-make-den-fst.cc +++ b/src/chainbin/chain-make-den-fst.cc @@ -56,7 +56,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; - TransitionModel trans_model; + Transitions trans_model; fst::StdVectorFst phone_lm; ReadKaldiObject(tree_rxfilename, &ctx_dep); diff --git a/src/chainbin/nnet3-chain-acc-lda-stats.cc b/src/chainbin/nnet3-chain-acc-lda-stats.cc index 693eb2dad86..0cf2d449d76 100644 --- a/src/chainbin/nnet3-chain-acc-lda-stats.cc +++ b/src/chainbin/nnet3-chain-acc-lda-stats.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "lat/lattice-functions.h" #include "nnet3/nnet-nnet.h" #include "nnet3/nnet-chain-example.h" diff --git a/src/chainbin/nnet3-chain-copy-egs.cc b/src/chainbin/nnet3-chain-copy-egs.cc index 0117fe2200f..46744b239d0 100644 --- a/src/chainbin/nnet3-chain-copy-egs.cc +++ b/src/chainbin/nnet3-chain-copy-egs.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "nnet3/nnet-chain-example.h" namespace kaldi { diff --git a/src/chainbin/nnet3-chain-e2e-get-egs.cc b/src/chainbin/nnet3-chain-e2e-get-egs.cc index 8cdda8deb32..31b14cb7b0f 100644 --- a/src/chainbin/nnet3-chain-e2e-get-egs.cc +++ b/src/chainbin/nnet3-chain-e2e-get-egs.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "hmm/posterior.h" #include "nnet3/nnet-example.h" @@ -74,7 +74,7 @@ static int32 FindMinimumLengthPath( */ static bool ProcessFile(const ExampleGenerationConfig &opts, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::StdVectorFst &normalization_fst, const MatrixBase &feats, const MatrixBase *ivector_feats, @@ -285,7 +285,7 @@ int main(int argc, char *argv[]) { KALDI_ASSERT(normalization_fst.NumStates() > 0); } - TransitionModel trans_model; + Transitions trans_model; ReadKaldiObject(trans_model_rxfilename, &trans_model); RandomAccessBaseFloatMatrixReader feat_reader(feature_rspecifier); diff --git a/src/chainbin/nnet3-chain-get-egs.cc b/src/chainbin/nnet3-chain-get-egs.cc index 1032b7e2125..2c506c5b460 100644 --- a/src/chainbin/nnet3-chain-get-egs.cc +++ b/src/chainbin/nnet3-chain-get-egs.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/posterior.h" #include "nnet3/nnet-example.h" #include "nnet3/nnet-chain-example.h" @@ -86,7 +86,7 @@ namespace nnet3 { **/ -static bool ProcessFile(const TransitionModel *trans_mdl, +static bool ProcessFile(const Transitions *trans_mdl, const fst::StdVectorFst &normalization_fst, const GeneralMatrix &feats, const MatrixBase *ivector_feats, @@ -345,8 +345,8 @@ int main(int argc, char *argv[]) { UtteranceSplitter utt_splitter(eg_config); - const TransitionModel *trans_mdl_ptr = NULL; - TransitionModel trans_mdl; + const Transitions *trans_mdl_ptr = NULL; + Transitions trans_mdl; if (!trans_mdl_rxfilename.empty()) { ReadKaldiObject(trans_mdl_rxfilename, &trans_mdl); diff --git a/src/chainbin/nnet3-chain-merge-egs.cc b/src/chainbin/nnet3-chain-merge-egs.cc index a3686d2fc30..14bdbe55115 100644 --- a/src/chainbin/nnet3-chain-merge-egs.cc +++ b/src/chainbin/nnet3-chain-merge-egs.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "nnet3/nnet-chain-example.h" diff --git a/src/chainbin/nnet3-chain-normalize-egs.cc b/src/chainbin/nnet3-chain-normalize-egs.cc index a97797e3246..70f6852e963 100644 --- a/src/chainbin/nnet3-chain-normalize-egs.cc +++ b/src/chainbin/nnet3-chain-normalize-egs.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "nnet3/nnet-chain-example.h" #include "chain/chain-supervision.h" diff --git a/src/chainbin/nnet3-chain-shuffle-egs.cc b/src/chainbin/nnet3-chain-shuffle-egs.cc index 7ab6e28f607..94ba30799b0 100644 --- a/src/chainbin/nnet3-chain-shuffle-egs.cc +++ b/src/chainbin/nnet3-chain-shuffle-egs.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "nnet3/nnet-chain-example.h" int main(int argc, char *argv[]) { diff --git a/src/decoder/decodable-matrix.cc b/src/decoder/decodable-matrix.cc index 3cc7b87f2d7..98cd75d1ede 100644 --- a/src/decoder/decodable-matrix.cc +++ b/src/decoder/decodable-matrix.cc @@ -22,7 +22,7 @@ namespace kaldi { DecodableMatrixMapped::DecodableMatrixMapped( - const TransitionModel &tm, + const Transitions &tm, const MatrixBase &likes, int32 frame_offset): trans_model_(tm), likes_(&likes), likes_to_delete_(NULL), @@ -32,12 +32,12 @@ DecodableMatrixMapped::DecodableMatrixMapped( if (likes.NumCols() != tm.NumPdfs()) KALDI_ERR << "Mismatch, matrix has " - << likes.NumCols() << " rows but transition-model has " + << likes.NumCols() << " rows but transitions.has " << tm.NumPdfs() << " pdf-ids."; } DecodableMatrixMapped::DecodableMatrixMapped( - const TransitionModel &tm, const Matrix *likes, + const Transitions &tm, const Matrix *likes, int32 frame_offset): trans_model_(tm), likes_(likes), likes_to_delete_(likes), frame_offset_(frame_offset) { @@ -45,7 +45,7 @@ DecodableMatrixMapped::DecodableMatrixMapped( raw_data_ = likes->Data() - (stride_ * frame_offset_); if (likes->NumCols() != tm.NumPdfs()) KALDI_ERR << "Mismatch, matrix has " - << likes->NumCols() << " rows but transition-model has " + << likes->NumCols() << " rows but transitions.has " << tm.NumPdfs() << " pdf-ids."; } diff --git a/src/decoder/decodable-matrix.h b/src/decoder/decodable-matrix.h index 475638a35af..5e9642ee6b9 100644 --- a/src/decoder/decodable-matrix.h +++ b/src/decoder/decodable-matrix.h @@ -24,7 +24,7 @@ #include #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "itf/decodable-itf.h" #include "matrix/kaldi-matrix.h" @@ -34,26 +34,26 @@ namespace kaldi { class DecodableMatrixScaledMapped: public DecodableInterface { public: // This constructor creates an object that will not delete "likes" when done. - DecodableMatrixScaledMapped(const TransitionModel &tm, + DecodableMatrixScaledMapped(const Transitions &tm, const Matrix &likes, BaseFloat scale): trans_model_(tm), likes_(&likes), scale_(scale), delete_likes_(false) { if (likes.NumCols() != tm.NumPdfs()) KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " - << likes.NumCols() << " rows but transition-model has " + << likes.NumCols() << " rows but transitions.has " << tm.NumPdfs() << " pdf-ids."; } // This constructor creates an object that will delete "likes" // when done. - DecodableMatrixScaledMapped(const TransitionModel &tm, + DecodableMatrixScaledMapped(const Transitions &tm, BaseFloat scale, const Matrix *likes): trans_model_(tm), likes_(likes), scale_(scale), delete_likes_(true) { if (likes->NumCols() != tm.NumPdfs()) KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " - << likes->NumCols() << " rows but transition-model has " + << likes->NumCols() << " rows but transitions.has " << tm.NumPdfs() << " pdf-ids."; } @@ -76,7 +76,7 @@ class DecodableMatrixScaledMapped: public DecodableInterface { if (delete_likes_) delete likes_; } private: - const TransitionModel &trans_model_; // for tid to pdf mapping + const Transitions &trans_model_; // for tid to pdf mapping const Matrix *likes_; BaseFloat scale_; bool delete_likes_; @@ -100,13 +100,13 @@ class DecodableMatrixMapped: public DecodableInterface { // This constructor creates an object that will not delete "likes" when done. // the frame_offset is the frame the row 0 of 'likes' corresponds to, would be // greater than one if this is not the first chunk of likelihoods. - DecodableMatrixMapped(const TransitionModel &tm, + DecodableMatrixMapped(const Transitions &tm, const MatrixBase &likes, int32 frame_offset = 0); // This constructor creates an object that will delete "likes" // when done. - DecodableMatrixMapped(const TransitionModel &tm, + DecodableMatrixMapped(const Transitions &tm, const Matrix *likes, int32 frame_offset = 0); @@ -122,7 +122,7 @@ class DecodableMatrixMapped: public DecodableInterface { virtual ~DecodableMatrixMapped(); private: - const TransitionModel &trans_model_; // for tid to pdf mapping + const Transitions &trans_model_; // for tid to pdf mapping const MatrixBase *likes_; const Matrix *likes_to_delete_; int32 frame_offset_; @@ -151,7 +151,7 @@ class DecodableMatrixMapped: public DecodableInterface { */ class DecodableMatrixMappedOffset: public DecodableInterface { public: - DecodableMatrixMappedOffset(const TransitionModel &tm): + DecodableMatrixMappedOffset(const Transitions &tm): trans_model_(tm), frame_offset_(0), input_is_finished_(false) { } virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); } @@ -194,7 +194,7 @@ class DecodableMatrixMappedOffset: public DecodableInterface { // nothing special to do in destructor. virtual ~DecodableMatrixMappedOffset() { } private: - const TransitionModel &trans_model_; // for tid to pdf mapping + const Transitions &trans_model_; // for tid to pdf mapping Matrix loglikes_; int32 frame_offset_; bool input_is_finished_; diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index ff573c74d15..71799a5b700 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -32,7 +32,7 @@ namespace kaldi { DecodeUtteranceLatticeFasterClass::DecodeUtteranceLatticeFasterClass( LatticeFasterDecoder *decoder, DecodableInterface *decodable, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, BaseFloat acoustic_scale, @@ -201,7 +201,7 @@ template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, @@ -299,7 +299,7 @@ bool DecodeUtteranceLatticeFaster( template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl > &decoder, DecodableInterface &decodable, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, @@ -314,7 +314,7 @@ template bool DecodeUtteranceLatticeFaster( template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl &decoder, DecodableInterface &decodable, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, @@ -331,7 +331,7 @@ template bool DecodeUtteranceLatticeFaster( bool DecodeUtteranceLatticeSimple( LatticeSimpleDecoder &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index fc81137f356..c2c357c2629 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -103,7 +103,7 @@ template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, @@ -129,7 +129,7 @@ class DecodeUtteranceLatticeFasterClass { DecodeUtteranceLatticeFasterClass( LatticeFasterDecoder *decoder, DecodableInterface *decodable, - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, BaseFloat acoustic_scale, @@ -150,7 +150,7 @@ class DecodeUtteranceLatticeFasterClass { // The following variables correspond to inputs: LatticeFasterDecoder *decoder_; DecodableInterface *decodable_; - const TransitionModel *trans_model_; + const Transitions *trans_model_; const fst::SymbolTable *word_syms_; std::string utt_; BaseFloat acoustic_scale_; @@ -183,7 +183,7 @@ class DecodeUtteranceLatticeFasterClass { bool DecodeUtteranceLatticeSimple( LatticeSimpleDecoder &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, diff --git a/src/decoder/training-graph-compiler.cc b/src/decoder/training-graph-compiler.cc index 191d02f1720..db1a75f7a25 100644 --- a/src/decoder/training-graph-compiler.cc +++ b/src/decoder/training-graph-compiler.cc @@ -23,7 +23,7 @@ namespace kaldi { -TrainingGraphCompiler::TrainingGraphCompiler(const TransitionModel &trans_model, +TrainingGraphCompiler::TrainingGraphCompiler(const Transitions &trans_model, const ContextDependency &ctx_dep, // Does not maintain reference to this. fst::VectorFst *lex_fst, const std::vector &disambig_syms, diff --git a/src/decoder/training-graph-compiler.h b/src/decoder/training-graph-compiler.h index ee56c6dfb3d..600844b8b8a 100644 --- a/src/decoder/training-graph-compiler.h +++ b/src/decoder/training-graph-compiler.h @@ -21,7 +21,7 @@ #define KALDI_DECODER_TRAINING_GRAPH_COMPILER_H_ #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fst/fstlib.h" #include "fstext/fstext-lib.h" #include "tree/context-dep.h" @@ -58,7 +58,7 @@ struct TrainingGraphCompilerOptions { class TrainingGraphCompiler { public: - TrainingGraphCompiler(const TransitionModel &trans_model, // Maintains reference to this object. + TrainingGraphCompiler(const Transitions &trans_model, // Maintains reference to this object. const ContextDependency &ctx_dep, // And this. fst::VectorFst *lex_fst, // Takes ownership of this object. // It should not contain disambiguation symbols or subsequential symbol, @@ -93,7 +93,7 @@ class TrainingGraphCompiler { ~TrainingGraphCompiler() { delete lex_fst_; } private: - const TransitionModel &trans_model_; + const Transitions &trans_model_; const ContextDependency &ctx_dep_; fst::VectorFst *lex_fst_; // lexicon FST (an input; we take // ownership as we need to modify it). diff --git a/src/feat/Makefile b/src/feat/Makefile index dcd029f7f94..9850e578d9a 100644 --- a/src/feat/Makefile +++ b/src/feat/Makefile @@ -4,12 +4,12 @@ all: include ../kaldi.mk -TESTFILES = feature-mfcc-test feature-plp-test feature-fbank-test \ +TESTFILES = feature-mfcc-test feature-fbank-test \ feature-functions-test pitch-functions-test feature-sdc-test \ resample-test online-feature-test signal-test wave-reader-test -OBJFILES = feature-functions.o feature-mfcc.o feature-plp.o feature-fbank.o \ - feature-spectrogram.o mel-computations.o wave-reader.o \ +OBJFILES = feature-functions.o feature-mfcc.o feature-fbank.o \ + mel-computations.o wave-reader.o \ pitch-functions.o resample.o online-feature.o signal.o \ feature-window.o @@ -17,6 +17,6 @@ LIBNAME = kaldi-feat ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/feat/feature-common-inl.h b/src/feat/feature-common-inl.h index 26127a4dc4d..10bfe5cdfd1 100644 --- a/src/feat/feature-common-inl.h +++ b/src/feat/feature-common-inl.h @@ -70,15 +70,12 @@ void OfflineFeatureTpl::Compute( } output->Resize(rows_out, cols_out); Vector window; // windowed waveform. - bool use_raw_log_energy = computer_.NeedRawLogEnergy(); for (int32 r = 0; r < rows_out; r++) { // r is frame index. - BaseFloat raw_log_energy = 0.0; ExtractWindow(0, wave, r, computer_.GetFrameOptions(), - feature_window_function_, &window, - (use_raw_log_energy ? &raw_log_energy : NULL)); + feature_window_function_, &window); SubVector output_row(*output, r); - computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); + computer_.Compute(vtln_warp, &window, &output_row); } } diff --git a/src/feat/feature-common.h b/src/feat/feature-common.h index 45911cef585..664806beb49 100644 --- a/src/feat/feature-common.h +++ b/src/feat/feature-common.h @@ -115,8 +115,10 @@ class OfflineFeatureTpl { // Note: feature_window_function_ is the windowing function, which initialized // using the options class, that we cache at this level. OfflineFeatureTpl(const Options &opts): - computer_(opts), - feature_window_function_(computer_.GetFrameOptions()) { } + computer_(opts) { + InitFeatureWindowFunction(computer_.GetFrameOptions(), + &feature_window_function_); + } // Internal (and back-compatibility) interface for computing features, which // requires that the user has already checked that the sampling frequency @@ -164,7 +166,7 @@ class OfflineFeatureTpl { OfflineFeatureTpl &operator =(const OfflineFeatureTpl &other); F computer_; - FeatureWindowFunction feature_window_function_; + Vector feature_window_function_; }; /// @} End of "addtogroup feat" diff --git a/src/feat/feature-fbank.cc b/src/feat/feature-fbank.cc index 10f7e67d607..8becf6a8141 100644 --- a/src/feat/feature-fbank.cc +++ b/src/feat/feature-fbank.cc @@ -25,8 +25,7 @@ namespace kaldi { FbankComputer::FbankComputer(const FbankOptions &opts): opts_(opts), srfft_(NULL) { - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); + KALDI_ASSERT(opts.energy_floor > 0.0 && "Nonzero energy floor is required."); int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... @@ -38,7 +37,7 @@ FbankComputer::FbankComputer(const FbankOptions &opts): } FbankComputer::FbankComputer(const FbankComputer &other): - opts_(other.opts_), log_energy_floor_(other.log_energy_floor_), + opts_(other.opts_), mel_banks_(other.mel_banks_), srfft_(NULL) { for (std::map::iterator iter = mel_banks_.begin(); iter != mel_banks_.end(); @@ -69,8 +68,7 @@ const MelBanks* FbankComputer::GetMelBanks(BaseFloat vtln_warp) { return this_mel_banks; } -void FbankComputer::Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, +void FbankComputer::Compute(BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { @@ -80,10 +78,10 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, feature->Dim() == this->Dim()); - // Compute energy after window function (not the raw one). - if (opts_.use_energy && !opts_.raw_energy) + BaseFloat signal_log_energy; + if (opts_.use_energy) signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); + opts_.energy_floor)); if (srfft_ != NULL) // Compute FFT using split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -95,30 +93,20 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, SubVector power_spectrum(*signal_frame, 0, signal_frame->Dim() / 2 + 1); - // Use magnitude instead of power if requested. - if (!opts_.use_power) - power_spectrum.ApplyPow(0.5); - - int32 mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + int32 mel_offset = (opts_.use_energy ? 1 : 0); SubVector mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins); // Sum with mel fiterbanks over the power spectrum mel_banks.Compute(power_spectrum, &mel_energies); - if (opts_.use_log_fbank) { - // Avoid log of zero (which should be prevented anyway by dithering). - mel_energies.ApplyFloor(std::numeric_limits::epsilon()); - mel_energies.ApplyLog(); // take the log. - } - // Copy energy as first value (or the last, if htk_compat == true). + mel_energies.ApplyFloor(opts_.energy_floor); + mel_energies.ApplyLog(); // take the log. + + // Copy energy as first value if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_log_energy < log_energy_floor_) { - signal_log_energy = log_energy_floor_; - } - int32 energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; - (*feature)(energy_index) = signal_log_energy; + (*feature)(0) = signal_log_energy; } } diff --git a/src/feat/feature-fbank.h b/src/feat/feature-fbank.h index 724d7d148dc..04421c506b6 100644 --- a/src/feat/feature-fbank.h +++ b/src/feat/feature-fbank.h @@ -42,41 +42,24 @@ struct FbankOptions { FrameExtractionOptions frame_opts; MelBanksOptions mel_opts; bool use_energy; // append an extra dimension with energy to the filter banks - BaseFloat energy_floor; - bool raw_energy; // If true, compute energy before preemphasis and windowing - bool htk_compat; // If true, put energy last (if using energy) - bool use_log_fbank; // if true (default), produce log-filterbank, else linear - bool use_power; // if true (default), use power in filterbank analysis, else magnitude. + BaseFloat energy_floor; // Floor on energy, to avoid log(0.0). The floor of + // 1e-10 may be interpreted as (approximately) + // 0.1 * 2**-30. The smallest nonzero value in a 16-bit + // waveform would be 1^-15, and 1^-30 is its square. FbankOptions(): mel_opts(23), - // defaults the #mel-banks to 23 for the FBANK computations. - // this seems to be common for 16khz-sampled data, - // but for 8khz-sampled data, 15 may be better. - use_energy(false), - energy_floor(0.0), - raw_energy(true), - htk_compat(false), - use_log_fbank(true), - use_power(true) {} + use_energy(false), + energy_floor(1.0e-10) { } void Register(OptionsItf *opts) { frame_opts.Register(opts); mel_opts.Register(opts); opts->Register("use-energy", &use_energy, - "Add an extra dimension with energy to the FBANK output."); + "Add an extra dimension with energy to the filterbank " + "output."); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in FBANK computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("htk-compat", &htk_compat, "If true, put energy last. " - "Warning: not sufficient to get HTK compatible features (need " - "to change other parameters)."); - opts->Register("use-log-fbank", &use_log_fbank, - "If true, produce log-filterbank, else produce linear."); - opts->Register("use-power", &use_power, - "If true, use power, else use magnitude."); + "Floor on energy (absolute, not relative) in filterbank " + "computation."); } }; @@ -94,8 +77,6 @@ class FbankComputer { return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); } - bool NeedRawLogEnergy() { return opts_.use_energy && opts_.raw_energy; } - const FrameExtractionOptions &GetFrameOptions() const { return opts_.frame_opts; } @@ -104,11 +85,6 @@ class FbankComputer { Function that computes one frame of features from one frame of signal. - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). @param [in] vtln_warp The VTLN warping factor that the user wants to be applied when computing features for this utterance. Will normally be 1.0, meaning no warping is to be done. The value will @@ -121,8 +97,7 @@ class FbankComputer { @param [out] feature Pointer to a vector of size this->Dim(), to which the computed feature will be written. */ - void Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, + void Compute(BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature); @@ -133,7 +108,6 @@ class FbankComputer { FbankOptions opts_; - BaseFloat log_energy_floor_; std::map mel_banks_; // BaseFloat is VTLN coefficient. SplitRadixRealFft *srfft_; // Disallow assignment. diff --git a/src/feat/feature-mfcc-test.cc b/src/feat/feature-mfcc-test.cc index c4367139707..305ac5abe50 100644 --- a/src/feat/feature-mfcc-test.cc +++ b/src/feat/feature-mfcc-test.cc @@ -88,14 +88,10 @@ static void UnitTestSimple() { // the parametrization object MfccOptions op; // trying to have same opts as baseline. - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; op.frame_opts.window_type = "rectangular"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; Mfcc mfcc(op); // use default parameters @@ -129,14 +125,10 @@ static void UnitTestHTKCompare1() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; op.use_energy = false; // C0 not energy. Mfcc mfcc(op); @@ -188,7 +180,7 @@ static void UnitTestHTKCompare1() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.1"); } @@ -213,14 +205,10 @@ static void UnitTestHTKCompare2() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; op.use_energy = true; // Use energy. Mfcc mfcc(op); @@ -272,7 +260,7 @@ static void UnitTestHTKCompare2() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.2"); } @@ -297,16 +285,11 @@ static void UnitTestHTKCompare3() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; op.use_energy = true; // Use energy. op.mel_opts.low_freq = 20.0; - //op.mel_opts.debug_mel = true; - op.mel_opts.htk_mode = true; Mfcc mfcc(op); @@ -357,7 +340,7 @@ static void UnitTestHTKCompare3() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.3"); } @@ -382,14 +365,11 @@ static void UnitTestHTKCompare4() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; op.mel_opts.low_freq = 0.0; - op.htk_compat = true; op.use_energy = true; // Use energy. - op.mel_opts.htk_mode = true; Mfcc mfcc(op); @@ -440,7 +420,7 @@ static void UnitTestHTKCompare4() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.4"); } @@ -465,16 +445,13 @@ static void UnitTestHTKCompare5() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; op.use_energy = true; // Use energy. op.mel_opts.low_freq = 0.0; op.mel_opts.vtln_low = 100.0; op.mel_opts.vtln_high = 7500.0; - op.mel_opts.htk_mode = true; BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1, // differs slightly for higher mel bins if warp_factor <0.9 @@ -528,7 +505,7 @@ static void UnitTestHTKCompare5() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.5"); } @@ -553,15 +530,12 @@ static void UnitTestHTKCompare6() { // use mfcc with default configuration... MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.97; op.frame_opts.window_type = "hamming"; op.frame_opts.remove_dc_offset = false; op.frame_opts.round_to_power_of_two = true; op.mel_opts.num_bins = 24; op.mel_opts.low_freq = 125.0; op.mel_opts.high_freq = 7800.0; - op.htk_compat = true; op.use_energy = false; // C0 not energy. Mfcc mfcc(op); @@ -613,7 +587,7 @@ static void UnitTestHTKCompare6() { } std::cout << "Test passed :)\n\n"; - + unlink("tmp.test.wav.fea_kaldi.6"); } @@ -682,5 +656,3 @@ int main() { return 1; } } - - diff --git a/src/feat/feature-mfcc.cc b/src/feat/feature-mfcc.cc index 899988c2822..ffa3b5450b5 100644 --- a/src/feat/feature-mfcc.cc +++ b/src/feat/feature-mfcc.cc @@ -25,18 +25,18 @@ namespace kaldi { -void MfccComputer::Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, +void MfccComputer::Compute(BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && feature->Dim() == this->Dim()); - const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); - - if (opts_.use_energy && !opts_.raw_energy) + BaseFloat signal_log_energy; + if (opts_.use_energy) signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); + opts_.energy_floor)); + + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); if (srfft_ != NULL) // Compute FFT using the split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -50,33 +50,15 @@ void MfccComputer::Compute(BaseFloat signal_log_energy, mel_banks.Compute(power_spectrum, &mel_energies_); - // avoid log of zero (which should be prevented anyway by dithering). - mel_energies_.ApplyFloor(std::numeric_limits::epsilon()); - mel_energies_.ApplyLog(); // take the log. + mel_energies_.ApplyFloor(opts_.energy_floor); + mel_energies_.ApplyLog(); feature->SetZero(); // in case there were NaNs. // feature = dct_matrix_ * mel_energies [which now have log] feature->AddMatVec(1.0, dct_matrix_, kNoTrans, mel_energies_, 0.0); - if (opts_.cepstral_lifter != 0.0) - feature->MulElements(lifter_coeffs_); - - if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_log_energy < log_energy_floor_) - signal_log_energy = log_energy_floor_; + if (opts_.use_energy) (*feature)(0) = signal_log_energy; - } - - if (opts_.htk_compat) { - BaseFloat energy = (*feature)(0); - for (int32 i = 0; i < opts_.num_ceps - 1; i++) - (*feature)(i) = (*feature)(i+1); - if (!opts_.use_energy) - energy *= M_SQRT2; // scale on C0 (actually removing a scale - // we previously added that's part of one common definition of - // the cosine transform.) - (*feature)(opts_.num_ceps - 1) = energy; - } } MfccComputer::MfccComputer(const MfccOptions &opts): @@ -98,12 +80,6 @@ MfccComputer::MfccComputer(const MfccOptions &opts): SubMatrix dct_rows(dct_matrix, 0, opts.num_ceps, 0, num_bins); dct_matrix_.Resize(opts.num_ceps, num_bins); dct_matrix_.CopyFromMat(dct_rows); // subset of rows. - if (opts.cepstral_lifter != 0.0) { - lifter_coeffs_.Resize(opts.num_ceps); - ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); - } - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... @@ -117,7 +93,6 @@ MfccComputer::MfccComputer(const MfccOptions &opts): MfccComputer::MfccComputer(const MfccComputer &other): opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), dct_matrix_(other.dct_matrix_), - log_energy_floor_(other.log_energy_floor_), mel_banks_(other.mel_banks_), srfft_(NULL), mel_energies_(other.mel_energies_.Dim(), kUndefined) { diff --git a/src/feat/feature-mfcc.h b/src/feat/feature-mfcc.h index 66c52e89821..83aea3fb9bb 100644 --- a/src/feat/feature-mfcc.h +++ b/src/feat/feature-mfcc.h @@ -1,7 +1,7 @@ // feat/feature-mfcc.h // Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University -// 2014-2016 Johns Hopkins University (author: Daniel Povey) +// 2014-2019 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -39,25 +39,14 @@ struct MfccOptions { FrameExtractionOptions frame_opts; MelBanksOptions mel_opts; int32 num_ceps; // e.g. 13: num cepstral coeffs, counting zero. - bool use_energy; // use energy; else C0 - BaseFloat energy_floor; // 0 by default; set to a value like 1.0 or 0.1 if - // you disable dithering. - bool raw_energy; // If true, compute energy before preemphasis and windowing - BaseFloat cepstral_lifter; // Scaling factor on cepstra for HTK compatibility. - // if 0.0, no liftering is done. - bool htk_compat; // if true, put energy/C0 last and introduce a factor of - // sqrt(2) on C0 to be the same as HTK. + bool use_energy; // if true, use energy; else C0 + BaseFloat energy_floor; MfccOptions() : mel_opts(23), - // defaults the #mel-banks to 23 for the MFCC computations. - // this seems to be common for 16khz-sampled data, - // but for 8khz-sampled data, 15 may be better. num_ceps(13), use_energy(true), - energy_floor(0.0), - raw_energy(true), - cepstral_lifter(22.0), - htk_compat(false) {} + energy_floor(1.0e-10) { } + void Register(OptionsItf *opts) { frame_opts.Register(opts); @@ -67,17 +56,8 @@ struct MfccOptions { opts->Register("use-energy", &use_energy, "Use energy (not C0) in MFCC computation"); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in MFCC computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("cepstral-lifter", &cepstral_lifter, - "Constant that controls scaling of MFCCs"); - opts->Register("htk-compat", &htk_compat, - "If true, put energy or C0 last and use a factor of sqrt(2) on " - "C0. Warning: not sufficient to get HTK compatible features " - "(need to change other parameters)."); + "Floor on energy (absolute, not relative) of mel bins etc. " + "in MFCC computation. "); } }; @@ -96,17 +76,10 @@ class MfccComputer { int32 Dim() const { return opts_.num_ceps; } - bool NeedRawLogEnergy() { return opts_.use_energy && opts_.raw_energy; } - /** Function that computes one frame of features from one frame of signal. - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). @param [in] vtln_warp The VTLN warping factor that the user wants to be applied when computing features for this utterance. Will normally be 1.0, meaning no warping is to be done. The value will @@ -119,8 +92,7 @@ class MfccComputer { @param [out] feature Pointer to a vector of size this->Dim(), to which the computed feature will be written. */ - void Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, + void Compute(BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature); @@ -134,7 +106,6 @@ class MfccComputer { MfccOptions opts_; Vector lifter_coeffs_; Matrix dct_matrix_; // matrix we left-multiply by to perform DCT. - BaseFloat log_energy_floor_; std::map mel_banks_; // BaseFloat is VTLN coefficient. SplitRadixRealFft *srfft_; diff --git a/src/feat/feature-plp-test.cc b/src/feat/feature-plp-test.cc deleted file mode 100644 index ad872cffcd0..00000000000 --- a/src/feat/feature-plp-test.cc +++ /dev/null @@ -1,177 +0,0 @@ -// feat/feature-plp-test.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include - -#include "feat/feature-plp.h" -#include "base/kaldi-math.h" -#include "matrix/kaldi-matrix-inl.h" -#include "feat/wave-reader.h" - -using namespace kaldi; - - - - - -/** - */ -static void UnitTestSimple() { - std::cout << "=== UnitTestSimple() ===\n"; - - Vector v(100000); - Matrix m; - - // init with noise - for (int32 i = 0; i < v.Dim(); i++) { - v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2); - } - - std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; - // the parametrization object - PlpOptions op; - // trying to have same opts as baseline. - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "rectangular"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; -// op.htk_compat = true; - - Plp plp(op); - // use default parameters - - // compute mfccs. - plp.Compute(v, 1.0, &m); - - // possibly dump - // std::cout << "== Output features == \n" << m; - std::cout << "Test passed :)\n\n"; -} - - -static void UnitTestHTKCompare1() { - std::cout << "=== UnitTestHTKCompare1() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.plp_htk.1", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use plp with default configuration... - PlpOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.htk_compat = true; - op.use_energy = false; // C0 not energy. - op.cepstral_scale = 1.0; - - Plp plp(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - plp.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 0.10) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (i_old != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021413 // PLP_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.plp_kaldi.1", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.plp_kaldi.1"); -} - - - - -static void UnitTestFeat() { - UnitTestSimple(); - UnitTestHTKCompare1(); -} - - - - -int main() { - try { - for (int i = 0; i < 5; i++) - UnitTestFeat(); - std::cout << "Tests succeeded.\n"; - return 0; - } catch (const std::exception &e) { - std::cerr << e.what(); - return 1; - } -} - - diff --git a/src/feat/feature-plp.cc b/src/feat/feature-plp.cc deleted file mode 100644 index 8f4a7d66161..00000000000 --- a/src/feat/feature-plp.cc +++ /dev/null @@ -1,191 +0,0 @@ -// feat/feature-plp.cc - -// Copyright 2009-2011 Petr Motlicek; Karel Vesely -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-plp.h" - -namespace kaldi { - -PlpComputer::PlpComputer(const PlpOptions &opts): - opts_(opts), srfft_(NULL), - mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), - autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), - lpc_coeffs_(opts_.lpc_order, kUndefined), - raw_cepstrum_(opts_.lpc_order, kUndefined) { - - if (opts.cepstral_lifter != 0.0) { - lifter_coeffs_.Resize(opts.num_ceps); - ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); - } - InitIdftBases(opts_.lpc_order + 1, opts_.mel_opts.num_bins + 2, - &idft_bases_); - - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); - - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); - - // We'll definitely need the filterbanks info for VTLN warping factor 1.0. - // [note: this call caches it.] - GetMelBanks(1.0); -} - -PlpComputer::PlpComputer(const PlpComputer &other): - opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), - idft_bases_(other.idft_bases_), log_energy_floor_(other.log_energy_floor_), - mel_banks_(other.mel_banks_), equal_loudness_(other.equal_loudness_), - srfft_(NULL), - mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), - autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), - lpc_coeffs_(opts_.lpc_order, kUndefined), - raw_cepstrum_(opts_.lpc_order, kUndefined) { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - iter->second = new MelBanks(*(iter->second)); - for (std::map*>::iterator - iter = equal_loudness_.begin(); - iter != equal_loudness_.end(); ++iter) - iter->second = new Vector(*(iter->second)); - if (other.srfft_ != NULL) - srfft_ = new SplitRadixRealFft(*(other.srfft_)); -} - -PlpComputer::~PlpComputer() { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - delete iter->second; - for (std::map* >::iterator - iter = equal_loudness_.begin(); - iter != equal_loudness_.end(); ++iter) - delete iter->second; - delete srfft_; -} - -const MelBanks *PlpComputer::GetMelBanks(BaseFloat vtln_warp) { - MelBanks *this_mel_banks = NULL; - std::map::iterator iter = mel_banks_.find(vtln_warp); - if (iter == mel_banks_.end()) { - this_mel_banks = new MelBanks(opts_.mel_opts, - opts_.frame_opts, - vtln_warp); - mel_banks_[vtln_warp] = this_mel_banks; - } else { - this_mel_banks = iter->second; - } - return this_mel_banks; -} - -const Vector *PlpComputer::GetEqualLoudness(BaseFloat vtln_warp) { - const MelBanks *this_mel_banks = GetMelBanks(vtln_warp); - Vector *ans = NULL; - std::map*>::iterator iter - = equal_loudness_.find(vtln_warp); - if (iter == equal_loudness_.end()) { - ans = new Vector; - GetEqualLoudnessVector(*this_mel_banks, ans); - equal_loudness_[vtln_warp] = ans; - } else { - ans = iter->second; - } - return ans; -} - -void PlpComputer::Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && - feature->Dim() == this->Dim()); - - const MelBanks &mel_banks = *GetMelBanks(vtln_warp); - const Vector &equal_loudness = *GetEqualLoudness(vtln_warp); - - - KALDI_ASSERT(opts_.num_ceps <= opts_.lpc_order+1); // our num-ceps includes C0. - - - if (opts_.use_energy && !opts_.raw_energy) - signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); - - if (srfft_ != NULL) // Compute FFT using split-radix algorithm. - srfft_->Compute(signal_frame->Data(), true); - else // An alternative algorithm that works for non-powers-of-two. - RealFft(signal_frame, true); - - // Convert the FFT into a power spectrum. - ComputePowerSpectrum(signal_frame); // elements 0 ... signal_frame->Dim()/2 - - SubVector power_spectrum(*signal_frame, - 0, signal_frame->Dim() / 2 + 1); - - int32 num_mel_bins = opts_.mel_opts.num_bins; - - SubVector mel_energies(mel_energies_duplicated_, 1, num_mel_bins); - - mel_banks.Compute(power_spectrum, &mel_energies); - - mel_energies.MulElements(equal_loudness); - - mel_energies.ApplyPow(opts_.compress_factor); - - // duplicate first and last elements - mel_energies_duplicated_(0) = mel_energies_duplicated_(1); - mel_energies_duplicated_(num_mel_bins + 1) = - mel_energies_duplicated_(num_mel_bins); - - autocorr_coeffs_.SetZero(); // In case of NaNs or infs - autocorr_coeffs_.AddMatVec(1.0, idft_bases_, kNoTrans, - mel_energies_duplicated_, 0.0); - - BaseFloat residual_log_energy = ComputeLpc(autocorr_coeffs_, &lpc_coeffs_); - - residual_log_energy = std::max(residual_log_energy, - std::numeric_limits::min()); - - Lpc2Cepstrum(opts_.lpc_order, lpc_coeffs_.Data(), raw_cepstrum_.Data()); - feature->Range(1, opts_.num_ceps - 1).CopyFromVec( - raw_cepstrum_.Range(0, opts_.num_ceps - 1)); - (*feature)(0) = residual_log_energy; - - if (opts_.cepstral_lifter != 0.0) - feature->MulElements(lifter_coeffs_); - - if (opts_.cepstral_scale != 1.0) - feature->Scale(opts_.cepstral_scale); - - if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_log_energy < log_energy_floor_) - signal_log_energy = log_energy_floor_; - (*feature)(0) = signal_log_energy; - } - - if (opts_.htk_compat) { // reorder the features. - BaseFloat log_energy = (*feature)(0); - for (int32 i = 0; i < opts_.num_ceps-1; i++) - (*feature)(i) = (*feature)(i+1); - (*feature)(opts_.num_ceps-1) = log_energy; - } -} - - -} // namespace kaldi diff --git a/src/feat/feature-plp.h b/src/feat/feature-plp.h deleted file mode 100644 index 958c5706e89..00000000000 --- a/src/feat/feature-plp.h +++ /dev/null @@ -1,176 +0,0 @@ -// feat/feature-plp.h - -// Copyright 2009-2011 Petr Motlicek; Karel Vesely - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_PLP_H_ -#define KALDI_FEAT_FEATURE_PLP_H_ - -#include -#include - -#include "feat/feature-common.h" -#include "feat/feature-functions.h" -#include "feat/feature-window.h" -#include "feat/mel-computations.h" -#include "itf/options-itf.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - - -/// PlpOptions contains basic options for computing PLP features. -/// It only includes things that can be done in a "stateless" way, i.e. -/// it does not include energy max-normalization. -/// It does not include delta computation. -struct PlpOptions { - FrameExtractionOptions frame_opts; - MelBanksOptions mel_opts; - int32 lpc_order; - int32 num_ceps; // num cepstra including zero - bool use_energy; // use energy; else C0 - BaseFloat energy_floor; - bool raw_energy; // If true, compute energy before preemphasis and windowing - BaseFloat compress_factor; - int32 cepstral_lifter; - BaseFloat cepstral_scale; - - bool htk_compat; // if true, put energy/C0 last and introduce a factor of - // sqrt(2) on C0 to be the same as HTK. - - PlpOptions() : mel_opts(23), - // default number of mel-banks for the PLP computation; this - // seems to be common for 16kHz-sampled data. For 8kHz-sampled - // data, 15 may be better. - lpc_order(12), - num_ceps(13), - use_energy(true), - energy_floor(0.0), - raw_energy(true), - compress_factor(0.33333), - cepstral_lifter(22), - cepstral_scale(1.0), - htk_compat(false) {} - - void Register(OptionsItf *opts) { - frame_opts.Register(opts); - mel_opts.Register(opts); - opts->Register("lpc-order", &lpc_order, - "Order of LPC analysis in PLP computation"); - opts->Register("num-ceps", &num_ceps, - "Number of cepstra in PLP computation (including C0)"); - opts->Register("use-energy", &use_energy, - "Use energy (not C0) for zeroth PLP feature"); - opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in PLP computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("compress-factor", &compress_factor, - "Compression factor in PLP computation"); - opts->Register("cepstral-lifter", &cepstral_lifter, - "Constant that controls scaling of PLPs"); - opts->Register("cepstral-scale", &cepstral_scale, - "Scaling constant in PLP computation"); - opts->Register("htk-compat", &htk_compat, - "If true, put energy or C0 last. Warning: not sufficient " - "to get HTK compatible features (need to change other " - "parameters)."); - } -}; - - -/// This is the new-style interface to the PLP computation. -class PlpComputer { - public: - typedef PlpOptions Options; - explicit PlpComputer(const PlpOptions &opts); - PlpComputer(const PlpComputer &other); - - const FrameExtractionOptions &GetFrameOptions() const { - return opts_.frame_opts; - } - - int32 Dim() const { return opts_.num_ceps; } - - bool NeedRawLogEnergy() { return opts_.use_energy && opts_.raw_energy; } - - /** - Function that computes one frame of features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). - @param [in] vtln_warp The VTLN warping factor that the user wants - to be applied when computing features for this utterance. Will - normally be 1.0, meaning no warping is to be done. The value will - be ignored for feature types that don't support VLTN, such as - spectrogram features. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - ~PlpComputer(); - private: - - const MelBanks *GetMelBanks(BaseFloat vtln_warp); - - const Vector *GetEqualLoudness(BaseFloat vtln_warp); - - PlpOptions opts_; - Vector lifter_coeffs_; - Matrix idft_bases_; - BaseFloat log_energy_floor_; - std::map mel_banks_; // BaseFloat is VTLN coefficient. - std::map* > equal_loudness_; - SplitRadixRealFft *srfft_; - - // temporary vector used inside Compute; size is opts_.mel_opts.num_bins + 2 - Vector mel_energies_duplicated_; - // temporary vector used inside Compute; size is opts_.lpc_order + 1 - Vector autocorr_coeffs_; - // temporary vector used inside Compute; size is opts_.lpc_order - Vector lpc_coeffs_; - // temporary vector used inside Compute; size is opts_.lpc_order - Vector raw_cepstrum_; - - // Disallow assignment. - PlpComputer &operator =(const PlpComputer &other); -}; - -typedef OfflineFeatureTpl Plp; - -/// @} End of "addtogroup feat" - -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_PLP_H_ diff --git a/src/feat/feature-window.cc b/src/feat/feature-window.cc index c5d4cc29831..b68b8854128 100644 --- a/src/feat/feature-window.cc +++ b/src/feat/feature-window.cc @@ -1,7 +1,7 @@ // feat/feature-window.cc // Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation -// 2013-2016 Johns Hopkins University (author: Daniel Povey) +// 2013-2019 Johns Hopkins University (author: Daniel Povey) // 2014 IMSL, PKU-HKUST (author: Wei Shi) // See ../../COPYING for clarification regarding multiple authors @@ -30,13 +30,9 @@ namespace kaldi { int64 FirstSampleOfFrame(int32 frame, const FrameExtractionOptions &opts) { int64 frame_shift = opts.WindowShift(); - if (opts.snip_edges) { - return frame * frame_shift; - } else { - int64 midpoint_of_frame = frame_shift * frame + frame_shift / 2, - beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; - return beginning_of_frame; - } + int64 midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; } int32 NumFrames(int64 num_samples, @@ -44,85 +40,54 @@ int32 NumFrames(int64 num_samples, bool flush) { int64 frame_shift = opts.WindowShift(); int64 frame_length = opts.WindowSize(); - if (opts.snip_edges) { - // with --snip-edges=true (the default), we use a HTK-like approach to - // determining the number of frames-- all frames have to fit completely into - // the waveform, and the first frame begins at sample zero. - if (num_samples < frame_length) - return 0; - else - return (1 + ((num_samples - frame_length) / frame_shift)); - // You can understand the expression above as follows: 'num_samples - - // frame_length' is how much room we have to shift the frame within the - // waveform; 'frame_shift' is how much we shift it each time; and the ratio - // is how many times we can shift it (integer arithmetic rounds down). - } else { - // if --snip-edges=false, the number of frames is determined by rounding the - // (file-length / frame-shift) to the nearest integer. The point of this - // formula is to make the number of frames an obvious and predictable - // function of the frame shift and signal length, which makes many - // segmentation-related questions simpler. - // - // Because integer division in C++ rounds toward zero, we add (half the - // frame-shift minus epsilon) before dividing, to have the effect of - // rounding towards the closest integer. - int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift; - - if (flush) - return num_frames; - - // note: 'end' always means the last plus one, i.e. one past the last. - int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts) - + frame_length; - - // the following code is optimized more for clarity than efficiency. - // If flush == false, we can't output frames that extend past the end - // of the signal. - while (num_frames > 0 && end_sample_of_last_frame > num_samples) { - num_frames--; - end_sample_of_last_frame -= frame_shift; - } + + // The number of frames is determined by rounding the + // (file-length / frame-shift) to the nearest integer. The point of this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) return num_frames; - } -} + // note: 'end' always means the last plus one, i.e. one past the last. + int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts) + + frame_length; -void Dither(VectorBase *waveform, BaseFloat dither_value) { - if (dither_value == 0.0) - return; - int32 dim = waveform->Dim(); - BaseFloat *data = waveform->Data(); - RandomState rstate; - for (int32 i = 0; i < dim; i++) - data[i] += RandGauss(&rstate) * dither_value; + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; } -void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff) { - if (preemph_coeff == 0.0) return; - KALDI_ASSERT(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); - for (int32 i = waveform->Dim()-1; i > 0; i--) - (*waveform)(i) -= preemph_coeff * (*waveform)(i-1); - (*waveform)(0) -= preemph_coeff * (*waveform)(0); -} - -FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) { +void InitFeatureWindowFunction(const FrameExtractionOptions &opts, + Vector *window_function) { int32 frame_length = opts.WindowSize(); KALDI_ASSERT(frame_length > 0); - window.Resize(frame_length); + window_function->Resize(frame_length); double a = M_2PI / (frame_length-1); for (int32 i = 0; i < frame_length; i++) { double i_fl = static_cast(i); if (opts.window_type == "hanning") { - window(i) = 0.5 - 0.5*cos(a * i_fl); + (*window_function)(i) = 0.5 - 0.5*cos(a * i_fl); } else if (opts.window_type == "hamming") { - window(i) = 0.54 - 0.46*cos(a * i_fl); + (*window_function)(i) = 0.54 - 0.46*cos(a * i_fl); } else if (opts.window_type == "povey") { // like hamming but goes to zero at edges. - window(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85); + (*window_function)(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85); } else if (opts.window_type == "rectangular") { - window(i) = 1.0; + (*window_function)(i) = 1.0; } else if (opts.window_type == "blackman") { - window(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) + + (*window_function)(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) + (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); } else { KALDI_ERR << "Invalid window type " << opts.window_type; @@ -131,54 +96,32 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) } void ProcessWindow(const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - VectorBase *window, - BaseFloat *log_energy_pre_window) { + const VectorBase *window_function, + VectorBase *window) { int32 frame_length = opts.WindowSize(); KALDI_ASSERT(window->Dim() == frame_length); - if (opts.dither != 0.0) - Dither(window, opts.dither); - if (opts.remove_dc_offset) window->Add(-window->Sum() / frame_length); - if (log_energy_pre_window != NULL) { - BaseFloat energy = std::max(VecVec(*window, *window), - std::numeric_limits::epsilon()); - *log_energy_pre_window = Log(energy); - } - - if (opts.preemph_coeff != 0.0) - Preemphasize(window, opts.preemph_coeff); - - window->MulElements(window_function.window); + window->MulElements(*window_function); } // ExtractWindow extracts a windowed frame of waveform with a power-of-two, -// padded size. It does mean subtraction, pre-emphasis and dithering as -// requested. +// padded size. It does mean subtraction if requested. void ExtractWindow(int64 sample_offset, const VectorBase &wave, int32 f, // with 0 <= f < NumFrames(feats, opts) const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - Vector *window, - BaseFloat *log_energy_pre_window) { + const Vector &window_function, + Vector *window) { KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0); int32 frame_length = opts.WindowSize(), frame_length_padded = opts.PaddedWindowSize(); - int64 num_samples = sample_offset + wave.Dim(), - start_sample = FirstSampleOfFrame(f, opts), - end_sample = start_sample + frame_length; + int64 start_sample = FirstSampleOfFrame(f, opts); - if (opts.snip_edges) { - KALDI_ASSERT(start_sample >= sample_offset && - end_sample <= num_samples); - } else { - KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); - } + KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); if (window->Dim() != frame_length_padded) window->Resize(frame_length_padded, kUndefined); @@ -216,7 +159,7 @@ void ExtractWindow(int64 sample_offset, SubVector frame(*window, 0, frame_length); - ProcessWindow(opts, window_function, &frame, log_energy_pre_window); + ProcessWindow(opts, window_function, &frame); } } // namespace kaldi diff --git a/src/feat/feature-window.h b/src/feat/feature-window.h index 2fccaefb9a1..ccbc1cd2d9b 100644 --- a/src/feat/feature-window.h +++ b/src/feat/feature-window.h @@ -36,8 +36,6 @@ struct FrameExtractionOptions { BaseFloat samp_freq; BaseFloat frame_shift_ms; // in milliseconds. BaseFloat frame_length_ms; // in milliseconds. - BaseFloat dither; // Amount of dithering, 0.0 means no dither. - BaseFloat preemph_coeff; // Preemphasis coefficient. bool remove_dc_offset; // Subtract mean of wave before FFT. std::string window_type; // e.g. Hamming window // May be "hamming", "rectangular", "povey", "hanning", "blackman" @@ -46,7 +44,6 @@ struct FrameExtractionOptions { // I just don't think the Hamming window makes sense as a windowing function. bool round_to_power_of_two; BaseFloat blackman_coeff; - bool snip_edges; bool allow_downsample; bool allow_upsample; int max_feature_vectors; @@ -54,16 +51,14 @@ struct FrameExtractionOptions { samp_freq(16000), frame_shift_ms(10.0), frame_length_ms(25.0), - dither(1.0), - preemph_coeff(0.97), remove_dc_offset(true), window_type("povey"), round_to_power_of_two(true), blackman_coeff(0.42), - snip_edges(true), allow_downsample(false), - max_feature_vectors(-1), - allow_upsample(false) { } + allow_upsample(false), + max_feature_vectors(-1) { } + void Register(OptionsItf *opts) { opts->Register("sample-frequency", &samp_freq, @@ -71,13 +66,8 @@ struct FrameExtractionOptions { "if specified there)"); opts->Register("frame-length", &frame_length_ms, "Frame length in milliseconds"); opts->Register("frame-shift", &frame_shift_ms, "Frame shift in milliseconds"); - opts->Register("preemphasis-coefficient", &preemph_coeff, - "Coefficient for use in signal preemphasis"); opts->Register("remove-dc-offset", &remove_dc_offset, "Subtract mean from waveform on each frame"); - opts->Register("dither", &dither, "Dithering constant (0.0 means no dither). " - "If you turn this off, you should set the --energy-floor " - "option, e.g. to 1.0 or 0.1"); opts->Register("window-type", &window_type, "Type of window " "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" "|\"blackmann\")"); @@ -86,11 +76,6 @@ struct FrameExtractionOptions { opts->Register("round-to-power-of-two", &round_to_power_of_two, "If true, round window size to power of two by zero-padding " "input to FFT."); - opts->Register("snip-edges", &snip_edges, - "If true, end effects will be handled by outputting only frames that " - "completely fit in the file, and the number of frames depends on the " - "frame-length. If false, the number of frames depends only on the " - "frame-shift, and we reflect the data at the ends."); opts->Register("allow-downsample", &allow_downsample, "If true, allow the input waveform to have a higher frequency than " "the specified --sample-frequency (and we'll downsample)."); @@ -115,13 +100,11 @@ struct FrameExtractionOptions { }; -struct FeatureWindowFunction { - FeatureWindowFunction() {} - explicit FeatureWindowFunction(const FrameExtractionOptions &opts); - FeatureWindowFunction(const FeatureWindowFunction &other): - window(other.window) { } - Vector window; -}; +// Sets up the feature window function (e.g. Hamming) as specified by the +// options. +void InitFeatureWindowFunction( + const FrameExtractionOptions &opts, + Vector *window_function); /** @@ -157,13 +140,12 @@ int64 FirstSampleOfFrame(int32 frame, void Dither(VectorBase *waveform, BaseFloat dither_value); -void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff); /** - This function does all the windowing steps after actually - extracting the windowed signal: depeding on the - configuration, it does dithering, dc offset removal, - preemphasis, and multiplication by the windowing function. + This function does all the windowing steps after actually extracting the + windowed signal: depeding on the configuration, it dc offset removal and + multiplication by the windowing function. + @param [in] opts The options class to be used @param [in] window_function The windowing function-- should have been initialized using 'opts'. @@ -172,14 +154,10 @@ void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff); opts.PaddedWindowSize(), with the remaining samples zero, as the FFT code is more efficient if it operates on data with power-of-two size. - @param [out] log_energy_pre_window If non-NULL, then after dithering and - DC offset removal, this function will write to this pointer the log of - the total energy (i.e. sum-squared) of the frame. */ void ProcessWindow(const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - VectorBase *window, - BaseFloat *log_energy_pre_window = NULL); + const VectorBase &window_function, + VectorBase *window); /* @@ -201,18 +179,15 @@ void ProcessWindow(const FrameExtractionOptions &opts, @param [in] window_function The windowing function, as derived from the options class. @param [out] window The windowed, possibly-padded waveform to be - extracted. Will be resized as needed. - @param [out] log_energy_pre_window If non-NULL, the log-energy of - the signal prior to pre-emphasis and multiplying by - the windowing function will be written to here. + extracted. Will be resized as needed. */ void ExtractWindow(int64 sample_offset, const VectorBase &wave, int32 f, const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - Vector *window, - BaseFloat *log_energy_pre_window = NULL); + const VectorBase &window_function, + Vector *window); + /// @} End of "addtogroup feat" diff --git a/src/feat/mel-computations.cc b/src/feat/mel-computations.cc index 810b6247e93..1772caadf4a 100644 --- a/src/feat/mel-computations.cc +++ b/src/feat/mel-computations.cc @@ -32,8 +32,7 @@ namespace kaldi { MelBanks::MelBanks(const MelBanksOptions &opts, const FrameExtractionOptions &frame_opts, - BaseFloat vtln_warp_factor): - htk_mode_(opts.htk_mode) { + BaseFloat vtln_warp_factor) { int32 num_bins = opts.num_bins; if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; BaseFloat sample_freq = frame_opts.samp_freq; @@ -128,10 +127,6 @@ MelBanks::MelBanks(const MelBanksOptions &opts, bins_[bin].second.Resize(size); bins_[bin].second.CopyFromVec(this_bin.Range(first_index, size)); - // Replicate a bug in HTK, for testing purposes. - if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0) - bins_[bin].second(0) = 0.0; - } if (debug_) { for (size_t i = 0; i < bins_.size(); i++) { @@ -144,8 +139,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts, MelBanks::MelBanks(const MelBanks &other): center_freqs_(other.center_freqs_), bins_(other.bins_), - debug_(other.debug_), - htk_mode_(other.htk_mode_) { } + debug_(other.debug_) { } BaseFloat MelBanks::VtlnWarpFreq(BaseFloat vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. BaseFloat vtln_high_cutoff, @@ -232,8 +226,6 @@ void MelBanks::Compute(const VectorBase &power_spectrum, int32 offset = bins_[i].first; const Vector &v(bins_[i].second); BaseFloat energy = VecVec(v, power_spectrum.Range(offset, v.Dim())); - // HTK-like flooring- for testing purposes (we prefer dither) - if (htk_mode_ && energy < 1.0) energy = 1.0; (*mel_energies_out)(i) = energy; // The following assert was added due to a problem with OpenBlas that @@ -250,91 +242,8 @@ void MelBanks::Compute(const VectorBase &power_spectrum, } } -void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs) { - // Compute liftering coefficients (scaling on cepstral coeffs) - // coeffs are numbered slightly differently from HTK: the zeroth - // index is C0, which is not affected. - for (int32 i = 0; i < coeffs->Dim(); i++) - (*coeffs)(i) = 1.0 + 0.5 * Q * sin (M_PI * i / Q); -} - - -// Durbin's recursion - converts autocorrelation coefficients to the LPC -// pTmp - temporal place [n] -// pAC - autocorrelation coefficients [n + 1] -// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i] * s[n-i]}}) -// F(z) = 1 / (1 - A(z)), 1 is not stored in the demoninator -BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp) { - BaseFloat ki; // reflection coefficient - int i; - int j; - - BaseFloat E = pAC[0]; - - for (i = 0; i < n; i++) { - // next reflection coefficient - ki = pAC[i + 1]; - for (j = 0; j < i; j++) - ki += pLP[j] * pAC[i - j]; - ki = ki / E; - - // new error - BaseFloat c = 1 - ki * ki; - if (c < 1.0e-5) // remove NaNs for constan signal - c = 1.0e-5; - E *= c; - - // new LP coefficients - pTmp[i] = -ki; - for (j = 0; j < i; j++) - pTmp[j] = pLP[j] - ki * pLP[i - j - 1]; - - for (j = 0; j <= i; j++) - pLP[j] = pTmp[j]; - } - - return E; -} - - -void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst) { - for (int32 i = 0; i < n; i++) { - double sum = 0.0; - int j; - for (j = 0; j < i; j++) { - sum += static_cast(i - j) * pLPC[j] * pCepst[i - j - 1]; - } - pCepst[i] = -pLPC[i] - sum / static_cast(i + 1); - } -} - -void GetEqualLoudnessVector(const MelBanks &mel_banks, - Vector *ans) { - int32 n = mel_banks.NumBins(); - // Central frequency of each mel bin. - const Vector &f0 = mel_banks.GetCenterFreqs(); - ans->Resize(n); - for (int32 i = 0; i < n; i++) { - BaseFloat fsq = f0(i) * f0(i); - BaseFloat fsub = fsq / (fsq + 1.6e5); - (*ans)(i) = fsub * fsub * ((fsq + 1.44e6) / (fsq + 9.61e6)); - } -} -// Compute LP coefficients from autocorrelation coefficients. -BaseFloat ComputeLpc(const VectorBase &autocorr_in, - Vector *lpc_out) { - int32 n = autocorr_in.Dim() - 1; - KALDI_ASSERT(lpc_out->Dim() == n); - Vector tmp(n); - BaseFloat ans = Durbin(n, autocorr_in.Data(), - lpc_out->Data(), - tmp.Data()); - if (ans <= 0.0) - KALDI_WARN << "Zero energy in LPC computation"; - return -Log(1.0 / ans); // forms the C0 value -} } // namespace kaldi diff --git a/src/feat/mel-computations.h b/src/feat/mel-computations.h index 7053da54f3a..6c56a9ab83d 100644 --- a/src/feat/mel-computations.h +++ b/src/feat/mel-computations.h @@ -1,7 +1,7 @@ // feat/mel-computations.h // Copyright 2009-2011 Phonexia s.r.o.; Microsoft Corporation -// 2016 Johns Hopkins University (author: Daniel Povey) +// 2016-2019 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -44,18 +44,14 @@ struct MelBanksOptions { int32 num_bins; // e.g. 25; number of triangular bins BaseFloat low_freq; // e.g. 20; lower frequency cutoff BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative - // ->added to the Nyquist frequency to get the cutoff. + // ->added to the Nyquist frequency to get the cutoff. BaseFloat vtln_low; // vtln lower cutoff of warping function. BaseFloat vtln_high; // vtln upper cutoff of warping function: if negative, added // to the Nyquist frequency to get the cutoff. bool debug_mel; - // htk_mode is a "hidden" config, it does not show up on command line. - // Enables more exact compatibibility with HTK, for testing purposes. Affects - // mel-energy flooring and reproduces a bug in HTK. - bool htk_mode; explicit MelBanksOptions(int num_bins = 25) : num_bins(num_bins), low_freq(20), high_freq(0), vtln_low(100), - vtln_high(-500), debug_mel(false), htk_mode(false) {} + vtln_high(-500), debug_mel(false) { } void Register(OptionsItf *opts) { opts->Register("num-mel-bins", &num_bins, @@ -87,10 +83,9 @@ class MelBanks { } static BaseFloat VtlnWarpFreq(BaseFloat vtln_low_cutoff, - BaseFloat vtln_high_cutoff, // discontinuities in warp func + BaseFloat vtln_high_cutoff, BaseFloat low_freq, - BaseFloat high_freq, // upper+lower frequency cutoffs in - // the mel computation + BaseFloat high_freq, BaseFloat vtln_warp_factor, BaseFloat freq); @@ -106,7 +101,7 @@ class MelBanks { const FrameExtractionOptions &frame_opts, BaseFloat vtln_warp_factor); - /// Compute Mel energies (note: not log enerties). + /// Compute Mel energies (note: not log energies). /// At input, "fft_energies" contains the FFT energies (not log). void Compute(const VectorBase &fft_energies, VectorBase *mel_energies_out) const; @@ -131,35 +126,9 @@ class MelBanks { std::vector > > bins_; bool debug_; - bool htk_mode_; }; -// Compute liftering coefficients (scaling on cepstral coeffs) -// coeffs are numbered slightly differently from HTK: the zeroth -// index is C0, which is not affected. -void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs); - - -// Durbin's recursion - converts autocorrelation coefficients to the LPC -// pTmp - temporal place [n] -// pAC - autocorrelation coefficients [n + 1] -// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i] * s[n-i]}}) -// F(z) = 1 / (1 - A(z)), 1 is not stored in the demoninator -// Returns log energy of residual (I think) -BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp); - -// Compute LP coefficients from autocorrelation coefficients. -// Returns log energy of residual (I think) -BaseFloat ComputeLpc(const VectorBase &autocorr_in, - Vector *lpc_out); - -void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst); - - - -void GetEqualLoudnessVector(const MelBanks &mel_banks, - Vector *ans); /// @} End of "addtogroup feat" } // namespace kaldi diff --git a/src/feat/online-feature-test.cc b/src/feat/online-feature-test.cc index 7ba6c7c32be..c5a2ae44ec7 100644 --- a/src/feat/online-feature-test.cc +++ b/src/feat/online-feature-test.cc @@ -195,55 +195,6 @@ void TestOnlineMfcc() { } } -void TestOnlinePlp() { - std::ifstream is("../feat/test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // the parametrization object - PlpOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.frame_opts.samp_freq = wave.SampFreq(); - op.mel_opts.low_freq = 0.0; - op.htk_compat = false; - op.use_energy = false; // C0 not energy. - Plp plp(op); - - // compute plp offline - Matrix plp_feats; - plp.Compute(waveform, 1.0, &plp_feats); // vtln not supported - - // compare - // The test waveform is about 1.44s long, so - // we try to break it into from 5 pieces to 9(not essential to do so) - for (int32 num_piece = 5; num_piece < 10; num_piece++) { - OnlinePlp online_plp(op); - std::vector piece_length(num_piece); - bool ret = RandomSplit(waveform.Dim(), &piece_length, num_piece); - KALDI_ASSERT(ret); - - int32 offset_start = 0; - for (int32 i = 0; i < num_piece; i++) { - Vector wave_piece( - waveform.Range(offset_start, piece_length[i])); - online_plp.AcceptWaveform(wave.SampFreq(), wave_piece); - offset_start += piece_length[i]; - } - online_plp.InputFinished(); - - Matrix online_plp_feats; - GetOutput(&online_plp, &online_plp_feats); - - AssertEqual(plp_feats, online_plp_feats); - } -} - void TestOnlineTransform() { std::ifstream is("../feat/test_data/test.wav", std::ios_base::binary); WaveData wave; @@ -332,9 +283,9 @@ void TestOnlineAppendFeature() { // The test waveform is about 1.44s long, so // we try to break it into from 5 pieces to 9(not essential to do so) for (int32 num_piece = 5; num_piece < 10; num_piece++) { - OnlineMfcc online_mfcc(mfcc_op); - OnlinePlp online_plp(plp_op); - OnlineAppendFeature online_mfcc_plp(&online_mfcc, &online_plp); + OnlineMfcc online_mfcc(mfcc_op), + online_mfcc2(mfcc_op); + OnlineAppendFeature online_mfcc_doubled(&online_mfcc, &online_mfcc2); std::vector piece_length(num_piece); bool ret = RandomSplit(waveform.Dim(), &piece_length, num_piece); @@ -344,32 +295,32 @@ void TestOnlineAppendFeature() { Vector wave_piece( waveform.Range(offset_start, piece_length[i])); online_mfcc.AcceptWaveform(wave.SampFreq(), wave_piece); - online_plp.AcceptWaveform(wave.SampFreq(), wave_piece); + online_mfcc2.AcceptWaveform(wave.SampFreq(), wave_piece); offset_start += piece_length[i]; } online_mfcc.InputFinished(); - online_plp.InputFinished(); + online_mfcc2.InputFinished(); - Matrix online_mfcc_plp_feats; - GetOutput(&online_mfcc_plp, &online_mfcc_plp_feats); + Matrix online_mfcc_doubled_feats; + GetOutput(&online_mfcc_doubled, &online_mfcc_doubled_feats); - // compare mfcc_feats & plp_features with online_mfcc_plp_feats - KALDI_ASSERT(mfcc_feats.NumRows() == online_mfcc_plp_feats.NumRows() - && plp_feats.NumRows() == online_mfcc_plp_feats.NumRows() + // compare mfcc_feats & plp_features with online_mfcc_doubled_feats + KALDI_ASSERT(mfcc_feats.NumRows() == online_mfcc_doubled_feats.NumRows() + && plp_feats.NumRows() == online_mfcc_doubled_feats.NumRows() && mfcc_feats.NumCols() + plp_feats.NumCols() - == online_mfcc_plp_feats.NumCols()); - for (MatrixIndexT i = 0; i < online_mfcc_plp_feats.NumRows(); i++) { + == online_mfcc_doubled_feats.NumCols()); + for (MatrixIndexT i = 0; i < online_mfcc_doubled_feats.NumRows(); i++) { for (MatrixIndexT j = 0; j < mfcc_feats.NumCols(); j++) { - KALDI_ASSERT(std::abs(mfcc_feats(i, j) - online_mfcc_plp_feats(i, j)) + KALDI_ASSERT(std::abs(mfcc_feats(i, j) - online_mfcc_doubled_feats(i, j)) < 0.0001*std::max(1.0, static_cast(std::abs(mfcc_feats(i, j)) - + std::abs(online_mfcc_plp_feats(i, j))))); + + std::abs(online_mfcc_doubled_feats(i, j))))); } for (MatrixIndexT k = 0; k < plp_feats.NumCols(); k++) { KALDI_ASSERT( std::abs(plp_feats(i, k) - - online_mfcc_plp_feats(i, mfcc_feats.NumCols() + k)) + online_mfcc_doubled_feats(i, mfcc_feats.NumCols() + k)) < 0.0001*std::max(1.0, static_cast(std::abs(plp_feats(i, k)) - +std::abs(online_mfcc_plp_feats(i, mfcc_feats.NumCols() + k))))); + +std::abs(online_mfcc_doubled_feats(i, mfcc_feats.NumCols() + k))))); } } } diff --git a/src/feat/online-feature.cc b/src/feat/online-feature.cc index 813e7b16f0c..138dabe2236 100644 --- a/src/feat/online-feature.cc +++ b/src/feat/online-feature.cc @@ -69,9 +69,13 @@ void OnlineGenericBaseFeature::GetFrame(int32 frame, template OnlineGenericBaseFeature::OnlineGenericBaseFeature( const typename C::Options &opts): - computer_(opts), window_function_(computer_.GetFrameOptions()), - input_finished_(false), waveform_offset_(0), - features_(opts.frame_opts.max_feature_vectors) { } + computer_(opts), + features_(opts.frame_opts.max_feature_vectors), + input_finished_(false), + waveform_offset_(0) { + InitFeatureWindowFunction(computer_.GetFrameOptions(), + &window_function_); +} template void OnlineGenericBaseFeature::AcceptWaveform(BaseFloat sampling_rate, @@ -105,17 +109,14 @@ void OnlineGenericBaseFeature::ComputeFeatures() { KALDI_ASSERT(num_frames_new >= num_frames_old); Vector window; - bool need_raw_log_energy = computer_.NeedRawLogEnergy(); for (int32 frame = num_frames_old; frame < num_frames_new; frame++) { - BaseFloat raw_log_energy = 0.0; ExtractWindow(waveform_offset_, waveform_remainder_, frame, - frame_opts, window_function_, &window, - need_raw_log_energy ? &raw_log_energy : NULL); + frame_opts, window_function_, &window); Vector *this_feature = new Vector(computer_.Dim(), kUndefined); // note: this online feature-extraction code does not support VTLN. BaseFloat vtln_warp = 1.0; - computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature); + computer_.Compute(vtln_warp, &window, this_feature); features_.PushBack(this_feature); } // OK, we will now discard any portion of the signal that will not be @@ -142,7 +143,6 @@ void OnlineGenericBaseFeature::ComputeFeatures() { // instantiate the templates defined here for MFCC, PLP and filterbank classes. template class OnlineGenericBaseFeature; -template class OnlineGenericBaseFeature; template class OnlineGenericBaseFeature; diff --git a/src/feat/online-feature.h b/src/feat/online-feature.h index d47a6b13e9b..0ddc2601dec 100644 --- a/src/feat/online-feature.h +++ b/src/feat/online-feature.h @@ -32,7 +32,6 @@ #include "base/kaldi-error.h" #include "feat/feature-functions.h" #include "feat/feature-mfcc.h" -#include "feat/feature-plp.h" #include "feat/feature-fbank.h" #include "itf/online-feature-itf.h" @@ -72,7 +71,7 @@ class RecyclingVector { /// This is a templated class for online feature extraction; -/// it's templated on a class like MfccComputer or PlpComputer +/// it's templated on a class like MfccComputer /// that does the basic feature extraction. template class OnlineGenericBaseFeature: public OnlineBaseFeature { @@ -127,11 +126,11 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature { // waveform_remainder_ while incrementing waveform_offset_ by the same amount. void ComputeFeatures(); - C computer_; // class that does the MFCC or PLP or filterbank computation + C computer_; // class that does the MFCC or filterbank computation - FeatureWindowFunction window_function_; + Vector window_function_; - // features_ is the Mfcc or Plp or Fbank features that we have already computed. + // features_ is the Mfcc or Fbank features that we have already computed. RecyclingVector features_; @@ -153,7 +152,6 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature { }; typedef OnlineGenericBaseFeature OnlineMfcc; -typedef OnlineGenericBaseFeature OnlinePlp; typedef OnlineGenericBaseFeature OnlineFbank; @@ -594,7 +592,7 @@ class OnlineCacheFeature: public OnlineFeatureInterface { /// This online-feature class implements combination of two feature -/// streams (such as pitch, plp) into one stream. +/// streams (such as pitch) into one stream. class OnlineAppendFeature: public OnlineFeatureInterface { public: virtual int32 Dim() const { return src1_->Dim() + src2_->Dim(); } diff --git a/src/feat/pitch-functions-test.cc b/src/feat/pitch-functions-test.cc index 0e481c18674..e3953acb884 100644 --- a/src/feat/pitch-functions-test.cc +++ b/src/feat/pitch-functions-test.cc @@ -25,7 +25,6 @@ #include #include "base/kaldi-math.h" -#include "feat/feature-plp.h" #include "feat/pitch-functions.h" #include "feat/wave-reader.h" #include "sys/stat.h" diff --git a/src/feat/wave-reader.cc b/src/feat/wave-reader.cc index f8259a3a82e..bd35b1cff43 100644 --- a/src/feat/wave-reader.cc +++ b/src/feat/wave-reader.cc @@ -308,7 +308,11 @@ void WaveData::Read(std::istream &is) { uint16 *data_ptr = reinterpret_cast(&buffer[0]); - // The matrix is arranged row per channel, column per sample. + // Scale the wave data to the range [-1, 1]. Prior to kaldi-10, + // it was in the range [-327680.0, 32768.0]. + const BaseFloat scale = 1.0 / 32768.0; + + // The row-indexes are channels; column-indexes are samples. data_.Resize(header.NumChannels(), buffer.size() / header.BlockAlign()); for (uint32 i = 0; i < data_.NumCols(); ++i) { @@ -316,7 +320,7 @@ void WaveData::Read(std::istream &is) { int16 k = *data_ptr++; if (header.ReverseBytes()) KALDI_SWAP2(k); - data_(j, i) = k; + data_(j, i) = k * scale; } } } @@ -358,9 +362,13 @@ void WaveData::Write(std::ostream &os) const { int32 stride = data_.Stride(); int num_clipped = 0; + + // This scaling factor is because we are writing 16-bit data. + const BaseFloat scale = 32768.0; + for (int32 i = 0; i < num_samp; i++) { for (int32 j = 0; j < num_chan; j++) { - int32 elem = static_cast(trunc(data_ptr[j * stride + i])); + int32 elem = static_cast(trunc(data_ptr[j * stride + i] * scale)); int16 elem_16 = static_cast(elem); if (elem < std::numeric_limits::min()) { elem_16 = std::numeric_limits::min(); diff --git a/src/feat/wave-reader.h b/src/feat/wave-reader.h index 7ba981c2c24..6c7fb5b5258 100644 --- a/src/feat/wave-reader.h +++ b/src/feat/wave-reader.h @@ -2,7 +2,7 @@ // Copyright 2009-2011 Karel Vesely; Microsoft Corporation // 2013 Florent Masson -// 2013 Johns Hopkins University (author: Daniel Povey) +// 2013-2019 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -57,10 +57,6 @@ namespace kaldi { -/// For historical reasons, we scale waveforms to the range -/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. -const BaseFloat kWaveSampleMax = 32768.0; - /// This class reads and hold wave file header information. class WaveInfo { public: @@ -121,6 +117,8 @@ class WaveData { // This function returns the wave data-- it's in a matrix // becase there may be multiple channels. In the normal case // there's just one channel so Data() will have one row. + // This data will be in the range [-1, 1]. This is a difference + // from pre-kaldi10. const Matrix &Data() const { return data_; } BaseFloat SampFreq() const { return samp_freq_; } diff --git a/src/featbin/Makefile b/src/featbin/Makefile index 861ba3f7a93..1067244b2db 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -8,7 +8,7 @@ BINFILES = add-deltas add-deltas-sdc append-post-to-feats \ compose-transforms compute-and-process-kaldi-pitch-feats \ compute-cmvn-stats compute-cmvn-stats-two-channel \ compute-fbank-feats compute-kaldi-pitch-feats compute-mfcc-feats \ - compute-plp-feats compute-spectrogram-feats concat-feats copy-feats \ + concat-feats copy-feats \ copy-feats-to-htk copy-feats-to-sphinx extend-transform-dim \ extract-feature-segments extract-segments feat-to-dim \ feat-to-len fmpe-acc-stats fmpe-apply-transform fmpe-est \ @@ -26,6 +26,6 @@ TESTFILES = ADDLIBS = ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/featbin/compute-plp-feats.cc b/src/featbin/compute-plp-feats.cc deleted file mode 100644 index 3e9fe9d7423..00000000000 --- a/src/featbin/compute-plp-feats.cc +++ /dev/null @@ -1,184 +0,0 @@ -// featbin/compute-plp-feats.cc - -// Copyright 2009-2012 Microsoft Corporation -// Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "feat/feature-plp.h" -#include "feat/wave-reader.h" - - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - const char *usage = - "Create PLP feature files.\n" - "Usage: compute-plp-feats [options...] \n"; - - // construct all the global objects - ParseOptions po(usage); - PlpOptions plp_opts; - bool subtract_mean = false; - BaseFloat vtln_warp = 1.0; - std::string vtln_map_rspecifier; - std::string utt2spk_rspecifier; - int32 channel = -1; - BaseFloat min_duration = 0.0; - // Define defaults for gobal options - std::string output_format = "kaldi"; - - // Register the options - po.Register("output-format", &output_format, "Format of the output " - "files [kaldi, htk]"); - po.Register("subtract-mean", &subtract_mean, "Subtract mean of each " - "feature file [CMS]. "); - po.Register("vtln-warp", &vtln_warp, "Vtln warp factor (only applicable " - "if vtln-map not specified)"); - po.Register("vtln-map", &vtln_map_rspecifier, "Map from utterance or " - "speaker-id to vtln warp factor (rspecifier)"); - po.Register("utt2spk", &utt2spk_rspecifier, "Utterance to speaker-id " - "map (if doing VTLN and you have warps per speaker)"); - po.Register("channel", &channel, "Channel to extract (-1 -> expect mono, " - "0 -> left, 1 -> right)"); - po.Register("min-duration", &min_duration, "Minimum duration of segments " - "to process (in seconds)."); - - plp_opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 2) { - po.PrintUsage(); - exit(1); - } - - std::string wav_rspecifier = po.GetArg(1); - - std::string output_wspecifier = po.GetArg(2); - - Plp plp(plp_opts); - - SequentialTableReader reader(wav_rspecifier); - BaseFloatMatrixWriter kaldi_writer; // typedef to TableWriter. - TableWriter htk_writer; - - if (utt2spk_rspecifier != "") - KALDI_ASSERT(vtln_map_rspecifier != "" && "the utt2spk option is only " - "needed if the vtln-map option is used."); - RandomAccessBaseFloatReaderMapped vtln_map_reader(vtln_map_rspecifier, - utt2spk_rspecifier); - - if (output_format == "kaldi") { - if (!kaldi_writer.Open(output_wspecifier)) - KALDI_ERR << "Could not initialize output with wspecifier " - << output_wspecifier; - } else if (output_format == "htk") { - if (!htk_writer.Open(output_wspecifier)) - KALDI_ERR << "Could not initialize output with wspecifier " - << output_wspecifier; - } else { - KALDI_ERR << "Invalid output_format string " << output_format; - } - - int32 num_utts = 0, num_success = 0; - for (; !reader.Done(); reader.Next()) { - num_utts++; - std::string utt = reader.Key(); - const WaveData &wave_data = reader.Value(); - if (wave_data.Duration() < min_duration) { - KALDI_WARN << "File: " << utt << " is too short (" - << wave_data.Duration() << " sec): producing no output."; - continue; - } - int32 num_chan = wave_data.Data().NumRows(), this_chan = channel; - { // This block works out the channel (0=left, 1=right...) - KALDI_ASSERT(num_chan > 0); // should have been caught in - // reading code if no channels. - if (channel == -1) { - this_chan = 0; - if (num_chan != 1) - KALDI_WARN << "Channel not specified but you have data with " - << num_chan << " channels; defaulting to zero"; - } else { - if (this_chan >= num_chan) { - KALDI_WARN << "File with id " << utt << " has " - << num_chan << " channels but you specified channel " - << channel << ", producing no output."; - continue; - } - } - } - BaseFloat vtln_warp_local; // Work out VTLN warp factor. - if (vtln_map_rspecifier != "") { - if (!vtln_map_reader.HasKey(utt)) { - KALDI_WARN << "No vtln-map entry for utterance-id (or speaker-id) " - << utt; - continue; - } - vtln_warp_local = vtln_map_reader.Value(utt); - } else { - vtln_warp_local = vtln_warp; - } - - SubVector waveform(wave_data.Data(), this_chan); - Matrix features; - try { - plp.ComputeFeatures(waveform, wave_data.SampFreq(), vtln_warp_local, &features); - } catch (...) { - KALDI_WARN << "Failed to compute features for utterance " - << utt; - continue; - } - if (subtract_mean) { - Vector mean(features.NumCols()); - mean.AddRowSumMat(1.0, features); - mean.Scale(1.0 / features.NumRows()); - for (size_t i = 0; i < features.NumRows(); i++) - features.Row(i).AddVec(-1.0, mean); - } - if (output_format == "kaldi") { - kaldi_writer.Write(utt, features); - } else { - std::pair, HtkHeader> p; - p.first.Resize(features.NumRows(), features.NumCols()); - p.first.CopyFromMat(features); - HtkHeader header = { - features.NumRows(), - 100000, // 10ms shift - static_cast(sizeof(float)*features.NumCols()), - 013 | // PLP - 020000 // C0 [no option currently to use energy in PLP. - }; - p.second = header; - htk_writer.Write(utt, p); - } - if (num_utts % 10 == 0) - KALDI_LOG << "Processed " << num_utts << " utterances"; - KALDI_VLOG(2) << "Processed features for key " << utt; - num_success++; - } - KALDI_LOG << " Done " << num_success << " out of " << num_utts - << " utterances."; - return (num_success != 0 ? 0 : 1); - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/featbin/compute-spectrogram-feats.cc b/src/featbin/compute-spectrogram-feats.cc deleted file mode 100644 index 3b40a6fa5c7..00000000000 --- a/src/featbin/compute-spectrogram-feats.cc +++ /dev/null @@ -1,158 +0,0 @@ -// featbin/compute-spectrogram-feats.cc - -// Copyright 2009-2011 Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "feat/feature-spectrogram.h" -#include "feat/wave-reader.h" - - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - const char *usage = - "Create spectrogram feature files.\n" - "Usage: compute-spectrogram-feats [options...] \n"; - - // construct all the global objects - ParseOptions po(usage); - SpectrogramOptions spec_opts; - bool subtract_mean = false; - int32 channel = -1; - BaseFloat min_duration = 0.0; - // Define defaults for gobal options - std::string output_format = "kaldi"; - - // Register the option struct - spec_opts.Register(&po); - // Register the options - po.Register("output-format", &output_format, "Format of the output files [kaldi, htk]"); - po.Register("subtract-mean", &subtract_mean, "Subtract mean of each feature file [CMS]; not recommended to do it this way. "); - po.Register("channel", &channel, "Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right)"); - po.Register("min-duration", &min_duration, "Minimum duration of segments to process (in seconds)."); - - // OPTION PARSING .......................................................... - // - - // parse options (+filling the registered variables) - po.Read(argc, argv); - - if (po.NumArgs() != 2) { - po.PrintUsage(); - exit(1); - } - - std::string wav_rspecifier = po.GetArg(1); - - std::string output_wspecifier = po.GetArg(2); - - Spectrogram spec(spec_opts); - - SequentialTableReader reader(wav_rspecifier); - BaseFloatMatrixWriter kaldi_writer; // typedef to TableWriter. - TableWriter htk_writer; - - if (output_format == "kaldi") { - if (!kaldi_writer.Open(output_wspecifier)) - KALDI_ERR << "Could not initialize output with wspecifier " - << output_wspecifier; - } else if (output_format == "htk") { - if (!htk_writer.Open(output_wspecifier)) - KALDI_ERR << "Could not initialize output with wspecifier " - << output_wspecifier; - } else { - KALDI_ERR << "Invalid output_format string " << output_format; - } - - int32 num_utts = 0, num_success = 0; - for (; !reader.Done(); reader.Next()) { - num_utts++; - std::string utt = reader.Key(); - const WaveData &wave_data = reader.Value(); - if (wave_data.Duration() < min_duration) { - KALDI_WARN << "File: " << utt << " is too short (" - << wave_data.Duration() << " sec): producing no output."; - continue; - } - int32 num_chan = wave_data.Data().NumRows(), this_chan = channel; - { // This block works out the channel (0=left, 1=right...) - KALDI_ASSERT(num_chan > 0); // should have been caught in - // reading code if no channels. - if (channel == -1) { - this_chan = 0; - if (num_chan != 1) - KALDI_WARN << "Channel not specified but you have data with " - << num_chan << " channels; defaulting to zero"; - } else { - if (this_chan >= num_chan) { - KALDI_WARN << "File with id " << utt << " has " - << num_chan << " channels but you specified channel " - << channel << ", producing no output."; - continue; - } - } - } - - SubVector waveform(wave_data.Data(), this_chan); - Matrix features; - try { - spec.ComputeFeatures(waveform, wave_data.SampFreq(), 1.0, &features); - } catch (...) { - KALDI_WARN << "Failed to compute features for utterance " - << utt; - continue; - } - if (subtract_mean) { - Vector mean(features.NumCols()); - mean.AddRowSumMat(1.0, features); - mean.Scale(1.0 / features.NumRows()); - for (int32 i = 0; i < features.NumRows(); i++) - features.Row(i).AddVec(-1.0, mean); - } - if (output_format == "kaldi") { - kaldi_writer.Write(utt, features); - } else { - std::pair, HtkHeader> p; - p.first.Resize(features.NumRows(), features.NumCols()); - p.first.CopyFromMat(features); - int32 frame_shift = spec_opts.frame_opts.frame_shift_ms * 10000; - HtkHeader header = { - features.NumRows(), - frame_shift, - static_cast(sizeof(float)*features.NumCols()), - 007 | 020000 - }; - p.second = header; - htk_writer.Write(utt, p); - } - if(num_utts % 10 == 0) - KALDI_LOG << "Processed " << num_utts << " utterances"; - KALDI_VLOG(2) << "Processed features for key " << utt; - num_success++; - } - KALDI_LOG << " Done " << num_success << " out of " << num_utts - << " utterances."; - return (num_success != 0 ? 0 : 1); - } catch(const std::exception& e) { - std::cerr << e.what(); - return -1; - } - return 0; -} - diff --git a/src/fgmmbin/fgmm-global-info.cc b/src/fgmmbin/fgmm-global-info.cc index e00384fe13f..867db3bdc50 100644 --- a/src/fgmmbin/fgmm-global-info.cc +++ b/src/fgmmbin/fgmm-global-info.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/full-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { diff --git a/src/fgmmbin/fgmm-gselect.cc b/src/fgmmbin/fgmm-gselect.cc index ab36af74275..3d962972127 100644 --- a/src/fgmmbin/fgmm-gselect.cc +++ b/src/fgmmbin/fgmm-gselect.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/full-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { diff --git a/src/gmm/Makefile b/src/gmm/Makefile index caee6734afe..9b770bb4845 100644 --- a/src/gmm/Makefile +++ b/src/gmm/Makefile @@ -9,13 +9,13 @@ TESTFILES = diag-gmm-test mle-diag-gmm-test full-gmm-test mle-full-gmm-test \ OBJFILES = diag-gmm.o diag-gmm-normal.o mle-diag-gmm.o am-diag-gmm.o \ mle-am-diag-gmm.o full-gmm.o full-gmm-normal.o mle-full-gmm.o \ - model-common.o decodable-am-diag-gmm.o model-test-common.o \ - ebw-diag-gmm.o indirect-diff-diag-gmm.o + model-common.o decodable-am-diag-gmm.o model-test-common.o \ + ebw-diag-gmm.o indirect-diff-diag-gmm.o LIBNAME = kaldi-gmm ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a diff --git a/src/gmm/decodable-am-diag-gmm.h b/src/gmm/decodable-am-diag-gmm.h index 745b4f61b14..f2e03005708 100644 --- a/src/gmm/decodable-am-diag-gmm.h +++ b/src/gmm/decodable-am-diag-gmm.h @@ -26,11 +26,9 @@ #include "base/kaldi-common.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "itf/decodable-itf.h" -#include "transform/regression-tree.h" -#include "transform/regtree-fmllr-diag-gmm.h" -#include "transform/regtree-mllr-diag-gmm.h" + namespace kaldi { @@ -46,13 +44,13 @@ class DecodableAmDiagGmmUnmapped : public DecodableInterface { public: /// If you set log_sum_exp_prune to a value greater than 0 it will prune /// in the LogSumExp operation (larger = more exact); I suggest 5. - /// This is advisable if it's spending a long time doing exp - /// operations. + /// This is advisable if it's spending a long time doing exp + /// operations. DecodableAmDiagGmmUnmapped(const AmDiagGmm &am, const Matrix &feats, BaseFloat log_sum_exp_prune = -1.0): acoustic_model_(am), feature_matrix_(feats), - previous_frame_(-1), log_sum_exp_prune_(log_sum_exp_prune), + previous_frame_(-1), log_sum_exp_prune_(log_sum_exp_prune), data_squared_(feats.NumCols()) { ResetLogLikeCache(); } @@ -63,7 +61,7 @@ class DecodableAmDiagGmmUnmapped : public DecodableInterface { return LogLikelihoodZeroBased(frame, state_index - 1); } virtual int32 NumFramesReady() const { return feature_matrix_.NumRows(); } - + // Indices are one-based! This is for compatibility with OpenFst. virtual int32 NumIndices() const { return acoustic_model_.NumPdfs(); } @@ -98,7 +96,7 @@ class DecodableAmDiagGmmUnmapped : public DecodableInterface { class DecodableAmDiagGmm: public DecodableAmDiagGmmUnmapped { public: DecodableAmDiagGmm(const AmDiagGmm &am, - const TransitionModel &tm, + const Transitions &tm, const Matrix &feats, BaseFloat log_sum_exp_prune = -1.0) : DecodableAmDiagGmmUnmapped(am, feats, log_sum_exp_prune), @@ -107,21 +105,21 @@ class DecodableAmDiagGmm: public DecodableAmDiagGmmUnmapped { // Note, frames are numbered from zero. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { return LogLikelihoodZeroBased(frame, - trans_model_.TransitionIdToPdf(tid)); + trans_model_.TransitionIdToPdfFast(tid)); } // Indices are one-based! This is for compatibility with OpenFst. virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } - const TransitionModel *TransModel() { return &trans_model_; } + const Transitions *TransModel() { return &trans_model_; } private: // want to access public to have pdf id information - const TransitionModel &trans_model_; // for tid to pdf mapping + const Transitions &trans_model_; // for tid to pdf mapping KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmDiagGmm); }; class DecodableAmDiagGmmScaled: public DecodableAmDiagGmmUnmapped { public: DecodableAmDiagGmmScaled(const AmDiagGmm &am, - const TransitionModel &tm, + const Transitions &tm, const Matrix &feats, BaseFloat scale, BaseFloat log_sum_exp_prune = -1.0): @@ -131,7 +129,7 @@ class DecodableAmDiagGmmScaled: public DecodableAmDiagGmmUnmapped { // This version of the initializer takes ownership of the pointer // "feats" and will delete it when this class is destroyed. DecodableAmDiagGmmScaled(const AmDiagGmm &am, - const TransitionModel &tm, + const Transitions &tm, BaseFloat scale, BaseFloat log_sum_exp_prune, Matrix *feats): @@ -140,20 +138,20 @@ class DecodableAmDiagGmmScaled: public DecodableAmDiagGmmUnmapped { // Note, frames are numbered from zero but transition-ids from one. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - return scale_*LogLikelihoodZeroBased(frame, - trans_model_.TransitionIdToPdf(tid)); + return scale_ * LogLikelihoodZeroBased( + frame, trans_model_.TransitionIdToPdfFast(tid)); } // Indices are one-based! This is for compatibility with OpenFst. virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } - const TransitionModel *TransModel() { return &trans_model_; } + const Transitions *TransModel() { return &trans_model_; } virtual ~DecodableAmDiagGmmScaled() { delete delete_feats_; } - + private: // want to access it public to have pdf id information - const TransitionModel &trans_model_; // for transition-id to pdf mapping + const Transitions &trans_model_; // for transition-id to pdf mapping BaseFloat scale_; Matrix *delete_feats_; KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmDiagGmmScaled); diff --git a/src/gmmbin/Makefile b/src/gmmbin/Makefile index 82d10abe9ce..1e926e88432 100644 --- a/src/gmmbin/Makefile +++ b/src/gmmbin/Makefile @@ -6,25 +6,24 @@ include ../kaldi.mk BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \ gmm-decode-faster gmm-decode-simple gmm-align-compiled \ - gmm-sum-accs gmm-est-regtree-fmllr gmm-acc-stats-twofeats \ + gmm-sum-accs gmm-acc-stats-twofeats \ gmm-acc-stats gmm-init-lvtln gmm-est-lvtln-trans gmm-train-lvtln-special \ gmm-acc-mllt gmm-mixup gmm-init-model gmm-transform-means \ - gmm-make-regtree gmm-decode-faster-regtree-fmllr gmm-post-to-gpost \ - gmm-est-fmllr-gpost gmm-est-fmllr gmm-est-regtree-fmllr-ali \ - gmm-est-regtree-mllr gmm-compute-likes \ - gmm-decode-faster-regtree-mllr gmm-latgen-simple \ + gmm-post-to-gpost \ + gmm-est-fmllr-gpost gmm-est-fmllr gmm-compute-likes \ + gmm-latgen-simple \ gmm-rescore-lattice gmm-decode-biglm-faster \ gmm-est-gaussians-ebw gmm-est-weights-ebw gmm-latgen-faster gmm-copy \ gmm-global-acc-stats gmm-global-est gmm-global-sum-accs gmm-gselect \ gmm-latgen-biglm-faster gmm-ismooth-stats gmm-global-get-frame-likes \ gmm-global-est-fmllr gmm-global-to-fgmm gmm-global-acc-stats-twofeats \ - gmm-global-copy gmm-fmpe-acc-stats gmm-acc-stats2 gmm-init-model-flat gmm-info \ + gmm-global-copy gmm-acc-stats2 gmm-init-model-flat gmm-info \ gmm-get-stats-deriv gmm-est-rescale gmm-boost-silence \ gmm-basis-fmllr-accs gmm-basis-fmllr-training gmm-est-basis-fmllr \ gmm-est-map gmm-adapt-map gmm-latgen-map gmm-basis-fmllr-accs-gpost \ gmm-est-basis-fmllr-gpost gmm-latgen-faster-parallel \ - gmm-est-fmllr-raw gmm-est-fmllr-raw-gpost gmm-global-init-from-feats \ - gmm-global-info gmm-latgen-faster-regtree-fmllr gmm-est-fmllr-global \ + gmm-global-init-from-feats \ + gmm-global-info gmm-est-fmllr-global \ gmm-acc-mllt-global gmm-transform-means-global gmm-global-get-post \ gmm-global-gselect-to-post gmm-global-est-lvtln-trans gmm-init-biphone @@ -38,7 +37,7 @@ ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/gmmbin/gmm-acc-mllt-global.cc b/src/gmmbin/gmm-acc-mllt-global.cc index bed91c053d3..ac3ec2237c9 100644 --- a/src/gmmbin/gmm-acc-mllt-global.cc +++ b/src/gmmbin/gmm-acc-mllt-global.cc @@ -23,7 +23,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/mllt.h" #include "hmm/posterior.h" diff --git a/src/gmmbin/gmm-acc-mllt.cc b/src/gmmbin/gmm-acc-mllt.cc index 6e57f082a62..be0d501b3f5 100644 --- a/src/gmmbin/gmm-acc-mllt.cc +++ b/src/gmmbin/gmm-acc-mllt.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/mllt.h" #include "hmm/posterior.h" @@ -58,7 +58,7 @@ int main(int argc, char *argv[]) { typedef kaldi::int32 int32; AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-acc-stats-ali.cc b/src/gmmbin/gmm-acc-stats-ali.cc index 5552d45738e..b20212b4771 100644 --- a/src/gmmbin/gmm-acc-stats-ali.cc +++ b/src/gmmbin/gmm-acc-stats-ali.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" @@ -53,7 +53,7 @@ int main(int argc, char *argv[]) { accs_wxfilename = po.GetArg(4); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-acc-stats-twofeats.cc b/src/gmmbin/gmm-acc-stats-twofeats.cc index 05f94ff5ef6..3bae910233b 100644 --- a/src/gmmbin/gmm-acc-stats-twofeats.cc +++ b/src/gmmbin/gmm-acc-stats-twofeats.cc @@ -23,7 +23,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "hmm/posterior.h" @@ -59,7 +59,7 @@ int main(int argc, char *argv[]) { typedef kaldi::int32 int32; AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-acc-stats.cc b/src/gmmbin/gmm-acc-stats.cc index e213fffdeff..beeee8ec758 100644 --- a/src/gmmbin/gmm-acc-stats.cc +++ b/src/gmmbin/gmm-acc-stats.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "hmm/posterior.h" @@ -59,7 +59,7 @@ int main(int argc, char *argv[]) { AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-acc-stats2.cc b/src/gmmbin/gmm-acc-stats2.cc index 70730c8ca7d..30f3ff80e10 100644 --- a/src/gmmbin/gmm-acc-stats2.cc +++ b/src/gmmbin/gmm-acc-stats2.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "hmm/posterior.h" @@ -62,7 +62,7 @@ int main(int argc, char *argv[]) { AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_rxfilename, &binary); diff --git a/src/gmmbin/gmm-adapt-map.cc b/src/gmmbin/gmm-adapt-map.cc index ec3eb8cea9b..30fbc1e8d73 100644 --- a/src/gmmbin/gmm-adapt-map.cc +++ b/src/gmmbin/gmm-adapt-map.cc @@ -25,7 +25,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "hmm/posterior.h" @@ -72,7 +72,7 @@ int main(int argc, char *argv[]) { MapAmDiagGmmWriter map_am_writer(map_am_wspecifier); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input is(model_filename, &binary); diff --git a/src/gmmbin/gmm-align-compiled.cc b/src/gmmbin/gmm-align-compiled.cc index 36349774773..02beb372b60 100644 --- a/src/gmmbin/gmm-align-compiled.cc +++ b/src/gmmbin/gmm-align-compiled.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-utils.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" @@ -77,7 +77,7 @@ int main(int argc, char *argv[]) { alignment_wspecifier = po.GetArg(4), scores_wspecifier = po.GetOptArg(5); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-align.cc b/src/gmmbin/gmm-align.cc index c9c2fde11f6..e84a90cdb9a 100644 --- a/src/gmmbin/gmm-align.cc +++ b/src/gmmbin/gmm-align.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-utils.h" #include "decoder/decoder-wrappers.h" #include "decoder/training-graph-compiler.h" @@ -73,7 +73,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_in_filename, &ctx_dep); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-basis-fmllr-accs-gpost.cc b/src/gmmbin/gmm-basis-fmllr-accs-gpost.cc index f8f7b5d3433..9001b64ae82 100644 --- a/src/gmmbin/gmm-basis-fmllr-accs-gpost.cc +++ b/src/gmmbin/gmm-basis-fmllr-accs-gpost.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "transform/basis-fmllr-diag-gmm.h" #include "hmm/posterior.h" @@ -34,7 +34,7 @@ using std::vector; namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const GaussPost &gpost, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { for (size_t i = 0; i < gpost.size(); i++) { @@ -81,7 +81,7 @@ int main(int argc, char *argv[]) { gpost_rspecifier = po.GetArg(3), accs_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-basis-fmllr-accs.cc b/src/gmmbin/gmm-basis-fmllr-accs.cc index 58b365318f0..d78d652dfc5 100644 --- a/src/gmmbin/gmm-basis-fmllr-accs.cc +++ b/src/gmmbin/gmm-basis-fmllr-accs.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "transform/basis-fmllr-diag-gmm.h" #include "hmm/posterior.h" @@ -34,7 +34,7 @@ using std::vector; namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const Posterior &post, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { Posterior pdf_post; @@ -82,7 +82,7 @@ int main(int argc, char *argv[]) { post_rspecifier = po.GetArg(3), accs_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-basis-fmllr-training.cc b/src/gmmbin/gmm-basis-fmllr-training.cc index 3d93c3ca877..d433f6903f6 100644 --- a/src/gmmbin/gmm-basis-fmllr-training.cc +++ b/src/gmmbin/gmm-basis-fmllr-training.cc @@ -25,7 +25,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "transform/basis-fmllr-diag-gmm.h" @@ -53,7 +53,7 @@ int main(int argc, char *argv[]) { model_rxfilename = po.GetArg(1), basis_wspecifier = po.GetArg(2); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-boost-silence.cc b/src/gmmbin/gmm-boost-silence.cc index 7c9e4c82806..ef57f1190cb 100644 --- a/src/gmmbin/gmm-boost-silence.cc +++ b/src/gmmbin/gmm-boost-silence.cc @@ -19,7 +19,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/am-diag-gmm.h" int main(int argc, char *argv[]) { @@ -67,7 +67,7 @@ int main(int argc, char *argv[]) { } AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_rxfilename, &binary_read); diff --git a/src/gmmbin/gmm-compute-likes.cc b/src/gmmbin/gmm-compute-likes.cc index 78c813e1c3b..c7101f1a9ae 100644 --- a/src/gmmbin/gmm-compute-likes.cc +++ b/src/gmmbin/gmm-compute-likes.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "base/timer.h" @@ -55,7 +55,7 @@ int main(int argc, char *argv[]) { AmDiagGmm am_gmm; { bool binary; - TransitionModel trans_model; // not needed. + Transitions trans_model; // not needed. Input ki(model_in_filename, &binary); trans_model.Read(ki.Stream(), binary); am_gmm.Read(ki.Stream(), binary); diff --git a/src/gmmbin/gmm-copy.cc b/src/gmmbin/gmm-copy.cc index 0b33bc6d67f..bd42aeb2a25 100644 --- a/src/gmmbin/gmm-copy.cc +++ b/src/gmmbin/gmm-copy.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(2); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-decode-biglm-faster.cc b/src/gmmbin/gmm-decode-biglm-faster.cc index 6e47d68de3c..9e7845e7849 100644 --- a/src/gmmbin/gmm-decode-biglm-faster.cc +++ b/src/gmmbin/gmm-decode-biglm-faster.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/biglm-faster-decoder.h" #include "gmm/decodable-am-diag-gmm.h" @@ -111,7 +111,7 @@ int main(int argc, char *argv[]) alignment_wspecifier = po.GetOptArg(7), lattice_wspecifier = po.GetOptArg(8); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-decode-faster-regtree-fmllr.cc b/src/gmmbin/gmm-decode-faster-regtree-fmllr.cc deleted file mode 100644 index ca39cbe8cb7..00000000000 --- a/src/gmmbin/gmm-decode-faster-regtree-fmllr.cc +++ /dev/null @@ -1,290 +0,0 @@ -// gmmbin/gmm-decode-faster-regtree-fmllr.cc - -// Copyright 2009-2012 Microsoft Corporation; Saarland University; -// Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "transform/regression-tree.h" -#include "transform/regtree-fmllr-diag-gmm.h" -#include "transform/fmllr-diag-gmm.h" -#include "fstext/fstext-lib.h" -#include "decoder/faster-decoder.h" -#include "transform/decodable-am-diag-gmm-regtree.h" -#include "base/timer.h" -#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc - -using fst::SymbolTable; -using fst::VectorFst; -using fst::StdArc; -using kaldi::BaseFloat; -using std::string; -using std::vector; -using kaldi::LatticeWeight; -using kaldi::LatticeArc; - -struct DecodeInfo { - public: - DecodeInfo(const kaldi::AmDiagGmm &am, - const kaldi::TransitionModel &tm, kaldi::FasterDecoder *decoder, - BaseFloat scale, bool allow_partial, - const kaldi::Int32VectorWriter &wwriter, - const kaldi::Int32VectorWriter &awriter, fst::SymbolTable *wsyms) - : acoustic_model(am), trans_model(tm), decoder(decoder), - acoustic_scale(scale), allow_partial(allow_partial), words_writer(wwriter), - alignment_writer(awriter), word_syms(wsyms) {} - - const kaldi::AmDiagGmm &acoustic_model; - const kaldi::TransitionModel &trans_model; - kaldi::FasterDecoder *decoder; - BaseFloat acoustic_scale; - bool allow_partial; - const kaldi::Int32VectorWriter &words_writer; - const kaldi::Int32VectorWriter &alignment_writer; - fst::SymbolTable *word_syms; - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(DecodeInfo); -}; - -bool DecodeUtterance(kaldi::FasterDecoder *decoder, - kaldi::DecodableInterface *decodable, - DecodeInfo *info, - const string &uttid, - int32 num_frames, - BaseFloat *total_like) { - decoder->Decode(decodable); - KALDI_LOG << "Length of file is " << num_frames; - - VectorFst decoded; // linear FST. - if ( (info->allow_partial || decoder->ReachedFinal()) - && decoder->GetBestPath(&decoded) ) { - if (!decoder->ReachedFinal()) - KALDI_WARN << "Decoder did not reach end-state, outputting partial " - "traceback."; - - vector alignment, words; - LatticeWeight weight; - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - - info->words_writer.Write(uttid, words); - if (info->alignment_writer.IsOpen()) - info->alignment_writer.Write(uttid, alignment); - if (info->word_syms != NULL) { - std::ostringstream ss; - ss << uttid << ' '; - for (size_t i = 0; i < words.size(); i++) { - string s = info->word_syms->Find(words[i]); - if (s == "") - KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; - ss << s << ' '; - } - ss << '\n'; - KALDI_LOG << ss.str(); - } - - BaseFloat like = -weight.Value1() -weight.Value2(); - KALDI_LOG << "Log-like per frame = " << (like/num_frames); - (*total_like) += like; - return true; - } else { - KALDI_WARN << "Did not successfully decode utterance, length = " - << num_frames; - return false; - } -} - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - typedef kaldi::int32 int32; - - const char *usage = "Decode features using GMM-based model.\n" - "Usage: gmm-decode-faster-regtree-fmllr [options] model-in fst-in " - "regtree-in features-rspecifier transforms-rspecifier " - "words-wspecifier [alignments-wspecifier]\n"; - ParseOptions po(usage); - bool binary = true; - bool allow_partial = true; - BaseFloat acoustic_scale = 0.1; - - std::string word_syms_filename, utt2spk_rspecifier; - FasterDecoderOptions decoder_opts; - decoder_opts.Register(&po, true); // true == include obscure settings. - po.Register("utt2spk", &utt2spk_rspecifier, "rspecifier for utterance to " - "speaker map"); - po.Register("binary", &binary, "Write output in binary mode"); - po.Register("acoustic-scale", &acoustic_scale, - "Scaling factor for acoustic likelihoods"); - po.Register("word-symbol-table", &word_syms_filename, - "Symbol table for words [for debug output]"); - po.Register("allow-partial", &allow_partial, - "Produce output even when final state was not reached"); - po.Read(argc, argv); - - if (po.NumArgs() < 6 || po.NumArgs() > 7) { - po.PrintUsage(); - exit(1); - } - - std::string model_in_filename = po.GetArg(1), - fst_in_filename = po.GetArg(2), - regtree_filename = po.GetArg(3), - feature_rspecifier = po.GetArg(4), - xforms_rspecifier = po.GetArg(5), - words_wspecifier = po.GetArg(6), - alignment_wspecifier = po.GetOptArg(7); - - TransitionModel trans_model; - AmDiagGmm am_gmm; - { - bool binary_read; - Input ki(model_in_filename, &binary_read); - trans_model.Read(ki.Stream(), binary_read); - am_gmm.Read(ki.Stream(), binary_read); - } - - VectorFst *decode_fst = fst::ReadFstKaldi(fst_in_filename); - - RegressionTree regtree; - { - bool binary_read; - Input in(regtree_filename, &binary_read); - regtree.Read(in.Stream(), binary_read, am_gmm); - } - - RandomAccessRegtreeFmllrDiagGmmReaderMapped fmllr_reader(xforms_rspecifier, - utt2spk_rspecifier); - - Int32VectorWriter words_writer(words_wspecifier); - - Int32VectorWriter alignment_writer(alignment_wspecifier); - - fst::SymbolTable *word_syms = NULL; - if (word_syms_filename != "") { - word_syms = fst::SymbolTable::ReadText(word_syms_filename); - if (!word_syms) { - KALDI_ERR << "Could not read symbol table from file " - << word_syms_filename; - } - } - - BaseFloat tot_like = 0.0; - kaldi::int64 frame_count = 0; - int num_success = 0, num_fail = 0; - FasterDecoder decoder(*decode_fst, decoder_opts); - - Timer timer; - - DecodeInfo decode_info(am_gmm, trans_model, &decoder, acoustic_scale, - allow_partial, words_writer, alignment_writer, - word_syms); - - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - - Matrix features(feature_reader.Value()); - feature_reader.FreeCurrent(); - if (features.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << utt; - num_fail++; - continue; - } - - if (!fmllr_reader.HasKey(utt)) { // Decode without FMLLR if none found - KALDI_WARN << "No FMLLR transform for key " << utt << - ", decoding without fMLLR."; - kaldi::DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, - features, - acoustic_scale); - if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info, - utt, features.NumRows(), &tot_like)) { - frame_count += gmm_decodable.NumFramesReady(); - num_success++; - } else { - num_fail++; - } - continue; - } - - // If found, load the transforms for the current utterance. - RegtreeFmllrDiagGmm fmllr(fmllr_reader.Value(utt)); - if (fmllr.NumRegClasses() == 1) { - Matrix xformed_features(features); - Matrix fmllr_matrix; - fmllr.GetXformMatrix(0, &fmllr_matrix); - for (int32 i = 0; i < xformed_features.NumRows(); i++) { - SubVector row(xformed_features, i); - ApplyAffineTransform(fmllr_matrix, &row); - } - kaldi::DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, - xformed_features, - acoustic_scale); - - if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info, - utt, xformed_features.NumRows(), &tot_like)) { - frame_count += gmm_decodable.NumFramesReady(); - num_success++; - } else { - num_fail++; - } - } else { - kaldi::DecodableAmDiagGmmRegtreeFmllr gmm_decodable(am_gmm, trans_model, - features, fmllr, - regtree, - acoustic_scale); - if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info, - utt, features.NumRows(), &tot_like)) { - frame_count += gmm_decodable.NumFramesReady(); - num_success++; - } else { - num_fail++; - } - } - } // end looping over all utterances - - KALDI_LOG << "Average log-likelihood per frame is " << (tot_like - / frame_count) << " over " << frame_count << " frames."; - - double elapsed = timer.Elapsed(); - KALDI_LOG << "Time taken [excluding initialization] " << elapsed - << "s: real-time factor assuming 100 frames/sec is " - << (elapsed * 100.0 / frame_count); - KALDI_LOG << "Done " << num_success << " utterances, failed for " - << num_fail; - - delete word_syms; - delete decode_fst; - if (num_success != 0) - return 0; - else - return 1; - } - catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - - diff --git a/src/gmmbin/gmm-decode-faster-regtree-mllr.cc b/src/gmmbin/gmm-decode-faster-regtree-mllr.cc deleted file mode 100644 index 9a5d9486b9f..00000000000 --- a/src/gmmbin/gmm-decode-faster-regtree-mllr.cc +++ /dev/null @@ -1,267 +0,0 @@ -// gmmbin/gmm-decode-faster-regtree-mllr.cc - -// Copyright 2009-2013 Microsoft Corporation; Saarland University; -// Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "transform/regression-tree.h" -#include "transform/regtree-mllr-diag-gmm.h" -#include "fstext/fstext-lib.h" -#include "decoder/faster-decoder.h" -#include "transform/decodable-am-diag-gmm-regtree.h" -#include "base/timer.h" -#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc - -using fst::SymbolTable; -using fst::VectorFst; -using fst::StdArc; -using kaldi::BaseFloat; -using std::string; -using std::vector; -using kaldi::LatticeWeight; -using kaldi::LatticeArc; - -struct DecodeInfo { - public: - DecodeInfo(const kaldi::AmDiagGmm &am, - const kaldi::TransitionModel &tm, kaldi::FasterDecoder *decoder, - BaseFloat scale, bool allow_partial, - const kaldi::Int32VectorWriter &wwriter, - const kaldi::Int32VectorWriter &awriter, fst::SymbolTable *wsyms) - : acoustic_model(am), trans_model(tm), decoder(decoder), - acoustic_scale(scale), allow_partial(allow_partial), words_writer(wwriter), - alignment_writer(awriter), word_syms(wsyms) {} - - const kaldi::AmDiagGmm &acoustic_model; - const kaldi::TransitionModel &trans_model; - kaldi::FasterDecoder *decoder; - BaseFloat acoustic_scale; - bool allow_partial; - const kaldi::Int32VectorWriter &words_writer; - const kaldi::Int32VectorWriter &alignment_writer; - fst::SymbolTable *word_syms; - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(DecodeInfo); -}; - -bool DecodeUtterance(kaldi::FasterDecoder *decoder, - kaldi::DecodableInterface *decodable, - DecodeInfo *info, - const string &uttid, - int32 num_frames, - BaseFloat *total_like) { - decoder->Decode(decodable); - KALDI_LOG << "Length of file is " << num_frames;; - - VectorFst decoded; // linear FST. - if ( (info->allow_partial || decoder->ReachedFinal()) - && decoder->GetBestPath(&decoded) ) { - if (!decoder->ReachedFinal()) - KALDI_WARN << "Decoder did not reach end-state, outputting partial " - "traceback."; - - vector alignment, words; - LatticeWeight weight; - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - - info->words_writer.Write(uttid, words); - if (info->alignment_writer.IsOpen()) - info->alignment_writer.Write(uttid, alignment); - if (info->word_syms != NULL) { - std::ostringstream ss; - ss << uttid << ' '; - for (size_t i = 0; i < words.size(); i++) { - string s = info->word_syms->Find(words[i]); - if (s == "") - KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; - ss << s << ' '; - } - ss << '\n'; - KALDI_LOG << ss.str(); - } - - BaseFloat like = -weight.Value1() -weight.Value2(); - KALDI_LOG << "Log-like per frame = " << (like/num_frames); - (*total_like) += like; - return true; - } else { - KALDI_WARN << "Did not successfully decode utterance, length = " - << num_frames; - return false; - } -} - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - typedef kaldi::int32 int32; - - const char *usage = "Decode features using GMM-based model.\n" - "Usage: gmm-decode-faster-regtree-mllr [options] model-in fst-in " - "regtree-in features-rspecifier transforms-rspecifier " - "words-wspecifier [alignments-wspecifier]\n"; - ParseOptions po(usage); - bool binary = true; - bool allow_partial = true; - BaseFloat acoustic_scale = 0.1; - - std::string word_syms_filename, utt2spk_rspecifier; - FasterDecoderOptions decoder_opts; - decoder_opts.Register(&po, true); // true == include obscure settings. - po.Register("utt2spk", &utt2spk_rspecifier, "rspecifier for utterance to " - "speaker map"); - po.Register("binary", &binary, "Write output in binary mode"); - po.Register("acoustic-scale", &acoustic_scale, - "Scaling factor for acoustic likelihoods"); - po.Register("word-symbol-table", &word_syms_filename, - "Symbol table for words [for debug output]"); - po.Register("allow-partial", &allow_partial, - "Produce output even when final state was not reached"); - po.Read(argc, argv); - - if (po.NumArgs() < 6 || po.NumArgs() > 7) { - po.PrintUsage(); - exit(1); - } - - std::string model_in_filename = po.GetArg(1), - fst_in_filename = po.GetArg(2), - regtree_filename = po.GetArg(3), - feature_rspecifier = po.GetArg(4), - xforms_rspecifier = po.GetArg(5), - words_wspecifier = po.GetArg(6), - alignment_wspecifier = po.GetOptArg(7); - - TransitionModel trans_model; - AmDiagGmm am_gmm; - { - bool binary_read; - Input ki(model_in_filename, &binary_read); - trans_model.Read(ki.Stream(), binary_read); - am_gmm.Read(ki.Stream(), binary_read); - } - - VectorFst *decode_fst = fst::ReadFstKaldi(fst_in_filename); - - RegressionTree regtree; - { - bool binary_read; - Input in(regtree_filename, &binary_read); - regtree.Read(in.Stream(), binary_read, am_gmm); - } - - RandomAccessRegtreeMllrDiagGmmReaderMapped mllr_reader(xforms_rspecifier, - utt2spk_rspecifier); - - Int32VectorWriter words_writer(words_wspecifier); - - Int32VectorWriter alignment_writer(alignment_wspecifier); - - fst::SymbolTable *word_syms = NULL; - if (word_syms_filename != "") { - word_syms = fst::SymbolTable::ReadText(word_syms_filename); - if (!word_syms) { - KALDI_ERR << "Could not read symbol table from file " - << word_syms_filename; - } - } - - BaseFloat tot_like = 0.0; - kaldi::int64 frame_count = 0; - int num_success = 0, num_fail = 0; - FasterDecoder decoder(*decode_fst, decoder_opts); - - Timer timer; - - DecodeInfo decode_info(am_gmm, trans_model, &decoder, acoustic_scale, - allow_partial, words_writer, alignment_writer, - word_syms); - - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - - Matrix features(feature_reader.Value()); - feature_reader.FreeCurrent(); - if (features.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << utt; - num_fail++; - continue; - } - - if (!mllr_reader.HasKey(utt)) { // Decode without MLLR if none found - KALDI_WARN << "No MLLR transform for key " << utt << - ", decoding without MLLR."; - kaldi::DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, - features, - acoustic_scale); - if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info, - utt, features.NumRows(), &tot_like)) { - frame_count += gmm_decodable.NumFramesReady(); - num_success++; - } else { - num_fail++; - } - continue; - } - - // If found, load the transforms for the current utterance. - const RegtreeMllrDiagGmm &mllr = mllr_reader.Value(utt); - kaldi::DecodableAmDiagGmmRegtreeMllr gmm_decodable(am_gmm, trans_model, - features, mllr, - regtree, - acoustic_scale); - if (DecodeUtterance(&decoder, &gmm_decodable, &decode_info, - utt, features.NumRows(), &tot_like)) { - frame_count += gmm_decodable.NumFramesReady(); - num_success++; - } else { - num_fail++; - } - } // end looping over all utterances - - double elapsed = timer.Elapsed(); - KALDI_LOG << "Time taken [excluding initialization] " << elapsed - << "s: real-time factor assuming 100 frames/sec is " - << (elapsed * 100.0 / frame_count); - KALDI_LOG << "Done " << num_success << " utterances, failed for " - << num_fail; - KALDI_LOG << "Overall log-likelihood per frame is " - << (tot_like / frame_count) << " over " << frame_count - << " frames."; - - delete decode_fst; - if (num_success != 0) - return 0; - else - return 1; - } - catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - - diff --git a/src/gmmbin/gmm-decode-faster.cc b/src/gmmbin/gmm-decode-faster.cc index 34c4ff2c37e..438e3d9c9d1 100644 --- a/src/gmmbin/gmm-decode-faster.cc +++ b/src/gmmbin/gmm-decode-faster.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/faster-decoder.h" #include "gmm/decodable-am-diag-gmm.h" @@ -101,7 +101,7 @@ int main(int argc, char *argv[]) { alignment_wspecifier = po.GetOptArg(5), lattice_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-decode-simple.cc b/src/gmmbin/gmm-decode-simple.cc index 5ef35552dc0..ef87585cc1e 100644 --- a/src/gmmbin/gmm-decode-simple.cc +++ b/src/gmmbin/gmm-decode-simple.cc @@ -23,7 +23,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/simple-decoder.h" #include "gmm/decodable-am-diag-gmm.h" @@ -78,7 +78,7 @@ int main(int argc, char *argv[]) { alignment_wspecifier = po.GetOptArg(5), lattice_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-est-basis-fmllr-gpost.cc b/src/gmmbin/gmm-est-basis-fmllr-gpost.cc index 54b92d8aa61..3d864c88086 100644 --- a/src/gmmbin/gmm-est-basis-fmllr-gpost.cc +++ b/src/gmmbin/gmm-est-basis-fmllr-gpost.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "transform/basis-fmllr-diag-gmm.h" #include "hmm/posterior.h" @@ -34,7 +34,7 @@ using std::vector; namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const GaussPost &gpost, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { for (size_t i = 0; i < gpost.size(); i++) { @@ -87,7 +87,7 @@ int main(int argc, char *argv[]) { gpost_rspecifier = po.GetArg(4), trans_wspecifier = po.GetArg(5); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-est-basis-fmllr.cc b/src/gmmbin/gmm-est-basis-fmllr.cc index 0d163169ce2..fe64a1b2166 100644 --- a/src/gmmbin/gmm-est-basis-fmllr.cc +++ b/src/gmmbin/gmm-est-basis-fmllr.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "transform/basis-fmllr-diag-gmm.h" #include "hmm/posterior.h" @@ -34,7 +34,7 @@ using std::vector; namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const Posterior &post, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { Posterior pdf_post; @@ -89,7 +89,7 @@ int main(int argc, char *argv[]) { post_rspecifier = po.GetArg(4), trans_wspecifier = po.GetArg(5); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-est-fmllr-global.cc b/src/gmmbin/gmm-est-fmllr-global.cc index b3af0780aa5..d167ba25890 100644 --- a/src/gmmbin/gmm-est-fmllr-global.cc +++ b/src/gmmbin/gmm-est-fmllr-global.cc @@ -27,7 +27,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "hmm/posterior.h" diff --git a/src/gmmbin/gmm-est-fmllr-gpost.cc b/src/gmmbin/gmm-est-fmllr-gpost.cc index d1cae0d7f48..9d830737718 100644 --- a/src/gmmbin/gmm-est-fmllr-gpost.cc +++ b/src/gmmbin/gmm-est-fmllr-gpost.cc @@ -27,14 +27,14 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "hmm/posterior.h" namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const GaussPost &gpost, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { for (size_t i = 0; i < gpost.size(); i++) { @@ -81,7 +81,7 @@ int main(int argc, char *argv[]) { gpost_rspecifier = po.GetArg(3), trans_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-est-fmllr-raw-gpost.cc b/src/gmmbin/gmm-est-fmllr-raw-gpost.cc deleted file mode 100644 index 1f5a09f233b..00000000000 --- a/src/gmmbin/gmm-est-fmllr-raw-gpost.cc +++ /dev/null @@ -1,198 +0,0 @@ -// gmmbin/gmm-est-fmllr-raw-gpost.cc - -// Copyright 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Guoguo Chen - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "transform/fmllr-raw.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "util/common-utils.h" -#include "hmm/posterior.h" - -namespace kaldi { - - -void AccStatsForUtterance(const TransitionModel &trans_model, - const AmDiagGmm &am_gmm, - const GaussPost &gpost, - const Matrix &feats, - FmllrRawAccs *accs) { - for (size_t t = 0; t < gpost.size(); t++) { - for (size_t i = 0; i < gpost[t].size(); i++) { - int32 pdf = gpost[t][i].first; - const Vector &posterior(gpost[t][i].second); - accs->AccumulateFromPosteriors(am_gmm.GetPdf(pdf), - feats.Row(t), posterior); - } - } -} - - -} - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - using namespace kaldi; - const char *usage = - "Estimate fMLLR transforms in the space before splicing and linear transforms\n" - "such as LDA+MLLT, but using models in the space transformed by these transforms\n" - "Requires the original spliced features, and the full LDA+MLLT (or similar) matrix\n" - "including the 'rejected' rows (see the program get-full-lda-mat). Reads in\n" - "Gaussian-level posteriors.\n" - "Usage: gmm-est-fmllr-raw-gpost [options] " - " \n"; - - - int32 raw_feat_dim = 13; - ParseOptions po(usage); - FmllrRawOptions opts; - std::string spk2utt_rspecifier; - po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to " - "utterance-list map"); - po.Register("raw-feat-dim", &raw_feat_dim, "Dimension of raw features " - "prior to splicing"); - opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - std::string model_rxfilename = po.GetArg(1), - full_lda_mat_rxfilename = po.GetArg(2), - feature_rspecifier = po.GetArg(3), - gpost_rspecifier = po.GetArg(4), - transform_wspecifier = po.GetArg(5); - - AmDiagGmm am_gmm; - TransitionModel trans_model; - { - bool binary; - Input ki(model_rxfilename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - - Matrix full_lda_mat; - ReadKaldiObject(full_lda_mat_rxfilename, &full_lda_mat); - - RandomAccessGaussPostReader gpost_reader(gpost_rspecifier); - BaseFloatMatrixWriter transform_writer(transform_wspecifier); - - double tot_auxf_impr = 0.0, tot_count = 0.0; - - int32 num_done = 0, num_err = 0; - if (!spk2utt_rspecifier.empty()) { // Adapting per speaker - SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { - FmllrRawAccs accs(raw_feat_dim, am_gmm.Dim(), full_lda_mat); - std::string spk = spk2utt_reader.Key(); - const std::vector &uttlist = spk2utt_reader.Value(); - for (size_t i = 0; i < uttlist.size(); i++) { - std::string utt = uttlist[i]; - if (!feature_reader.HasKey(utt)) { - KALDI_WARN << "Features not found for utterance " << utt; - num_err++; - continue; - } - if (!gpost_reader.HasKey(utt)) { - KALDI_WARN << "Gaussian-level posteriors not found for utterance " << utt; - num_err++; - continue; - } - const Matrix &feats = feature_reader.Value(utt); - const GaussPost &gpost = gpost_reader.Value(utt); - if (static_cast(gpost.size()) != feats.NumRows()) { - KALDI_WARN << "Size mismatch between gposteriors " << gpost.size() - << " and features " << feats.NumRows(); - num_err++; - continue; - } - - AccStatsForUtterance(trans_model, am_gmm, gpost, feats, &accs); - num_done++; - } - - BaseFloat auxf_impr, count; - { - Matrix transform(raw_feat_dim, raw_feat_dim + 1); - transform.SetUnit(); - accs.Update(opts, &transform, &auxf_impr, &count); - transform_writer.Write(spk, transform); - } - KALDI_LOG << "For speaker " << spk << ", auxf-impr from raw fMLLR is " - << (auxf_impr/count) << " over " << count << " frames."; - tot_auxf_impr += auxf_impr; - tot_count += count; - } - } else { // --spk2utt option not given -> adapt per utterance. - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - std::string utt = feature_reader.Key(); - if (!gpost_reader.HasKey(utt)) { - KALDI_WARN << "Gaussian-level posteriors not found for utterance " << utt; - num_err++; - continue; - } - const Matrix &feats = feature_reader.Value(); - const GaussPost &gpost = gpost_reader.Value(utt); - - if (static_cast(gpost.size()) != feats.NumRows()) { - KALDI_WARN << "Size mismatch between posteriors " << gpost.size() - << " and features " << feats.NumRows(); - num_err++; - continue; - } - - FmllrRawAccs accs(raw_feat_dim, am_gmm.Dim(), full_lda_mat); - - AccStatsForUtterance(trans_model, am_gmm, gpost, feats, &accs); - - BaseFloat auxf_impr, count; - { - Matrix transform(raw_feat_dim, raw_feat_dim + 1); - transform.SetUnit(); - accs.Update(opts, &transform, &auxf_impr, &count); - transform_writer.Write(utt, transform); - } - KALDI_LOG << "For utterance " << utt << ", auxf-impr from raw fMLLR is " - << (auxf_impr/count) << " over " << count << " frames."; - tot_auxf_impr += auxf_impr; - tot_count += count; - num_done++; - } - } - - KALDI_LOG << "Processed " << num_done << " utterances, " - << num_err << " had errors."; - KALDI_LOG << "Overall raw-fMLLR auxf impr per frame is " - << (tot_auxf_impr / tot_count) << " over " << tot_count - << " frames."; - return (num_done != 0 ? 0 : 1); - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/gmmbin/gmm-est-fmllr-raw.cc b/src/gmmbin/gmm-est-fmllr-raw.cc deleted file mode 100644 index 5e83bfb1fb3..00000000000 --- a/src/gmmbin/gmm-est-fmllr-raw.cc +++ /dev/null @@ -1,199 +0,0 @@ -// gmmbin/gmm-est-fmllr-raw.cc - -// Copyright 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Guoguo Chen - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "transform/fmllr-raw.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "util/common-utils.h" -#include "hmm/posterior.h" - -namespace kaldi { - - -void AccStatsForUtterance(const TransitionModel &trans_model, - const AmDiagGmm &am_gmm, - const Posterior &post, - const Matrix &feats, - FmllrRawAccs *accs) { - Posterior pdf_post; - ConvertPosteriorToPdfs(trans_model, post, &pdf_post); - for (size_t t = 0; t < post.size(); t++) { - for (size_t i = 0; i < pdf_post[t].size(); i++) { - int32 pdf = pdf_post[t][i].first; - BaseFloat weight = pdf_post[t][i].second; - accs->AccumulateForGmm(am_gmm.GetPdf(pdf), - feats.Row(t), weight); - } - } -} - - -} - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - using namespace kaldi; - const char *usage = - "Estimate fMLLR transforms in the space before splicing and linear transforms\n" - "such as LDA+MLLT, but using models in the space transformed by these transforms\n" - "Requires the original spliced features, and the full LDA+MLLT (or similar) matrix\n" - "including the 'rejected' rows (see the program get-full-lda-mat)\n" - "Usage: gmm-est-fmllr-raw [options] " - " \n"; - - - int32 raw_feat_dim = 13; - ParseOptions po(usage); - FmllrRawOptions opts; - std::string spk2utt_rspecifier; - po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to " - "utterance-list map"); - po.Register("raw-feat-dim", &raw_feat_dim, "Dimension of raw features " - "prior to splicing"); - opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - std::string model_rxfilename = po.GetArg(1), - full_lda_mat_rxfilename = po.GetArg(2), - feature_rspecifier = po.GetArg(3), - post_rspecifier = po.GetArg(4), - transform_wspecifier = po.GetArg(5); - - AmDiagGmm am_gmm; - TransitionModel trans_model; - { - bool binary; - Input ki(model_rxfilename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - - Matrix full_lda_mat; - ReadKaldiObject(full_lda_mat_rxfilename, &full_lda_mat); - - RandomAccessPosteriorReader post_reader(post_rspecifier); - BaseFloatMatrixWriter transform_writer(transform_wspecifier); - - double tot_auxf_impr = 0.0, tot_count = 0.0; - - int32 num_done = 0, num_err = 0; - if (!spk2utt_rspecifier.empty()) { // Adapting per speaker - SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { - FmllrRawAccs accs(raw_feat_dim, am_gmm.Dim(), full_lda_mat); - std::string spk = spk2utt_reader.Key(); - const std::vector &uttlist = spk2utt_reader.Value(); - for (size_t i = 0; i < uttlist.size(); i++) { - std::string utt = uttlist[i]; - if (!feature_reader.HasKey(utt)) { - KALDI_WARN << "Features not found for utterance " << utt; - num_err++; - continue; - } - if (!post_reader.HasKey(utt)) { - KALDI_WARN << "Posteriors not found for utterance " << utt; - num_err++; - continue; - } - const Matrix &feats = feature_reader.Value(utt); - const Posterior &post = post_reader.Value(utt); - if (static_cast(post.size()) != feats.NumRows()) { - KALDI_WARN << "Size mismatch between posteriors " << post.size() - << " and features " << feats.NumRows(); - num_err++; - continue; - } - - AccStatsForUtterance(trans_model, am_gmm, post, feats, &accs); - num_done++; - } - - BaseFloat auxf_impr, count; - { - Matrix transform(raw_feat_dim, raw_feat_dim + 1); - transform.SetUnit(); - accs.Update(opts, &transform, &auxf_impr, &count); - transform_writer.Write(spk, transform); - } - KALDI_LOG << "For speaker " << spk << ", auxf-impr from raw fMLLR is " - << (auxf_impr/count) << " over " << count << " frames."; - tot_auxf_impr += auxf_impr; - tot_count += count; - } - } else { // --spk2utt option not given -> adapt per utterance. - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - std::string utt = feature_reader.Key(); - if (!post_reader.HasKey(utt)) { - KALDI_WARN << "Posteriors not found for utterance " << utt; - num_err++; - continue; - } - const Matrix &feats = feature_reader.Value(); - const Posterior &post = post_reader.Value(utt); - - if (static_cast(post.size()) != feats.NumRows()) { - KALDI_WARN << "Size mismatch between posteriors " << post.size() - << " and features " << feats.NumRows(); - num_err++; - continue; - } - - FmllrRawAccs accs(raw_feat_dim, am_gmm.Dim(), full_lda_mat); - - AccStatsForUtterance(trans_model, am_gmm, post, feats, &accs); - - BaseFloat auxf_impr, count; - { - Matrix transform(raw_feat_dim, raw_feat_dim + 1); - transform.SetUnit(); - accs.Update(opts, &transform, &auxf_impr, &count); - transform_writer.Write(utt, transform); - } - KALDI_LOG << "For utterance " << utt << ", auxf-impr from raw fMLLR is " - << (auxf_impr/count) << " over " << count << " frames."; - tot_auxf_impr += auxf_impr; - tot_count += count; - num_done++; - } - } - - KALDI_LOG << "Processed " << num_done << " utterances, " - << num_err << " had errors."; - KALDI_LOG << "Overall raw-fMLLR auxf impr per frame is " - << (tot_auxf_impr / tot_count) << " over " << tot_count - << " frames."; - return (num_done != 0 ? 0 : 1); - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/gmmbin/gmm-est-fmllr.cc b/src/gmmbin/gmm-est-fmllr.cc index 9f8dfd89143..c44a284b2f8 100644 --- a/src/gmmbin/gmm-est-fmllr.cc +++ b/src/gmmbin/gmm-est-fmllr.cc @@ -27,14 +27,14 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "hmm/posterior.h" namespace kaldi { void AccumulateForUtterance(const Matrix &feats, const Posterior &post, - const TransitionModel &trans_model, + const Transitions &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats) { Posterior pdf_post; @@ -83,7 +83,7 @@ int main(int argc, char *argv[]) { post_rspecifier = po.GetArg(3), trans_wspecifier = po.GetArg(4); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-est-gaussians-ebw.cc b/src/gmmbin/gmm-est-gaussians-ebw.cc index bbd53c2bec0..cfbb8ece02d 100644 --- a/src/gmmbin/gmm-est-gaussians-ebw.cc +++ b/src/gmmbin/gmm-est-gaussians-ebw.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/ebw-diag-gmm.h" int main(int argc, char *argv[]) { @@ -62,7 +62,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(4); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-est-lvtln-trans.cc b/src/gmmbin/gmm-est-lvtln-trans.cc index abfc24a6585..849560dd437 100644 --- a/src/gmmbin/gmm-est-lvtln-trans.cc +++ b/src/gmmbin/gmm-est-lvtln-trans.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/lvtln.h" #include "hmm/posterior.h" @@ -86,7 +86,7 @@ int main(int argc, char *argv[]) { { bool binary; Input ki(model_rxfilename, &binary); - TransitionModel trans_model; + Transitions trans_model; trans_model.Read(ki.Stream(), binary); am_gmm.Read(ki.Stream(), binary); } diff --git a/src/gmmbin/gmm-est-map.cc b/src/gmmbin/gmm-est-map.cc index 22ea8acda51..eb2b44d5961 100644 --- a/src/gmmbin/gmm-est-map.cc +++ b/src/gmmbin/gmm-est-map.cc @@ -22,7 +22,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" int main(int argc, char *argv[]) { @@ -65,7 +65,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(3); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-est-regtree-fmllr-ali.cc b/src/gmmbin/gmm-est-regtree-fmllr-ali.cc deleted file mode 100644 index 0158bae8298..00000000000 --- a/src/gmmbin/gmm-est-regtree-fmllr-ali.cc +++ /dev/null @@ -1,202 +0,0 @@ -// gmmbin/gmm-est-regtree-fmllr-ali.cc - -// Copyright 2009-2011 Saarland University; Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -using std::string; -#include -using std::vector; - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "transform/regtree-fmllr-diag-gmm.h" - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - using namespace kaldi; - const char *usage = - "Compute FMLLR transforms per-utterance (default) or per-speaker for " - "the supplied set of speakers (spk2utt option). Note: writes RegtreeFmllrDiagGmm objects\n" - "Usage: gmm-est-regtree-fmllr-ali [options] " - " \n"; - - ParseOptions po(usage); - string spk2utt_rspecifier; - bool binary = true; - po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to " - "utterance-list map"); - po.Register("binary", &binary, "Write output in binary mode"); - // register other modules - RegtreeFmllrOptions opts; - opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - string model_filename = po.GetArg(1), - feature_rspecifier = po.GetArg(2), - alignments_rspecifier = po.GetArg(3), - regtree_filename = po.GetArg(4), - xforms_wspecifier = po.GetArg(5); - - RandomAccessInt32VectorReader alignments_reader(alignments_rspecifier); - RegtreeFmllrDiagGmmWriter fmllr_writer(xforms_wspecifier); - - AmDiagGmm am_gmm; - TransitionModel trans_model; - { - bool binary; - Input ki(model_filename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - RegressionTree regtree; - { - bool binary; - Input in(regtree_filename, &binary); - regtree.Read(in.Stream(), binary, am_gmm); - } - - RegtreeFmllrDiagGmm fmllr_xforms; - RegtreeFmllrDiagGmmAccs fmllr_accs; - fmllr_accs.Init(regtree.NumBaseclasses(), am_gmm.Dim()); - - double tot_like = 0.0; - kaldi::int64 tot_t = 0; - - int32 num_done = 0, num_no_alignment = 0, num_other_error = 0; - double tot_objf_impr = 0.0, tot_t_objf = 0.0; - if (spk2utt_rspecifier != "") { // per-speaker adaptation - SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { - string spk = spk2utt_reader.Key(); - fmllr_accs.SetZero(); - const vector &uttlist = spk2utt_reader.Value(); - for (vector::const_iterator utt_itr = uttlist.begin(), - itr_end = uttlist.end(); utt_itr != itr_end; ++utt_itr) { - if (!feature_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find features for utterance " << *utt_itr; - continue; - } - if (!alignments_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find aligned transcription for utterance " - << *utt_itr; - num_no_alignment++; - continue; - } - const Matrix &feats = feature_reader.Value(*utt_itr); - const vector &alignment = alignments_reader.Value(*utt_itr); - if (static_cast(alignment.size()) != feats.NumRows()) { - KALDI_WARN << "Alignments has wrong size " << (alignment.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - BaseFloat file_like = 0.0; - for (size_t i = 0; i < alignment.size(); i++) { - int32 pdf_id = trans_model.TransitionIdToPdf(alignment[i]); - file_like += fmllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, 1.0); - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like - / alignment.size()) << " over " << alignment.size() - << " frames.\n"; - tot_like += file_like; - tot_t += alignment.size(); - num_done++; - if (num_done % 10 == 0) KALDI_VLOG(1) - << "Avg like per frame so far is " << (tot_like / tot_t) << '\n'; - } // end looping over all utterances of the current speaker - BaseFloat objf_impr, t; - fmllr_accs.Update(regtree, opts, &fmllr_xforms, &objf_impr, &t); - KALDI_LOG << "fMLLR objf improvement for speaker " << spk << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - fmllr_writer.Write(spk, fmllr_xforms); - } // end looping over speakers - } else { // per-utterance adaptation - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - string key = feature_reader.Key(); - if (!alignments_reader.HasKey(key)) { - KALDI_WARN << "Did not find aligned transcription for utterance " - << key; - num_no_alignment++; - continue; - } - const Matrix &feats = feature_reader.Value(); - const vector &alignment = alignments_reader.Value(key); - - if (static_cast(alignment.size()) != feats.NumRows()) { - KALDI_WARN << "Alignments has wrong size " << (alignment.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - num_done++; - BaseFloat file_like = 0.0; - fmllr_accs.SetZero(); - for (size_t i = 0; i < alignment.size(); i++) { - int32 pdf_id = trans_model.TransitionIdToPdf(alignment[i]); - file_like += fmllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, 1.0); - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like - / alignment.size()) << " over " << alignment.size() << " frames."; - tot_like += file_like; - tot_t += alignment.size(); - if (num_done % 10 == 0) KALDI_VLOG(1) - << "Avg like per frame so far is " << (tot_like / tot_t); - BaseFloat objf_impr, t; - fmllr_accs.Update(regtree, opts, &fmllr_xforms, &objf_impr, &t); - KALDI_LOG << "fMLLR objf improvement for utterance " << key << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - fmllr_writer.Write(feature_reader.Key(), fmllr_xforms); - } - } - - KALDI_LOG << "Overall objf improvement from fMLLR is " - << (tot_objf_impr/tot_t_objf) - << " per frame over " << tot_t_objf << " frames."; - KALDI_LOG << "Done " << num_done << " files, " << num_no_alignment - << " with no alignments, " << num_other_error - << " with other errors."; - KALDI_LOG << "Overall acoustic like per frame = " << (tot_like / tot_t) - << " over " << tot_t << " frames."; - return 0; - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/gmmbin/gmm-est-regtree-fmllr.cc b/src/gmmbin/gmm-est-regtree-fmllr.cc deleted file mode 100644 index ca807f07fd4..00000000000 --- a/src/gmmbin/gmm-est-regtree-fmllr.cc +++ /dev/null @@ -1,216 +0,0 @@ -// gmmbin/gmm-est-regtree-fmllr.cc - -// Copyright 2009-2011 Saarland University; Microsoft Corporation -// 2014 Guoguo Chen - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -using std::string; -#include -using std::vector; - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "hmm/posterior.h" -#include "transform/regtree-fmllr-diag-gmm.h" - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - using namespace kaldi; - const char *usage = - "Compute FMLLR transforms per-utterance (default) or per-speaker for " - "the supplied set of speakers (spk2utt option). Note: writes RegtreeFmllrDiagGmm objects\n" - "Usage: gmm-est-regtree-fmllr [options] " - " \n"; - - ParseOptions po(usage); - string spk2utt_rspecifier; - bool binary = true; - po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to " - "utterance-list map"); - po.Register("binary", &binary, "Write output in binary mode"); - // register other modules - RegtreeFmllrOptions opts; - opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - string model_filename = po.GetArg(1), - feature_rspecifier = po.GetArg(2), - posteriors_rspecifier = po.GetArg(3), - regtree_filename = po.GetArg(4), - xforms_wspecifier = po.GetArg(5); - - RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier); - RegtreeFmllrDiagGmmWriter fmllr_writer(xforms_wspecifier); - - AmDiagGmm am_gmm; - TransitionModel trans_model; - { - bool binary; - Input ki(model_filename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - RegressionTree regtree; - { - bool binary; - Input in(regtree_filename, &binary); - regtree.Read(in.Stream(), binary, am_gmm); - } - - RegtreeFmllrDiagGmm fmllr_xforms; - RegtreeFmllrDiagGmmAccs fmllr_accs; - fmllr_accs.Init(regtree.NumBaseclasses(), am_gmm.Dim()); - - double tot_like = 0.0, tot_t = 0; - - int32 num_done = 0, num_no_posterior = 0, num_other_error = 0; - double tot_objf_impr = 0.0, tot_t_objf = 0.0; - if (spk2utt_rspecifier != "") { // per-speaker adaptation - SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { - string spk = spk2utt_reader.Key(); - fmllr_accs.SetZero(); - const vector &uttlist = spk2utt_reader.Value(); - for (vector::const_iterator utt_itr = uttlist.begin(), - itr_end = uttlist.end(); utt_itr != itr_end; ++utt_itr) { - if (!feature_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find features for utterance " << *utt_itr; - continue; - } - if (!posteriors_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find posteriors for utterance " - << *utt_itr; - num_no_posterior++; - continue; - } - const Matrix &feats = feature_reader.Value(*utt_itr); - const Posterior &posterior = posteriors_reader.Value(*utt_itr); - if (static_cast(posterior.size()) != feats.NumRows()) { - KALDI_WARN << "Posteriors has wrong size " << (posterior.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - BaseFloat file_like = 0.0, file_t = 0.0; - Posterior pdf_posterior; - ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior); - for (size_t i = 0; i < posterior.size(); i++) { - for (size_t j = 0; j < pdf_posterior[i].size(); j++) { - int32 pdf_id = pdf_posterior[i][j].first; - BaseFloat prob = pdf_posterior[i][j].second; - file_like += fmllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, - prob); - file_t += prob; - } - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like/file_t) - << " over " << file_t << " frames."; - tot_like += file_like; - tot_t += file_t; - num_done++; - if (num_done % 10 == 0) - KALDI_VLOG(1) << "Avg like per frame so far is " - << (tot_like / tot_t); - } // end looping over all utterances of the current speaker - BaseFloat objf_impr, t; - fmllr_accs.Update(regtree, opts, &fmllr_xforms, &objf_impr, &t); - KALDI_LOG << "fMLLR objf improvement for speaker " << spk << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - fmllr_writer.Write(spk, fmllr_xforms); - } // end looping over speakers - } else { // per-utterance adaptation - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - string key = feature_reader.Key(); - if (!posteriors_reader.HasKey(key)) { - KALDI_WARN << "Did not find posteriors for utterance " - << key; - num_no_posterior++; - continue; - } - const Matrix &feats = feature_reader.Value(); - const Posterior &posterior = posteriors_reader.Value(key); - - if (static_cast(posterior.size()) != feats.NumRows()) { - KALDI_WARN << "Posteriors has wrong size " << (posterior.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - num_done++; - BaseFloat file_like = 0.0, file_t = 0.0; - fmllr_accs.SetZero(); - Posterior pdf_posterior; - ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior); - for (size_t i = 0; i < posterior.size(); i++) { - for (size_t j = 0; j < pdf_posterior[i].size(); j++) { - int32 pdf_id = pdf_posterior[i][j].first; - BaseFloat prob = pdf_posterior[i][j].second; - file_like += fmllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, - prob); - file_t += prob; - } - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like/file_t) - << " over " << file_t << " frames."; - tot_like += file_like; - tot_t += file_t; - if (num_done % 10 == 0) - KALDI_VLOG(1) << "Avg like per frame so far is " - << (tot_like / tot_t); - BaseFloat objf_impr, t; - fmllr_accs.Update(regtree, opts, &fmllr_xforms, &objf_impr, &t); - KALDI_LOG << "fMLLR objf improvement for utterance " << key << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - fmllr_writer.Write(feature_reader.Key(), fmllr_xforms); - } - } - KALDI_LOG << "Done " << num_done << " files, " << num_no_posterior - << " with no posteriors, " << num_other_error - << " with other errors."; - KALDI_LOG << "Overall objf improvement from MLLR is " << (tot_objf_impr/tot_t_objf) - << " per frame " << " over " << tot_t_objf << " frames."; - KALDI_LOG << "Overall acoustic likelihood was " << (tot_like/tot_t) - << " over " << tot_t << " frames."; - return 0; - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/gmmbin/gmm-est-regtree-mllr.cc b/src/gmmbin/gmm-est-regtree-mllr.cc deleted file mode 100644 index a4df5cc84c1..00000000000 --- a/src/gmmbin/gmm-est-regtree-mllr.cc +++ /dev/null @@ -1,215 +0,0 @@ -// gmmbin/gmm-est-regtree-mllr.cc - -// Copyright 2009-2011 Saarland University; Microsoft Corporation -// 2014 Guoguo Chen - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -using std::string; -#include -using std::vector; - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" -#include "transform/regtree-mllr-diag-gmm.h" -#include "hmm/posterior.h" - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - using namespace kaldi; - const char *usage = - "Compute MLLR transforms per-utterance (default) or per-speaker for " - "the supplied set of speakers (spk2utt option). Note: writes RegtreeMllrDiagGmm objects\n" - "Usage: gmm-est-regtree-mllr [options] " - " \n"; - - ParseOptions po(usage); - string spk2utt_rspecifier; - bool binary = true; - po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to " - "utterance-list map"); - po.Register("binary", &binary, "Write output in binary mode"); - // register other modules - RegtreeMllrOptions opts; - opts.Register(&po); - - po.Read(argc, argv); - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - string model_filename = po.GetArg(1), - feature_rspecifier = po.GetArg(2), - posteriors_rspecifier = po.GetArg(3), - regtree_filename = po.GetArg(4), - xforms_wspecifier = po.GetArg(5); - - RandomAccessPosteriorReader posteriors_reader(posteriors_rspecifier); - RegtreeMllrDiagGmmWriter mllr_writer(xforms_wspecifier); - - AmDiagGmm am_gmm; - TransitionModel trans_model; - { - bool binary; - Input ki(model_filename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - RegressionTree regtree; - { - bool binary; - Input in(regtree_filename, &binary); - regtree.Read(in.Stream(), binary, am_gmm); - } - - RegtreeMllrDiagGmm mllr_xforms; - RegtreeMllrDiagGmmAccs mllr_accs; - mllr_accs.Init(regtree.NumBaseclasses(), am_gmm.Dim()); - - double tot_like = 0.0, tot_t = 0; - - int32 num_done = 0, num_no_posterior = 0, num_other_error = 0; - double tot_objf_impr = 0.0, tot_t_objf = 0.0; - if (spk2utt_rspecifier != "") { // per-speaker adaptation - SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { - string spk = spk2utt_reader.Key(); - mllr_accs.SetZero(); - const vector &uttlist = spk2utt_reader.Value(); - for (vector::const_iterator utt_itr = uttlist.begin(), - itr_end = uttlist.end(); utt_itr != itr_end; ++utt_itr) { - if (!feature_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find features for utterance " << *utt_itr; - continue; - } - if (!posteriors_reader.HasKey(*utt_itr)) { - KALDI_WARN << "Did not find posteriors for utterance " - << *utt_itr; - num_no_posterior++; - continue; - } - const Matrix &feats = feature_reader.Value(*utt_itr); - const Posterior &posterior = posteriors_reader.Value(*utt_itr); - if (posterior.size() != feats.NumRows()) { - KALDI_WARN << "Posteriors has wrong size " << (posterior.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - BaseFloat file_like = 0.0, file_t = 0.0; - Posterior pdf_posterior; - ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior); - for (size_t i = 0; i < posterior.size(); i++) { - for (size_t j = 0; j < pdf_posterior[i].size(); j++) { - int32 pdf_id = pdf_posterior[i][j].first; - BaseFloat prob = pdf_posterior[i][j].second; - file_like += mllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, - prob); - file_t += prob; - } - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like/file_t) - << " over " << file_t << " frames."; - tot_like += file_like; - tot_t += file_t; - num_done++; - if (num_done % 10 == 0) - KALDI_VLOG(1) << "Avg like per frame so far is " - << (tot_like / tot_t); - } // end looping over all utterances of the current speaker - BaseFloat objf_impr, t; - mllr_accs.Update(regtree, opts, &mllr_xforms, &objf_impr, &t); - KALDI_LOG << "MLLR objf improvement for speaker " << spk << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - mllr_writer.Write(spk, mllr_xforms); - } // end looping over speakers - } else { // per-utterance adaptation - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !feature_reader.Done(); feature_reader.Next()) { - string key = feature_reader.Key(); - if (!posteriors_reader.HasKey(key)) { - KALDI_WARN << "Did not find aligned transcription for utterance " - << key; - num_no_posterior++; - continue; - } - const Matrix &feats = feature_reader.Value(); - const Posterior &posterior = posteriors_reader.Value(key); - - if (posterior.size() != feats.NumRows()) { - KALDI_WARN << "Posteriors has wrong size " << (posterior.size()) - << " vs. " << (feats.NumRows()); - num_other_error++; - continue; - } - - num_done++; - BaseFloat file_like = 0.0, file_t = 0.0; - mllr_accs.SetZero(); - Posterior pdf_posterior; - ConvertPosteriorToPdfs(trans_model, posterior, &pdf_posterior); - for (size_t i = 0; i < posterior.size(); i++) { - for (size_t j = 0; j < pdf_posterior[i].size(); j++) { - int32 pdf_id = pdf_posterior[i][j].first; - BaseFloat prob = pdf_posterior[i][j].second; - file_like += mllr_accs.AccumulateForGmm(regtree, am_gmm, - feats.Row(i), pdf_id, - prob); - file_t += prob; - } - } - KALDI_VLOG(2) << "Average like for this file is " << (file_like/file_t) - << " over " << file_t << " frames."; - tot_like += file_like; - tot_t += file_t; - if (num_done % 10 == 0) - KALDI_VLOG(1) << "Avg like per frame so far is " << (tot_like / tot_t); - BaseFloat objf_impr, t; - mllr_accs.Update(regtree, opts, &mllr_xforms, &objf_impr, &t); - KALDI_LOG << "MLLR objf improvement for utterance " << key << " is " - << (objf_impr/(t+1.0e-10)) << " per frame over " << t - << " frames."; - tot_objf_impr += objf_impr; - tot_t_objf += t; - mllr_writer.Write(feature_reader.Key(), mllr_xforms); - } - } - KALDI_LOG << "Done " << num_done << " files, " << num_no_posterior - << " with no posteriors, " << num_other_error - << " with other errors."; - KALDI_LOG << "Overall objf improvement from MLLR is " << (tot_objf_impr/tot_t_objf) - << " per frame " << " over " << tot_t_objf << " frames."; - KALDI_LOG << "Overall acoustic likelihood was " << (tot_like/tot_t) - << " over " << tot_t << " frames."; - return 0; - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - diff --git a/src/gmmbin/gmm-est-rescale.cc b/src/gmmbin/gmm-est-rescale.cc index a432b3d77f6..1e9c1e2aa84 100644 --- a/src/gmmbin/gmm-est-rescale.cc +++ b/src/gmmbin/gmm-est-rescale.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/indirect-diff-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { using namespace kaldi; @@ -62,7 +62,7 @@ int main(int argc, char *argv[]) { model_wxfilename = po.GetArg(4); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_rxfilename, &binary_read); diff --git a/src/gmmbin/gmm-est-weights-ebw.cc b/src/gmmbin/gmm-est-weights-ebw.cc index f19343a7ac4..9cf2c2d7d04 100644 --- a/src/gmmbin/gmm-est-weights-ebw.cc +++ b/src/gmmbin/gmm-est-weights-ebw.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/ebw-diag-gmm.h" int main(int argc, char *argv[]) { @@ -62,7 +62,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(4); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-est.cc b/src/gmmbin/gmm-est.cc index 18c836a1f50..545bbc054ef 100644 --- a/src/gmmbin/gmm-est.cc +++ b/src/gmmbin/gmm-est.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" int main(int argc, char *argv[]) { @@ -79,7 +79,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(3); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-fmpe-acc-stats.cc b/src/gmmbin/gmm-fmpe-acc-stats.cc index 4868b63b6ae..17cba7dc489 100644 --- a/src/gmmbin/gmm-fmpe-acc-stats.cc +++ b/src/gmmbin/gmm-fmpe-acc-stats.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmpe.h" @@ -60,7 +60,7 @@ int main(int argc, char *argv[]) { stats_wxfilename = po.GetArg(6); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_rxfilename, &binary); diff --git a/src/gmmbin/gmm-get-stats-deriv.cc b/src/gmmbin/gmm-get-stats-deriv.cc index 939fe260b34..a6fd9764719 100644 --- a/src/gmmbin/gmm-get-stats-deriv.cc +++ b/src/gmmbin/gmm-get-stats-deriv.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/indirect-diff-diag-gmm.h" int main(int argc, char *argv[]) { @@ -64,7 +64,7 @@ int main(int argc, char *argv[]) { deriv_wxfilename = po.GetArg(5); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_rxfilename, &binary_read); diff --git a/src/gmmbin/gmm-global-est-fmllr.cc b/src/gmmbin/gmm-global-est-fmllr.cc index b1d5b68e594..951b8addf2d 100644 --- a/src/gmmbin/gmm-global-est-fmllr.cc +++ b/src/gmmbin/gmm-global-est-fmllr.cc @@ -25,7 +25,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" namespace kaldi { diff --git a/src/gmmbin/gmm-global-est-lvtln-trans.cc b/src/gmmbin/gmm-global-est-lvtln-trans.cc index 10bb5bec5d5..95b56503f2c 100644 --- a/src/gmmbin/gmm-global-est-lvtln-trans.cc +++ b/src/gmmbin/gmm-global-est-lvtln-trans.cc @@ -26,7 +26,7 @@ using std::vector; #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/lvtln.h" #include "hmm/posterior.h" diff --git a/src/gmmbin/gmm-global-info.cc b/src/gmmbin/gmm-global-info.cc index 7c21005b449..00222ef81c3 100644 --- a/src/gmmbin/gmm-global-info.cc +++ b/src/gmmbin/gmm-global-info.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { diff --git a/src/gmmbin/gmm-gselect.cc b/src/gmmbin/gmm-gselect.cc index a873b962591..357998e996d 100644 --- a/src/gmmbin/gmm-gselect.cc +++ b/src/gmmbin/gmm-gselect.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { diff --git a/src/gmmbin/gmm-info.cc b/src/gmmbin/gmm-info.cc index 31f7aea0921..f1c436cd57e 100644 --- a/src/gmmbin/gmm-info.cc +++ b/src/gmmbin/gmm-info.cc @@ -20,7 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { try { @@ -46,7 +46,7 @@ int main(int argc, char *argv[]) { std::string model_in_filename = po.GetArg(1); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-init-biphone.cc b/src/gmmbin/gmm-init-biphone.cc index 0775a5c7b23..10fc9ad4048 100644 --- a/src/gmmbin/gmm-init-biphone.cc +++ b/src/gmmbin/gmm-init-biphone.cc @@ -23,8 +23,8 @@ #include "gmm/am-diag-gmm.h" #include "tree/event-map.h" #include "tree/context-dep.h" -#include "hmm/hmm-topology.h" -#include "hmm/transition-model.h" +#include "hmm/topology.h" +#include "hmm/transitions.h" namespace kaldi { // This function reads a file like: @@ -314,7 +314,7 @@ int main(int argc, char *argv[]) { Vector glob_mean(dim); glob_mean.Set(1.0); - HmmTopology topo; + Topology topo; bool binary_in; Input ki(topo_filename, &binary_in); topo.Read(ki.Stream(), binary_in); @@ -375,7 +375,7 @@ int main(int argc, char *argv[]) { am_gmm.AddPdf(gmm); // Now the transition model: - TransitionModel trans_model(*ctx_dep, topo); + Transitions trans_model(*ctx_dep, topo); { Output ko(model_filename, binary); diff --git a/src/gmmbin/gmm-init-model-flat.cc b/src/gmmbin/gmm-init-model-flat.cc index fecd91f49fd..d41b99c35e6 100644 --- a/src/gmmbin/gmm-init-model-flat.cc +++ b/src/gmmbin/gmm-init-model-flat.cc @@ -21,7 +21,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "tree/build-tree-utils.h" #include "tree/context-dep.h" @@ -104,7 +104,7 @@ int main(int argc, char *argv[]) { ContextDependency ctx_dep; ReadKaldiObject(tree_filename, &ctx_dep); - HmmTopology topo; + Topology topo; ReadKaldiObject(topo_filename, &topo); Vector global_inverse_var, global_mean; @@ -138,7 +138,7 @@ int main(int argc, char *argv[]) { for (int i = 0; i < num_pdfs; i++) am_gmm.AddPdf(gmm); - TransitionModel trans_model(ctx_dep, topo); + Transitions trans_model(ctx_dep, topo); { Output ko(model_out_filename, binary); diff --git a/src/gmmbin/gmm-init-model.cc b/src/gmmbin/gmm-init-model.cc index e2d943b19eb..a081f326b1c 100644 --- a/src/gmmbin/gmm-init-model.cc +++ b/src/gmmbin/gmm-init-model.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" #include "tree/build-tree-utils.h" #include "tree/context-dep.h" @@ -35,7 +35,7 @@ namespace kaldi { void InitAmGmm(const BuildTreeStatsType &stats, const EventMap &to_pdf_map, AmDiagGmm *am_gmm, - const TransitionModel &trans_model, + const Transitions &trans_model, BaseFloat var_floor) { // Get stats split by tree-leaf ( == pdf): std::vector split_stats; @@ -126,7 +126,7 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats, ContextDependency old_tree; { // Read old_gm_gmm bool binary_in; - TransitionModel old_trans_model; + Transitions old_trans_model; Input ki(old_model_rxfilename, &binary_in); old_trans_model.Read(ki.Stream(), binary_in); old_am_gmm.Read(ki.Stream(), binary_in); @@ -270,12 +270,12 @@ int main(int argc, char *argv[]) { } KALDI_LOG << "Number of separate statistics is " << stats.size(); - HmmTopology topo; + Topology topo; ReadKaldiObject(topo_filename, &topo); const EventMap &to_pdf = ctx_dep.ToPdfMap(); // not owned here. - TransitionModel trans_model(ctx_dep, topo); + Transitions trans_model(ctx_dep, topo); // Now, the summed_stats will be used to initialize the GMM. AmDiagGmm am_gmm; diff --git a/src/gmmbin/gmm-init-mono.cc b/src/gmmbin/gmm-init-mono.cc index 3c370c36515..a91948e446b 100644 --- a/src/gmmbin/gmm-init-mono.cc +++ b/src/gmmbin/gmm-init-mono.cc @@ -21,8 +21,8 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/hmm-topology.h" -#include "hmm/transition-model.h" +#include "hmm/topology.h" +#include "hmm/transitions.h" #include "tree/context-dep.h" namespace kaldi { @@ -116,7 +116,7 @@ int main(int argc, char *argv[]) { glob_mean.CopyFromVec(mean_stats); } - HmmTopology topo; + Topology topo; bool binary_in; Input ki(topo_filename, &binary_in); topo.Read(ki.Stream(), binary_in); @@ -164,7 +164,7 @@ int main(int argc, char *argv[]) { } // Now the transition model: - TransitionModel trans_model(*ctx_dep, topo); + Transitions trans_model(*ctx_dep, topo); { Output ko(model_filename, binary); diff --git a/src/gmmbin/gmm-ismooth-stats.cc b/src/gmmbin/gmm-ismooth-stats.cc index b29e1efc1c3..a524d27b47b 100644 --- a/src/gmmbin/gmm-ismooth-stats.cc +++ b/src/gmmbin/gmm-ismooth-stats.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/ebw-diag-gmm.h" int main(int argc, char *argv[]) { @@ -77,7 +77,7 @@ int main(int argc, char *argv[]) { stats.Write(ko.Stream(), binary_write); } else if (smooth_from_model) { // Smoothing from model... AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; Vector dst_transition_accs; AccumAmDiagGmm dst_stats; { // read src model diff --git a/src/gmmbin/gmm-latgen-biglm-faster.cc b/src/gmmbin/gmm-latgen-biglm-faster.cc index d4e0645b16c..0d881b41ebb 100644 --- a/src/gmmbin/gmm-latgen-biglm-faster.cc +++ b/src/gmmbin/gmm-latgen-biglm-faster.cc @@ -24,7 +24,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/lattice-biglm-faster-decoder.h" #include "gmm/decodable-am-diag-gmm.h" @@ -35,7 +35,7 @@ namespace kaldi { // Takes care of output. Returns true on success. bool DecodeUtterance(LatticeBiglmFasterDecoder &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. - const TransitionModel &trans_model, + const Transitions &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, @@ -186,7 +186,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(7), alignment_wspecifier = po.GetOptArg(8); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-latgen-faster-parallel.cc b/src/gmmbin/gmm-latgen-faster-parallel.cc index 41f414bcb9c..8cc0aa5dad4 100644 --- a/src/gmmbin/gmm-latgen-faster-parallel.cc +++ b/src/gmmbin/gmm-latgen-faster-parallel.cc @@ -24,7 +24,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "gmm/decodable-am-diag-gmm.h" @@ -82,7 +82,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(5), alignment_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-latgen-faster-regtree-fmllr.cc b/src/gmmbin/gmm-latgen-faster-regtree-fmllr.cc deleted file mode 100644 index 36031b13c1e..00000000000 --- a/src/gmmbin/gmm-latgen-faster-regtree-fmllr.cc +++ /dev/null @@ -1,218 +0,0 @@ -// gmmbin/gmm-latgen-faster-regtree-fmllr.cc - -// Copyright 2009-2012 Microsoft Corporation -// 2012-2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Alpha Cephei Inc. - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "gmm/am-diag-gmm.h" -#include "tree/context-dep.h" -#include "hmm/transition-model.h" -#include "fstext/fstext-lib.h" -#include "decoder/decoder-wrappers.h" -#include "gmm/decodable-am-diag-gmm.h" -#include "base/timer.h" -#include "transform/regression-tree.h" -#include "transform/regtree-fmllr-diag-gmm.h" -#include "transform/decodable-am-diag-gmm-regtree.h" -#include "feat/feature-functions.h" // feature reversal - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - typedef kaldi::int32 int32; - using fst::SymbolTable; - using fst::Fst; - using fst::StdArc; - - const char *usage = - "Generate lattices using GMM-based model and RegTree-FMLLR adaptation.\n" - "Usage: gmm-latgen-faster-regtree-fmllr [options] model-in regtree-in (fst-in|fsts-rspecifier) features-rspecifier transform-rspecifier" - " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; - ParseOptions po(usage); - Timer timer; - bool allow_partial = false; - BaseFloat acoustic_scale = 0.1; - LatticeFasterDecoderConfig config; - - std::string word_syms_filename, utt2spk_rspecifier; - config.Register(&po); - po.Register("utt2spk", &utt2spk_rspecifier, "rspecifier for utterance to " - "speaker map used to load the transform"); - po.Register("acoustic-scale", &acoustic_scale, - "Scaling factor for acoustic likelihoods"); - po.Register("word-symbol-table", &word_syms_filename, - "Symbol table for words [for debug output]"); - po.Register("allow-partial", &allow_partial, - "If true, produce output even if end state was not reached."); - - po.Read(argc, argv); - - if (po.NumArgs() < 4 || po.NumArgs() > 6) { - po.PrintUsage(); - exit(1); - } - - std::string model_in_filename = po.GetArg(1), - regtree_in_str = po.GetArg(2), - fst_in_str = po.GetArg(3), - feature_rspecifier = po.GetArg(4), - xforms_rspecifier = po.GetArg(5), - lattice_wspecifier = po.GetArg(6), - words_wspecifier = po.GetOptArg(7), - alignment_wspecifier = po.GetOptArg(8); - - TransitionModel trans_model; - AmDiagGmm am_gmm; - { - bool binary; - Input ki(model_in_filename, &binary); - trans_model.Read(ki.Stream(), binary); - am_gmm.Read(ki.Stream(), binary); - } - - RegressionTree regtree; - { - bool binary_read; - Input in(regtree_in_str, &binary_read); - regtree.Read(in.Stream(), binary_read, am_gmm); - } - - RandomAccessRegtreeFmllrDiagGmmReaderMapped fmllr_reader(xforms_rspecifier, - utt2spk_rspecifier); - - bool determinize = config.determinize_lattice; - CompactLatticeWriter compact_lattice_writer; - LatticeWriter lattice_writer; - if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) - : lattice_writer.Open(lattice_wspecifier))) - KALDI_ERR << "Could not open table for writing lattices: " - << lattice_wspecifier; - - Int32VectorWriter words_writer(words_wspecifier); - - Int32VectorWriter alignment_writer(alignment_wspecifier); - - fst::SymbolTable *word_syms = NULL; - if (word_syms_filename != "") - if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) - KALDI_ERR << "Could not read symbol table from file " - << word_syms_filename; - - double tot_like = 0.0; - kaldi::int64 frame_count = 0; - int num_done = 0, num_err = 0; - - if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { - SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); - // Input FST is just one FST, not a table of FSTs. - Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); - - { - LatticeFasterDecoder decoder(*decode_fst, config); - - for (; !feature_reader.Done(); feature_reader.Next()) { - std::string utt = feature_reader.Key(); - Matrix features (feature_reader.Value()); - feature_reader.FreeCurrent(); - if (features.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << utt; - num_err++; - continue; - } - if (!fmllr_reader.HasKey(utt)) { - KALDI_WARN << "Not decoding utterance " << utt - << " because no transform available."; - num_err++; - continue; - } - - RegtreeFmllrDiagGmm fmllr(fmllr_reader.Value(utt)); - - kaldi::DecodableAmDiagGmmRegtreeFmllr gmm_decodable(am_gmm, trans_model, - features, fmllr, - regtree, - acoustic_scale); - double like; - if (DecodeUtteranceLatticeFaster( - decoder, gmm_decodable, trans_model, word_syms, utt, acoustic_scale, - determinize, allow_partial, &alignment_writer, &words_writer, - &compact_lattice_writer, &lattice_writer, &like)) { - tot_like += like; - frame_count += features.NumRows(); - num_done++; - } else num_err++; - } - } - delete decode_fst; // delete this only after decoder goes out of scope. - } else { // We have different FSTs for different utterances. - SequentialTableReader fst_reader(fst_in_str); - RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); - for (; !fst_reader.Done(); fst_reader.Next()) { - std::string utt = fst_reader.Key(); - const Matrix &features = feature_reader.Value(utt); - if (features.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << utt; - num_err++; - continue; - } - if (!fmllr_reader.HasKey(utt)) { - KALDI_WARN << "Not decoding utterance " << utt - << " because no transform available."; - num_err++; - continue; - } - - RegtreeFmllrDiagGmm fmllr(fmllr_reader.Value(utt)); - kaldi::DecodableAmDiagGmmRegtreeFmllr gmm_decodable(am_gmm, trans_model, - features, fmllr, - regtree, - acoustic_scale); - - LatticeFasterDecoder decoder(fst_reader.Value(), config); - double like; - if (DecodeUtteranceLatticeFaster( - decoder, gmm_decodable, trans_model, word_syms, utt, acoustic_scale, - determinize, allow_partial, &alignment_writer, &words_writer, - &compact_lattice_writer, &lattice_writer, &like)) { - tot_like += like; - frame_count += features.NumRows(); - num_done++; - } else num_err++; - } - } - - double elapsed = timer.Elapsed(); - KALDI_LOG << "Time taken "<< elapsed - << "s: real-time factor assuming 100 frames/sec is " - << (elapsed*100.0/frame_count); - KALDI_LOG << "Done " << num_done << " utterances, failed for " - << num_err; - KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " - << frame_count << " frames."; - - delete word_syms; - if (num_done != 0) return 0; - else return 1; - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} diff --git a/src/gmmbin/gmm-latgen-faster.cc b/src/gmmbin/gmm-latgen-faster.cc index 6bc475d1b79..75a9d95aacd 100644 --- a/src/gmmbin/gmm-latgen-faster.cc +++ b/src/gmmbin/gmm-latgen-faster.cc @@ -24,7 +24,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "gmm/decodable-am-diag-gmm.h" @@ -72,7 +72,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(5), alignment_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-latgen-map.cc b/src/gmmbin/gmm-latgen-map.cc index 541b031fe6c..b7462b93e0f 100644 --- a/src/gmmbin/gmm-latgen-map.cc +++ b/src/gmmbin/gmm-latgen-map.cc @@ -26,7 +26,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "gmm/mle-am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/fmllr-diag-gmm.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" @@ -84,7 +84,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(6), alignment_wspecifier = po.GetOptArg(7); - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input is(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-latgen-simple.cc b/src/gmmbin/gmm-latgen-simple.cc index 812bee7fef4..d7ffe86c4ae 100644 --- a/src/gmmbin/gmm-latgen-simple.cc +++ b/src/gmmbin/gmm-latgen-simple.cc @@ -24,7 +24,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "gmm/decodable-am-diag-gmm.h" @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { words_wspecifier = po.GetOptArg(5), alignment_wspecifier = po.GetOptArg(6); - TransitionModel trans_model; + Transitions trans_model; AmDiagGmm am_gmm; { bool binary; diff --git a/src/gmmbin/gmm-make-regtree.cc b/src/gmmbin/gmm-make-regtree.cc deleted file mode 100644 index 8c79d013e0d..00000000000 --- a/src/gmmbin/gmm-make-regtree.cc +++ /dev/null @@ -1,107 +0,0 @@ -// gmmbin/gmm-make-regtree.cc - -// Copyright 2009-2011 Saarland University; Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "util/kaldi-io.h" -#include "util/text-utils.h" -#include "gmm/mle-am-diag-gmm.h" -#include "tree/context-dep.h" -#include "hmm/transition-model.h" -#include "transform/regression-tree.h" - - -int main(int argc, char *argv[]) { - try { - typedef kaldi::int32 int32; - typedef kaldi::BaseFloat BaseFloat; - - const char *usage = - "Build regression class tree.\n" - "Usage: gmm-make-regtree [options] \n" - "E.g.: gmm-make-regtree --silphones=1:2:3 --state-occs=1.occs 1.mdl 1.regtree\n" - " [Note: state-occs come from --write-occs option of gmm-est]\n"; - - std::string occs_in_filename; - std::string sil_phones_str; - bool binary_write = true; - int32 max_leaves = 1; - kaldi::ParseOptions po(usage); - po.Register("state-occs", &occs_in_filename, "File containing state occupancies (use --write-occs in gmm-est)"); - po.Register("sil-phones", &sil_phones_str, "Colon-separated list of integer ids of silence phones, e.g. 1:2:3; if used, create top-level speech/sil split (only one reg-class for silence)."); - po.Register("binary", &binary_write, "Write output in binary mode"); - po.Register("max-leaves", &max_leaves, "Maximum number of leaves in regression tree."); - po.Read(argc, argv); - - if (po.NumArgs() != 2) { - po.PrintUsage(); - exit(1); - } - - std::string model_in_filename = po.GetArg(1), - tree_out_filename = po.GetArg(2); - - kaldi::AmDiagGmm am_gmm; - kaldi::TransitionModel trans_model; - { - bool binary_read; - kaldi::Input ki(model_in_filename, &binary_read); - trans_model.Read(ki.Stream(), binary_read); - am_gmm.Read(ki.Stream(), binary_read); - } - - kaldi::Vector state_occs; - if (occs_in_filename != "") { - bool binary_read; - kaldi::Input ki(occs_in_filename, &binary_read); - state_occs.Read(ki.Stream(), binary_read); - } else { - KALDI_LOG << "--state-occs option not provided so using constant occupancies."; - state_occs.Resize(am_gmm.NumPdfs()); - state_occs.Set(1.0); - } - - std::vector sil_pdfs; - if (sil_phones_str != "") { - std::vector sil_phones; - if (!kaldi::SplitStringToIntegers(sil_phones_str, ":", false, &sil_phones)) - KALDI_ERR << "invalid sil-phones option " << sil_phones_str; - std::sort(sil_phones.begin(), sil_phones.end()); - bool ans = GetPdfsForPhones(trans_model, sil_phones, &sil_pdfs); - if (!ans) - KALDI_WARN << "Pdfs associated with silence phones are not only " - "associated with silence phones: your speech-silence split " - "may not be meaningful."; - } - - kaldi::RegressionTree regtree; - regtree.BuildTree(state_occs, sil_pdfs, am_gmm, max_leaves); - // Write out the regression tree - { - kaldi::Output ko(tree_out_filename, binary_write); - regtree.Write(ko.Stream(), binary_write); - } - - KALDI_LOG << "Written regression tree to " << tree_out_filename; - } catch(const std::exception &e) { - std::cerr << e.what() << '\n'; - return -1; - } -} - - diff --git a/src/gmmbin/gmm-mixup.cc b/src/gmmbin/gmm-mixup.cc index a76b3805d89..51919560b10 100644 --- a/src/gmmbin/gmm-mixup.cc +++ b/src/gmmbin/gmm-mixup.cc @@ -21,7 +21,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "gmm/mle-am-diag-gmm.h" int main(int argc, char *argv[]) { @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) { model_out_filename = po.GetArg(3); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_filename, &binary_read); diff --git a/src/gmmbin/gmm-post-to-gpost.cc b/src/gmmbin/gmm-post-to-gpost.cc index 59da0f9a1ac..1260c9b922a 100644 --- a/src/gmmbin/gmm-post-to-gpost.cc +++ b/src/gmmbin/gmm-post-to-gpost.cc @@ -22,7 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/posterior.h" int main(int argc, char *argv[]) { @@ -56,7 +56,7 @@ int main(int argc, char *argv[]) { typedef kaldi::int32 int32; AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-rescore-lattice.cc b/src/gmmbin/gmm-rescore-lattice.cc index 54156442e64..36088cac304 100644 --- a/src/gmmbin/gmm-rescore-lattice.cc +++ b/src/gmmbin/gmm-rescore-lattice.cc @@ -22,7 +22,7 @@ #include "util/common-utils.h" #include "util/stl-utils.h" #include "gmm/am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" #include "lat/lattice-functions.h" @@ -61,7 +61,7 @@ int main(int argc, char *argv[]) { lats_wspecifier = po.GetArg(4); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary; Input ki(model_filename, &binary); diff --git a/src/gmmbin/gmm-sum-accs.cc b/src/gmmbin/gmm-sum-accs.cc index c9886e867f5..49146925bab 100644 --- a/src/gmmbin/gmm-sum-accs.cc +++ b/src/gmmbin/gmm-sum-accs.cc @@ -19,7 +19,7 @@ #include "util/common-utils.h" #include "gmm/mle-am-diag-gmm.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" int main(int argc, char *argv[]) { diff --git a/src/gmmbin/gmm-transform-means-global.cc b/src/gmmbin/gmm-transform-means-global.cc index 6b1a6be8330..857b602c19b 100644 --- a/src/gmmbin/gmm-transform-means-global.cc +++ b/src/gmmbin/gmm-transform-means-global.cc @@ -22,7 +22,7 @@ #include "util/common-utils.h" #include "gmm/diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/mllt.h" int main(int argc, char *argv[]) { diff --git a/src/gmmbin/gmm-transform-means.cc b/src/gmmbin/gmm-transform-means.cc index 5c08ec32b10..3a27d73a947 100644 --- a/src/gmmbin/gmm-transform-means.cc +++ b/src/gmmbin/gmm-transform-means.cc @@ -22,7 +22,7 @@ #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "tree/context-dep.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "transform/mllt.h" int main(int argc, char *argv[]) { @@ -55,7 +55,7 @@ int main(int argc, char *argv[]) { ReadKaldiObject(mat_rxfilename, &mat); AmDiagGmm am_gmm; - TransitionModel trans_model; + Transitions trans_model; { bool binary_read; Input ki(model_in_rxfilename, &binary_read); diff --git a/src/gst-plugin/gst-online-gmm-decode-faster.cc b/src/gst-plugin/gst-online-gmm-decode-faster.cc index 958bce41d80..094d398960a 100644 --- a/src/gst-plugin/gst-online-gmm-decode-faster.cc +++ b/src/gst-plugin/gst-online-gmm-decode-faster.cc @@ -389,7 +389,7 @@ gst_online_gmm_decode_faster_allocate(GstOnlineGmmDecodeFaster * filter) { Input ki(filter->lda_mat_rspecifier_, &binary_in); filter->lda_transform_->Read(ki.Stream(), binary_in); } - filter->trans_model_ = new TransitionModel(); + filter->trans_model_ = new Transitions(); filter->am_gmm_ = new AmDiagGmm(); { bool binary; diff --git a/src/gst-plugin/gst-online-gmm-decode-faster.h b/src/gst-plugin/gst-online-gmm-decode-faster.h index b950d1e0a12..529c510115a 100644 --- a/src/gst-plugin/gst-online-gmm-decode-faster.h +++ b/src/gst-plugin/gst-online-gmm-decode-faster.h @@ -65,7 +65,7 @@ struct _GstOnlineGmmDecodeFaster { OnlineFasterDecoder *decoder_; Matrix *lda_transform_; - TransitionModel *trans_model_; + Transitions *trans_model_; AmDiagGmm *am_gmm_; fst::Fst *decode_fst_; fst::SymbolTable *word_syms_; diff --git a/src/hmm/Makefile b/src/hmm/Makefile index 0ad5da74c28..0315a51b214 100644 --- a/src/hmm/Makefile +++ b/src/hmm/Makefile @@ -3,14 +3,13 @@ all: include ../kaldi.mk -TESTFILES = hmm-topology-test hmm-utils-test transition-model-test posterior-test +TESTFILES = topology-test hmm-utils-test transitions-test posterior-test -OBJFILES = hmm-topology.o transition-model.o hmm-utils.o tree-accu.o \ +OBJFILES = topology.o transitions.o hmm-utils.o tree-accu.o \ posterior.o hmm-test-utils.o LIBNAME = kaldi-hmm ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk - diff --git a/src/hmm/hmm-test-utils.cc b/src/hmm/hmm-test-utils.cc index ceca116c828..5f00474219b 100644 --- a/src/hmm/hmm-test-utils.cc +++ b/src/hmm/hmm-test-utils.cc @@ -23,7 +23,7 @@ namespace kaldi { -TransitionModel *GenRandTransitionModel(ContextDependency **ctx_dep_out) { +Transitions *GenRandTransitionModel(ContextDependency **ctx_dep_out) { std::vector phones; phones.push_back(1); for (int32 i = 2; i < 20; i++) @@ -38,16 +38,16 @@ TransitionModel *GenRandTransitionModel(ContextDependency **ctx_dep_out) { GenRandContextDependencyLarge(phones, N, P, true, &num_pdf_classes); - HmmTopology topo = GenRandTopology(phones, num_pdf_classes); + Topology topo = GenRandTopology(phones, num_pdf_classes); - TransitionModel *trans_model = new TransitionModel(*ctx_dep, topo); + Transitions *trans_model = new TransitionModel(*ctx_dep, topo); if (ctx_dep_out == NULL) delete ctx_dep; else *ctx_dep_out = ctx_dep; return trans_model; } -HmmTopology GetDefaultTopology(const std::vector &phones_in) { +Topology GetDefaultTopology(const std::vector &phones_in) { std::vector phones(phones_in); std::sort(phones.begin(), phones.end()); KALDI_ASSERT(IsSortedAndUniq(phones) && !phones.empty()); @@ -76,7 +76,7 @@ HmmTopology GetDefaultTopology(const std::vector &phones_in) { " \n" " \n"; - HmmTopology topo; + Topology topo; std::istringstream iss(topo_string.str()); topo.Read(iss, false); return topo; @@ -84,7 +84,7 @@ HmmTopology GetDefaultTopology(const std::vector &phones_in) { } -HmmTopology GenRandTopology(const std::vector &phones_in, +Topology GenRandTopology(const std::vector &phones_in, const std::vector &num_pdf_classes) { std::vector phones(phones_in); std::sort(phones.begin(), phones.end()); @@ -165,13 +165,13 @@ HmmTopology GenRandTopology(const std::vector &phones_in, } topo_string << "\n"; - HmmTopology topo; + Topology topo; std::istringstream iss(topo_string.str()); topo.Read(iss, false); return topo; } -HmmTopology GenRandTopology() { +Topology GenRandTopology() { std::vector phones; phones.push_back(1); for (int32 i = 2; i < 20; i++) @@ -187,12 +187,12 @@ HmmTopology GenRandTopology() { } } -void GeneratePathThroughHmm(const HmmTopology &topology, +void GeneratePathThroughHmm(const Topology &topology, bool reorder, int32 phone, std::vector > *path) { path->clear(); - const HmmTopology::TopologyEntry &this_entry = + const Topology::TopologyEntry &this_entry = topology.TopologyForPhone(phone); int32 cur_state = 0; // start-state is always state zero. int32 num_states = this_entry.size(), final_state = num_states - 1; @@ -200,7 +200,7 @@ void GeneratePathThroughHmm(const HmmTopology &topology, // that's different from the start state. std::vector > pending_self_loops; while (cur_state != final_state) { - const HmmTopology::HmmState &cur_hmm_state = this_entry[cur_state]; + const Topology::HmmState &cur_hmm_state = this_entry[cur_state]; int32 num_transitions = cur_hmm_state.transitions.size(), transition_index = RandInt(0, num_transitions - 1); if (cur_hmm_state.forward_pdf_class != -1) { @@ -230,7 +230,7 @@ void GeneratePathThroughHmm(const HmmTopology &topology, void GenerateRandomAlignment(const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, bool reorder, const std::vector &phone_sequence, std::vector *alignment) { @@ -253,7 +253,7 @@ void GenerateRandomAlignment(const ContextDependencyInterface &ctx_dep, int32 phone = phone_sequence[i]; GeneratePathThroughHmm(trans_model.GetTopo(), reorder, phone, &path); for (size_t k = 0; k < path.size(); k++) { - const HmmTopology::TopologyEntry &entry = + const Topology::TopologyEntry &entry = trans_model.GetTopo().TopologyForPhone(phone); int32 hmm_state = path[k].first, transition_index = path[k].second, diff --git a/src/hmm/hmm-test-utils.h b/src/hmm/hmm-test-utils.h index 4faaa92fa66..f9f516e7d4c 100644 --- a/src/hmm/hmm-test-utils.h +++ b/src/hmm/hmm-test-utils.h @@ -21,38 +21,38 @@ #ifndef KALDI_HMM_HMM_TEST_UTILS_H_ #define KALDI_HMM_HMM_TEST_UTILS_H_ -#include "hmm/hmm-topology.h" -#include "hmm/transition-model.h" +#include "hmm/topology.h" +#include "hmm/transitions.h" #include "lat/kaldi-lattice.h" #include "tree/context-dep.h" namespace kaldi { -// Here we put a convenience function for generating a TransitionModel object -- +// Here we put a convenience function for generating a Transitions object -- // useful in test code. We may put other testing-related things here in time. -// This function returns a randomly generated TransitionModel object. +// This function returns a randomly generated Transitions object. // If 'ctx_dep' is not NULL, it outputs to *ctx_dep a pointer to the // tree that was used to generate the transition model. -TransitionModel *GenRandTransitionModel(ContextDependency **ctx_dep); +Transitions *GenRandTransitionModel(ContextDependency **ctx_dep); -/// This function returns a HmmTopology object giving a normal 3-state topology, +/// This function returns a Topology object giving a normal 3-state topology, /// covering all phones in the list "phones". This is mainly of use in testing /// code. -HmmTopology GetDefaultTopology(const std::vector &phones); +Topology GetDefaultTopology(const std::vector &phones); -/// This method of generating an arbitrary HmmTopology object allows you to +/// This method of generating an arbitrary Topology object allows you to /// specify the number of pdf-classes for each phone separately. /// 'num_pdf_classes' is indexed by the phone-index (so the length will be /// longer than the length of the 'phones' vector, which for example lacks the /// zero index and may have gaps). -HmmTopology GenRandTopology(const std::vector &phones, +Topology GenRandTopology(const std::vector &phones, const std::vector &num_pdf_classes); /// This version of GenRandTopology() generates the phone list and number of pdf /// classes randomly. -HmmTopology GenRandTopology(); +Topology GenRandTopology(); /// This function generates a random path through the HMM for the given /// phone. The 'path' output is a list of pairs (HMM-state, transition-index) @@ -60,7 +60,7 @@ HmmTopology GenRandTopology(); /// used in other test code. /// the 'reorder' option is as described in the documentation; if true, the /// self-loops from a state are reordered to come after the forward-transition. -void GeneratePathThroughHmm(const HmmTopology &topology, +void GeneratePathThroughHmm(const Topology &topology, bool reorder, int32 phone, std::vector > *path); @@ -69,7 +69,7 @@ void GeneratePathThroughHmm(const HmmTopology &topology, /// For use in test code, this function generates an alignment (a sequence of /// transition-ids) corresponding to a given phone sequence. void GenerateRandomAlignment(const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, bool reorder, const std::vector &phone_sequence, std::vector *alignment); diff --git a/src/hmm/hmm-topology-test.cc b/src/hmm/hmm-topology-test.cc index 14081d2355d..9a3a65b61a4 100644 --- a/src/hmm/hmm-topology-test.cc +++ b/src/hmm/hmm-topology-test.cc @@ -18,13 +18,13 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "hmm/hmm-test-utils.h" namespace kaldi { -void TestHmmTopology() { +void TestTopology() { bool binary = (Rand()%2 == 0); std::string input_str = "\n" @@ -69,7 +69,7 @@ void TestHmmTopology() { "\n" "\n"; - HmmTopology topo; + Topology topo; if (RandInt(0, 1) == 0) { topo = GenRandTopology(); @@ -83,7 +83,7 @@ void TestHmmTopology() { std::ostringstream oss; topo.Write(oss, binary); - HmmTopology topo2; + Topology topo2; // std::cout << oss.str() << '\n' << std::flush; std::istringstream iss2(oss.str()); topo2.Read(iss2, binary); @@ -96,7 +96,7 @@ void TestHmmTopology() { } { // test chain topology - HmmTopology chain_topo; + Topology chain_topo; std::istringstream chain_iss(chain_input_str); chain_topo.Read(chain_iss, false); KALDI_ASSERT(chain_topo.MinLength(3) == 1); @@ -116,7 +116,7 @@ void TestHmmTopology() { int main() { // repeat the test ten times for (int i = 0; i < 10; i++) { - kaldi::TestHmmTopology(); + kaldi::TestTopology(); } std::cout << "Test OK.\n"; } diff --git a/src/hmm/hmm-topology.h b/src/hmm/hmm-topology.h deleted file mode 100644 index 750d35bcfe4..00000000000 --- a/src/hmm/hmm-topology.h +++ /dev/null @@ -1,194 +0,0 @@ -// hmm/hmm-topology.h - -// Copyright 2009-2011 Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_HMM_HMM_TOPOLOGY_H_ -#define KALDI_HMM_HMM_TOPOLOGY_H_ - -#include "base/kaldi-common.h" -#include "util/const-integer-set.h" - - -namespace kaldi { - - -/// \addtogroup hmm_group -/// @{ - -/* - // The following would be the text form for the "normal" HMM topology. - // Note that the first state is the start state, and the final state, - // which must have no output transitions and must be nonemitting, has - // an exit probability of one (no other state can have nonzero exit - // probability; you can treat the transition probability to the final - // state as an exit probability). - // Note also that it's valid to omit the "" entry of the , which - // will mean we won't have a pdf on that state [non-emitting state]. This is equivalent - // to setting the to -1. We do this normally just for the final state. - // The Topology object can have multiple blocks. - // This is useful if there are multiple types of topology in the system. - - - - 1 2 3 4 5 6 7 8 - 0 0 - 0 0.5 - 1 0.5 - - 1 1 - 1 0.5 - 2 0.5 - - 2 2 - 2 0.5 - 3 0.5 - 0.5 - - 3 - - - -*/ - -// kNoPdf is used where pdf_class or pdf would be used, to indicate, -// none is there. Mainly useful in skippable models, but also used -// for end states. -// A caveat with nonemitting states is that their out-transitions -// are not trainable, due to technical issues with the way -// we decided to accumulate the stats. Any transitions arising from (*) -// HMM states with "kNoPdf" as the label are second-class transitions, -// They do not have "transition-states" or "transition-ids" associated -// with them. They are used to create the FST version of the -// HMMs, where they lead to epsilon arcs. -// (*) "arising from" is a bit of a technical term here, due to the way -// (if reorder == true), we put the transition-id associated with the -// outward arcs of the state, on the input transition to the state. - -/// A constant used in the HmmTopology class as the \ref pdf_class "pdf-class" -/// kNoPdf, which is used when a HMM-state is nonemitting (has no associated -/// PDF). - -static const int32 kNoPdf = -1; - -/// A class for storing topology information for phones. See \ref hmm for context. -/// This object is sometimes accessed in a file by itself, but more often -/// as a class member of the Transition class (this is for convenience to reduce -/// the number of files programs have to access). - -class HmmTopology { - public: - /// A structure defined inside HmmTopology to represent a HMM state. - struct HmmState { - /// The \ref pdf_class forward-pdf-class, typically 0, 1 or 2 (the same as the HMM-state index), - /// but may be different to enable us to hardwire sharing of state, and may be - /// equal to \ref kNoPdf == -1 in order to specify nonemitting states (unusual). - int32 forward_pdf_class; - - /// The \ref pdf_class self-loop pdf-class, similar to \ref pdf_class forward-pdf-class. - /// They will either both be \ref kNoPdf, or neither be \ref kNoPdf. - int32 self_loop_pdf_class; - - /// A list of transitions, indexed by what we call a 'transition-index'. - /// The first member of each pair is the index of the next HmmState, and the - /// second is the default transition probability (before training). - std::vector > transitions; - - explicit HmmState(int32 pdf_class) { - this->forward_pdf_class = pdf_class; - this->self_loop_pdf_class = pdf_class; - } - explicit HmmState(int32 forward_pdf_class, int32 self_loop_pdf_class) { - KALDI_ASSERT((forward_pdf_class != kNoPdf && self_loop_pdf_class != kNoPdf) || - (forward_pdf_class == kNoPdf && self_loop_pdf_class == kNoPdf)); - this->forward_pdf_class = forward_pdf_class; - this->self_loop_pdf_class = self_loop_pdf_class; - } - - bool operator == (const HmmState &other) const { - return (forward_pdf_class == other.forward_pdf_class && - self_loop_pdf_class == other.self_loop_pdf_class && - transitions == other.transitions); - } - - HmmState(): forward_pdf_class(-1), self_loop_pdf_class(-1) { } - }; - - /// TopologyEntry is a typedef that represents the topology of - /// a single (prototype) state. - typedef std::vector TopologyEntry; - - void Read(std::istream &is, bool binary); - void Write(std::ostream &os, bool binary) const; - - // Checks that the object is valid, and throw exception otherwise. - void Check(); - - /// Returns true if this HmmTopology is really 'hmm-like', i.e. the pdf-class on - /// the self-loops and forward transitions of all states are identical. [note: in HMMs, - /// the densities are associated with the states.] We have extended this to - /// support 'non-hmm-like' topologies (where those pdf-classes are different), - /// in order to make for more compact decoding graphs in our so-called 'chain models' - /// (AKA lattice-free MMI), where we use 1-state topologies that have different pdf-classes - /// for the self-loop and the forward transition. Note that we always use the 'reorder=true' - /// option so the 'forward transition' actually comes before the self-loop. - bool IsHmm() const; - - /// Returns the topology entry (i.e. vector of HmmState) for this phone; - /// will throw exception if phone not covered by the topology. - const TopologyEntry &TopologyForPhone(int32 phone) const; - - /// Returns the number of \ref pdf_class "pdf-classes" for this phone; - /// throws exception if phone not covered by this topology. - int32 NumPdfClasses(int32 phone) const; - - /// Returns a reference to a sorted, unique list of phones covered by - /// the topology (these phones will be positive integers, and usually - /// contiguous and starting from one but the toolkit doesn't assume - /// they are contiguous). - const std::vector &GetPhones() const { return phones_; }; - - /// Outputs a vector of int32, indexed by phone, that gives the - /// number of \ref pdf_class pdf-classes for the phones; this is - /// used by tree-building code such as BuildTree(). - void GetPhoneToNumPdfClasses(std::vector *phone2num_pdf_classes) const; - - // Returns the minimum number of frames it takes to traverse this model for - // this phone: e.g. 3 for the normal HMM topology. - int32 MinLength(int32 phone) const; - - HmmTopology() {} - - bool operator == (const HmmTopology &other) const { - return phones_ == other.phones_ && phone2idx_ == other.phone2idx_ - && entries_ == other.entries_; - } - // Allow default assignment operator and copy constructor. - private: - std::vector phones_; // list of all phones we have topology for. Sorted, uniq. no epsilon (zero) phone. - std::vector phone2idx_; // map from phones to indexes into the entries vector (or -1 for not present). - std::vector entries_; -}; - - -/// @} end "addtogroup hmm_group" - - -} // end namespace kaldi - - -#endif diff --git a/src/hmm/hmm-utils-test.cc b/src/hmm/hmm-utils-test.cc index 69728cc8ca7..cf282ac03c5 100644 --- a/src/hmm/hmm-utils-test.cc +++ b/src/hmm/hmm-utils-test.cc @@ -202,7 +202,7 @@ void TestAccumulateTreeStatsOptions() { void TestSplitToPhones() { ContextDependency *ctx_dep = NULL; - TransitionModel *trans_model = GenRandTransitionModel(&ctx_dep); + Transitions *trans_model = GenRandTransitionModel(&ctx_dep); std::vector phone_seq; int32 num_phones = RandInt(0, 10); const std::vector &phone_list = trans_model->GetPhones(); @@ -273,11 +273,11 @@ void TestConvertAlignment() { } - HmmTopology topo_old = GenRandTopology(phones, num_pdf_classes_old), + Topology topo_old = GenRandTopology(phones, num_pdf_classes_old), topo_new = (new_topology ? GenRandTopology(phones, num_pdf_classes_new) : topo_old); - TransitionModel trans_model_old(*ctx_dep_old, topo_old), + Transitions trans_model_old(*ctx_dep_old, topo_old), trans_model_new(*ctx_dep_new, topo_new); std::vector phone_sequence; diff --git a/src/hmm/hmm-utils.cc b/src/hmm/hmm-utils.cc index 06edf8d5976..a70dc5275c2 100644 --- a/src/hmm/hmm-utils.cc +++ b/src/hmm/hmm-utils.cc @@ -32,7 +32,7 @@ namespace kaldi { fst::VectorFst *GetHmmAsFsa( std::vector phone_window, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const HTransducerConfig &config, HmmCacheType *cache) { using namespace fst; @@ -48,8 +48,8 @@ fst::VectorFst *GetHmmAsFsa( KALDI_ERR << "phone == 0. Some mismatch happened, or there is " "a code error."; - const HmmTopology &topo = trans_model.GetTopo(); - const HmmTopology::TopologyEntry &entry = topo.TopologyForPhone(phone); + const Topology &topo = trans_model.GetTopo(); + const Topology::TopologyEntry &entry = topo.TopologyForPhone(phone); // vector of the pdfs, indexed by pdf-class (pdf-classes must start from zero // and be contiguous). @@ -154,7 +154,7 @@ fst::VectorFst *GetHmmAsFsa( fst::VectorFst* GetHmmAsFsaSimple(std::vector phone_window, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, BaseFloat prob_scale) { using namespace fst; @@ -167,8 +167,8 @@ GetHmmAsFsaSimple(std::vector phone_window, int32 phone = phone_window[P]; KALDI_ASSERT(phone != 0); - const HmmTopology &topo = trans_model.GetTopo(); - const HmmTopology::TopologyEntry &entry = topo.TopologyForPhone(phone); + const Topology &topo = trans_model.GetTopo(); + const Topology::TopologyEntry &entry = topo.TopologyForPhone(phone); VectorFst *ans = new VectorFst; @@ -253,7 +253,7 @@ static inline fst::VectorFst *MakeTrivialAcceptor(int32 label) { // The H transducer has a separate outgoing arc for each of the symbols in ilabel_info. fst::VectorFst *GetHTransducer(const std::vector > &ilabel_info, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const HTransducerConfig &config, std::vector *disambig_syms_left) { KALDI_ASSERT(ilabel_info.size() >= 1 && ilabel_info[0].size() == 0); // make sure that eps == eps. @@ -334,7 +334,7 @@ fst::VectorFst *GetHTransducer(const std::vector void GetIlabelMapping (const std::vector > &ilabel_info_old, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, std::vector *old2new_map) { KALDI_ASSERT(old2new_map != NULL); @@ -404,7 +404,7 @@ void GetIlabelMapping (const std::vector > &ilabel_info_old, -fst::VectorFst *GetPdfToTransitionIdTransducer(const TransitionModel &trans_model) { +fst::VectorFst *GetPdfToTransitionIdTransducer(const Transitions &trans_model) { using namespace fst; VectorFst *ans = new VectorFst; typedef VectorFst::Weight Weight; @@ -437,7 +437,7 @@ class TidToTstateMapper { // with values over 100000/kNontermBigNumber) to zero. // Its point is to provide an equivalence class on labels that's relevant to what // the self-loop will be on the following (or preceding) state. - TidToTstateMapper(const TransitionModel &trans_model, + TidToTstateMapper(const Transitions &trans_model, const std::vector &disambig_syms, bool check_no_self_loops): trans_model_(trans_model), @@ -461,7 +461,7 @@ class TidToTstateMapper { } private: - const TransitionModel &trans_model_; + const Transitions &trans_model_; const std::vector &disambig_syms_; // sorted. bool check_no_self_loops_; }; @@ -469,7 +469,7 @@ class TidToTstateMapper { // This is the code that expands an FST from transition-states to // transition-ids, in the case where reorder == true, i.e. the non-optional // transition is before the self-loop. -static void AddSelfLoopsReorder(const TransitionModel &trans_model, +static void AddSelfLoopsReorder(const Transitions &trans_model, const std::vector &disambig_syms, BaseFloat self_loop_scale, bool check_no_self_loops, @@ -553,7 +553,7 @@ static void AddSelfLoopsReorder(const TransitionModel &trans_model, // transition-ids, in the case where reorder == false, i.e. non-optional // transition is after the self-loop. static void AddSelfLoopsNoReorder( - const TransitionModel &trans_model, + const Transitions &trans_model, const std::vector &disambig_syms, BaseFloat self_loop_scale, bool check_no_self_loops, @@ -599,7 +599,7 @@ static void AddSelfLoopsNoReorder( } } -void AddSelfLoops(const TransitionModel &trans_model, +void AddSelfLoops(const Transitions &trans_model, const std::vector &disambig_syms, BaseFloat self_loop_scale, bool reorder, @@ -622,7 +622,7 @@ void AddSelfLoops(const TransitionModel &trans_model, // code doesn't care what the answer is. // The "alignment" vector contains a sequence of TransitionIds. -static bool IsReordered(const TransitionModel &trans_model, +static bool IsReordered(const Transitions &trans_model, const std::vector &alignment) { for (size_t i = 0; i + 1 < alignment.size(); i++) { int32 tstate1 = trans_model.TransitionIdToTransitionState(alignment[i]), @@ -656,7 +656,7 @@ static bool IsReordered(const TransitionModel &trans_model, // checks (if the input does not start at the start of a phone or does not // end at the end of a phone, we should expect that false will be returned). -static bool SplitToPhonesInternal(const TransitionModel &trans_model, +static bool SplitToPhonesInternal(const Transitions &trans_model, const std::vector &alignment, bool reordered, std::vector > *split_output) { @@ -720,7 +720,7 @@ static bool SplitToPhonesInternal(const TransitionModel &trans_model, } -bool SplitToPhones(const TransitionModel &trans_model, +bool SplitToPhones(const Transitions &trans_model, const std::vector &alignment, std::vector > *split_alignment) { KALDI_ASSERT(split_alignment != NULL); @@ -740,8 +740,8 @@ bool SplitToPhones(const TransitionModel &trans_model, 'subsample' value is not 1). */ static inline void ConvertAlignmentForPhone( - const TransitionModel &old_trans_model, - const TransitionModel &new_trans_model, + const Transitions &old_trans_model, + const Transitions &new_trans_model, const ContextDependencyInterface &new_ctx_dep, const std::vector &old_phone_alignment, const std::vector &new_phone_window, @@ -754,7 +754,7 @@ static inline void ConvertAlignmentForPhone( old_central_phone = old_trans_model.TransitionIdToPhone( old_phone_alignment[0]), new_central_phone = new_phone_window[P]; - const HmmTopology &old_topo = old_trans_model.GetTopo(), + const Topology &old_topo = old_trans_model.GetTopo(), &new_topo = new_trans_model.GetTopo(); bool topology_mismatch = !(old_topo.TopologyForPhone(old_central_phone) == @@ -846,7 +846,7 @@ static inline void ConvertAlignmentForPhone( reduced-frame-rate system. @param new_lengths [out] The vector for storing new lengths. */ -static bool ComputeNewPhoneLengths(const HmmTopology &topology, +static bool ComputeNewPhoneLengths(const Topology &topology, const std::vector &mapped_phones, const std::vector &old_lengths, int32 conversion_shift, @@ -923,8 +923,8 @@ static bool ComputeNewPhoneLengths(const HmmTopology &topology, 'conversion_shift' is for. */ -static bool ConvertAlignmentInternal(const TransitionModel &old_trans_model, - const TransitionModel &new_trans_model, +static bool ConvertAlignmentInternal(const Transitions &old_trans_model, + const Transitions &new_trans_model, const ContextDependencyInterface &new_ctx_dep, const std::vector &old_alignment, int32 conversion_shift, @@ -1010,8 +1010,8 @@ static bool ConvertAlignmentInternal(const TransitionModel &old_trans_model, return true; } -bool ConvertAlignment(const TransitionModel &old_trans_model, - const TransitionModel &new_trans_model, +bool ConvertAlignment(const Transitions &old_trans_model, + const Transitions &new_trans_model, const ContextDependencyInterface &new_ctx_dep, const std::vector &old_alignment, int32 subsample_factor, @@ -1062,7 +1062,7 @@ bool ConvertAlignment(const TransitionModel &old_trans_model, } // Returns the scaled, but not negated, log-prob, with the given scaling factors. -static BaseFloat GetScaledTransitionLogProb(const TransitionModel &trans_model, +static BaseFloat GetScaledTransitionLogProb(const Transitions &trans_model, int32 trans_id, BaseFloat transition_scale, BaseFloat self_loop_scale) { @@ -1085,7 +1085,7 @@ static BaseFloat GetScaledTransitionLogProb(const TransitionModel &trans_model, -void AddTransitionProbs(const TransitionModel &trans_model, +void AddTransitionProbs(const Transitions &trans_model, const std::vector &disambig_syms, // may be empty BaseFloat transition_scale, BaseFloat self_loop_scale, @@ -1118,7 +1118,7 @@ void AddTransitionProbs(const TransitionModel &trans_model, } } -void AddTransitionProbs(const TransitionModel &trans_model, +void AddTransitionProbs(const Transitions &trans_model, BaseFloat transition_scale, BaseFloat self_loop_scale, Lattice *lat) { @@ -1205,7 +1205,7 @@ bool ConvertPhnxToProns(const std::vector &phnx, void GetRandomAlignmentForPhone(const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const std::vector &phone_window, std::vector *alignment) { typedef fst::StdArc Arc; @@ -1257,7 +1257,7 @@ void GetRandomAlignmentForPhone(const ContextDependencyInterface &ctx_dep, delete fst; } -void ChangeReorderingOfAlignment(const TransitionModel &trans_model, +void ChangeReorderingOfAlignment(const Transitions &trans_model, std::vector *alignment) { int32 start_pos = 0, size = alignment->size(); while (start_pos != size) { diff --git a/src/hmm/hmm-utils.h b/src/hmm/hmm-utils.h index a8ad846949e..9cefa557bb3 100644 --- a/src/hmm/hmm-utils.h +++ b/src/hmm/hmm-utils.h @@ -20,8 +20,8 @@ #ifndef KALDI_HMM_HMM_UTILS_H_ #define KALDI_HMM_HMM_UTILS_H_ -#include "hmm/hmm-topology.h" -#include "hmm/transition-model.h" +#include "hmm/topology.h" +#include "hmm/transitions.h" #include "lat/kaldi-lattice.h" namespace kaldi { @@ -93,7 +93,7 @@ typedef unordered_map >, fst::VectorFst *GetHmmAsFsa( std::vector context_window, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const HTransducerConfig &config, HmmCacheType *cache = NULL); @@ -104,7 +104,7 @@ fst::VectorFst *GetHmmAsFsa( fst::VectorFst* GetHmmAsFsaSimple(std::vector context_window, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, BaseFloat prob_scale); @@ -126,7 +126,7 @@ GetHmmAsFsaSimple(std::vector context_window, fst::VectorFst* GetHTransducer(const std::vector > &ilabel_info, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const HTransducerConfig &config, std::vector *disambig_syms_left); @@ -148,7 +148,7 @@ GetHTransducer(const std::vector > &ilabel_info, */ void GetIlabelMapping(const std::vector > &ilabel_info_old, const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, std::vector *old2new_map); @@ -182,7 +182,7 @@ void GetIlabelMapping(const std::vector > &ilabel_info_old, * which emulates the behavior of older code. * @param fst [in, out] The FST to be modified. */ -void AddSelfLoops(const TransitionModel &trans_model, +void AddSelfLoops(const Transitions &trans_model, const std::vector &disambig_syms, // used as a check only. BaseFloat self_loop_scale, bool reorder, @@ -206,7 +206,7 @@ void AddSelfLoops(const TransitionModel &trans_model, * see \ref hmm_scale. * @param fst [in, out] The FST to be modified. */ -void AddTransitionProbs(const TransitionModel &trans_model, +void AddTransitionProbs(const Transitions &trans_model, const std::vector &disambig_syms, BaseFloat transition_scale, BaseFloat self_loop_scale, @@ -216,7 +216,7 @@ void AddTransitionProbs(const TransitionModel &trans_model, This is as AddSelfLoops(), but operates on a Lattice, where it affects the graph part of the weight (the first element of the pair). */ -void AddTransitionProbs(const TransitionModel &trans_model, +void AddTransitionProbs(const Transitions &trans_model, BaseFloat transition_scale, BaseFloat self_loop_scale, Lattice *lat); @@ -225,11 +225,11 @@ void AddTransitionProbs(const TransitionModel &trans_model, /// Returns a transducer from pdfs plus one (input) to transition-ids (output). /// Currenly of use only for testing. fst::VectorFst* -GetPdfToTransitionIdTransducer(const TransitionModel &trans_model); +GetPdfToTransitionIdTransducer(const Transitions &trans_model); /// Converts all transition-ids in the FST to pdfs plus one. /// Placeholder: not implemented yet! -void ConvertTransitionIdsToPdfs(const TransitionModel &trans_model, +void ConvertTransitionIdsToPdfs(const Transitions &trans_model, const std::vector &disambig_syms, fst::VectorFst *fst); @@ -248,7 +248,7 @@ void ConvertTransitionIdsToPdfs(const TransitionModel &trans_model, /// die or throw an exception. /// This function works out by itself whether the graph was created /// with "reordering", and just does the right thing. -bool SplitToPhones(const TransitionModel &trans_model, +bool SplitToPhones(const Transitions &trans_model, const std::vector &alignment, std::vector > *split_alignment); @@ -279,13 +279,13 @@ bool SplitToPhones(const TransitionModel &trans_model, the same as the input where possible.] @param reorder [in] True if you want the pdf-ids on the new alignment to be 'reordered'. (vs. the way they appear in - the HmmTopology object) + the Topology object) @param phone_map [in] If non-NULL, map from old to new phones. @param new_alignment [out] The converted alignment. */ -bool ConvertAlignment(const TransitionModel &old_trans_model, - const TransitionModel &new_trans_model, +bool ConvertAlignment(const Transitions &old_trans_model, + const Transitions &new_trans_model, const ContextDependencyInterface &new_ctx_dep, const std::vector &old_alignment, int32 subsample_factor, // 1 in the normal case -> no subsampling. @@ -319,14 +319,14 @@ bool ConvertPhnxToProns(const std::vector &phnx, The alignment will be without 'reordering'. */ void GetRandomAlignmentForPhone(const ContextDependencyInterface &ctx_dep, - const TransitionModel &trans_model, + const Transitions &trans_model, const std::vector &phone_window, std::vector *alignment); /* If the alignment was non-reordered makes it reordered, and vice versa. */ -void ChangeReorderingOfAlignment(const TransitionModel &trans_model, +void ChangeReorderingOfAlignment(const Transitions &trans_model, std::vector *alignment); /// @} end "addtogroup hmm_group" diff --git a/src/hmm/posterior.cc b/src/hmm/posterior.cc index 860a979a0ce..3089be237b2 100644 --- a/src/hmm/posterior.cc +++ b/src/hmm/posterior.cc @@ -299,8 +299,8 @@ void AlignmentToPosterior(const std::vector &ali, } struct ComparePosteriorByPdfs { - const TransitionModel *tmodel_; - ComparePosteriorByPdfs(const TransitionModel &tmodel): tmodel_(&tmodel) {} + const Transitions *tmodel_; + ComparePosteriorByPdfs(const Transitions &tmodel): tmodel_(&tmodel) {} bool operator() (const std::pair &a, const std::pair &b) { if (tmodel_->TransitionIdToPdf(a.first) @@ -311,7 +311,7 @@ struct ComparePosteriorByPdfs { } }; -void SortPosteriorByPdfs(const TransitionModel &tmodel, +void SortPosteriorByPdfs(const Transitions &tmodel, Posterior *post) { ComparePosteriorByPdfs compare(tmodel); for (size_t i = 0; i < post->size(); i++) { @@ -319,7 +319,7 @@ void SortPosteriorByPdfs(const TransitionModel &tmodel, } } -void ConvertPosteriorToPdfs(const TransitionModel &tmodel, +void ConvertPosteriorToPdfs(const Transitions &tmodel, const Posterior &post_in, Posterior *post_out) { post_out->clear(); @@ -345,7 +345,7 @@ void ConvertPosteriorToPdfs(const TransitionModel &tmodel, } } -void ConvertPosteriorToPhones(const TransitionModel &tmodel, +void ConvertPosteriorToPhones(const Transitions &tmodel, const Posterior &post_in, Posterior *post_out) { post_out->clear(); @@ -372,7 +372,7 @@ void ConvertPosteriorToPhones(const TransitionModel &tmodel, } -void WeightSilencePost(const TransitionModel &trans_model, +void WeightSilencePost(const Transitions &trans_model, const ConstIntegerSet &silence_set, BaseFloat silence_scale, Posterior *post) { @@ -395,7 +395,7 @@ void WeightSilencePost(const TransitionModel &trans_model, } -void WeightSilencePostDistributed(const TransitionModel &trans_model, +void WeightSilencePostDistributed(const Transitions &trans_model, const ConstIntegerSet &silence_set, BaseFloat silence_scale, Posterior *post) { @@ -537,7 +537,7 @@ template void PosteriorToMatrix(const Posterior &post, template void PosteriorToPdfMatrix(const Posterior &post, - const TransitionModel &model, + const Transitions &model, Matrix *mat) { // Allocate the matrix, int32 num_rows = post.size(), @@ -557,10 +557,10 @@ void PosteriorToPdfMatrix(const Posterior &post, } // instantiate the template function, template void PosteriorToPdfMatrix(const Posterior &post, - const TransitionModel &model, + const Transitions &model, Matrix *mat); template void PosteriorToPdfMatrix(const Posterior &post, - const TransitionModel &model, + const Transitions &model, Matrix *mat); } // End namespace kaldi diff --git a/src/hmm/posterior.h b/src/hmm/posterior.h index e153c249740..7663cf0ce42 100644 --- a/src/hmm/posterior.h +++ b/src/hmm/posterior.h @@ -26,7 +26,7 @@ #include "base/kaldi-common.h" #include "util/const-integer-set.h" #include "util/kaldi-table.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "matrix/kaldi-matrix.h" @@ -205,19 +205,19 @@ void AlignmentToPosterior(const std::vector &ali, /// Sorts posterior entries so that transition-ids with same pdf-id are next to /// each other. -void SortPosteriorByPdfs(const TransitionModel &tmodel, +void SortPosteriorByPdfs(const Transitions &tmodel, Posterior *post); /// Converts a posterior over transition-ids to be a posterior /// over pdf-ids. -void ConvertPosteriorToPdfs(const TransitionModel &tmodel, +void ConvertPosteriorToPdfs(const Transitions &tmodel, const Posterior &post_in, Posterior *post_out); /// Converts a posterior over transition-ids to be a posterior /// over phones. -void ConvertPosteriorToPhones(const TransitionModel &tmodel, +void ConvertPosteriorToPhones(const Transitions &tmodel, const Posterior &post_in, Posterior *post_out); @@ -225,7 +225,7 @@ void ConvertPosteriorToPhones(const TransitionModel &tmodel, /// in the set "silence_set" by scale "silence_scale". /// The interface was changed in Feb 2014 to do the modification /// "in-place" rather than having separate input and output. -void WeightSilencePost(const TransitionModel &trans_model, +void WeightSilencePost(const Transitions &trans_model, const ConstIntegerSet &silence_set, BaseFloat silence_scale, Posterior *post); @@ -236,7 +236,7 @@ void WeightSilencePost(const TransitionModel &trans_model, /// has the effect that frames that are mostly silence get down-weighted. /// The interface was changed in Feb 2014 to do the modification /// "in-place" rather than having separate input and output. -void WeightSilencePostDistributed(const TransitionModel &trans_model, +void WeightSilencePostDistributed(const Transitions &trans_model, const ConstIntegerSet &silence_set, BaseFloat silence_scale, Posterior *post); @@ -250,11 +250,11 @@ void PosteriorToMatrix(const Posterior &post, /// This converts a Posterior to a Matrix. The number of matrix-rows is the same /// as the 'post.size()', the number of matrix-columns is defined by 'NumPdfs' -/// in the TransitionModel. +/// in the Transitions. /// The elements which are not specified in 'Posterior' are equal to zero. template void PosteriorToPdfMatrix(const Posterior &post, - const TransitionModel &model, + const Transitions &model, Matrix *mat); /// @} end "addtogroup posterior_group" diff --git a/src/hmm/hmm-topology.cc b/src/hmm/topology.cc similarity index 89% rename from src/hmm/hmm-topology.cc rename to src/hmm/topology.cc index cf134065dbf..a0563f90c0d 100644 --- a/src/hmm/hmm-topology.cc +++ b/src/hmm/topology.cc @@ -1,7 +1,7 @@ -// hmm/hmm-topology.cc +// hmm/topology.cc // Copyright 2009-2011 Microsoft Corporation -// 2014 Johns Hopkins University (author: Daniel Povey) +// 2014-2019 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -20,7 +20,7 @@ #include -#include "hmm/hmm-topology.h" +#include "hmm/topology.h" #include "util/text-utils.h" @@ -28,7 +28,7 @@ namespace kaldi { -void HmmTopology::GetPhoneToNumPdfClasses(std::vector *phone2num_pdf_classes) const { +void Topology::GetPhoneToNumPdfClasses(std::vector *phone2num_pdf_classes) const { KALDI_ASSERT(!phones_.empty()); phone2num_pdf_classes->clear(); phone2num_pdf_classes->resize(phones_.back() + 1, -1); @@ -36,7 +36,7 @@ void HmmTopology::GetPhoneToNumPdfClasses(std::vector *phone2num_pdf_clas (*phone2num_pdf_classes)[phones_[i]] = NumPdfClasses(phones_[i]); } -void HmmTopology::Read(std::istream &is, bool binary) { +void Topology::Read(std::istream &is, bool binary) { ExpectToken(is, binary, ""); if (!binary) { // Text-mode read, different "human-readable" format. phones_.clear(); @@ -46,19 +46,19 @@ void HmmTopology::Read(std::istream &is, bool binary) { while ( ! (is >> token).fail() ) { if (token == "") { break; } // finished parsing. else if (token != "") { - KALDI_ERR << "Reading HmmTopology object, expected or , got "< or , got "<"); std::vector phones; std::string s; while (1) { is >> s; - if (is.fail()) KALDI_ERR << "Reading HmmTopology object, unexpected end of file while expecting phones."; + if (is.fail()) KALDI_ERR << "Reading Topology object, unexpected end of file while expecting phones."; if (s == "") break; else { int32 phone; if (!ConvertStringToInteger(s, &phone)) - KALDI_ERR << "Reading HmmTopology object, expected " + KALDI_ERR << "Reading Topology object, expected " << "integer, got instead " << s; phones.push_back(phone); } @@ -105,7 +105,7 @@ void HmmTopology::Read(std::istream &is, bool binary) { if(token == "") // TODO: remove this clause after a while. KALDI_ERR << "You are trying to read old-format topology with new Kaldi."; if (token != "") - KALDI_ERR << "Reading HmmTopology, unexpected token "<"); if (!binary) { // Text-mode write. @@ -228,15 +228,15 @@ void HmmTopology::Write(std::ostream &os, bool binary) const { if (!binary) os << "\n"; } -void HmmTopology::Check() { +void Topology::Check() { if (entries_.empty() || phones_.empty() || phone2idx_.empty()) - KALDI_ERR << "HmmTopology::Check(), empty object."; + KALDI_ERR << "Topology::Check(), empty object."; std::vector is_seen(entries_.size(), false); for (size_t i = 0; i < phones_.size(); i++) { int32 phone = phones_[i]; if (static_cast(phone) >= phone2idx_.size() || static_cast(phone2idx_[phone]) >= entries_.size()) - KALDI_ERR << "HmmTopology::Check(), phone has no valid index."; + KALDI_ERR << "Topology::Check(), phone has no valid index."; is_seen[phone2idx_[phone]] = true; } for (size_t i = 0; i < entries_.size(); i++) { @@ -244,13 +244,13 @@ void HmmTopology::Check() { KALDI_ERR << "HmmTopoloy::Check(), entry with no corresponding phones."; int32 num_states = static_cast(entries_[i].size()); if (num_states <= 1) - KALDI_ERR << "HmmTopology::Check(), cannot only have one state (i.e., must " + KALDI_ERR << "Topology::Check(), cannot only have one state (i.e., must " "have at least one emitting state)."; if (!entries_[i][num_states-1].transitions.empty()) - KALDI_ERR << "HmmTopology::Check(), last state must have no transitions."; + KALDI_ERR << "Topology::Check(), last state must have no transitions."; // not sure how necessary this next stipulation is. if (entries_[i][num_states-1].forward_pdf_class != kNoPdf) - KALDI_ERR << "HmmTopology::Check(), last state must not be emitting."; + KALDI_ERR << "Topology::Check(), last state must not be emitting."; std::vector has_trans_in(num_states, false); std::vector seen_pdf_classes; @@ -267,7 +267,7 @@ void HmmTopology::Check() { k++) { tot_prob += entries_[i][j].transitions[k].second; if (entries_[i][j].transitions[k].second <= 0.0) - KALDI_ERR << "HmmTopology::Check(), negative or zero transition prob."; + KALDI_ERR << "Topology::Check(), negative or zero transition prob."; int32 dst_state = entries_[i][j].transitions[k].first; // The commented code in the next few lines disallows a completely // skippable phone, as this would cause to stop working some mechanisms @@ -280,9 +280,9 @@ void HmmTopology::Check() { "stop the SplitToPhones function from identifying the last state " "of a phone."; if (dst_state < 0 || dst_state >= num_states) - KALDI_ERR << "HmmTopology::Check(), invalid dest state " << (dst_state); + KALDI_ERR << "Topology::Check(), invalid dest state " << (dst_state); if (seen_transition.count(dst_state) != 0) - KALDI_ERR << "HmmTopology::Check(), duplicate transition found."; + KALDI_ERR << "Topology::Check(), duplicate transition found."; if (dst_state == k) { // self_loop... KALDI_ASSERT(entries_[i][j].self_loop_pdf_class != kNoPdf && "Nonemitting states cannot have self-loops."); @@ -302,17 +302,17 @@ void HmmTopology::Check() { // make sure all but start state have input transitions. for (int32 j = 1; j < num_states; j++) if (!has_trans_in[j]) - KALDI_ERR << "HmmTopology::Check, state "<<(j)<<" has no input transitions."; + KALDI_ERR << "Topology::Check, state "<<(j)<<" has no input transitions."; SortAndUniq(&seen_pdf_classes); if (seen_pdf_classes.front() != 0 || seen_pdf_classes.back() != static_cast(seen_pdf_classes.size()) - 1) { - KALDI_ERR << "HmmTopology::Check(), pdf_classes are expected to be " + KALDI_ERR << "Topology::Check(), pdf_classes are expected to be " "contiguous and start from zero."; } } } -bool HmmTopology::IsHmm() const { +bool Topology::IsHmm() const { const std::vector &phones = GetPhones(); KALDI_ASSERT(!phones.empty()); for (size_t i = 0; i < phones.size(); i++) { @@ -328,14 +328,14 @@ bool HmmTopology::IsHmm() const { return true; } -const HmmTopology::TopologyEntry& HmmTopology::TopologyForPhone(int32 phone) const { // Will throw if phone not covered. +const Topology::TopologyEntry& HmmTopology::TopologyForPhone(int32 phone) const { // Will throw if phone not covered. if (static_cast(phone) >= phone2idx_.size() || phone2idx_[phone] == -1) { KALDI_ERR << "TopologyForPhone(), phone "<<(phone)<<" not covered."; } return entries_[phone2idx_[phone]]; } -int32 HmmTopology::NumPdfClasses(int32 phone) const { +int32 Topology::NumPdfClasses(int32 phone) const { // will throw if phone not covered. const TopologyEntry &entry = TopologyForPhone(phone); int32 max_pdf_class = 0; @@ -346,7 +346,7 @@ int32 HmmTopology::NumPdfClasses(int32 phone) const { return max_pdf_class+1; } -int32 HmmTopology::MinLength(int32 phone) const { +int32 Topology::MinLength(int32 phone) const { const TopologyEntry &entry = TopologyForPhone(phone); // min_length[state] gives the minimum length for sequences up to and // including that state. diff --git a/src/hmm/topology.h b/src/hmm/topology.h new file mode 100644 index 00000000000..eae0640af08 --- /dev/null +++ b/src/hmm/topology.h @@ -0,0 +1,138 @@ +// hmm/topology.h + +// Copyright 2009-2011 Microsoft Corporation +// 2019 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_HMM_TOPOLOGY_H_ +#define KALDI_HMM_HMM_TOPOLOGY_H_ + +#include +#include "base/kaldi-common.h" + + +namespace kaldi { + + +/// \addtogroup hmm_group +/// @{ + +/* + The following would be the text form for the "normal" 3-state HMM topology/ + "bakis model", with the typical reordering that we do to improve the + compactness of the compiled FSTs. The format is the OpenFst acceptor format. + The fields are, for transitions, + + and, for final-states, + + + The may be interpreted as negative log probabilities. + We normally set them so as to sum to one, in order to keep the fully + compiled (HCLG) graph fairly stochastic (meaning: sum-to-one, like an + HMM). + + The integers on the arcs, which we call 'pdf-classes', define which + arcs share the same "pdf" and which ones are distinct. + + Preconditions on topology: + - pdf-classes (3rd field on arcs) must + form a contiguous list of numbers starting from 1, although + different arcs with the same pdf-class are allowed. (We avoid 0 + because it is "special" in OpenFST, it is used for epsilon). + - The start state must be state 0 and there must be no + transitions entering it except (possibly) a self-loop (although + a self-loop on state 0 is not advised for decoding-graph-size + reasons) + - The start state must not be final. + + + + + 1 2 3 4 5 6 7 8 + 0 1 1 0.0 + 1 1 1 0.693 + 1 2 2 0.693 + 2 2 2 0.693 + 2 3 3 0.693 + 3 3 3 0.693 + 3 0.693 + + +*/ + + +/// A class for storing topology information for phones. See \ref hmm for context. +/// This object is sometimes accessed in a file by itself, but more often +/// as a class member of the Transition class (this is for convenience to reduce +/// the number of files programs have to access). + +class Topology { + public: + + void Read(std::istream &is, bool binary); + void Write(std::ostream &os, bool binary) const; + + // Checks that the object is valid, and throw exception otherwise. + void Check(); + + /// Returns the topology entry for this phone; + /// will throw exception if phone not covered by the topology. + const fst::StdFst &TopologyForPhone(int32 phone) const; + + /// Returns the number of \ref pdf_class "pdf-classes" for this phone; + /// throws exception if phone not covered by this topology. + int32 NumPdfClasses(int32 phone) const; + + /// Returns a reference to a sorted, unique list of phones covered by + /// the topology (these phones will be positive integers, and usually + /// contiguous and starting from one but the toolkit doesn't assume + /// they are contiguous). + const std::vector &GetPhones() const { return phones_; }; + + /// Outputs a vector of int32, indexed by phone, that gives the + /// number of \ref pdf_class pdf-classes for the phones; this is + /// used by tree-building code such as BuildTree(). + void GetPhoneToNumPdfClasses(std::vector *phone2num_pdf_classes) const; + + // Returns the minimum number of arcs/frames it takes to traverse this model + // for this phone: e.g. 3 for the normal HMM topology. + int32 MinLength(int32 phone) const; + + Topology() {} + + bool operator == (const Topology &other) const; + + // was: + //return phones_ == other.phones_ && phone2idx_ == other.phone2idx_ + //&& entries_ == other.entries_; + // TODO: implement this; we probably need Equal() on fsts. + + // Allow default assignment operator and copy constructor. + private: + std::vector phones_; // list of all phones we have topology for. Sorted, uniq. no epsilon (zero) phone. + std::vector phone2idx_; // map from phones to indexes into the entries vector (or -1 for not present). + std::vector entries_; +}; + + +/// @} end "addtogroup hmm_group" + + +} // end namespace kaldi + + +#endif diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h deleted file mode 100644 index e453c24f9cb..00000000000 --- a/src/hmm/transition-model.h +++ /dev/null @@ -1,371 +0,0 @@ -// hmm/transition-model.h - -// Copyright 2009-2012 Microsoft Corporation -// Johns Hopkins University (author: Guoguo Chen) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_HMM_TRANSITION_MODEL_H_ -#define KALDI_HMM_TRANSITION_MODEL_H_ - -#include "base/kaldi-common.h" -#include "util/const-integer-set.h" -#include "fst/fst-decl.h" // forward declarations. -#include "hmm/hmm-topology.h" -#include "itf/options-itf.h" -#include "itf/context-dep-itf.h" -#include "matrix/kaldi-vector.h" - -namespace kaldi { - -/// \addtogroup hmm_group -/// @{ - -// The class TransitionModel is a repository for the transition probabilities. -// It also handles certain integer mappings. -// The basic model is as follows. Each phone has a HMM topology defined in -// hmm-topology.h. Each HMM-state of each of these phones has a number of -// transitions (and final-probs) out of it. Each HMM-state defined in the -// HmmTopology class has an associated "pdf_class". This gets replaced with -// an actual pdf-id via the tree. The transition model associates the -// transition probs with the (phone, HMM-state, pdf-id). We associate with -// each such triple a transition-state. Each -// transition-state has a number of associated probabilities to estimate; -// this depends on the number of transitions/final-probs in the topology for -// that (phone, HMM-state). Each probability has an associated transition-index. -// We associate with each (transition-state, transition-index) a unique transition-id. -// Each individual probability estimated by the transition-model is asociated with a -// transition-id. -// -// List of the various types of quantity referred to here and what they mean: -// phone: a phone index (1, 2, 3 ...) -// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) -// pdf-id: a number output by the Compute function of ContextDependency (it -// indexes pdf's, either forward or self-loop). Zero-based. -// transition-state: the states for which we estimate transition probabilities for transitions -// out of them. In some topologies, will map one-to-one with pdf-ids. -// One-based, since it appears on FSTs. -// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the -// "transitions" vector in HmmTopology::HmmState. [if it is out of range, -// equal to transitions.size(), it refers to the final-prob.] -// Zero-based. -// transition-id: identifier of a unique parameter of the TransitionModel. -// Associated with a (transition-state, transition-index) pair. -// One-based, since it appears on FSTs. -// -// List of the possible mappings TransitionModel can do: -// (phone, HMM-state, forward-pdf-id, self-loop-pdf-id) -> transition-state -// (transition-state, transition-index) -> transition-id -// Reverse mappings: -// transition-id -> transition-state -// transition-id -> transition-index -// transition-state -> phone -// transition-state -> HMM-state -// transition-state -> forward-pdf-id -// transition-state -> self-loop-pdf-id -// -// The main things the TransitionModel object can do are: -// Get initialized (need ContextDependency and HmmTopology objects). -// Read/write. -// Update [given a vector of counts indexed by transition-id]. -// Do the various integer mappings mentioned above. -// Get the probability (or log-probability) associated with a particular transition-id. - - -// Note: this was previously called TransitionUpdateConfig. -struct MleTransitionUpdateConfig { - BaseFloat floor; - BaseFloat mincount; - bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. - MleTransitionUpdateConfig(BaseFloat floor = 0.01, - BaseFloat mincount = 5.0, - bool share_for_pdfs = false): - floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {} - - void Register (OptionsItf *opts) { - opts->Register("transition-floor", &floor, - "Floor for transition probabilities"); - opts->Register("transition-min-count", &mincount, - "Minimum count required to update transitions from a state"); - opts->Register("share-for-pdfs", &share_for_pdfs, - "If true, share all transition parameters where the states " - "have the same pdf."); - } -}; - -struct MapTransitionUpdateConfig { - BaseFloat tau; - bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. - MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { } - - void Register (OptionsItf *opts) { - opts->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " - "probabilities."); - opts->Register("share-for-pdfs", &share_for_pdfs, - "If true, share all transition parameters where the states " - "have the same pdf."); - } -}; - -class TransitionModel { - - public: - /// Initialize the object [e.g. at the start of training]. - /// The class keeps a copy of the HmmTopology object, but not - /// the ContextDependency object. - TransitionModel(const ContextDependencyInterface &ctx_dep, - const HmmTopology &hmm_topo); - - - /// Constructor that takes no arguments: typically used prior to calling Read. - TransitionModel(): num_pdfs_(0) { } - - void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. - void Write(std::ostream &os, bool binary) const; - - - /// return reference to HMM-topology object. - const HmmTopology &GetTopo() const { return topo_; } - - /// \name Integer mapping functions - /// @{ - - int32 TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const; - int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; - int32 TransitionIdToTransitionState(int32 trans_id) const; - int32 TransitionIdToTransitionIndex(int32 trans_id) const; - int32 TransitionStateToPhone(int32 trans_state) const; - int32 TransitionStateToHmmState(int32 trans_state) const; - int32 TransitionStateToForwardPdfClass(int32 trans_state) const; - int32 TransitionStateToSelfLoopPdfClass(int32 trans_state) const; - int32 TransitionStateToForwardPdf(int32 trans_state) const; - int32 TransitionStateToSelfLoopPdf(int32 trans_state) const; - int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if - // this state doesn't have a self-loop. - - inline int32 TransitionIdToPdf(int32 trans_id) const; - // TransitionIdToPdfFast is as TransitionIdToPdf but skips an assertion - // (unless we're in paranoid mode). - inline int32 TransitionIdToPdfFast(int32 trans_id) const; - - int32 TransitionIdToPhone(int32 trans_id) const; - int32 TransitionIdToPdfClass(int32 trans_id) const; - int32 TransitionIdToHmmState(int32 trans_id) const; - - /// @} - - bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state - // (which is bound to be nonemitting). - bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. - - /// Returns the total number of transition-ids (note, these are one-based). - inline int32 NumTransitionIds() const { return id2state_.size()-1; } - - /// Returns the number of transition-indices for a particular transition-state. - /// Note: "Indices" is the plural of "index". Index is not the same as "id", - /// here. A transition-index is a zero-based offset into the transitions - /// out of a particular transition state. - int32 NumTransitionIndices(int32 trans_state) const; - - /// Returns the total number of transition-states (note, these are one-based). - int32 NumTransitionStates() const { return tuples_.size(); } - - // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one. - // In normal cases this should equal the number of pdfs in the system, but if you - // initialized this object with fewer than all the phones, and it happens that - // an unseen phone has the highest-numbered pdf, this might be different. - int32 NumPdfs() const { return num_pdfs_; } - - // This loops over the tuples and finds the highest phone index present. If - // the FST symbol table for the phones is created in the expected way, i.e.: - // starting from 1 ( is 0) and numbered contiguously till the last phone, - // this will be the total number of phones. - int32 NumPhones() const; - - /// Returns a sorted, unique list of phones. - const std::vector &GetPhones() const { return topo_.GetPhones(); } - - // Transition-parameter-getting functions: - BaseFloat GetTransitionProb(int32 trans_id) const; - BaseFloat GetTransitionLogProb(int32 trans_id) const; - - // The following functions are more specialized functions for getting - // transition probabilities, that are provided for convenience. - - /// Returns the log-probability of a particular non-self-loop transition - /// after subtracting the probability mass of the self-loop and renormalizing; - /// will crash if called on a self-loop. Specifically: - /// for non-self-loops it returns the log of (that prob divided by (1 minus - /// self-loop-prob-for-that-state)). - BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; - - /// Returns the log-prob of the non-self-loop probability - /// mass for this transition state. (you can get the self-loop prob, if a self-loop - /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). - BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; - - /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed - /// by transition-id. This was previously called Update(). - void MleUpdate(const Vector &stats, - const MleTransitionUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out); - - /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, - /// indexed by transition-id. - void MapUpdate(const Vector &stats, - const MapTransitionUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out); - - /// Print will print the transition model in a human-readable way, for purposes of human - /// inspection. The "occs" are optional (they are indexed by pdf-id). - void Print(std::ostream &os, - const std::vector &phone_names, - const Vector *occs = NULL); - - - void InitStats(Vector *stats) const { stats->Resize(NumTransitionIds()+1); } - - void Accumulate(BaseFloat prob, int32 trans_id, Vector *stats) const { - KALDI_ASSERT(trans_id <= NumTransitionIds()); - (*stats)(trans_id) += prob; - // This is trivial and doesn't require class members, but leaves us more open - // to design changes than doing it manually. - } - - /// returns true if all the integer class members are identical (but does not - /// compare the transition probabilities. - bool Compatible(const TransitionModel &other) const; - - private: - void MleUpdateShared(const Vector &stats, - const MleTransitionUpdateConfig &cfg, - BaseFloat *objf_impr_out, BaseFloat *count_out); - void MapUpdateShared(const Vector &stats, - const MapTransitionUpdateConfig &cfg, - BaseFloat *objf_impr_out, BaseFloat *count_out); - void ComputeTuples(const ContextDependencyInterface &ctx_dep); // called from constructor. initializes tuples_. - void ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep); - void ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep); - void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. - void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just - // non_self_loop_log_probs_; called whenever log-probs change. - void InitializeProbs(); // called from constructor. - void Check() const; - bool IsHmm() const; - - struct Tuple { - int32 phone; - int32 hmm_state; - int32 forward_pdf; - int32 self_loop_pdf; - Tuple() { } - Tuple(int32 phone, int32 hmm_state, int32 forward_pdf, int32 self_loop_pdf): - phone(phone), hmm_state(hmm_state), forward_pdf(forward_pdf), self_loop_pdf(self_loop_pdf) { } - bool operator < (const Tuple &other) const { - if (phone < other.phone) return true; - else if (phone > other.phone) return false; - else if (hmm_state < other.hmm_state) return true; - else if (hmm_state > other.hmm_state) return false; - else if (forward_pdf < other.forward_pdf) return true; - else if (forward_pdf > other.forward_pdf) return false; - else return (self_loop_pdf < other.self_loop_pdf); - } - bool operator == (const Tuple &other) const { - return (phone == other.phone && hmm_state == other.hmm_state - && forward_pdf == other.forward_pdf && self_loop_pdf == other.self_loop_pdf); - } - }; - - HmmTopology topo_; - - /// Tuples indexed by transition state minus one; - /// the tuples are in sorted order which allows us to do the reverse mapping from - /// tuple to transition state - std::vector tuples_; - - /// Gives the first transition_id of each transition-state; indexed by - /// the transition-state. Array indexed 1..num-transition-states+1 (the last one - /// is needed so we can know the num-transitions of the last transition-state. - std::vector state2id_; - - /// For each transition-id, the corresponding transition - /// state (indexed by transition-id). - std::vector id2state_; - - std::vector id2pdf_id_; - - /// For each transition-id, the corresponding log-prob. Indexed by transition-id. - Vector log_probs_; - - /// For each transition-state, the log of (1 - self-loop-prob). Indexed by - /// transition-state. - Vector non_self_loop_log_probs_; - - /// This is actually one plus the highest-numbered pdf we ever got back from the - /// tree (but the tree numbers pdfs contiguously from zero so this is the number - /// of pdfs). - int32 num_pdfs_; - - KALDI_DISALLOW_COPY_AND_ASSIGN(TransitionModel); -}; - -inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { - KALDI_ASSERT( - static_cast(trans_id) < id2pdf_id_.size() && - "Likely graph/model mismatch (graph built from wrong model?)"); - return id2pdf_id_[trans_id]; -} - -inline int32 TransitionModel::TransitionIdToPdfFast(int32 trans_id) const { - // Note: it's a little dangerous to assert this only in paranoid mode. - // However, this function is called in the inner loop of decoders and - // the assertion likely takes a significant amount of time. We make - // sure that past the end of thd id2pdf_id_ array there are big - // numbers, which will make the calling code more likely to segfault - // (rather than silently die) if this is called for out-of-range values. - KALDI_PARANOID_ASSERT( - static_cast(trans_id) < id2pdf_id_.size() && - "Likely graph/model mismatch (graph built from wrong model?)"); - return id2pdf_id_[trans_id]; -} - -/// Works out which pdfs might correspond to the given phones. Will return true -/// if these pdfs correspond *just* to these phones, false if these pdfs are also -/// used by other phones. -/// @param trans_model [in] Transition-model used to work out this information -/// @param phones [in] A sorted, uniq vector that represents a set of phones -/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond -/// to one of this set of phones. -/// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from -/// just this set (false if they may be shared with phones outside this set). -bool GetPdfsForPhones(const TransitionModel &trans_model, - const std::vector &phones, - std::vector *pdfs); - -/// Works out which phones might correspond to the given pdfs. Similar to the -/// above GetPdfsForPhones(, ,) -bool GetPhonesForPdfs(const TransitionModel &trans_model, - const std::vector &pdfs, - std::vector *phones); -/// @} - - -} // end namespace kaldi - - -#endif diff --git a/src/hmm/transition-model-test.cc b/src/hmm/transitions-test.cc similarity index 87% rename from src/hmm/transition-model-test.cc rename to src/hmm/transitions-test.cc index 841c714efb1..9b9d7099801 100644 --- a/src/hmm/transition-model-test.cc +++ b/src/hmm/transitions-test.cc @@ -17,22 +17,22 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "hmm/hmm-test-utils.h" namespace kaldi { -void TestTransitionModel() { +void TestTransitions() { - TransitionModel *trans_model = GenRandTransitionModel(NULL); + Transitions *trans_model = GenRandTransitionModel(NULL); bool binary = (rand() % 2 == 0); std::ostringstream os; trans_model->Write(os, binary); - TransitionModel trans_model2; + Transitions trans_model2; std::istringstream is2(os.str()); trans_model2.Read(is2, binary); @@ -50,7 +50,7 @@ void TestTransitionModel() { int main() { for (int i = 0; i < 2; i++) - kaldi::TestTransitionModel(); + kaldi::TestTransitions(); KALDI_LOG << "Test OK.\n"; } diff --git a/src/hmm/transition-model.cc b/src/hmm/transitions.cc similarity index 88% rename from src/hmm/transition-model.cc rename to src/hmm/transitions.cc index 5ecb7776f00..4198ea9cd45 100644 --- a/src/hmm/transition-model.cc +++ b/src/hmm/transitions.cc @@ -1,7 +1,9 @@ -// hmm/transition-model.cc +// hmm/transitions.cc -// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +// Copyright 2009-2012 Microsoft Corporation // Johns Hopkins University (author: Guoguo Chen) +// 2012-2019 Johns Hopkins University (Author: Daniel Povey) + // See ../../COPYING for clarification regarding multiple authors // @@ -19,12 +21,17 @@ // limitations under the License. #include -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "tree/context-dep.h" namespace kaldi { -void TransitionModel::ComputeTuples(const ContextDependencyInterface &ctx_dep) { +bool Transitions::operator == (const Transitions &other) { + return topo_ == other.topo_ && info_ == other.info_ && + num_pdfs_ == other.num_pdfs_; +} + +void Transitions::ComputeTuples(const ContextDependencyInterface &ctx_dep) { if (IsHmm()) ComputeTuplesIsHmm(ctx_dep); else @@ -35,7 +42,7 @@ void TransitionModel::ComputeTuples(const ContextDependencyInterface &ctx_dep) { // this sorting defines the transition-ids. } -void TransitionModel::ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep) { +void Transitions::ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep) { const std::vector &phones = topo_.GetPhones(); KALDI_ASSERT(!phones.empty()); @@ -54,7 +61,7 @@ void TransitionModel::ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_d // can correspond to. for (size_t i = 0; i < phones.size(); i++) { // setting up to_hmm_state_list. int32 phone = phones[i]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... int32 pdf_class = entry[j].forward_pdf_class; if (pdf_class != kNoPdf) { @@ -79,7 +86,7 @@ void TransitionModel::ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_d } } -void TransitionModel::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep) { +void Transitions::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep) { const std::vector &phones = topo_.GetPhones(); KALDI_ASSERT(!phones.empty()); @@ -94,7 +101,7 @@ void TransitionModel::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_ pdf_class_pairs.resize(1 + *std::max_element(phones.begin(), phones.end())); for (size_t i = 0; i < phones.size(); i++) { int32 phone = phones[i]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class; if (forward_pdf_class != kNoPdf) @@ -110,7 +117,7 @@ void TransitionModel::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_ // can correspond to. for (size_t i = 0; i < phones.size(); i++) { // setting up to_hmm_state_list. int32 phone = phones[i]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); std::map, std::vector > phone_to_hmm_state_list; for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class; @@ -141,7 +148,7 @@ void TransitionModel::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_ } } -void TransitionModel::ComputeDerived() { +void Transitions::ComputeDerived() { state2id_.resize(tuples_.size()+2); // indexed by transition-state, which // is one based, but also an entry for one past end of list. @@ -158,7 +165,7 @@ void TransitionModel::ComputeDerived() { self_loop_pdf = tuples_[tstate-1].self_loop_pdf; num_pdfs_ = std::max(num_pdfs_, 1 + forward_pdf); num_pdfs_ = std::max(num_pdfs_, 1 + self_loop_pdf); - const HmmTopology::HmmState &state = topo_.TopologyForPhone(phone)[hmm_state]; + const Topology::HmmState &state = topo_.TopologyForPhone(phone)[hmm_state]; int32 my_num_ids = static_cast(state.transitions.size()); cur_transition_id += my_num_ids; // # trans out of this state. } @@ -187,26 +194,26 @@ void TransitionModel::ComputeDerived() { id2pdf_id_.resize(cur_transition_id); } -void TransitionModel::InitializeProbs() { +void Transitions::InitializeProbs() { log_probs_.Resize(NumTransitionIds()+1); // one-based array, zeroth element empty. for (int32 trans_id = 1; trans_id <= NumTransitionIds(); trans_id++) { int32 trans_state = id2state_[trans_id]; int32 trans_index = trans_id - state2id_[trans_state]; const Tuple &tuple = tuples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); BaseFloat prob = entry[tuple.hmm_state].transitions[trans_index].second; if (prob <= 0.0) - KALDI_ERR << "TransitionModel::InitializeProbs, zero " + KALDI_ERR << "Transitions::InitializeProbs, zero " "probability [should remove that entry in the topology]"; if (prob > 1.0) - KALDI_WARN << "TransitionModel::InitializeProbs, prob greater than one."; + KALDI_WARN << "Transitions::InitializeProbs, prob greater than one."; log_probs_(trans_id) = Log(prob); } ComputeDerivedOfProbs(); } -void TransitionModel::Check() const { +void Transitions::Check() const { KALDI_ASSERT(NumTransitionIds() != 0 && NumTransitionStates() != 0); { int32 sum = 0; @@ -228,12 +235,12 @@ void TransitionModel::Check() const { } } -bool TransitionModel::IsHmm() const { +bool Transitions::IsHmm() const { const std::vector &phones = topo_.GetPhones(); KALDI_ASSERT(!phones.empty()); for (size_t i = 0; i < phones.size(); i++) { int32 phone = phones[i]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... if (entry[j].forward_pdf_class != entry[j].self_loop_pdf_class) return false; @@ -242,8 +249,8 @@ bool TransitionModel::IsHmm() const { return true; } -TransitionModel::TransitionModel(const ContextDependencyInterface &ctx_dep, - const HmmTopology &hmm_topo): topo_(hmm_topo) { +Transitions::TransitionModel(const ContextDependencyInterface &ctx_dep, + const Topology &hmm_topo): topo_(hmm_topo) { // First thing is to get all possible tuples. ComputeTuples(ctx_dep); ComputeDerived(); @@ -251,7 +258,7 @@ TransitionModel::TransitionModel(const ContextDependencyInterface &ctx_dep, Check(); } -int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const { +int32 Transitions::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const { Tuple tuple(phone, hmm_state, pdf, self_loop_pdf); // Note: if this ever gets too expensive, which is unlikely, we can refactor // this code to sort first on pdf, and then index on pdf, so those @@ -259,7 +266,7 @@ int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int3 std::vector::const_iterator iter = std::lower_bound(tuples_.begin(), tuples_.end(), tuple); if (iter == tuples_.end() || !(*iter == tuple)) { - KALDI_ERR << "TransitionModel::TupleToTransitionState, tuple not found." + KALDI_ERR << "Transitions::TupleToTransitionState, tuple not found." << " (incompatible tree and model?)"; } // tuples_ is indexed by transition_state-1, so add one. @@ -267,68 +274,68 @@ int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int3 } -int32 TransitionModel::NumTransitionIndices(int32 trans_state) const { +int32 Transitions::NumTransitionIndices(int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return static_cast(state2id_[trans_state+1]-state2id_[trans_state]); } -int32 TransitionModel::TransitionIdToTransitionState(int32 trans_id) const { +int32 Transitions::TransitionIdToTransitionState(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); return id2state_[trans_id]; } -int32 TransitionModel::TransitionIdToTransitionIndex(int32 trans_id) const { +int32 Transitions::TransitionIdToTransitionIndex(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); return trans_id - state2id_[id2state_[trans_id]]; } -int32 TransitionModel::TransitionStateToPhone(int32 trans_state) const { +int32 Transitions::TransitionStateToPhone(int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return tuples_[trans_state-1].phone; } -int32 TransitionModel::TransitionStateToForwardPdf(int32 trans_state) const { +int32 Transitions::TransitionStateToForwardPdf(int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return tuples_[trans_state-1].forward_pdf; } -int32 TransitionModel::TransitionStateToForwardPdfClass( +int32 Transitions::TransitionStateToForwardPdfClass( int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); const Tuple &t = tuples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); return entry[t.hmm_state].forward_pdf_class; } -int32 TransitionModel::TransitionStateToSelfLoopPdfClass( +int32 Transitions::TransitionStateToSelfLoopPdfClass( int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); const Tuple &t = tuples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); return entry[t.hmm_state].self_loop_pdf_class; } -int32 TransitionModel::TransitionStateToSelfLoopPdf(int32 trans_state) const { +int32 Transitions::TransitionStateToSelfLoopPdf(int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return tuples_[trans_state-1].self_loop_pdf; } -int32 TransitionModel::TransitionStateToHmmState(int32 trans_state) const { +int32 Transitions::TransitionStateToHmmState(int32 trans_state) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return tuples_[trans_state-1].hmm_state; } -int32 TransitionModel::PairToTransitionId(int32 trans_state, int32 trans_index) const { +int32 Transitions::PairToTransitionId(int32 trans_state, int32 trans_index) const { KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); KALDI_ASSERT(trans_index < state2id_[trans_state+1] - state2id_[trans_state]); return state2id_[trans_state] + trans_index; } -int32 TransitionModel::NumPhones() const { +int32 Transitions::NumPhones() const { int32 num_trans_state = tuples_.size(); int32 max_phone_id = 0; for (int32 i = 0; i < num_trans_state; ++i) { @@ -339,12 +346,12 @@ int32 TransitionModel::NumPhones() const { } -bool TransitionModel::IsFinal(int32 trans_id) const { +bool Transitions::IsFinal(int32 trans_id) const { KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; int32 trans_index = trans_id - state2id_[trans_state]; const Tuple &tuple = tuples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); KALDI_ASSERT(static_cast(trans_index) < @@ -357,12 +364,12 @@ bool TransitionModel::IsFinal(int32 trans_id) const { -int32 TransitionModel::SelfLoopOf(int32 trans_state) const { // returns the self-loop transition-id, +int32 Transitions::SelfLoopOf(int32 trans_state) const { // returns the self-loop transition-id, KALDI_ASSERT(static_cast(trans_state-1) < tuples_.size()); const Tuple &tuple = tuples_[trans_state-1]; // or zero if does not exist. int32 phone = tuple.phone, hmm_state = tuple.hmm_state; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); KALDI_ASSERT(static_cast(hmm_state) < entry.size()); for (int32 trans_index = 0; trans_index < static_cast(entry[hmm_state].transitions.size()); @@ -372,7 +379,7 @@ int32 TransitionModel::SelfLoopOf(int32 trans_state) const { // returns the sel return 0; // invalid transition id. } -void TransitionModel::ComputeDerivedOfProbs() { +void Transitions::ComputeDerivedOfProbs() { non_self_loop_log_probs_.Resize(NumTransitionStates()+1); // this array indexed // by transition-state with nothing in zeroth element. for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { @@ -391,8 +398,8 @@ void TransitionModel::ComputeDerivedOfProbs() { } } -void TransitionModel::Read(std::istream &is, bool binary) { - ExpectToken(is, binary, ""); +void Transitions::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, ""); topo_.Read(is, binary); std::string token; ReadToken(is, binary, &token); @@ -414,14 +421,14 @@ void TransitionModel::Read(std::istream &is, bool binary) { ExpectToken(is, binary, ""); log_probs_.Read(is, binary); ExpectToken(is, binary, ""); - ExpectToken(is, binary, ""); + ExpectToken(is, binary, ""); ComputeDerivedOfProbs(); Check(); } -void TransitionModel::Write(std::ostream &os, bool binary) const { +void Transitions::Write(std::ostream &os, bool binary) const { bool is_hmm = IsHmm(); - WriteToken(os, binary, ""); + WriteToken(os, binary, ""); if (!binary) os << "\n"; topo_.Write(os, binary); if (is_hmm) @@ -448,31 +455,31 @@ void TransitionModel::Write(std::ostream &os, bool binary) const { log_probs_.Write(os, binary); WriteToken(os, binary, ""); if (!binary) os << "\n"; - WriteToken(os, binary, ""); + WriteToken(os, binary, ""); if (!binary) os << "\n"; } -BaseFloat TransitionModel::GetTransitionProb(int32 trans_id) const { +BaseFloat Transitions::GetTransitionProb(int32 trans_id) const { return Exp(log_probs_(trans_id)); } -BaseFloat TransitionModel::GetTransitionLogProb(int32 trans_id) const { +BaseFloat Transitions::GetTransitionLogProb(int32 trans_id) const { return log_probs_(trans_id); } -BaseFloat TransitionModel::GetNonSelfLoopLogProb(int32 trans_state) const { +BaseFloat Transitions::GetNonSelfLoopLogProb(int32 trans_state) const { KALDI_ASSERT(trans_state != 0); return non_self_loop_log_probs_(trans_state); } -BaseFloat TransitionModel::GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const { +BaseFloat Transitions::GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const { KALDI_ASSERT(trans_id != 0); KALDI_PARANOID_ASSERT(!IsSelfLoop(trans_id)); return log_probs_(trans_id) - GetNonSelfLoopLogProb(TransitionIdToTransitionState(trans_id)); } // stats are counts/weights, indexed by transition-id. -void TransitionModel::MleUpdate(const Vector &stats, +void Transitions::MleUpdate(const Vector &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out) { @@ -525,7 +532,7 @@ void TransitionModel::MleUpdate(const Vector &stats, } } } - KALDI_LOG << "TransitionModel::Update, objf change is " + KALDI_LOG << "Transitions::Update, objf change is " << (objf_impr_sum / count_sum) << " per frame over " << count_sum << " frames. "; KALDI_LOG << num_floored << " probabilities floored, " << num_skipped @@ -538,7 +545,7 @@ void TransitionModel::MleUpdate(const Vector &stats, // stats are counts/weights, indexed by transition-id. -void TransitionModel::MapUpdate(const Vector &stats, +void Transitions::MapUpdate(const Vector &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out) { @@ -596,7 +603,7 @@ void TransitionModel::MapUpdate(const Vector &stats, /// This version of the Update() function is for if the user specifies /// --share-for-pdfs=true. We share the transitions for all states that /// share the same pdf. -void TransitionModel::MleUpdateShared(const Vector &stats, +void Transitions::MleUpdateShared(const Vector &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out) { @@ -695,7 +702,7 @@ void TransitionModel::MleUpdateShared(const Vector &stats, /// This version of the MapUpdate() function is for if the user specifies /// --share-for-pdfs=true. We share the transitions for all states that /// share the same pdf. -void TransitionModel::MapUpdateShared(const Vector &stats, +void Transitions::MapUpdateShared(const Vector &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out) { @@ -782,18 +789,18 @@ void TransitionModel::MapUpdateShared(const Vector &stats, } -int32 TransitionModel::TransitionIdToPhone(int32 trans_id) const { +int32 Transitions::TransitionIdToPhone(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; return tuples_[trans_state-1].phone; } -int32 TransitionModel::TransitionIdToPdfClass(int32 trans_id) const { +int32 Transitions::TransitionIdToPdfClass(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; const Tuple &t = tuples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); if (IsSelfLoop(trans_id)) return entry[t.hmm_state].self_loop_pdf_class; @@ -802,14 +809,14 @@ int32 TransitionModel::TransitionIdToPdfClass(int32 trans_id) const { } -int32 TransitionModel::TransitionIdToHmmState(int32 trans_id) const { +int32 Transitions::TransitionIdToHmmState(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; const Tuple &t = tuples_[trans_state-1]; return t.hmm_state; } -void TransitionModel::Print(std::ostream &os, +void Transitions::Print(std::ostream &os, const std::vector &phone_names, const Vector *occs) { if (occs != NULL) @@ -841,7 +848,7 @@ void TransitionModel::Print(std::ostream &os, if (IsSelfLoop(tid)) os << " [self-loop]\n"; else { int32 hmm_state = tuple.hmm_state; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); KALDI_ASSERT(static_cast(hmm_state) < entry.size()); int32 next_hmm_state = entry[hmm_state].transitions[tidx].first; KALDI_ASSERT(next_hmm_state != hmm_state); @@ -851,7 +858,7 @@ void TransitionModel::Print(std::ostream &os, } } -bool GetPdfsForPhones(const TransitionModel &trans_model, +bool GetPdfsForPhones(const Transitions &trans_model, const std::vector &phones, std::vector *pdfs) { KALDI_ASSERT(IsSortedAndUniq(phones)); @@ -877,7 +884,7 @@ bool GetPdfsForPhones(const TransitionModel &trans_model, return true; } -bool GetPhonesForPdfs(const TransitionModel &trans_model, +bool GetPhonesForPdfs(const Transitions &trans_model, const std::vector &pdfs, std::vector *phones) { KALDI_ASSERT(IsSortedAndUniq(pdfs)); @@ -903,19 +910,19 @@ bool GetPhonesForPdfs(const TransitionModel &trans_model, return true; } -bool TransitionModel::Compatible(const TransitionModel &other) const { +bool Transitions::Compatible(const TransitionModel &other) const { return (topo_ == other.topo_ && tuples_ == other.tuples_ && state2id_ == other.state2id_ && id2state_ == other.id2state_ && num_pdfs_ == other.num_pdfs_); } -bool TransitionModel::IsSelfLoop(int32 trans_id) const { +bool Transitions::IsSelfLoop(int32 trans_id) const { KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; int32 trans_index = trans_id - state2id_[trans_state]; const Tuple &tuple = tuples_[trans_state-1]; int32 phone = tuple.phone, hmm_state = tuple.hmm_state; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + const Topology::TopologyEntry &entry = topo_.TopologyForPhone(phone); KALDI_ASSERT(static_cast(hmm_state) < entry.size()); return (static_cast(trans_index) < entry[hmm_state].transitions.size() && entry[hmm_state].transitions[trans_index].first == hmm_state); diff --git a/src/hmm/transitions.h b/src/hmm/transitions.h new file mode 100644 index 00000000000..b446e4cc6c4 --- /dev/null +++ b/src/hmm/transitions.h @@ -0,0 +1,263 @@ +// hmm/transitions.h + +// Copyright 2009-2012 Microsoft Corporation +// 2015 Guoguo Chen +// 2019 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_TRANSITION_MODEL_H_ +#define KALDI_HMM_TRANSITION_MODEL_H_ + +#include "base/kaldi-common.h" +#include "util/const-integer-set.h" +#include "fst/fst-decl.h" // forward declarations. +#include "hmm/topology.h" +#include "itf/options-itf.h" +#include "itf/context-dep-itf.h" +#include "matrix/kaldi-vector.h" + +namespace kaldi { + + +// The class Transitions handles various integer mappings. +// It used to be the home for the trainable transitions, but these +// no longer exist. This class can be initialized from the +// tree and the topology. +// +// The topology of an individual phone is as defined in topology.h. +// +// This class basically defines the concept of a "transition-id", +// which is a construct that we use in compiled decoding graphs +// to make it easy to look up the 'pdf-id' (think of this as the +// distribution or neural net output column associated with this +// state) and also figure out which phone we are in and which +// arc in that phone. +// +// In the original Kaldi, this object contained trainable transition +// probabilities, but these have been removed to simplify things. +// +// A transition-id maps to a 4-tuple as follows: +// (pdf-id, phone, topo-state, arc-index) +// where 'topo-state' is the state index in the fst::StdFst +// for the topology, and 'arc-index' is the index of +// the arc leaving that state (zero for the first-listed one, +// one for the second, etc.) + + +// List of the various types of quantity referred to here and what they mean: +// phone: a phone index (1, 2, 3 ...) +// topo-state: a state index in the phone-topology FST (see topology.h) +// arc-index: The index of the arc leaving this topo-state: +// 0 for the first-listed one, 1 for the second. Will be used +// to Seek() in the ArcIterator. +// pdf-id: A number output by the Compute() function of ContextDependency (it +// indexes pdf's, either forward or self-loop). Zero-based. +// In DNN-based systems this would be the column index of +// the neural net output. +// (*)self-loop-pdf-id: The pdf-id associated with the self-loop of this state, +// if there is one (we do not allow >1), or -1 if there is no +// self-loop. This will be the same as pdf-id' if this transition +// *is* the self-loop. It might seem odd that we require this +// to get the transition-id for a non-self-loop arc; the reason +// why it's necessary is that we initially create the graph +// without self-loops (for efficiency) and we need to be able +// to look up the corresponding self-loop transition-id to +// add self-loops to the graph. +// +// transition-id: The numbers that we put on the decoding-graph arcs. +// Each transition-id is associated with a 4-tuple +// (pdf-id, phone, topo-state, arc-index). +// + + +class Transitions { + + public: + /// Initialize the object. This is deterministic, so initializing + /// from the same objects will give you an equivalent numbering. + /// The class keeps a copy of the Topology object, but not + /// the ContextDependency object. + Transitions(const ContextDependencyInterface &ctx_dep, + const Topology &topo); + + + /// Constructor that takes no arguments: typically used prior to calling Read. + Transitions(): num_pdfs_(0) { } + + void Read(std::istream &is, bool binary); + void Write(std::ostream &os, bool binary) const; + + // This struct is the information associated with one transition-id. + // You can work out the transition-id from the first 5 fields. + struct TransitionIdInfo { + int32 phone; // The phone + int32 topo_state; // The state in the topology FST for this phone + int32 arc_index; // The arc-index leaving this state + int32 pdf_id; // The pdf-id associated with this arc (obtained from the + // tree and phonetic-context information, etc.) + + int32 self_loop_pdf_id; // The pdf-id associated with the self-loop + // transition (if any) leaving the *destiation* + // state of this arc, or zero if that state has no + // self-loop. Search for (*) above for + // explanation. + + // The remaining fields are 'derived information' that are worked out + // from the information above and from the phone topology, and placed + // here for convenience. + + // is_self_loop is true if this is a self-loop (a transition to the same + // state). We often need to know this, so it's convenient to have this + // information here. + bool is_self_loop; + // is_initial is true if this is a transition leaving the + // initial state. + // you transition through the HMM (we check that the topology has no + // other transitions to the first HMM-state). + bool is_initial; + + // is_final is true if this is a transition entering a final + // state. This is used together with is_initial (and boundary + // information) to locate phone boundaries, e.g. for lattice + // word alignment: an 'is_final' transition-id followed by an + // 'is_initial' transition-id marks a phone boundary, which + // we know because we do not allow the start-state in + // topologies to be final. + bool is_final; + + // transition_cost is the cost (negative log-prob) of this transition). + BaseFloat transition_cost; + // The transition-id associated with the self-loop of the *destination* of + // this arc, if there is one, or 0 if there is no such self-loop. + int32 self_loop_transition_id; + + bool operator < (const TransitionIdInfo &other) const { + if (phone < other.phone) return true; + else if (phone > other.phone) return false; + else if (topo_state < other.topo_state) return true; + else if (topo_state > other.topo_state) return false; + else if (pdf_id < other.pdf_id) return true; + else if (pdf_id > other.pdf_id) return false; + else return (self_loop_pdf_id < other.self_loop_pdf_id); + } + // TODO. operator == can compare all members. + bool operator == (const TransitionIdInfo &other) const; + }; + + + /// return reference to HMM-topology object. + const Topology &GetTopo() const { return topo_; } + + const TransitionIdInfo &InfoForTransitionId(int32 transition_id) const; + + inline int32 TransitionIdToPdfFast(int32 trans_id) const; + + /// This allows you to look up a transition-id. It returns 0 if nothing + /// was found. + int32 TupleToTransitionId(int32 phone, int32 topo_state, int32 arc_index, + int32 pdf_id, int32 self_loop_pdf_id) const; + + + /// Returns the total number of transition-ids (note, these are one-based). + inline int32 NumTransitionIds() const { return info_.size()-1; } + + // NumPdfs() returns the number of pdfs (pdf-ids) in the tree, + // as returned by ctx_dep.NumPdfs() for the tree passed to the constructor. + int32 NumPdfs() const { return num_pdfs_; } + + /// Returns a sorted, unique list of phones. + const std::vector &GetPhones() const { return topo_.GetPhones(); } + + + /// Print will print the transition model in a human-readable way, for purposes of human + /// inspection. The "occs" are optional (they are indexed by pdf-id). + void Print(std::ostream &os, + const std::vector &phone_names, + const Vector *occs = NULL); + + /// returns true if this is identical to 'other' + bool operator == (const Transitions &other); + + private: + + // Called from constructor. initializes info_ (at least, the first 5 + // fields); you then have to call ComputeDerived() to initalize teh rest. + void ComputeInfo(const ContextDependencyInterface &ctx_dep); + + void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. + + void Check() const; + + + Topology topo_; + + /// Information about transition-ids, indexed by transition-id. + /// the tuples are in sorted order which allows us to do the reverse mapping from + /// tuple to transition state + std::vector info_; + + + /// Accessing pdf_ids_[i] allows us to look up info_[i].pdf_id in a way that + /// is more friendly to memory caches than accessing info_; this is done in + /// the inner loops of decoders so it makes sense to optimize for it. + std::vector pdf_ids_; + + /// This is a copy of the NumPdfs() returned by the tree when we constructed + /// this object. Note: pdf-ids are zero-based. + int32 num_pdfs_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(Transitions); +}; + +inline int32 Transitions::TransitionIdToPdfFast(int32 trans_id) const { + // Note: it's a little dangerous to assert this only in paranoid mode. + // However, this function is called in the inner loop of decoders and + // the assertion likely takes a significant amount of time. We make + // sure that past the end of thd id2pdf_id_ array there are big + // numbers, which will make the calling code more likely to segfault + // (rather than silently die) if this is called for out-of-range values. + KALDI_PARANOID_ASSERT( + static_cast(trans_id) < pdf_ids_.size() && + "Likely graph/model mismatch (graph built from wrong model?)"); + return pdf_ids_[trans_id]; +} + +/// Works out which pdfs might correspond to the given phones. Will return true +/// if these pdfs correspond *just* to these phones, false if these pdfs are also +/// used by other phones. +/// @param trans_model [in] Transition-model used to work out this information +/// @param phones [in] A sorted, uniq vector that represents a set of phones +/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond +/// to one of this set of phones. +/// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from +/// just this set (false if they may be shared with phones outside this set). +bool GetPdfsForPhones(const Transitions &trans_model, + const std::vector &phones, + std::vector *pdfs); + +/// Works out which phones might correspond to the given pdfs. Similar to the +/// above GetPdfsForPhones(, ,) +bool GetPhonesForPdfs(const Transitions &trans_model, + const std::vector &pdfs, + std::vector *phones); +/// @} + + +} // end namespace kaldi + + +#endif diff --git a/src/hmm/tree-accu.cc b/src/hmm/tree-accu.cc index c8ce49d9bc7..80041d275e6 100644 --- a/src/hmm/tree-accu.cc +++ b/src/hmm/tree-accu.cc @@ -33,7 +33,7 @@ static int32 MapPhone(const std::vector &phone_map, } -void AccumulateTreeStats(const TransitionModel &trans_model, +void AccumulateTreeStats(const Transitions &trans_model, const AccumulateTreeStatsInfo &info, const std::vector &alignment, const Matrix &features, diff --git a/src/hmm/tree-accu.h b/src/hmm/tree-accu.h index 92e83c535c7..fd3e09567b5 100644 --- a/src/hmm/tree-accu.h +++ b/src/hmm/tree-accu.h @@ -23,7 +23,7 @@ #include // For isspace. #include #include "base/kaldi-common.h" -#include "hmm/transition-model.h" +#include "hmm/transitions.h" #include "tree/clusterable-classes.h" #include "tree/build-tree-questions.h" // needed for this typedef: // typedef std::vector > BuildTreeStatsType; @@ -74,7 +74,7 @@ struct AccumulateTreeStatsInfo { /// "normal" way). It adds to 'stats' the stats obtained from this file. Any /// new GaussClusterable* pointers in "stats" will be allocated with "new". -void AccumulateTreeStats(const TransitionModel &trans_model, +void AccumulateTreeStats(const Transitions &trans_model, const AccumulateTreeStatsInfo &info, const std::vector &alignment, const Matrix &features, diff --git a/src/itf/context-dep-itf.h b/src/itf/context-dep-itf.h index 40681bb5ccd..1fda7b93020 100644 --- a/src/itf/context-dep-itf.h +++ b/src/itf/context-dep-itf.h @@ -62,9 +62,9 @@ class ContextDependencyInterface { /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). - /// c.f. hmm/hmm-topology.h for meaning of pdf-class. + /// c.f. hmm/topology.h for meaning of pdf-class. /// This is the old, simpler interface of GetPdfInfo(), and that this one can - /// only be called if the HmmTopology object's IsHmm() function call returns + /// only be called if the Topology object's IsHmm() function call returns /// true. virtual void GetPdfInfo( const std::vector &phones, // list of phones diff --git a/src/lat/determinize-lattice-pruned.cc b/src/lat/determinize-lattice-pruned.cc index 22eae8199ff..64d8c3fffc0 100644 --- a/src/lat/determinize-lattice-pruned.cc +++ b/src/lat/determinize-lattice-pruned.cc @@ -1290,7 +1290,7 @@ bool DeterminizeLatticePruned(const ExpandedFst > &ifst, template typename ArcTpl::Label DeterminizeLatticeInsertPhones( - const kaldi::TransitionModel &trans_model, + const kaldi::Transitions &trans_model, MutableFst > *fst) { // Define some types. typedef ArcTpl Arc; @@ -1312,32 +1312,28 @@ typename ArcTpl::Label DeterminizeLatticeInsertPhones( !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); - // Note: the words are on the input symbol side and transition-id's are on + // Note: the words are on the input symbol side and transition-ids are on // the output symbol side. - if ((arc.olabel != 0) - && (trans_model.TransitionIdToHmmState(arc.olabel) == 0) - && (!trans_model.IsSelfLoop(arc.olabel))) { - Label phone = - static_cast