Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: begin to add CTC training with kaldi pybind and PyTorch. #3947

Open
wants to merge 11 commits into
base: pybind11
Choose a base branch
from
16 changes: 16 additions & 0 deletions egs/aishell/s10b/cmd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# you can change cmd.sh depending on what type of queue you are using.
# If you have no queueing system and want to run on a local machine, you
# can change all instances 'queue.pl' to run.pl (but be careful and run
# commands one by one: most recipes will exhaust the memory on your
# machine). queue.pl works with GridEngine (qsub). slurm.pl works
# with slurm. Different queues are configured differently, with different
# queue names and different ways of specifying things like memory;
# to account for these differences you can create and edit the file
# conf/queue.conf to match your queue's configuration. Search for
# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information,
# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl.

export train_cmd="run.pl"
export decode_cmd="run.pl"
export mkgraph_cmd="run.pl"
export cuda_cmd="run.pl"
1 change: 1 addition & 0 deletions egs/aishell/s10b/conf/fbank.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--num-mel-bins=40
96 changes: 96 additions & 0 deletions egs/aishell/s10b/ctc/add_deltas_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import torch
import torch.nn as nn
import torch.nn.functional as F


def compute_delta_feat(x, weight):
'''
Args:
x: input feat of shape [batch_size, feat_dim, seq_len]

weight: coefficients for computing delta features;
it has shape [feat_dim, 1, kernel_size].

Returns:
a tensor of shape [batch_size, feat_dim, seq_len]
'''

assert x.ndim == 3

assert weight.ndim == 3
assert weight.size(0) == x.size(1)
assert weight.size(1) == 1
assert weight.size(2) % 2 == 1

feat_dim = x.size(1)

# NOTE(fangjun): we perform a depthwise convolution here by
# setting groups == number of channels
y = F.conv1d(input=x, weight=weight, groups=feat_dim)

return y


class AddDeltasLayer(nn.Module):
'''
This class implements `add-deltas` with order == 2 and window == 2.

Note that it has no trainable `nn.Parameter`s.
'''

def __init__(self,
first_order_coef=[-1, 0, 1],
second_order_coef=[1, 0, -2, 0, 1]):
'''
Args:
first_order_coef: coefficient to compute the first order delta feature

second_order_coef: coefficient to compute the second order delta feature
'''
super().__init__()

self.first_order_coef = torch.tensor(first_order_coef).float()
self.second_order_coef = torch.tensor(second_order_coef).float()

def forward(self, x):
'''
Args:
x: a tensor of shape [batch_size, feat_dim, seq_len]

Returns:
a tensor of shape [batch_size, feat_dim * 3, seq_len]
'''
if self.first_order_coef.ndim != 3:
num_duplicates = x.size(1)

# yapf: disable
self.first_order_coef = self.first_order_coef.reshape(1, 1, -1)
self.first_order_coef = torch.cat([self.first_order_coef] * num_duplicates, dim=0)

self.second_order_coef = self.second_order_coef.reshape(1, 1, -1)
self.second_order_coef = torch.cat([self.second_order_coef] * num_duplicates, dim=0)
# yapf: enable

device = x.device
self.first_order_coef = self.first_order_coef.to(device)
self.second_order_coef = self.second_order_coef.to(device)

first_order = compute_delta_feat(x, self.first_order_coef)
second_order = compute_delta_feat(x, self.second_order_coef)

# since we did not perform padding, we have to remove some frames
# from the 0th and 1st order features
zeroth_valid = (x.size(2) - second_order.size(2)) // 2
first_valid = (first_order.size(2) - second_order.size(2)) // 2

y = torch.cat([
x[:, :, zeroth_valid:-zeroth_valid,],
first_order[:, :, first_valid:-first_valid],
second_order,
],
dim=1)

return y
79 changes: 79 additions & 0 deletions egs/aishell/s10b/ctc/add_deltas_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import os
import shutil
import tempfile
import unittest

import numpy as np

import torch
import torch.nn.functional as F

import kaldi

from add_deltas_layer import AddDeltasLayer


class AddDeltasLayerTest(unittest.TestCase):

def test(self):
x = torch.tensor([
[1, 3],
[5, 10],
[0, 1],
[10, 20],
[3, 1],
[3, 2],
[5, 1],
[10, -2],
[10, 20],
[100, 200],
]).float()

x = x.unsqueeze(0)

transform = AddDeltasLayer(first_order_coef=[-0.2, -0.1, 0, 0.1, 0.2],
second_order_coef=[
0.04, 0.04, 0.01, -0.04, -0.1, -0.04,
0.01, 0.04, 0.04
])
y = transform(x.permute(0, 2, 1)).permute(0, 2, 1)

# now use kaldi's add-deltas to compute the ground truth
d = tempfile.mkdtemp()

wspecifier = 'ark:{}/feats.ark'.format(d)

writer = kaldi.MatrixWriter(wspecifier)
writer.Write('utt1', x.squeeze(0).numpy())
writer.Close()

delta_feats_specifier = 'ark:{dir}/delta.ark'.format(dir=d)

cmd = '''
add-deltas --print-args=false --delta-order=2 --delta-window=2 {} {}
'''.format(wspecifier, delta_feats_specifier)

os.system(cmd)

reader = kaldi.RandomAccessMatrixReader(delta_feats_specifier)

expected = reader['utt1']

y = y.squeeze(0)

np.testing.assert_array_almost_equal(y.numpy(),
expected.numpy()[4:-4, :],
decimal=5)

reader.Close()

shutil.rmtree(d)


if __name__ == '__main__':
unittest.main()
98 changes: 98 additions & 0 deletions egs/aishell/s10b/ctc/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

from datetime import datetime
import logging

import torch


def setup_logger(log_filename, log_level='info'):
now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
log_filename = '{}-{}'.format(log_filename, date_time)
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
if log_level == 'debug':
level = logging.DEBUG
elif log_level == 'info':
level = logging.INFO
elif log_level == 'warning':
level = logging.WARNING
logging.basicConfig(filename=log_filename,
format=formatter,
level=level,
filemode='w')
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger('').addHandler(console)


def load_checkpoint(filename, model):
logging.info('Loading checkpoint from {}'.format(filename))

checkpoint = torch.load(filename, map_location='cpu')

keys = ['state_dict', 'epoch', 'learning_rate', 'loss']
for k in keys:
assert k in checkpoint

if not list(model.state_dict().keys())[0].startswith('module.') \
and list(checkpoint['state_dict'])[0].startswith('module.'):
# the checkpoint was saved by DDP
logging.info('load checkpoint from DDP')
dst_state_dict = model.state_dict()
src_state_dict = checkpoint['state_dict']
for key in dst_state_dict.keys():
src_key = '{}.{}'.format('module', key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict)
else:
model.load_state_dict(checkpoint['state_dict'])

epoch = checkpoint['epoch']
learning_rate = checkpoint['learning_rate']
loss = checkpoint['loss']

return epoch, learning_rate, loss


def save_checkpoint(filename, model, epoch, learning_rate, loss, local_rank=0):
if local_rank != 0:
return
logging.info('Saving checkpoint to {filename}: epoch={epoch}, '
'learning_rate={learning_rate}, loss={loss}'.format(
filename=filename,
epoch=epoch,
learning_rate=learning_rate,
loss=loss))
checkpoint = {
'state_dict': model.state_dict(),
'epoch': epoch,
'learning_rate': learning_rate,
'loss': loss
}
torch.save(checkpoint, filename)


def save_training_info(filename,
model_path,
current_epoch,
learning_rate,
loss,
best_loss,
best_epoch,
local_rank=0):
if local_rank != 0:
return

with open(filename, 'w') as f:
f.write('model_path: {}\n'.format(model_path))
f.write('epoch: {}\n'.format(current_epoch))
f.write('learning rate: {}\n'.format(learning_rate))
f.write('loss: {}\n'.format(loss))
f.write('best loss: {}\n'.format(best_loss))
f.write('best epoch: {}\n'.format(best_epoch))
Loading