From 436e37609e31a5b78aee39b5c52e386f359284ea Mon Sep 17 00:00:00 2001 From: soltanianaref <101413285+soltanianaref@users.noreply.github.com> Date: Tue, 5 Apr 2022 21:23:36 +0430 Subject: [PATCH] Update encoders.py --- steganogan/encoders.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/steganogan/encoders.py b/steganogan/encoders.py index d1a93c8..84042fa 100644 --- a/steganogan/encoders.py +++ b/steganogan/encoders.py @@ -2,6 +2,14 @@ import torch from torch import nn +import torch.onnx +from torchvision.ops.deform_conv import DeformConv2d + +input = torch.rand(4, 3, 10, 10) +kh, kw = 3, 3 +weight = torch.rand(5, 3, kh, kw) +offset = torch.rand(4, 2 * kh * kw, 8, 8) +mask = torch.rand(4, kh * kw, 8, 8) class BasicEncoder(nn.Module): @@ -16,11 +24,11 @@ class BasicEncoder(nn.Module): add_image = False def _conv2d(self, in_channels, out_channels): - return nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1 + return DeformConv2d( + input, + offset=offset, + weight=weight, + mask=mask ) def _build_models(self):