-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_abc.py
129 lines (113 loc) · 5.19 KB
/
train_abc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from ruamel.yaml import YAML
# NeMo's "core" package
import nemo
# NeMo's ASR collection
import nemo.collections.asr as nemo_asr
# Create a Neural Factory
# It creates log files and tensorboard writers for us among other functions
nf = nemo.core.NeuralModuleFactory(log_dir='QuartzNet12x1_vivos_abc', create_tb_writer=True)
tb_writer = nf.tb_writer
# Path to our training manifest
train_dataset = "data/vivos_train.json"
# Path to our validation manifest
eval_datasets = "data/vivos_test.json"
# QuartzNet Model definition
# Here we will be using separable convolutions
# with 12 blocks (k=12 repeated once r=1 from the picture above)
yaml = YAML(typ="safe")
with open("config/quartznet12x1_char.yaml") as f:
quartznet_model_definition = yaml.load(f)
labels = quartznet_model_definition['labels']
print(len(labels), labels)
# Instantiate neural modules
data_layer = nemo_asr.AudioToTextDataLayer(
manifest_filepath=train_dataset,
labels=labels, batch_size=32)
data_layer_val = nemo_asr.AudioToTextDataLayer(
manifest_filepath=eval_datasets,
labels=labels, batch_size=32, shuffle=False)
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
spec_augment = nemo_asr.SpectrogramAugmentation(rect_masks=5)
encoder = nemo_asr.JasperEncoder(feat_in=64, **quartznet_model_definition['JasperEncoder'])
decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, num_classes=len(labels))
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(labels))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
# Training DAG (Model)
audio_signal, audio_signal_len, transcript, transcript_len = data_layer()
processed_signal, processed_signal_len = data_preprocessor(input_signal=audio_signal, length=audio_signal_len)
# Data argument
aug_signal = spec_augment(input_spec=processed_signal)
encoded, encoded_len = encoder(audio_signal=aug_signal, length=processed_signal_len)
log_probs = decoder(encoder_output=encoded)
predictions = greedy_decoder(log_probs=log_probs)
loss = ctc_loss(
log_probs=log_probs, targets=transcript,
input_length=encoded_len, target_length=transcript_len)
# Validation DAG (Model)
# We need to instantiate additional data layer neural module
# for validation data
audio_signal_v, audio_signal_len_v, transcript_v, transcript_len_v = data_layer_val()
processed_signal_v, processed_signal_len_v = data_preprocessor(
input_signal=audio_signal_v, length=audio_signal_len_v)
# Note that we are not using data-augmentation in validation DAG
encoded_v, encoded_len_v = encoder(
audio_signal=processed_signal_v, length=processed_signal_len_v)
log_probs_v = decoder(encoder_output=encoded_v)
predictions_v = greedy_decoder(log_probs=log_probs_v)
loss_v = ctc_loss(
log_probs=log_probs_v, targets=transcript_v,
input_length=encoded_len_v, target_length=transcript_len_v)
# These helper functions are needed to print and compute various metrics
# such as word error rate and log them into tensorboard
# they are domain-specific and are provided by NeMo's collections
from nemo.collections.asr.helpers import monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch
from functools import partial
# Callback to track loss and print predictions during training
train_callback = nemo.core.SimpleLossLoggerCallback(
tb_writer=tb_writer,
# Define the tensors that you want SimpleLossLoggerCallback to
# operate on
# Here we want to print our loss, and our word error rate which
# is a function of our predictions, transcript, and transcript_len
tensors=[loss, predictions, transcript, transcript_len],
# To print logs to screen, define a print_func
print_func=partial(
monitor_asr_train_progress,
labels=labels
))
saver_callback = nemo.core.CheckpointCallback(
folder="QuartzNet12x1_vivos_abc/checkpoints/",
# Set how often we want to save checkpoints
step_freq=100)
# PRO TIP: while you can only have 1 train DAG, you can have as many
# val DAGs and callbacks as you want. This is useful if you want to monitor
# progress on more than one val dataset at once (say LibriSpeech dev clean
# and dev other)
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_v, predictions_v, transcript_v, transcript_len_v],
# how to process evaluation batch - e.g. compute WER
user_iter_callback=partial(
process_evaluation_batch,
labels=labels
),
# how to aggregate statistics (e.g. WER) for the evaluation epoch
user_epochs_done_callback=partial(
process_evaluation_epoch, tag="valid"
),
eval_step=500,
tb_writer=tb_writer)
# Run training using your Neural Factory
# Once this "action" is called data starts flowing along train and eval DAGs
# and computations start to happen
nf.train(
# Specify the loss to optimize for
tensors_to_optimize=[loss],
# Specify which callbacks you want to run
callbacks=[train_callback, eval_callback, saver_callback],
# Specify what optimizer to use
optimizer="novograd",
# Specify optimizer parameters such as num_epochs and lr
optimization_params={
"num_epochs": 100, "lr": 0.02, "weight_decay": 1e-4
}
)