-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: begin to add CTC training with kaldi pybind and PyTorch. #3947
Open
csukuangfj
wants to merge
11
commits into
kaldi-asr:pybind11
Choose a base branch
from
mobvoi:fangjun-ctc
base: pybind11
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e6526be
begin to add CTC training with kaldi pybind and PyTorch.
csukuangfj b832be4
add more documentation.
csukuangfj 0148732
add unittest for the convert text to labels program.
csukuangfj ffa861c
add training script.
csukuangfj 5c7fdec
add kaldi's equivalent `add-deltas` to PyTorch.
csukuangfj 150b497
change the implementation of `add-deltas` to be a subclass of nn.Module
csukuangfj 906d57d
remove `permute` and disable padding in add deltas layer.
csukuangfj 676385e
wrap Baidu's warp-ctc to PyTorch.
csukuangfj 35820f2
use only lstm layers.
csukuangfj ae5dfbc
finish the CTC training pipeline.
csukuangfj 9d686a0
replace LSTM with TDNN-F.
csukuangfj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# you can change cmd.sh depending on what type of queue you are using. | ||
# If you have no queueing system and want to run on a local machine, you | ||
# can change all instances 'queue.pl' to run.pl (but be careful and run | ||
# commands one by one: most recipes will exhaust the memory on your | ||
# machine). queue.pl works with GridEngine (qsub). slurm.pl works | ||
# with slurm. Different queues are configured differently, with different | ||
# queue names and different ways of specifying things like memory; | ||
# to account for these differences you can create and edit the file | ||
# conf/queue.conf to match your queue's configuration. Search for | ||
# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, | ||
# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. | ||
|
||
export train_cmd="run.pl" | ||
export decode_cmd="run.pl" | ||
export mkgraph_cmd="run.pl" | ||
export cuda_cmd="run.pl" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
--num-mel-bins=40 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#!/bin/bash | ||
|
||
# Copyright 2017 Xingyu Na | ||
# Apache 2.0 | ||
|
||
. ./path.sh || exit 1; | ||
|
||
if [ $# != 2 ]; then | ||
echo "Usage: $0 <audio-path> <text-path>" | ||
echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript" | ||
exit 1; | ||
fi | ||
|
||
aishell_audio_dir=$1 | ||
aishell_text=$2/aishell_transcript_v0.8.txt | ||
|
||
train_dir=data/local/train | ||
dev_dir=data/local/dev | ||
test_dir=data/local/test | ||
tmp_dir=data/local/tmp | ||
|
||
mkdir -p $train_dir | ||
mkdir -p $dev_dir | ||
mkdir -p $test_dir | ||
mkdir -p $tmp_dir | ||
|
||
# data directory check | ||
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then | ||
echo "Error: $0 requires two directory arguments" | ||
exit 1; | ||
fi | ||
|
||
# find wav audio file for train, dev and test resp. | ||
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist | ||
n=`cat $tmp_dir/wav.flist | wc -l` | ||
[ $n -ne 141925 ] && \ | ||
echo Warning: expected 141925 data data files, found $n | ||
|
||
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; | ||
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; | ||
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; | ||
|
||
rm -r $tmp_dir | ||
|
||
# Transcriptions preparation | ||
for dir in $train_dir $dev_dir $test_dir; do | ||
echo Preparing $dir transcriptions | ||
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list | ||
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{i=NF-1;printf("%s %s\n",$NF,$i)}' > $dir/utt2spk_all | ||
paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all | ||
utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt | ||
awk '{print $1}' $dir/transcripts.txt > $dir/utt.list | ||
utils/filter_scp.pl -f 1 $dir/utt.list $dir/utt2spk_all | sort -u > $dir/utt2spk | ||
utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp | ||
sort -u $dir/transcripts.txt > $dir/text | ||
utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt | ||
done | ||
|
||
mkdir -p data/train data/dev data/test | ||
|
||
for f in spk2utt utt2spk wav.scp text; do | ||
cp $train_dir/$f data/train/$f || exit 1; | ||
cp $dev_dir/$f data/dev/$f || exit 1; | ||
cp $test_dir/$f data/test/$f || exit 1; | ||
done | ||
|
||
echo "$0: AISHELL data preparation succeeded" | ||
exit 0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/bin/bash | ||
|
||
# Copyright 2017 Xingyu Na | ||
# Apache 2.0 | ||
|
||
# prepare dict resources | ||
|
||
. ./path.sh | ||
|
||
[ $# != 1 ] && echo "Usage: $0 <resource-path>" && exit 1; | ||
|
||
res_dir=$1 | ||
dict_dir=data/local/dict | ||
mkdir -p $dict_dir | ||
cp $res_dir/lexicon.txt $dict_dir | ||
|
||
cat $dict_dir/lexicon.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}'| \ | ||
perl -e 'while(<>){ chomp($_); $phone = $_; next if ($phone eq "sil"); | ||
m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$1} .= "$phone "; } | ||
foreach $l (values %q) {print "$l\n";} | ||
' | sort -k1 > $dict_dir/nonsilence_phones.txt || exit 1; | ||
|
||
echo sil > $dict_dir/silence_phones.txt | ||
|
||
echo sil > $dict_dir/optional_silence.txt | ||
|
||
# No "extra questions" in the input to this setup, as we don't | ||
# have stress or tone | ||
|
||
cat $dict_dir/silence_phones.txt| awk '{printf("%s ", $1);} END{printf "\n";}' > $dict_dir/extra_questions.txt || exit 1; | ||
cat $dict_dir/nonsilence_phones.txt | perl -e 'while(<>){ foreach $p (split(" ", $_)) { | ||
$p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ | ||
>> $dict_dir/extra_questions.txt || exit 1; | ||
|
||
echo "$0: AISHELL dict preparation succeeded" | ||
exit 0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/bin/bash | ||
|
||
|
||
# To be run from one directory above this script. | ||
. ./path.sh | ||
|
||
text=data/local/train/text | ||
lexicon=data/local/dict/lexicon.txt | ||
|
||
for f in "$text" "$lexicon"; do | ||
[ ! -f $x ] && echo "$0: No such file $f" && exit 1; | ||
done | ||
|
||
# This script takes no arguments. It assumes you have already run | ||
# aishell_data_prep.sh. | ||
# It takes as input the files | ||
# data/local/train/text | ||
# data/local/dict/lexicon.txt | ||
dir=data/local/lm | ||
mkdir -p $dir | ||
|
||
kaldi_lm=`which train_lm.sh` | ||
if [ -z $kaldi_lm ]; then | ||
echo "$0: train_lm.sh is not found. That might mean it's not installed" | ||
echo "$0: or it is not added to PATH" | ||
echo "$0: Use the script tools/extras/install_kaldi_lm.sh to install it" | ||
exit 1 | ||
fi | ||
|
||
cleantext=$dir/text.no_oov | ||
|
||
cat $text | awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } } | ||
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \ | ||
> $cleantext || exit 1; | ||
|
||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ | ||
sort -nr > $dir/word.counts || exit 1; | ||
|
||
# Get counts from acoustic training transcripts, and add one-count | ||
# for each word in the lexicon (but not silence, we don't want it | ||
# in the LM-- we'll add it optionally later). | ||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ | ||
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ | ||
sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; | ||
|
||
# note: we probably won't really make use of <SPOKEN_NOISE> as there aren't any OOVs | ||
cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "<s>" "</s>" "<SPOKEN_NOISE>" > $dir/word_map \ | ||
|| exit 1; | ||
|
||
# note: ignore 1st field of train.txt, it's the utterance-id. | ||
cat $cleantext | awk -v wmap=$dir/word_map 'BEGIN{while((getline<wmap)>0)map[$1]=$2;} | ||
{ for(n=2;n<=NF;n++) { printf map[$n]; if(n<NF){ printf " "; } else { print ""; }}}' | gzip -c >$dir/train.gz \ | ||
|| exit 1; | ||
|
||
train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1; | ||
|
||
# LM is small enough that we don't need to prune it (only about 0.7M N-grams). | ||
# Perplexity over 128254.000000 words is 90.446690 | ||
|
||
# note: output is | ||
# data/local/lm/3gram-mincount/lm_unpruned.gz | ||
|
||
exit 0 | ||
|
||
|
||
# From here is some commands to do a baseline with SRILM (assuming | ||
# you have it installed). | ||
heldout_sent=10000 # Don't change this if you want result to be comparable with | ||
# kaldi_lm results | ||
sdir=$dir/srilm # in case we want to use SRILM to double-check perplexities. | ||
mkdir -p $sdir | ||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \ | ||
head -$heldout_sent > $sdir/heldout | ||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \ | ||
tail -n +$heldout_sent > $sdir/train | ||
|
||
cat $dir/word_map | awk '{print $1}' | cat - <(echo "<s>"; echo "</s>" ) > $sdir/wordlist | ||
|
||
|
||
ngram-count -text $sdir/train -order 3 -limit-vocab -vocab $sdir/wordlist -unk \ | ||
-map-unk "<SPOKEN_NOISE>" -kndiscount -interpolate -lm $sdir/srilm.o3g.kn.gz | ||
ngram -lm $sdir/srilm.o3g.kn.gz -ppl $sdir/heldout | ||
# 0 zeroprobs, logprob= -250954 ppl= 90.5091 ppl1= 132.482 | ||
|
||
# Note: perplexity SRILM gives to Kaldi-LM model is same as kaldi-lm reports above. | ||
# Difference in WSJ must have been due to different treatment of <SPOKEN_NOISE>. | ||
ngram -lm $dir/3gram-mincount/lm_unpruned.gz -ppl $sdir/heldout | ||
# 0 zeroprobs, logprob= -250913 ppl= 90.4439 ppl1= 132.379 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
import argparse | ||
import os | ||
|
||
import kaldi | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description='convert text to labels') | ||
|
||
parser.add_argument('--lexicon-filename', dest='lexicon_filename', type=str) | ||
parser.add_argument('--tokens-filename', dest='tokens_filename', type=str) | ||
parser.add_argument('--dir', help='input/output dir', type=str) | ||
|
||
args = parser.parse_args() | ||
|
||
assert os.path.isfile(args.lexicon_filename) | ||
assert os.path.isfile(args.tokens_filename) | ||
assert os.path.isfile(os.path.join(args.dir, 'text')) | ||
csukuangfj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return args | ||
|
||
|
||
def read_lexicon(filename): | ||
''' | ||
Returns: | ||
a dict whose keys are words and values are phones. | ||
''' | ||
lexicon = dict() | ||
with open(filename, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
word_phones = line.split() | ||
assert len(word_phones) >= 2 | ||
|
||
word = word_phones[0] | ||
phones = word_phones[1:] | ||
|
||
if word not in lexicon: | ||
# if there are multiple pronunciations for a word, | ||
# we choose only the first one and drop other alternatives | ||
lexicon[word] = phones | ||
|
||
return lexicon | ||
|
||
|
||
def read_tokens(filename): | ||
''' | ||
Returns: | ||
a dict whose keys are phones and values are phone indices | ||
csukuangfj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
''' | ||
tokens = dict() | ||
with open(filename, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
phone_index = line.split() | ||
assert len(phone_index) == 2 | ||
|
||
phone = phone_index[0] | ||
index = int(phone_index[1]) | ||
|
||
if phone == '<eps>': | ||
continue | ||
|
||
# decreased by one since we removed <eps> above | ||
index -= 1 | ||
|
||
assert phone not in tokens | ||
|
||
tokens[phone] = index | ||
|
||
assert '<blk>' in tokens | ||
|
||
# WARNING(fangjun): we assume that the blank symbol has index 0 | ||
# in the neural network output. | ||
# Do NOT confuse it with `<eps>` in fst. | ||
assert tokens['<blk>'] == 0 | ||
|
||
return tokens | ||
|
||
|
||
def read_text(filename): | ||
''' | ||
Returns: | ||
a dict whose keys are utterance IDs and values are texts | ||
''' | ||
transcript = dict() | ||
|
||
with open(filename, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
utt_text = line.split() | ||
assert len(utt_text) >= 2 | ||
|
||
utt = utt_text[0] | ||
text = utt_text[1:] | ||
|
||
assert utt not in transcript | ||
transcript[utt] = text | ||
|
||
return transcript | ||
|
||
|
||
def phones_to_indices(phone_list, tokens): | ||
index_list = [] | ||
|
||
for phone in phone_list: | ||
assert phone in tokens | ||
|
||
index = tokens[phone] | ||
index_list.append(index) | ||
|
||
return index_list | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
|
||
lexicon = read_lexicon(args.lexicon_filename) | ||
|
||
tokens = read_tokens(args.tokens_filename) | ||
|
||
transcript = read_text(os.path.join(args.dir, 'text')) | ||
|
||
transcript_labels = dict() | ||
|
||
for utt, text in transcript.items(): | ||
labels = [] | ||
for t in text: | ||
# TODO(fangjun): add support for OOV. | ||
phones = lexicon[t] | ||
|
||
indices = phones_to_indices(phones, tokens) | ||
|
||
labels.extend(indices) | ||
|
||
assert utt not in transcript_labels | ||
|
||
transcript_labels[utt] = labels | ||
|
||
wspecifier = 'ark,scp:{dir}/labels.ark,{dir}/labels.scp'.format( | ||
dir=args.dir) | ||
|
||
writer = kaldi.IntVectorWriter(wspecifier) | ||
|
||
for utt, labels in transcript_labels.items(): | ||
writer.Write(utt, labels) | ||
|
||
writer.Close() | ||
|
||
print('Generated label file {}/labels.scp successfully'.format(args.dir)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the standard OpenFST symbol-table format for these tokens.
I'm open to other opinions, but since we'll probably have these symbols present in FSTs I think symbol 0 should be reserved for and should be 1, and we can just apply an offset of 1 when interpreting the nnet outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... if the format is already the symbol-table format, bear in mind that the order of lines is actually arbitrary;what matters is the integer there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reuse the notation from EESEN (https://github.com/srvk/eesen), which calls
phones.txt
astokens.txt
.tokens.txt
is acutally a phone symbol table, withThe code here does not pose any constraint on the order of lines. What
matters here is only the integer of symbols. The first two integers
0
and1
are reserved. I think
0
is reserved for<eps>
. Here I reserve1
forthe blank symbol.
The script generating
tokens.txt
has considered the above constraint.Since there is a
T
inTLG.fst
, I keep usingtokens.txt
here insteadof
phones.txt
. I can switch tophones.txt
if you think that is more naturalin kaldi.