Skip to content

Commit

Permalink
updated PitchDNN export script
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Sep 29, 2023
1 parent ce28695 commit 0459a57
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions dnn/torch/neural-pitch/export_neuralpitch_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,47 +44,59 @@
import torch
import numpy as np

from models import large_if_ccode
from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector

def c_export(args, model):

message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"

enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
enc_writer.header.write(
writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
writer.header.write(
f"""
#include "opus_types.h"
"""
)


# encoder
encoder_dense_layers = [
('initial' , 'initial', 'TANH'),
('upsample' , 'upsample', 'TANH')
layers = [
('if_upsample.0', "dense_if_upsampler_1"),
('if_upsample.2', "dense_if_upsampler_2"),
('conv.1', "conv2d_1"),
('conv.4', "conv2d_2"),
('conv.7', "conv2d_3"),
('downsample.0', "dense_downsampler"),
("upsample.0", "dense_final_upsampler")
]

for name, export_name, _ in encoder_dense_layers:

for name, export_name in layers:
layer = model.get_submodule(name)
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
dump_torch_weights(writer, layer, name=export_name, verbose=True)


encoder_gru_layers = [
('gru' , 'gru', 'TANH'),
gru_layers = [
("GRU", "gru_1"),
]

enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name, _ in encoder_gru_layers])
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name in gru_layers])

writer.header.write(
f"""
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
"""
)

del enc_writer
writer.close()


if __name__ == "__main__":

os.makedirs(args.output_dir, exist_ok=True)
model = large_if_ccode()
checkpoint = torch.load(args.checkpoint ,map_location='cpu')
model = PitchDNN()
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)

0 comments on commit 0459a57

Please sign in to comment.