Skip to content

Commit

Permalink
Merge pull request #458 from SeanNaren/feature/V2
Browse files Browse the repository at this point in the history
Feature/v2
  • Loading branch information
Sean Naren authored Oct 1, 2019
2 parents d5dbadf + 8a5ecdf commit 9b9c96a
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 281 deletions.
129 changes: 51 additions & 78 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ export CUDA_HOME="/usr/local/cuda"
cd ../pytorch_binding && python setup.py install
```

Install pytorch audio:
```
sudo apt-get install sox libsox-dev libsox-fmt-all
git clone https://github.com/pytorch/audio.git
cd audio && python setup.py install
```

Install NVIDIA apex:
```
git clone --recursive https://github.com/NVIDIA/apex.git
Expand All @@ -59,69 +52,11 @@ Finally clone this repo and run this within the repo:
pip install -r requirements.txt
```

## Usage
## Training

### Datasets

Currently supports AN4, TEDLIUM, Voxforge and LibriSpeech. Scripts will setup the dataset and create manifest files used in dataloading.

#### AN4

To download and setup the an4 dataset run below command in the root folder of the repo:

```
cd data; python an4.py
```

#### TEDLIUM

You have the option to download the raw dataset file manually or through the script (which will cache it).
The file is found [here](http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz).

To download and setup the TEDLIUM_V2 dataset run below command in the root folder of the repo:

```
cd data; python ted.py # Optionally if you have downloaded the raw dataset file, pass --tar_path /path/to/TEDLIUM_release2.tar.gz
```
#### Voxforge

To download and setup the Voxforge dataset run the below command in the root folder of the repo:

```
cd data; python voxforge.py
```

Note that this dataset does not come with a validation dataset or test dataset.

#### LibriSpeech

To download and setup the LibriSpeech dataset run the below command in the root folder of the repo:

```
cd data; python librispeech.py
```

You have the option to download the raw dataset files manually or through the script (which will cache them as well).
In order to do this you must create the following folder structure and put the corresponding tar files that you download from [here](http://www.openslr.org/12/).

```
cd data/
mkdir LibriSpeech/ # This can be anything as long as you specify the directory path as --target-dir when running the librispeech.py script
mkdir LibriSpeech/val/
mkdir LibriSpeech/test/
mkdir LibriSpeech/train/
```

Now put the `tar.gz` files in the correct folders. They will now be used in the data pre-processing for librispeech and be removed after
formatting the dataset.

Optionally you can specify the exact librispeech files you want if you don't want to add all of them. This can be done like below:

```
cd data/
python librispeech.py --files-to-use "train-clean-100.tar.gz, train-clean-360.tar.gz,train-other-500.tar.gz, dev-clean.tar.gz,dev-other.tar.gz, test-clean.tar.gz,test-other.tar.gz"
```
Currently supports AN4, TEDLIUM, Voxforge, Common Voice and LibriSpeech. Scripts will setup the dataset and create manifest files used in data-loading. The scripts can be found in the data/ folder. Many of the scripts allow you to download the raw datasets separately if you choose so.

#### Custom Dataset

Expand All @@ -146,7 +81,7 @@ cd data/
python merge_manifests.py --output-path merged_manifest.csv --merge-dir all-manifests/ --min-duration 1 --max-duration 15 # durations in seconds
```

## Training
### Training a Model

```
python train.py --train-manifest data/train_manifest.csv --val-manifest data/val_manifest.csv
Expand All @@ -168,7 +103,7 @@ python train.py --tensorboard --logdir log_dir/ # Make sure the Tensorboard inst

For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training.

## Multi-GPU Training
### Multi-GPU Training

We support multi-GPU training via the distributed parallel wrapper (see [here](https://github.com/NVIDIA/sentiment-discovery/blob/master/analysis/scale.md) and [here](https://github.com/SeanNaren/deepspeech.pytorch/issues/211) to see why we don't use DataParallel).

Expand All @@ -180,22 +115,26 @@ python -m multiproc train.py --visdom --cuda # Add your parameters as normal, mu

multiproc will open a log for all processes other than the main process.

We suggest using the NCCL backend which defaults to TCP if Infiniband isn't available.

## Mixed Precision

If you are using NVIDIA volta cards or above to train your model, it's highly suggested to turn on mixed precision for speed/memory benefits. More information can be found [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). Also suggested is to turn on dyanmic loss scaling to handle small grad values:
You can also specify specific GPU IDs rather than allowing the script to use all available GPUs:

```
python train.py --train-manifest data/train_manifest.csv --val-manifest data/val_manifest.csv --mixed-precision --dynamic-loss-scale
python -m multiproc train.py --visdom --cuda --device-ids 0,1,2,3 # Add your parameters as normal, will only run on 4 GPUs
```

You can also specify specific GPU IDs rather than allowing the script to use all available GPUs:
We suggest using the NCCL backend which defaults to TCP if Infiniband isn't available.

### Mixed Precision

If you are using NVIDIA volta cards or above to train your model, it's highly suggested to turn on mixed precision for speed/memory benefits. More information can be found [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html).

Different Optimization levels are available. More information on the Nvidia Apex API can be seen [here](https://nvidia.github.io/apex/amp.html#opt-levels).

```
python -m multiproc train.py --visdom --cuda --device-ids 0,1,2,3 # Add your parameters as normal, will only run on 4 GPUs
python train.py --train-manifest data/train_manifest.csv --val-manifest data/val_manifest.csv --opt-level O1 --loss-scale 1.0
```

Training a model in mixed-precision means you can use 32 bit float or half precision at runtime. Float is default, to use half precision (Which on V100s come with a speedup and better memory use) use the `--half` flag when testing or transcribing.

### Noise Augmentation/Injection

There is support for two different types of noise; noise augmentation and noise injection.
Expand Down Expand Up @@ -282,7 +221,9 @@ An example script to output a transcription has been provided:
python transcribe.py --model-path models/deepspeech.pth --audio-path /path/to/audio.wav
```

## Server
If you used mixed-precision or half precision when training the model, you can use the `--half` flag for a speed/memory benefit.

## Inference Server

Included is a basic server script that will allow post request to be sent to the server to transcribe files.

Expand All @@ -292,6 +233,38 @@ python server.py --host 0.0.0.0 --port 8000 # Run on one window
curl -X POST http://0.0.0.0:8000/transcribe -H "Content-type: multipart/form-data" -F "file=@/path/to/input.wav"
```

## Using an ARPA LM

We support using kenlm based LMs. Below are instructions on how to take the LibriSpeech LMs found [here](http://www.openslr.org/11/) and tune the model to give you the best parameters when decoding, based on LibriSpeech.

### Tuning the LibriSpeech LMs

First ensure you've set up the librispeech datasets from the data/ folder.
In addition download the latest pre-trained librispeech model from the releases page, as well as the ARPA model you want to tune from [here](http://www.openslr.org/11/). For the below we use the 3-gram ARPA model (3e-7 prune).

First we need to generate the acoustic output to be used to evaluate the model on LibriSpeech val.
```
python test.py --test-manifest data/librispeech_val_manifest.csv --model-path librispeech_pretrained_v2.pth --cuda --half --save-output librispeech_val_output.npy
```

We use a beam width of 128 which gives reasonable results. We suggest using a CPU intensive node to carry out the grid search.

```
python search_lm_params.py --num-workers 16 --saved-output librispeech_val_output.npy --output-path libri_tune_output.json --lm-alpha-from 0 --lm-alpha-to 5 --lm-beta-from 0 --lm-beta-to 3 --lm-path 3-gram.pruned.3e-7.arpa --model-path librispeech_pretrained_v2.pth --beam-width 128 --lm-workers 16
```

This will run a grid search across the alpha/beta parameters using a beam width of 128. Use the below script to find the best alpha/beta params:

```
python select_lm_params.py --input-path libri_tune_output.json
```

Use the alpha/beta parameters when using the beam decoder.

### Building your own LM

To build your own LM you need to use the KenLM repo found [here](https://github.com/kpu/kenlm). Have a read of the documentation to get a sense of how to train your own LM. The above steps once trained can be used to find the appropriate parameters.

### Alternate Decoders
By default, `test.py` and `transcribe.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output.

Expand Down
6 changes: 3 additions & 3 deletions data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import scipy.signal
import torch
import torchaudio
from scipy.io.wavfile import read
import math
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
Expand All @@ -20,8 +20,8 @@


def load_audio(path):
sound, _ = torchaudio.load(path, normalization=True)
sound = sound.numpy().T
sample_rate, sound = read(path)
sound = sound.astype('float32') / 32767 # normalize audio
if len(sound.shape) > 1:
if sound.shape[1] == 1:
sound = sound.squeeze()
Expand Down
14 changes: 4 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(self, x, lengths):
"""
for module in self.seq_module:
x = module(x)
mask = torch.ByteTensor(x.size()).fill_(0)
mask = torch.BoolTensor(x.size()).fill_(0)
if x.is_cuda:
mask = mask.cuda()
for i, length in enumerate(lengths):
Expand Down Expand Up @@ -129,7 +129,7 @@ def __repr__(self):

class DeepSpeech(nn.Module):
def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layers=5, audio_conf=None,
bidirectional=True, context=20, mixed_precision=False):
bidirectional=True, context=20):
super(DeepSpeech, self).__init__()

# model metadata needed for serialization/deserialization
Expand All @@ -142,7 +142,6 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
self.audio_conf = audio_conf or {}
self.labels = labels
self.bidirectional = bidirectional
self.mixed_precision = mixed_precision

sample_rate = self.audio_conf.get("sample_rate", 16000)
window_size = self.audio_conf.get("window_size", 0.02)
Expand Down Expand Up @@ -187,8 +186,6 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
self.inference_softmax = InferenceBatchSoftmax()

def forward(self, x, lengths):
if x.is_cuda and self.mixed_precision:
x = x.half()
lengths = lengths.cpu().int()
output_lengths = self.get_seq_lens(lengths)
x, _ = self.conv(x, output_lengths)
Expand Down Expand Up @@ -230,8 +227,7 @@ def load_model(cls, path):
labels=package['labels'],
audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']],
bidirectional=package.get('bidirectional', True),
mixed_precision=package.get('mixed_precision', False))
bidirectional=package.get('bidirectional', True))
model.load_state_dict(package['state_dict'])
for x in model.rnns:
x.flatten_parameters()
Expand All @@ -244,8 +240,7 @@ def load_model_package(cls, package):
labels=package['labels'],
audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']],
bidirectional=package.get('bidirectional', True),
mixed_precision=package.get('mixed_precision', False))
bidirectional=package.get('bidirectional', True))
model.load_state_dict(package['state_dict'])
return model

Expand All @@ -261,7 +256,6 @@ def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=No
'labels': model.labels,
'state_dict': model.state_dict(),
'bidirectional': model.bidirectional,
'mixed_precision': model.mixed_precision
}
if optimizer is not None:
package['optim_dict'] = optimizer.state_dict()
Expand Down
6 changes: 4 additions & 2 deletions noise_inject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

import torch
import torchaudio
from scipy.io.wavfile import write

from data.data_loader import load_audio, NoiseInjection

Expand All @@ -18,5 +18,7 @@
data = load_audio(args.input_path)
mixed_data = noise_injector.inject_noise_sample(data, args.noise_path, args.noise_level)
mixed_data = torch.tensor(mixed_data, dtype=torch.float).unsqueeze(1) # Add channels dim
torchaudio.save(args.output_path, mixed_data, args.sample_rate)
write(filename=args.output_path,
data=mixed_data.numpy(),
rate=args.sample_rate)
print('Saved mixed file to %s' % args.output_path)
4 changes: 3 additions & 1 deletion opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def add_decoder_args(parser):


def add_inference_args(parser):
parser.add_argument('--cuda', action="store_true", help='Use cuda to test model')
parser.add_argument('--cuda', action="store_true", help='Use cuda')
parser.add_argument('--half', action="store_true",
help='Use half precision. This is recommended when using mixed-precision at training time')
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
parser.add_argument('--model-path', default='models/deepspeech_final.pth',
help='Path to model file created by training')
Expand Down
84 changes: 84 additions & 0 deletions search_lm_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import json
import sys
from multiprocessing.pool import Pool

import numpy as np
import torch
from tqdm import tqdm

from decoder import BeamCTCDecoder
from model import DeepSpeech
from opts import add_decoder_args

parser = argparse.ArgumentParser(description='Tune an ARPA LM based on a pre-trained acoustic model output')
parser.add_argument('--model-path', default='models/deepspeech_final.pth',
help='Path to model file created by training')
parser.add_argument('--saved-output', default="", type=str, help='Path to output from test.py')
parser.add_argument('--num-workers', default=16, type=int, help='Number of parallel decodes to run')
parser.add_argument('--output-path', default="tune_results.json", help="Where to save tuning results")
parser.add_argument('--lm-alpha-from', default=0.0, type=float, help='Language model weight start tuning')
parser.add_argument('--lm-alpha-to', default=3.0, type=float, help='Language model weight end tuning')
parser.add_argument('--lm-beta-from', default=0.0, type=float,
help='Language model word bonus (all words) start tuning')
parser.add_argument('--lm-beta-to', default=0.5, type=float,
help='Language model word bonus (all words) end tuning')
parser.add_argument('--lm-num-alphas', default=45, type=float, help='Number of alpha candidates for tuning')
parser.add_argument('--lm-num-betas', default=8, type=float, help='Number of beta candidates for tuning')
parser = add_decoder_args(parser)
args = parser.parse_args()

if args.lm_path is None:
print("error: LM must be provided for tuning")
sys.exit(1)

model = DeepSpeech.load_model(args.model_path)

saved_output = np.load(args.saved_output)


def init(beam_width, blank_index, lm_path):
global decoder
decoder = BeamCTCDecoder(model.labels, lm_path=lm_path, beam_width=beam_width, num_processes=args.lm_workers,
blank_index=blank_index)


def decode_dataset(params):
lm_alpha, lm_beta = params
global decoder
decoder._decoder.reset_params(lm_alpha, lm_beta)

total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
for out, sizes, target_strings in saved_output:
out = torch.Tensor(out).float()
sizes = torch.Tensor(sizes).int()
decoded_output, _, = decoder.decode(out, sizes)
for x in range(len(target_strings)):
transcript, reference = decoded_output[x][0], target_strings[x][0]
wer_inst = decoder.wer(transcript, reference)
cer_inst = decoder.cer(transcript, reference)
total_cer += cer_inst
total_wer += wer_inst
num_tokens += len(reference.split())
num_chars += len(reference)

wer = float(total_wer) / num_tokens
cer = float(total_cer) / num_chars

return [lm_alpha, lm_beta, wer * 100, cer * 100]


if __name__ == '__main__':
p = Pool(args.num_workers, init, [args.beam_width, model.labels.index('_'), args.lm_path])

cand_alphas = np.linspace(args.lm_alpha_from, args.lm_alpha_to, args.lm_num_alphas)
cand_betas = np.linspace(args.lm_beta_from, args.lm_beta_to, args.lm_num_betas)
params_grid = [(float(alpha), float(beta)) for alpha in cand_alphas
for beta in cand_betas]

scores = []
for params in tqdm(p.imap(decode_dataset, params_grid), total=len(params_grid)):
scores.append(list(params))
print("Saving tuning results to: {}".format(args.output_path))
with open(args.output_path, "w") as fh:
json.dump(scores, fh)
12 changes: 12 additions & 0 deletions select_lm_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import argparse
import json

parser = argparse.ArgumentParser(description='Select the best parameters based on the WER')
parser.add_argument('--input-path', type=str, help='Output json file from search_lm_params')
args = parser.parse_args()

with open(args.input_path) as f:
results = json.load(f)

min_results = min(results, key=lambda x: x[2]) # Find the minimum WER (alpha, beta, WER, CER)
print("Alpha: %f \nBeta: %f \nWER: %f\nCER: %f" % tuple(min_results))
Loading

0 comments on commit 9b9c96a

Please sign in to comment.