diff --git a/steganogan/encoders.py b/steganogan/encoders.py index d1a93c8..275edff 100644 --- a/steganogan/encoders.py +++ b/steganogan/encoders.py @@ -2,6 +2,8 @@ import torch from torch import nn +import torch.onnx +from torchvision.ops.deform_conv import DeformConv2d class BasicEncoder(nn.Module): @@ -16,7 +18,7 @@ class BasicEncoder(nn.Module): add_image = False def _conv2d(self, in_channels, out_channels): - return nn.Conv2d( + return DeformConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3,