From e8243dd1e57c2dffe10c877e9580ea5182364066 Mon Sep 17 00:00:00 2001 From: hwany Date: Tue, 2 Mar 2021 19:16:04 +0900 Subject: [PATCH] [Fix] model architecture --- pointnet/model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pointnet/model.py b/pointnet/model.py index 48de610c2..2d87200ce 100644 --- a/pointnet/model.py +++ b/pointnet/model.py @@ -89,11 +89,13 @@ def __init__(self, global_feat = True, feature_transform = False): super(PointNetfeat, self).__init__() self.stn = STN3d() self.conv1 = torch.nn.Conv1d(3, 64, 1) - self.conv2 = torch.nn.Conv1d(64, 128, 1) - self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.conv2 = torch.nn.Conv1d(64, 64, 1) + self.conv3 = torch.nn.Conv1d(64, 128, 1) + self.conv4 = torch.nn.Conv1d(128, 1024, 1) self.bn1 = nn.BatchNorm1d(64) - self.bn2 = nn.BatchNorm1d(128) - self.bn3 = nn.BatchNorm1d(1024) + self.bn2 = nn.BatchNorm1d(64) + self.bn3 = nn.BatchNorm1d(128) + self.bn4 = nn.BatchNorm1d(1024) self.global_feat = global_feat self.feature_transform = feature_transform if self.feature_transform: @@ -106,6 +108,7 @@ def forward(self, x): x = torch.bmm(x, trans) x = x.transpose(2, 1) x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) if self.feature_transform: trans_feat = self.fstn(x) @@ -116,8 +119,8 @@ def forward(self, x): trans_feat = None pointfeat = x - x = F.relu(self.bn2(self.conv2(x))) - x = self.bn3(self.conv3(x)) + x = F.relu(self.bn3(self.conv3(x))) + x = self.bn4(self.conv4(x)) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) if self.global_feat: @@ -211,3 +214,4 @@ def feature_transform_regularizer(trans): seg = PointNetDenseCls(k = 3) out, _, _ = seg(sim_data) print('seg', out.size()) +