Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX conversion #107

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ pip-delete-this-directory.txt
_site/
.sass-cache/
.jekyll-cache/
.jekyll-metadata
.jekyll-metadata
# FT:
data/
checkpoints/
*.onnx
50 changes: 50 additions & 0 deletions convert-to-jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import argparse
import torch
from models.forward_tacotron import ForwardTacotron


"""
Torchscript exporter for ⏩ ForwardTacotron
"""


# Declaring the convertor:
def run_convertor(model_path, save_path):
if not os.path.exists(model_path):
raise FileNotFoundError("Please give me an existing model!")
tts_model = ForwardTacotron.from_checkpoint(model_path)
tts_model.eval()
# Initialize a defined TTS model for torchscript in models/ForwardTacotron:
model_script = torch.jit.script(tts_model)
# Generate input for testing:
x = torch.ones((1, 5)).long()
# Try generating this input:
y = model_script.generate_jit(x)
if save_path is None:
save_path = model_path[:-3]+".ts"
# Finally, we export it:
torch.jit.save(
model_script,
save_path
)
print("Model successfully converted to torchscript.")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchScript convertor for ⏩ForwardTacotron")
parser.add_argument(
'--checkpoint_path',
'-c',
required=True,
type=str,
help='The full checkpoint (*.pt) file to convert.'
)
parser.add_argument(
'--output_path',
'-o',
default=None,
type=str,
help='Output path to save the converted TorchScript model.'
)
args = parser.parse_args()
run_convertor(args.checkpoint_path, args.output_path)
91 changes: 91 additions & 0 deletions convert-to-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import argparse
import torch
from models.forward_tacotron import ForwardTacotron
from utils.text.symbols import phonemes

"""
ONNX convertor for ⏩ ForwardTacotron
Lately, ONNX stuff for TTS models is popular, because these models provides a faster inference than the full PyTorch models. Faster inference is good to use these models, for example, in a screen reader. Also, ONNX models can be used in a multi-platform/system way such as IOS, Android Phone devices, etc.
The onnx compatibility has been fixed by Matthew C. (rmcpantoja).
"""

# ======================global vars======================
OPSET = 17
SEED = 1234
# ======================end global vars======================

# Declaring the convertor:
def run_convertor(model_path, save_path):
if not os.path.exists(model_path):
raise FileNotFoundError("Please give me an existing model!")
tts_model = ForwardTacotron.from_checkpoint(model_path)
tts_model.eval()
# Configure seed:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
# We create the custom generate() function to return the mel post only and acomodate synthesizer options into a list array:
def custom_generate(text, synth_options):
alpha = synth_options[0]
pitch = synth_options[1]
energy = synth_options[2]
# Todo: try inferencing this pitch/energy with an ONNX model:
pitch_function = lambda x: x * pitch
energy_function = lambda x: x * energy
infer = tts_model.generate(
text,
alpha=alpha,
pitch_function=pitch_function,
energy_function=energy_function,
onnx=True
)
mel = infer['mel_post']
return mel
# We replace the forward function to the created one:
tts_model.forward = custom_generate
# We set the inputs and outputs for the ONNX model:
dummy_input_length = 50
rand = torch.randint(low=0, high=len(phonemes), size=(1, dummy_input_length), dtype=torch.long)
synth_inputs = torch.FloatTensor(
[1.0, 1.0, 1.0] # Alpha, pitch, energy
)
model_inputs = (rand, synth_inputs)
input_names = [
"input",
"synth_options"
]
if save_path is None:
save_path = model_path[:-3]+".onnx"
# Finally, we export it:
torch.onnx.export(
model = tts_model,
args = model_inputs,
f = save_path,
opset_version=OPSET,
input_names=input_names,
output_names=['output'],
dynamic_axes = {
"input": {0: "batch_size", 1: "text"},
"output": {0: "batch_size", 1: "time"}
}
)
print("Checkpoint successfully converted to ONNX.")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Onnx conversor for ⏩ForwardTacotron")
parser.add_argument(
'--checkpoint_path',
'-c',
required=True,
type=str,
help='The full checkpoint (*.pt) file to convert.'
)
parser.add_argument(
'--output_path',
'-o',
default=None,
type=str,
help='Output path to save the converted ONNX model.'
)
args = parser.parse_args()
run_convertor(args.checkpoint_path, args.output_path)
90 changes: 90 additions & 0 deletions gen_forward_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
from pathlib import Path
import numpy as np
import onnxruntime
from utils.display import simple_table
from utils.dsp import DSP
from utils.files import read_config
from utils.paths import Paths
from utils.text.cleaners import Cleaner
from utils.text.tokenizer import Tokenizer


if __name__ == '__main__':

# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', default=None, type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--checkpoint', type=str, default=None, help='[string/path] path to .pt model file.')
parser.add_argument('--config', metavar='FILE', default='default.yaml', help='The config containing all hyperparams.')
parser.add_argument('--speaker', type=str, default=None, help='Speaker to generate audio for (only multispeaker).')

parser.add_argument('--alpha', type=float, default=1., help='Parameter for controlling length regulator for speedup '
'or slow-down of generated speech, e.g. alpha=2.0 is double-time')
parser.add_argument('--amp', type=float, default=1., help='Parameter for controlling pitch amplification')

# name of subcommand goes to args.vocoder
subparsers = parser.add_subparsers(dest='vocoder')
gl_parser = subparsers.add_parser('griffinlim')
mg_parser = subparsers.add_parser('melgan')
hg_parser = subparsers.add_parser('hifigan')

args = parser.parse_args()

assert args.vocoder in {'griffinlim', 'melgan', 'hifigan'}, \
'Please provide a valid vocoder! Choices: [griffinlim, melgan, hifigan]'

checkpoint_path = args.checkpoint
if checkpoint_path is None:
config = read_config(args.config)
paths = Paths(config['data_path'], config['tts_model_id'])
checkpoint_path = paths.forward_checkpoints / 'latest_model.onnx'
sess_options = onnxruntime.SessionOptions()
checkpoint = onnxruntime.InferenceSession(str(checkpoint_path), sess_options=sess_options)
config = read_config(args.config)
dsp = DSP.from_config(config)

voc_model, voc_dsp = None, None
out_path = Path('model_outputs')
out_path.mkdir(parents=True, exist_ok=True)
cleaner = Cleaner.from_config(config)
tokenizer = Tokenizer()

if args.input_text:
texts = [args.input_text]
else:
with open('sentences.txt', 'r', encoding='utf-8') as f:
texts = f.readlines()

pitch_function = lambda x: x * args.amp
energy_function = lambda x: x

for i, x in enumerate(texts, 1):
print(f'\n| Generating {i}/{len(texts)}')
text = x
x = cleaner(x)
x = tokenizer(x)
text = np.expand_dims(np.array(x, dtype=np.int64), 0)
synth_options = np.array(
[args.alpha, 1.0, 1.0],
dtype=np.float32,
)
speaker_name = args.speaker if args.speaker is not None else 'default_speaker'
wav_name = f"test{i}"
m = checkpoint.run(
None,
{
"input": text,
"synth_options": synth_options,
},
)[0]
#m = (m * 32767).astype(np.int16)
if args.vocoder == 'melgan':
torch.save(m, out_path / f'{wav_name}.mel')
if args.vocoder == 'hifigan':
np.save(str(out_path / f'{wav_name}.npy'), m, allow_pickle=False)
elif args.vocoder == 'griffinlim':
wav = dsp.griffinlim(m)
dsp.save_wav(wav, out_path / f'{wav_name}.wav')

print('\n\nDone.\n')
63 changes: 63 additions & 0 deletions gen_forward_onnx_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
from pathlib import Path
import numpy as np
import onnxruntime
from utils.files import read_config
from utils.dsp import DSP
from utils.paths import Paths
from utils.text.cleaners import Cleaner
from utils.text.tokenizer import Tokenizer
import time

if __name__ == '__main__':

# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--checkpoint', type=str, default=None, help='[string/path] path to .onnx model file.')
parser.add_argument('--config', metavar='FILE', default='default.yaml', help='The config containing all hyperparams.')
parser.add_argument('--speaker', type=str, default=None, help='Speaker to generate audio for (only multispeaker).')

args = parser.parse_args()

checkpoint_path = args.checkpoint
if checkpoint_path is None:
config = read_config(args.config)
paths = Paths(config['data_path'], config['tts_model_id'])
checkpoint_path = paths.forward_checkpoints / 'latest_model.onnx'
sess_options = onnxruntime.SessionOptions()
checkpoint = onnxruntime.InferenceSession(str(checkpoint_path), sess_options=sess_options)
config = read_config(args.config)
dsp = DSP.from_config(config)
cleaner = Cleaner.from_config(config)
tokenizer = Tokenizer()

with open('sentences.txt', 'r', encoding='utf-8') as f:
texts = f.readlines()
for i, x in enumerate(texts, 1):
print(f'\n| Generating {i}/{len(texts)}')
text = x
x = cleaner(x)
x = tokenizer(x)
text = np.expand_dims(np.array(x, dtype=np.int64), 0)
synth_options = np.array(
[1.0, 1.0, 1.0],
dtype=np.float32,
)
speaker_name = args.speaker if args.speaker is not None else 'default_speaker'
start_time = time.perf_counter()
m = checkpoint.run(
None,
{
"input": text,
"synth_options": synth_options,
},
)[0]
end_time = time.perf_counter()
mel_length = m.shape[-1]
spec_length = mel_length * dsp.hop_length
spec_sec = spec_length / dsp.sample_rate
infer_sec = (end_time - start_time)
rtf = infer_sec / spec_sec*1000
print(f"Sentence {i} generation time: {infer_sec} MS, RTF: {rtf} MS.")

print('\n\nDone.\n')
34 changes: 33 additions & 1 deletion models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,24 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm, MultiheadAttention
from torch.nn.utils.rnn import pad_sequence
#from torch.nn.utils.rnn import pad_sequence

class CustomPadSequence(nn.Module):
def __init__(self, padding_value=0.):
super().__init__()
self.padding_value = padding_value

def forward(self, sequences):
# Find the maximum length in the sequences
max_length = max(len(seq) for seq in sequences)

# Pad sequences with the specified padding value
padded_sequences = [F.pad(seq, (0, max_length - len(seq)), value=self.padding_value) for seq in sequences]

# Stack the padded sequences
padded_sequences = torch.stack(padded_sequences, dim=0)

return padded_sequences

class LengthRegulator(nn.Module):

Expand All @@ -23,6 +39,22 @@ def forward(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor:
x_expanded = pad_sequence(x_expanded, padding_value=0., batch_first=True)
return x_expanded

class LengthRegulator_onnx(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor:
dur[dur < 0] = 0.
x_expanded = []
for i in range(x.size(0)):
x_exp = torch.repeat_interleave(x[i], (dur[i] + 0.5).long(), dim=0)
x_expanded.append(x_exp)
customPadSequence = CustomPadSequence(padding_value=0.)
x_expanded = customPadSequence(x_expanded)
return x_expanded



class HighwayNetwork(nn.Module):

Expand Down
Loading