From ffa861cc04c0ab1839acc557200b3c364fdcadc4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 21 Feb 2020 22:02:33 +0800 Subject: [PATCH] add training script. --- egs/aishell/s10b/ctc/common.py | 28 ++++ egs/aishell/s10b/ctc/dataset.py | 200 +++++++++++++++++++++++++ egs/aishell/s10b/ctc/model.py | 145 ++++++++++++++++++ egs/aishell/s10b/ctc/options.py | 162 ++++++++++++++++++++ egs/aishell/s10b/ctc/train.py | 107 +++++++++++++ egs/aishell/s10b/local/generate_tlg.sh | 2 +- egs/aishell/s10b/local/run_ctc.sh | 59 ++++++++ egs/aishell/s10b/run.sh | 10 +- 8 files changed, 711 insertions(+), 2 deletions(-) create mode 100644 egs/aishell/s10b/ctc/common.py create mode 100644 egs/aishell/s10b/ctc/dataset.py create mode 100644 egs/aishell/s10b/ctc/model.py create mode 100644 egs/aishell/s10b/ctc/options.py create mode 100644 egs/aishell/s10b/ctc/train.py create mode 100755 egs/aishell/s10b/local/run_ctc.sh diff --git a/egs/aishell/s10b/ctc/common.py b/egs/aishell/s10b/ctc/common.py new file mode 100644 index 00000000000..22915cfe408 --- /dev/null +++ b/egs/aishell/s10b/ctc/common.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +from datetime import datetime +import logging + + +def setup_logger(log_filename, log_level='info'): + now = datetime.now() + date_time = now.strftime('%Y-%m-%d-%H-%M-%S') + log_filename = '{}-{}'.format(log_filename, date_time) + formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s' + if log_level == 'debug': + level = logging.DEBUG + elif log_level == 'info': + level = logging.INFO + elif log_level == 'warning': + level = logging.WARNING + logging.basicConfig(filename=log_filename, + format=formatter, + level=level, + filemode='w') + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger('').addHandler(console) diff --git a/egs/aishell/s10b/ctc/dataset.py b/egs/aishell/s10b/ctc/dataset.py new file mode 100644 index 00000000000..acbe4be9d3e --- /dev/null +++ b/egs/aishell/s10b/ctc/dataset.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import logging + +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +import kaldi + + +def get_ctc_dataloader(feats_scp, + labels_scp=None, + batch_size=1, + shuffle=False, + num_workers=0): + + dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) + + collate_fn = CtcDatasetCollateFunc() + + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn) + + return dataloader + + +class CtcDataset(Dataset): + + def __init__(self, feats_scp, labels_scp=None): + ''' + Args: + feats_scp: filename for feats.scp + labels_scp: if provided, it is the filename of labels.scp + ''' + assert os.path.isfile(feats_scp) + if labels_scp: + assert os.path.isfile(labels_scp) + logging.info('labels scp: {}'.format(labels_scp)) + else: + logging.warn('No labels scp is given.') + + # items is a dict of [uttid, feat_rxfilename, None] + # or [uttid, feat_rxfilename, label_rxfilename] if labels_scp is not None + items = dict() + + with open(feats_scp, 'r') as f: + for line in f: + # every line has the following format: + # uttid feat_rxfilename + uttid_rxfilename = line.split() + assert len(uttid_rxfilename) == 2 + + uttid, rxfilename = uttid_rxfilename + + assert uttid not in items + + items[uttid] = [uttid, rxfilename, None] + + if labels_scp: + expected_count = len(items) + n = 0 + with open(labels_scp, 'r') as f: + for line in f: + # every line has the following format: + # uttid rxfilename + uttid_rxfilename = line.split() + + assert len(uttid_rxfilename) == 2 + + uttid, rxfilename = uttid_rxfilename + + assert uttid in items + + items[uttid][-1] = rxfilename + + n += 1 + + # every utterance should have a label if + # labels_scp is given + assert n == expected_count + + self.items = list(items.values()) + self.num_items = len(self.items) + self.feats_scp = feats_scp + self.labels_scp = labels_scp + + def __len__(self): + return self.num_items + + def __getitem__(self, i): + ''' + Returns: + a list [key, feat_rxfilename, label_rxfilename] + Note that label_rxfilename may be None. + ''' + return self.items[i] + + def __str__(self): + s = 'feats scp: {}\n'.format(self.feats_scp) + + if self.labels_scp: + s += 'labels scp: {}\n'.format(self.labels_scp) + + s += 'num utterances: {}\n'.format(self.num_items) + + return s + + +class CtcDatasetCollateFunc: + + def __call__(self, batch): + ''' + Args: + batch: a list of [uttid, feat_rxfilename, label_rxfilename]. + Note that label_rxfilename may be None. + + Returns: + uttid_list: a list of utterance id + + feat: a 3-D float tensor of shape [batch_size, seq_len, feat_dim] + + feat_len_list: number of frames of each utterance before padding + + label_list: a list of labels of each utterance; It may be None. + + label_len_list: label length of each utterance; It is None if label_list is None. + ''' + uttid_list = [] # utterance id of each utterance + feat_len_list = [] # number of frames of each utterance + label_list = [] # label of each utterance + label_len_list = [] # label length of each utterance + + feat_list = [] + + for b in batch: + uttid, feat_rxfilename, label_rxfilename = b + + uttid_list.append(uttid) + + feat = kaldi.read_mat(feat_rxfilename).numpy() + feat = torch.from_numpy(feat).float() + feat_list.append(feat) + + feat_len_list.append(feat.size(0)) + + if label_rxfilename: + label = kaldi.read_vec_int(label_rxfilename) + label_list.append(label) + label_len_list.append(len(label)) + + feat = pad_sequence(feat_list, batch_first=True) + + if not label_list: + label_list = None + label_len_list = None + + return uttid_list, feat, feat_len_list, label_list, label_len_list + + +def _test_dataset(): + feats_scp = 'data/train_sp/feats.scp' + labels_scp = 'data/train_sp/labels.scp' + + dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) + + print(dataset) + + +def _test_dataloader(): + feats_scp = 'data/test/feats.scp' + labels_scp = 'data/test/labels.scp' + + dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) + + dataloader = DataLoader(dataset, + batch_size=2, + num_workers=10, + shuffle=True, + collate_fn=CtcDatasetCollateFunc()) + i = 0 + for batch in dataloader: + uttid_list, feat, feat_len_list, label_list, label_len_list = batch + print(uttid_list, feat.shape, feat_len_list, label_len_list) + i += 1 + if i > 10: + break + + +if __name__ == '__main__': + # _test_dataset() + _test_dataloader() diff --git a/egs/aishell/s10b/ctc/model.py b/egs/aishell/s10b/ctc/model.py new file mode 100644 index 00000000000..dfc1efed6e5 --- /dev/null +++ b/egs/aishell/s10b/ctc/model.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 + +# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + + +def get_ctc_model(input_dim, + output_dim, + num_layers=4, + hidden_dim=512, + proj_dim=256): + model = CtcModel(input_dim=input_dim, + output_dim=output_dim, + num_layers=num_layers, + hidden_dim=hidden_dim, + proj_dim=proj_dim) + + return model + + +class CtcModel(nn.Module): + + def __init__(self, input_dim, output_dim, num_layers, hidden_dim, proj_dim): + ''' + Args: + input_dim: input dimension of the network + + output_dim: output dimension of the network + + num_layers: number of LSTM layers of the network + + hidden_dim: the dimension of the hidden state of LSTM layers + + proj_dim: dimension of the affine layer after every LSTM layer + ''' + super().__init__() + + lstm_layer_list = [] + proj_layer_list = [] + + # batchnorm requires input of shape [N, C, L] == [batch_size, dim, seq_len] + self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim, + affine=False) + + for i in range(num_layers): + if i == 0: + lstm_input_dim = input_dim + else: + lstm_input_dim = proj_dim + + lstm_layer = nn.LSTM(input_size=lstm_input_dim, + hidden_size=hidden_dim, + num_layers=1, + batch_first=True) + + proj_layer = nn.Linear(in_features=hidden_dim, + out_features=proj_dim) + + lstm_layer_list.append(lstm_layer) + proj_layer_list.append(proj_layer) + + self.lstm_layer_list = nn.ModuleList(lstm_layer_list) + self.proj_layer_list = nn.ModuleList(proj_layer_list) + + self.num_layers = num_layers + + self.prefinal_affine = nn.Linear(in_features=proj_dim, + out_features=output_dim) + + def forward(self, feat, feat_len_list): + ''' + Args: + feat: a 3-D tensor of shape [batch_size, seq_len, feat_dim] + feat_len_list: feat length of each utterance before padding + + Returns: + a 3-D tensor of shape [batch_size, seq_len, output_dim] + representing log prob, i.e., the output of log_softmax. + ''' + x = feat + + # at his point, x is of shape [batch_size, seq_len, feat_dim] + x = x.permute(0, 2, 1) + + # at his point, x is of shape [batch_size, feat_dim, seq_len] == [N, C, L] + x = self.input_batch_norm(x) + + x = x.permute(0, 2, 1) + + # at his point, x is of shape [batch_size, seq_len, feat_dim] == [N, L, C] + + for i in range(self.num_layers): + x = pack_padded_sequence(input=x, + lengths=feat_len_list, + batch_first=True, + enforce_sorted=False) + + # TODO(fangjun): save intermediate LSTM state to support streaming inference + x, _ = self.lstm_layer_list[i](x) + + x, _ = pad_packed_sequence(x, batch_first=True) + + x = self.proj_layer_list[i](x) + + x = torch.tanh(x) + + x = self.prefinal_affine(x) + + x = F.log_softmax(x, dim=-1) + + return x + + +def _test_ctc_model(): + input_dim = 5 + output_dim = 20 + model = CtcModel(input_dim=input_dim, + output_dim=output_dim, + num_layers=2, + hidden_dim=3, + proj_dim=4) + + feat1 = torch.randn((6, input_dim)) + feat2 = torch.randn((8, input_dim)) + + from torch.nn.utils.rnn import pad_sequence + feat = pad_sequence([feat1, feat2], batch_first=True) + assert feat.shape == torch.Size([2, 8, input_dim]) + + feat_len_list = [6, 8] + x = model(feat, feat_len_list) + + assert x.shape == torch.Size([2, 8, output_dim]) + + +if __name__ == '__main__': + _test_ctc_model() diff --git a/egs/aishell/s10b/ctc/options.py b/egs/aishell/s10b/ctc/options.py new file mode 100644 index 00000000000..117cd7a00fa --- /dev/null +++ b/egs/aishell/s10b/ctc/options.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import argparse +import os + + +def _str2bool(v): + ''' + This function is modified from + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + ''' + if isinstance(v, bool): + return v + elif v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def _set_training_args(parser): + parser.add_argument('--train.labels-scp', + dest='labels_scp', + help='filename of labels.scp', + type=str) + + parser.add_argument('--train.num-epochs', + dest='num_epochs', + help='number of epochs to train', + type=int) + + parser.add_argument('--train.lr', + dest='learning_rate', + help='learning rate', + type=float) + + parser.add_argument('--train.l2-regularize', + dest='l2_regularize', + help='l2 regularize', + type=float) + + # TODO(fangjun): add validation feats_scp + + +def _check_training_args(args): + assert os.path.isfile(args.labels_scp) + + assert args.num_epochs > 0 + assert args.learning_rate > 0 + assert args.l2_regularize >= 0 + + if args.checkpoint: + assert os.path.exists(args.checkpoint) + + +def _check_args(args): + if args.is_training: + _check_training_args(args) + + assert os.path.isdir(args.dir) + assert os.path.isfile(args.feats_scp) + + assert args.batch_size > 0 + assert args.device_id >= 0 + + assert args.input_dim > 0 + assert args.output_dim > 0 + assert args.num_layers > 0 + assert args.hidden_dim > 0 + assert args.proj_dim > 0 + + assert args.log_level in ['debug', 'info', 'warning'] + + +def get_args(): + parser = argparse.ArgumentParser( + description='chain training in PyTorch with kaldi pybind') + + _set_training_args(parser) + + parser.add_argument('--is-training', + dest='is_training', + help='true for training, false for inference', + required=True, + type=_str2bool) + + parser.add_argument('--dir', + help='dir to save results. The user has to ' + 'create it before calling this script.', + required=True, + type=str) + + parser.add_argument('--feats-scp', + dest='feats_scp', + help='filename of feats.scp', + required=True, + type=str) + + parser.add_argument('--device-id', + dest='device_id', + help='GPU device id', + required=True, + type=int) + + parser.add_argument('--batch-size', + dest='batch_size', + help='batch size used in training and inference', + required=True, + type=int) + + parser.add_argument('--input-dim', + dest='input_dim', + help='input dimension of the network', + required=True, + type=int) + + parser.add_argument('--output-dim', + dest='output_dim', + help='output dimension of the network', + required=True, + type=int) + + parser.add_argument('--num-layers', + dest='num_layers', + help="number of LSTM layers in the network", + required=True, + type=int) + + parser.add_argument('--hidden-dim', + dest='hidden_dim', + help="dimension of the LSTM cell state", + required=True, + type=int) + + parser.add_argument( + '--proj-dim', + dest='proj_dim', + help="dimension of the affine layer after every LSTM layer", + required=True, + type=int) + + parser.add_argument('--log-level', + dest='log_level', + help='log level. valid values: debug, info, warning', + type=str, + default='info') + + parser.add_argument( + '--checkpoint', + dest='checkpoint', + help='filename of the checkpoint, required for inference', + type=str) + + args = parser.parse_args() + + _check_args(args) + + return args diff --git a/egs/aishell/s10b/ctc/train.py b/egs/aishell/s10b/ctc/train.py new file mode 100644 index 00000000000..bec4e363cf9 --- /dev/null +++ b/egs/aishell/s10b/ctc/train.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import logging +import os +import sys +import warnings + +# disable warnings when loading tensorboard +warnings.simplefilter(action='ignore', category=FutureWarning) + +import torch +import torch.optim as optim +from torch.nn.utils import clip_grad_value_ +import torch.nn.functional as F + +import kaldi + +from options import get_args +from common import setup_logger +from model import get_ctc_model +from dataset import get_ctc_dataloader + + +def main(): + args = get_args() + setup_logger('{}/log-train'.format(args.dir), args.log_level) + logging.info(' '.join(sys.argv)) + + if torch.cuda.is_available() == False: + logging.error('No GPU detected!') + sys.exit(-1) + + device = torch.device('cuda', args.device_id) + + model = get_ctc_model(input_dim=args.input_dim, + output_dim=args.output_dim, + num_layers=args.num_layers, + hidden_dim=args.hidden_dim, + proj_dim=args.proj_dim) + + model.to(device) + + dataloader = get_ctc_dataloader(feats_scp=args.feats_scp, + labels_scp=args.labels_scp, + batch_size=args.batch_size, + shuffle=True, + num_workers=8) + + lr = args.learning_rate + optimizer = optim.Adam(model.parameters(), + lr=lr, + weight_decay=args.l2_regularize) + + model.train() + + for epoch in range(args.num_epochs): + learning_rate = lr * pow(0.4, epoch) + + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + logging.info('epoch {}, learning rate {}'.format(epoch, learning_rate)) + + for batch_idx, batch in enumerate(dataloader): + uttidlist, feat, feat_len_list, label_list, label_len_list = batch + + feat = feat.to(device) + log_probs = model(feat, feat_len_list) + + # at this point log_probs is of shape: [batch_size, seq_len, output_dim] + # CTCLoss requires a layout: [seq_len, batch_size, output_dim] + + log_probs = log_probs.permute(1, 0, 2) + # now log_probs is of shape [seq_len, batch_size, output_dim] + + label_tensor_list = [torch.tensor(x) for x in label_list] + + targets = torch.cat(label_tensor_list).to(device) + + input_lengths = torch.tensor(feat_len_list).to(device) + + target_lengths = torch.tensor(label_len_list).to(device) + + loss = F.ctc_loss(log_probs=log_probs, + targets=targets, + input_lengths=input_lengths, + target_lengths=target_lengths, + blank=0, + reduction='mean') + + optimizer.zero_grad() + + loss.backward() + + # clip_grad_value_(model.parameters(), 5.0) + + optimizer.step() + + logging.info('batch {}, loss {}'.format(batch_idx, loss.item())) + + +if __name__ == '__main__': + torch.manual_seed(20200221) + main() diff --git a/egs/aishell/s10b/local/generate_tlg.sh b/egs/aishell/s10b/local/generate_tlg.sh index 6a63ebba4c5..1435bf69533 100755 --- a/egs/aishell/s10b/local/generate_tlg.sh +++ b/egs/aishell/s10b/local/generate_tlg.sh @@ -141,6 +141,6 @@ set -e # remove files not needed any more for f in G.fst L.fst T.fst LG.fst disambig.list \ - lexiconp.txt lexiconp_disambig.txt phones.list; do + lexiconp.txt lexiconp_disambig.txt; do rm $dir/$f done diff --git a/egs/aishell/s10b/local/run_ctc.sh b/egs/aishell/s10b/local/run_ctc.sh new file mode 100755 index 00000000000..f13f738c90c --- /dev/null +++ b/egs/aishell/s10b/local/run_ctc.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +set -e + +echo "$0 $@" # Print the command line for logging + +stage=0 + +device_id=1 + +train_data_dir=data/train_sp +dev_data_dir=data/dev_sp +test_data_dir=data/test +lang_dir=data/lang + +lr=1e-4 +num_epochs=6 +l2_regularize=1e-5 +num_layers=4 +hidden_dim=512 +proj_dim=200 +batch_size=64 + + +dir=exp/ctc + +. ./path.sh +. ./cmd.sh + +. parse_options.sh + +feat_dim=$(feat-to-dim --print-args=false scp:$train_data_dir/feats.scp -) +output_dim=$(cat $lang_dir/phones.list | wc -l) +# added by one since we have an extra blank symbol +output_dim=$[$output_dim+1] + +if [[ $stage -le 0 ]]; then + mkdir -p $dir + + # sort options alphabetically + python3 ./ctc/train.py \ + --batch-size $batch_size \ + --device-id $device_id \ + --dir=$dir \ + --feats-scp $train_data_dir/feats.scp \ + --hidden-dim $hidden_dim \ + --input-dim $feat_dim \ + --is-training true \ + --num-layers $num_layers \ + --output-dim $output_dim \ + --proj-dim $proj_dim \ + --train.l2-regularize $l2_regularize \ + --train.labels-scp $train_data_dir/labels.scp \ + --train.lr $lr \ + --train.num-epochs $num_epochs +fi diff --git a/egs/aishell/s10b/run.sh b/egs/aishell/s10b/run.sh index 6d5c9e421ea..f780be75ea9 100755 --- a/egs/aishell/s10b/run.sh +++ b/egs/aishell/s10b/run.sh @@ -13,7 +13,7 @@ data_url=www.openslr.org/resources/33 nj=30 -stage=6 +stage=0 if [[ $stage -le 0 ]]; then local/download_and_untar.sh $data $data_url data_aishell || exit 1 @@ -63,3 +63,11 @@ if [[ $stage -le 6 ]]; then ./local/convert_text_to_labels.sh data/$x data/lang done fi + +if [[ $stage -le 7 ]]; then + ./local/run_ctc.sh \ + --train-data-dir data/train_sp \ + --dev-data-dir data/dev_sp \ + --test-data-dir data/test \ + --lang-dir data/lang +fi