Skip to content

Commit

Permalink
feat: cm export
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Nov 21, 2024
1 parent c56fa2d commit acb41b7
Showing 1 changed file with 50 additions and 5 deletions.
55 changes: 50 additions & 5 deletions util/export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
import torch


from models import gan_networks
from models import gan_networks, diffusion_networks


class ConsistencyWrapper(torch.nn.Module):
"""
Consistency model wrapper for onnx & jit trace
"""

def __init__(self, model, sigmas):
super().__init__()
self.model = model
self.sigmas = sigmas

def forward(self, x, mask):
return self.model.restoration(x, None, self.sigmas, mask)


def export(opt, cuda, model_in_file, model_out_file, opset_version, export_type):
model = gan_networks.define_G(**vars(opt))
if opt.model_type == "cm":
opt.alg_palette_sampling_method = ""
opt.alg_diffusion_cond_embed_dim = 256
model = diffusion_networks.define_G(**vars(opt))
model_input_nc = opt.model_input_nc + opt.model_output_nc
else:
model = gan_networks.define_G(**vars(opt))
model_input_nc = opt.model_input_nc

model.eval()
model.load_state_dict(torch.load(model_in_file))
Expand All @@ -18,21 +39,45 @@ def export(opt, cuda, model_in_file, model_out_file, opset_version, export_type)
device = "cuda"
else:
device = "cpu"

dummy_input = torch.randn(
1, opt.model_input_nc, opt.data_crop_size, opt.data_crop_size, device=device
1,
model_input_nc,
opt.data_crop_size,
opt.data_crop_size,
device=device,
)
# print("dummy input size=", dummy_input.size())
dummy_inputs = [dummy_input]

if opt.model_type == "cm":
# at the moment, consistency models have two inputs: origin image and mask
# TODO allow to change number of sigmas
sigmas = [80.0, 24.4, 5.84, 0.9, 0.661]
model = ConsistencyWrapper(model, sigmas)
dummy_inputs += [
torch.randn(
1,
opt.model_input_nc,
opt.data_crop_size,
opt.data_crop_size,
device=device,
),
]

dummy_inputs = tuple(dummy_inputs)

if export_type == "onnx":
torch.onnx.export(
model,
dummy_input,
dummy_inputs,
model_out_file,
verbose=False,
opset_version=opset_version,
)

elif export_type == "jit":
jit_model = torch.jit.trace(model, dummy_input)
jit_model = torch.jit.trace(model, dummy_inputs)
jit_model.save(model_out_file)

else:
Expand Down

0 comments on commit acb41b7

Please sign in to comment.