Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add training script.
Browse files Browse the repository at this point in the history
csukuangfj committed Feb 21, 2020

Verified

This commit was signed with the committer’s verified signature.
magnetised Garry Hill
1 parent 0148732 commit ffa861c
Showing 8 changed files with 711 additions and 2 deletions.
28 changes: 28 additions & 0 deletions egs/aishell/s10b/ctc/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

from datetime import datetime
import logging


def setup_logger(log_filename, log_level='info'):
now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
log_filename = '{}-{}'.format(log_filename, date_time)
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
if log_level == 'debug':
level = logging.DEBUG
elif log_level == 'info':
level = logging.INFO
elif log_level == 'warning':
level = logging.WARNING
logging.basicConfig(filename=log_filename,
format=formatter,
level=level,
filemode='w')
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger('').addHandler(console)
200 changes: 200 additions & 0 deletions egs/aishell/s10b/ctc/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import os
import logging

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import kaldi


def get_ctc_dataloader(feats_scp,
labels_scp=None,
batch_size=1,
shuffle=False,
num_workers=0):

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

collate_fn = CtcDatasetCollateFunc()

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn)

return dataloader


class CtcDataset(Dataset):

def __init__(self, feats_scp, labels_scp=None):
'''
Args:
feats_scp: filename for feats.scp
labels_scp: if provided, it is the filename of labels.scp
'''
assert os.path.isfile(feats_scp)
if labels_scp:
assert os.path.isfile(labels_scp)
logging.info('labels scp: {}'.format(labels_scp))
else:
logging.warn('No labels scp is given.')

# items is a dict of [uttid, feat_rxfilename, None]
# or [uttid, feat_rxfilename, label_rxfilename] if labels_scp is not None
items = dict()

with open(feats_scp, 'r') as f:
for line in f:
# every line has the following format:
# uttid feat_rxfilename
uttid_rxfilename = line.split()
assert len(uttid_rxfilename) == 2

uttid, rxfilename = uttid_rxfilename

assert uttid not in items

items[uttid] = [uttid, rxfilename, None]

if labels_scp:
expected_count = len(items)
n = 0
with open(labels_scp, 'r') as f:
for line in f:
# every line has the following format:
# uttid rxfilename
uttid_rxfilename = line.split()

assert len(uttid_rxfilename) == 2

uttid, rxfilename = uttid_rxfilename

assert uttid in items

items[uttid][-1] = rxfilename

n += 1

# every utterance should have a label if
# labels_scp is given
assert n == expected_count

self.items = list(items.values())
self.num_items = len(self.items)
self.feats_scp = feats_scp
self.labels_scp = labels_scp

def __len__(self):
return self.num_items

def __getitem__(self, i):
'''
Returns:
a list [key, feat_rxfilename, label_rxfilename]
Note that label_rxfilename may be None.
'''
return self.items[i]

def __str__(self):
s = 'feats scp: {}\n'.format(self.feats_scp)

if self.labels_scp:
s += 'labels scp: {}\n'.format(self.labels_scp)

s += 'num utterances: {}\n'.format(self.num_items)

return s


class CtcDatasetCollateFunc:

def __call__(self, batch):
'''
Args:
batch: a list of [uttid, feat_rxfilename, label_rxfilename].
Note that label_rxfilename may be None.
Returns:
uttid_list: a list of utterance id
feat: a 3-D float tensor of shape [batch_size, seq_len, feat_dim]
feat_len_list: number of frames of each utterance before padding
label_list: a list of labels of each utterance; It may be None.
label_len_list: label length of each utterance; It is None if label_list is None.
'''
uttid_list = [] # utterance id of each utterance
feat_len_list = [] # number of frames of each utterance
label_list = [] # label of each utterance
label_len_list = [] # label length of each utterance

feat_list = []

for b in batch:
uttid, feat_rxfilename, label_rxfilename = b

uttid_list.append(uttid)

feat = kaldi.read_mat(feat_rxfilename).numpy()
feat = torch.from_numpy(feat).float()
feat_list.append(feat)

feat_len_list.append(feat.size(0))

if label_rxfilename:
label = kaldi.read_vec_int(label_rxfilename)
label_list.append(label)
label_len_list.append(len(label))

feat = pad_sequence(feat_list, batch_first=True)

if not label_list:
label_list = None
label_len_list = None

return uttid_list, feat, feat_len_list, label_list, label_len_list


def _test_dataset():
feats_scp = 'data/train_sp/feats.scp'
labels_scp = 'data/train_sp/labels.scp'

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

print(dataset)


def _test_dataloader():
feats_scp = 'data/test/feats.scp'
labels_scp = 'data/test/labels.scp'

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

dataloader = DataLoader(dataset,
batch_size=2,
num_workers=10,
shuffle=True,
collate_fn=CtcDatasetCollateFunc())
i = 0
for batch in dataloader:
uttid_list, feat, feat_len_list, label_list, label_len_list = batch
print(uttid_list, feat.shape, feat_len_list, label_len_list)
i += 1
if i > 10:
break


if __name__ == '__main__':
# _test_dataset()
_test_dataloader()
145 changes: 145 additions & 0 deletions egs/aishell/s10b/ctc/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3

# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


def get_ctc_model(input_dim,
output_dim,
num_layers=4,
hidden_dim=512,
proj_dim=256):
model = CtcModel(input_dim=input_dim,
output_dim=output_dim,
num_layers=num_layers,
hidden_dim=hidden_dim,
proj_dim=proj_dim)

return model


class CtcModel(nn.Module):

def __init__(self, input_dim, output_dim, num_layers, hidden_dim, proj_dim):
'''
Args:
input_dim: input dimension of the network
output_dim: output dimension of the network
num_layers: number of LSTM layers of the network
hidden_dim: the dimension of the hidden state of LSTM layers
proj_dim: dimension of the affine layer after every LSTM layer
'''
super().__init__()

lstm_layer_list = []
proj_layer_list = []

# batchnorm requires input of shape [N, C, L] == [batch_size, dim, seq_len]
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
affine=False)

for i in range(num_layers):
if i == 0:
lstm_input_dim = input_dim
else:
lstm_input_dim = proj_dim

lstm_layer = nn.LSTM(input_size=lstm_input_dim,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True)

proj_layer = nn.Linear(in_features=hidden_dim,
out_features=proj_dim)

lstm_layer_list.append(lstm_layer)
proj_layer_list.append(proj_layer)

self.lstm_layer_list = nn.ModuleList(lstm_layer_list)
self.proj_layer_list = nn.ModuleList(proj_layer_list)

self.num_layers = num_layers

self.prefinal_affine = nn.Linear(in_features=proj_dim,
out_features=output_dim)

def forward(self, feat, feat_len_list):
'''
Args:
feat: a 3-D tensor of shape [batch_size, seq_len, feat_dim]
feat_len_list: feat length of each utterance before padding
Returns:
a 3-D tensor of shape [batch_size, seq_len, output_dim]
representing log prob, i.e., the output of log_softmax.
'''
x = feat

# at his point, x is of shape [batch_size, seq_len, feat_dim]
x = x.permute(0, 2, 1)

# at his point, x is of shape [batch_size, feat_dim, seq_len] == [N, C, L]
x = self.input_batch_norm(x)

x = x.permute(0, 2, 1)

# at his point, x is of shape [batch_size, seq_len, feat_dim] == [N, L, C]

for i in range(self.num_layers):
x = pack_padded_sequence(input=x,
lengths=feat_len_list,
batch_first=True,
enforce_sorted=False)

# TODO(fangjun): save intermediate LSTM state to support streaming inference
x, _ = self.lstm_layer_list[i](x)

x, _ = pad_packed_sequence(x, batch_first=True)

x = self.proj_layer_list[i](x)

x = torch.tanh(x)

x = self.prefinal_affine(x)

x = F.log_softmax(x, dim=-1)

return x


def _test_ctc_model():
input_dim = 5
output_dim = 20
model = CtcModel(input_dim=input_dim,
output_dim=output_dim,
num_layers=2,
hidden_dim=3,
proj_dim=4)

feat1 = torch.randn((6, input_dim))
feat2 = torch.randn((8, input_dim))

from torch.nn.utils.rnn import pad_sequence
feat = pad_sequence([feat1, feat2], batch_first=True)
assert feat.shape == torch.Size([2, 8, input_dim])

feat_len_list = [6, 8]
x = model(feat, feat_len_list)

assert x.shape == torch.Size([2, 8, output_dim])


if __name__ == '__main__':
_test_ctc_model()
162 changes: 162 additions & 0 deletions egs/aishell/s10b/ctc/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3

# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import argparse
import os


def _str2bool(v):
'''
This function is modified from
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
'''
if isinstance(v, bool):
return v
elif v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


def _set_training_args(parser):
parser.add_argument('--train.labels-scp',
dest='labels_scp',
help='filename of labels.scp',
type=str)

parser.add_argument('--train.num-epochs',
dest='num_epochs',
help='number of epochs to train',
type=int)

parser.add_argument('--train.lr',
dest='learning_rate',
help='learning rate',
type=float)

parser.add_argument('--train.l2-regularize',
dest='l2_regularize',
help='l2 regularize',
type=float)

# TODO(fangjun): add validation feats_scp


def _check_training_args(args):
assert os.path.isfile(args.labels_scp)

assert args.num_epochs > 0
assert args.learning_rate > 0
assert args.l2_regularize >= 0

if args.checkpoint:
assert os.path.exists(args.checkpoint)


def _check_args(args):
if args.is_training:
_check_training_args(args)

assert os.path.isdir(args.dir)
assert os.path.isfile(args.feats_scp)

assert args.batch_size > 0
assert args.device_id >= 0

assert args.input_dim > 0
assert args.output_dim > 0
assert args.num_layers > 0
assert args.hidden_dim > 0
assert args.proj_dim > 0

assert args.log_level in ['debug', 'info', 'warning']


def get_args():
parser = argparse.ArgumentParser(
description='chain training in PyTorch with kaldi pybind')

_set_training_args(parser)

parser.add_argument('--is-training',
dest='is_training',
help='true for training, false for inference',
required=True,
type=_str2bool)

parser.add_argument('--dir',
help='dir to save results. The user has to '
'create it before calling this script.',
required=True,
type=str)

parser.add_argument('--feats-scp',
dest='feats_scp',
help='filename of feats.scp',
required=True,
type=str)

parser.add_argument('--device-id',
dest='device_id',
help='GPU device id',
required=True,
type=int)

parser.add_argument('--batch-size',
dest='batch_size',
help='batch size used in training and inference',
required=True,
type=int)

parser.add_argument('--input-dim',
dest='input_dim',
help='input dimension of the network',
required=True,
type=int)

parser.add_argument('--output-dim',
dest='output_dim',
help='output dimension of the network',
required=True,
type=int)

parser.add_argument('--num-layers',
dest='num_layers',
help="number of LSTM layers in the network",
required=True,
type=int)

parser.add_argument('--hidden-dim',
dest='hidden_dim',
help="dimension of the LSTM cell state",
required=True,
type=int)

parser.add_argument(
'--proj-dim',
dest='proj_dim',
help="dimension of the affine layer after every LSTM layer",
required=True,
type=int)

parser.add_argument('--log-level',
dest='log_level',
help='log level. valid values: debug, info, warning',
type=str,
default='info')

parser.add_argument(
'--checkpoint',
dest='checkpoint',
help='filename of the checkpoint, required for inference',
type=str)

args = parser.parse_args()

_check_args(args)

return args
107 changes: 107 additions & 0 deletions egs/aishell/s10b/ctc/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import logging
import os
import sys
import warnings

# disable warnings when loading tensorboard
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_value_
import torch.nn.functional as F

import kaldi

from options import get_args
from common import setup_logger
from model import get_ctc_model
from dataset import get_ctc_dataloader


def main():
args = get_args()
setup_logger('{}/log-train'.format(args.dir), args.log_level)
logging.info(' '.join(sys.argv))

if torch.cuda.is_available() == False:
logging.error('No GPU detected!')
sys.exit(-1)

device = torch.device('cuda', args.device_id)

model = get_ctc_model(input_dim=args.input_dim,
output_dim=args.output_dim,
num_layers=args.num_layers,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim)

model.to(device)

dataloader = get_ctc_dataloader(feats_scp=args.feats_scp,
labels_scp=args.labels_scp,
batch_size=args.batch_size,
shuffle=True,
num_workers=8)

lr = args.learning_rate
optimizer = optim.Adam(model.parameters(),
lr=lr,
weight_decay=args.l2_regularize)

model.train()

for epoch in range(args.num_epochs):
learning_rate = lr * pow(0.4, epoch)

for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate

logging.info('epoch {}, learning rate {}'.format(epoch, learning_rate))

for batch_idx, batch in enumerate(dataloader):
uttidlist, feat, feat_len_list, label_list, label_len_list = batch

feat = feat.to(device)
log_probs = model(feat, feat_len_list)

# at this point log_probs is of shape: [batch_size, seq_len, output_dim]
# CTCLoss requires a layout: [seq_len, batch_size, output_dim]

log_probs = log_probs.permute(1, 0, 2)
# now log_probs is of shape [seq_len, batch_size, output_dim]

label_tensor_list = [torch.tensor(x) for x in label_list]

targets = torch.cat(label_tensor_list).to(device)

input_lengths = torch.tensor(feat_len_list).to(device)

target_lengths = torch.tensor(label_len_list).to(device)

loss = F.ctc_loss(log_probs=log_probs,
targets=targets,
input_lengths=input_lengths,
target_lengths=target_lengths,
blank=0,
reduction='mean')

optimizer.zero_grad()

loss.backward()

# clip_grad_value_(model.parameters(), 5.0)

optimizer.step()

logging.info('batch {}, loss {}'.format(batch_idx, loss.item()))


if __name__ == '__main__':
torch.manual_seed(20200221)
main()
2 changes: 1 addition & 1 deletion egs/aishell/s10b/local/generate_tlg.sh
Original file line number Diff line number Diff line change
@@ -141,6 +141,6 @@ set -e

# remove files not needed any more
for f in G.fst L.fst T.fst LG.fst disambig.list \
lexiconp.txt lexiconp_disambig.txt phones.list; do
lexiconp.txt lexiconp_disambig.txt; do
rm $dir/$f
done
59 changes: 59 additions & 0 deletions egs/aishell/s10b/local/run_ctc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

set -e

echo "$0 $@" # Print the command line for logging

stage=0

device_id=1

train_data_dir=data/train_sp
dev_data_dir=data/dev_sp
test_data_dir=data/test
lang_dir=data/lang

lr=1e-4
num_epochs=6
l2_regularize=1e-5
num_layers=4
hidden_dim=512
proj_dim=200
batch_size=64


dir=exp/ctc

. ./path.sh
. ./cmd.sh

. parse_options.sh

feat_dim=$(feat-to-dim --print-args=false scp:$train_data_dir/feats.scp -)
output_dim=$(cat $lang_dir/phones.list | wc -l)
# added by one since we have an extra blank symbol <blk>
output_dim=$[$output_dim+1]

if [[ $stage -le 0 ]]; then
mkdir -p $dir

# sort options alphabetically
python3 ./ctc/train.py \
--batch-size $batch_size \
--device-id $device_id \
--dir=$dir \
--feats-scp $train_data_dir/feats.scp \
--hidden-dim $hidden_dim \
--input-dim $feat_dim \
--is-training true \
--num-layers $num_layers \
--output-dim $output_dim \
--proj-dim $proj_dim \
--train.l2-regularize $l2_regularize \
--train.labels-scp $train_data_dir/labels.scp \
--train.lr $lr \
--train.num-epochs $num_epochs
fi
10 changes: 9 additions & 1 deletion egs/aishell/s10b/run.sh
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ data_url=www.openslr.org/resources/33

nj=30

stage=6
stage=0

if [[ $stage -le 0 ]]; then
local/download_and_untar.sh $data $data_url data_aishell || exit 1
@@ -63,3 +63,11 @@ if [[ $stage -le 6 ]]; then
./local/convert_text_to_labels.sh data/$x data/lang
done
fi

if [[ $stage -le 7 ]]; then
./local/run_ctc.sh \
--train-data-dir data/train_sp \
--dev-data-dir data/dev_sp \
--test-data-dir data/test \
--lang-dir data/lang
fi

0 comments on commit ffa861c

Please sign in to comment.