Skip to content

Commit

Permalink
added rudimentary support for dumping nn.Conv2d layers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Sep 29, 2023
1 parent 25c65a0 commit c5c214d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dnn/torch/weight-exchange/wexchange/c_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
*/
"""

from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_vector
from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer, print_vector
24 changes: 24 additions & 0 deletions dnn/torch/weight-exchange/wexchange/c_export/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,37 @@ def print_conv1d_layer(writer : CWriter,
lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)


writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
writer.header.write(f"\n#define {name.upper()}_IN_SIZE {weight.shape[1]}\n")
writer.header.write(f"\n#define {name.upper()}_STATE_SIZE ({weight.shape[1]} * ({weight.shape[0] - 1}))\n")
writer.header.write(f"\n#define {name.upper()}_DELAY {(weight.shape[0] - 1) // 2}\n") # CAVE: delay is not a property of the conv layer

return weight.shape[0] * weight.shape[1]

def print_conv2d_layer(writer : CWriter,
name : str,
weight : np.ndarray,
bias : np.ndarray,
scale : float=1/128,
quantize : bool=False):

if quantize:
print("[print_conv2d_layer] warning: quantize argument ignored")

bias_name = name + "_bias"
float_weight_name = name + "_weight_float"

print_vector(writer, weight, float_weight_name)
print_vector(writer, bias, bias_name)

# init function
out_channels, in_channels, ksize1, ksize2 = weight.shape
init_call = f'conv2d_init(&model->{name}, arrays, "{bias_name}", "{float_weight_name}", {in_channels}, {out_channels}, {ksize1}, {ksize2})'

writer.layer_dict[name] = ('Conv2dLayer', init_call)



def print_gru_layer(writer : CWriter,
name : str,
Expand Down
1 change: 1 addition & 0 deletions dnn/torch/weight-exchange/wexchange/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""

from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights
from .torch import dump_torch_conv2d_weights, load_torch_conv2d_weights
from .torch import dump_torch_dense_weights, load_torch_dense_weights
from .torch import dump_torch_gru_weights, load_torch_gru_weights
from .torch import dump_torch_embedding_weights, load_torch_embedding_weights
Expand Down
33 changes: 32 additions & 1 deletion dnn/torch/weight-exchange/wexchange/torch/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import torch
import numpy as np

from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer
from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer

def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):

Expand Down Expand Up @@ -138,6 +138,33 @@ def load_torch_conv1d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b))


def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False):
w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy()
if conv.bias is None:
b = np.zeros(conv.out_channels, dtype=w.dtype)
else:
b = conv.bias.detach().cpu().numpy().copy()

if isinstance(where, CWriter):
return print_conv2d_layer(where, name, w, b, scale=scale, quantize=quantize)

else:
os.makedirs(where, exist_ok=True)

np.save(os.path.join(where, 'weight_oiwh.npy'), w)

np.save(os.path.join(where, 'bias.npy'), b)

def load_torch_conv2d_weights(where, conv):
with torch.no_grad():
w = np.load(os.path.join(where, 'weight_oiwh.npy'))
conv.weight.set_(torch.from_numpy(w).permute(0, 1, 3, 2))
if type(conv.bias) != type(None):
b = np.load(os.path.join(where, 'bias.npy'))
if conv.bias is not None:
conv.bias.set_(torch.from_numpy(b))


def dump_torch_embedding_weights(where, emb):
os.makedirs(where, exist_ok=True)

Expand All @@ -162,6 +189,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
return dump_torch_gru_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv1d):
return dump_torch_conv1d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv2d):
return dump_torch_conv2d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Embedding):
return dump_torch_embedding_weights(where, module)
else:
Expand All @@ -175,6 +204,8 @@ def load_torch_weights(where, module):
load_torch_gru_weights(where, module)
elif isinstance(module, torch.nn.Conv1d):
load_torch_conv1d_weights(where, module)
elif isinstance(module, torch.nn.Conv2d):
load_torch_conv2d_weights(where, module)
elif isinstance(module, torch.nn.Embedding):
load_torch_embedding_weights(where, module)
else:
Expand Down

0 comments on commit c5c214d

Please sign in to comment.