diff --git a/ml3d/torch/models/point_transformer.py b/ml3d/torch/models/point_transformer.py index a28293e19..ebec8e87c 100644 --- a/ml3d/torch/models/point_transformer.py +++ b/ml3d/torch/models/point_transformer.py @@ -34,6 +34,7 @@ class PointTransformer(BaseModel): """ def __init__(self, + device, name="PointTransformer", blocks=[2, 2, 2, 2, 2], in_channels=6, @@ -52,6 +53,7 @@ def __init__(self, batcher=batcher, augment=augment, **kwargs) + self.device = torch.device(device) cfg = self.cfg self.in_channels = in_channels self.augmenter = SemsegAugmentation(cfg.augment) @@ -108,12 +110,13 @@ def _make_enc(self, """ layers = [] layers.append( - TransitionDown(self.in_planes, planes * block.expansion, stride, + TransitionDown(self.device, self.in_planes, planes * block.expansion, stride, nsample)) self.in_planes = planes * block.expansion for _ in range(1, blocks): layers.append( - block(self.in_planes, + block(self.device, + self.in_planes, self.in_planes, share_planes, nsample=nsample)) @@ -141,12 +144,14 @@ def _make_dec(self, """ layers = [] layers.append( - TransitionUp(self.in_planes, + TransitionUp(self.device, + self.in_planes, None if is_head else planes * block.expansion)) self.in_planes = planes * block.expansion for _ in range(1, blocks): layers.append( - block(self.in_planes, + block(self.device, + self.in_planes, self.in_planes, share_planes, nsample=nsample)) @@ -377,7 +382,7 @@ def get_optimizer(self, cfg_pipeline): class Transformer(nn.Module): """Transformer layer of the model, uses self attention.""" - def __init__(self, in_planes, out_planes, share_planes=8, nsample=16): + def __init__(self, device, in_planes, out_planes, share_planes=8, nsample=16): """Constructor for Transformer Layer. Args: @@ -388,6 +393,7 @@ def __init__(self, in_planes, out_planes, share_planes=8, nsample=16): """ super().__init__() + self.device = device self.mid_planes = mid_planes = out_planes // 1 self.out_planes = out_planes self.share_planes = share_planes @@ -427,7 +433,8 @@ def forward(self, pxo): point, feat, row_splits = pxo # (n, 3), (n, c), (b) feat_q, feat_k, feat_v = self.linear_q(feat), self.linear_k( feat), self.linear_v(feat) # (n, c) - feat_k = queryandgroup(self.nsample, + feat_k = queryandgroup(self.device, + self.nsample, point, point, feat_k, @@ -435,7 +442,8 @@ def forward(self, pxo): row_splits, row_splits, use_xyz=True) # (n, nsample, 3+c) - feat_v = queryandgroup(self.nsample, + feat_v = queryandgroup(self.device, + self.nsample, point, point, feat_v, @@ -473,7 +481,7 @@ class TransitionDown(nn.Module): Subsamples points and increase receptive field. """ - def __init__(self, in_planes, out_planes, stride=1, nsample=16): + def __init__(self, device, in_planes, out_planes, stride=1, nsample=16): """Constructor for TransitionDown Layer. Args: @@ -484,6 +492,7 @@ def __init__(self, in_planes, out_planes, stride=1, nsample=16): """ super().__init__() + self.device = device self.stride, self.nsample = stride, nsample if stride != 1: self.linear = nn.Linear(3 + in_planes, out_planes, bias=False) @@ -504,7 +513,13 @@ def forward(self, pxo): List of point, feat, row_splits. """ + point, feat, row_splits = pxo # (n, 3), (n, c), (b+1) + feat = torch.tensor(feat, device=self.device) + row_splits = torch.tensor(row_splits, device=self.device) + point = torch.tensor(point, device=self.device) + + if self.stride != 1: new_row_splits = [0] count = 0 @@ -513,12 +528,13 @@ def forward(self, pxo): row_splits[i - 1].item()) // self.stride new_row_splits.append(count) - new_row_splits = torch.LongTensor(new_row_splits).to( - row_splits.device) + new_row_splits = torch.LongTensor(new_row_splits).to(self.device) + #new_row_splits = torch.LongTensor(new_row_splits).to(row_splits.device) idx = furthest_point_sample_v2(point, row_splits, new_row_splits) # (m) new_point = point[idx.long(), :] # (m, 3) - feat = queryandgroup(self.nsample, + feat = queryandgroup(self.device, + self.nsample, point, new_point, feat, @@ -532,7 +548,8 @@ def forward(self, pxo): feat = self.pool(feat).squeeze(-1) # (m, c) point, row_splits = new_point, new_row_splits else: - feat = self.relu(self.bn(self.linear(feat))) # (n, c) + feat = torch.tensor(feat, device=self.device) + feat = self.relu(self.bn(self.linear(feat))) # (n, c) return [point, feat, row_splits] @@ -542,7 +559,7 @@ class TransitionUp(nn.Module): Interpolate points based on corresponding encoder layer. """ - def __init__(self, in_planes, out_planes=None): + def __init__(self, device, in_planes, out_planes=None): """Constructor for TransitionUp Layer. Args: @@ -551,6 +568,7 @@ def __init__(self, in_planes, out_planes=None): """ super().__init__() + self.device = device if out_planes is None: self.linear1 = nn.Sequential(nn.Linear(2 * in_planes, in_planes), nn.BatchNorm1d(in_planes), @@ -595,7 +613,7 @@ def forward(self, pxo1, pxo2=None): point_1, feat_1, row_splits_1 = pxo1 point_2, feat_2, row_splits_2 = pxo2 feat = self.linear1(feat_1) + interpolation( - point_2, point_1, self.linear2(feat_2), row_splits_2, + self.device, point_2, point_1, self.linear2(feat_2), row_splits_2, row_splits_1) return feat @@ -607,7 +625,7 @@ class Bottleneck(nn.Module): """ expansion = 1 - def __init__(self, in_planes, planes, share_planes=8, nsample=16): + def __init__(self, device, in_planes, planes, share_planes=8, nsample=16): """Constructor for Bottleneck Layer. Args: @@ -620,7 +638,7 @@ def __init__(self, in_planes, planes, share_planes=8, nsample=16): super(Bottleneck, self).__init__() self.linear1 = nn.Linear(in_planes, planes, bias=False) self.bn1 = nn.BatchNorm1d(planes) - self.transformer2 = Transformer(planes, planes, share_planes, nsample) + self.transformer2 = Transformer(device, planes, planes, share_planes, nsample) self.bn2 = nn.BatchNorm1d(planes) self.linear3 = nn.Linear(planes, planes * self.expansion, bias=False) self.bn3 = nn.BatchNorm1d(planes * self.expansion) @@ -647,7 +665,8 @@ def forward(self, pxo): return [point, feat, row_splits] -def queryandgroup(nsample, +def queryandgroup(device, + nsample, points, queries, feat, @@ -677,7 +696,8 @@ def queryandgroup(nsample, if queries is None: queries = points if idx is None: - idx = knn_batch(points, + idx = knn_batch(device, + points, queries, k=nsample, points_row_splits=points_row_splits, @@ -697,7 +717,8 @@ def queryandgroup(nsample, return grouped_feat -def knn_batch(points, +def knn_batch(device, + points, queries, k, points_row_splits, @@ -729,12 +750,13 @@ def knn_batch(points, return_distances=True) if return_distances: return ans.neighbors_index.reshape( - -1, k).long().cuda(), ans.neighbors_distance.reshape(-1, k).cuda() + -1, k).long().to(device), ans.neighbors_distance.reshape(-1, k).to(device) else: - return ans.neighbors_index.reshape(-1, k).long().cuda() + return ans.neighbors_index.reshape(-1, k).long().to(device) -def interpolation(points, +def interpolation(device, + points, queries, feat, points_row_splits, @@ -756,7 +778,8 @@ def interpolation(points, if not (points.is_contiguous and queries.is_contiguous() and feat.is_contiguous()): raise ValueError("Interpolation (points/queries/feat not contiguous)") - idx, dist = knn_batch(points, + idx, dist = knn_batch(device, + points, queries, k=k, points_row_splits=points_row_splits, @@ -770,7 +793,7 @@ def interpolation(points, weight = dist_recip / norm # (n, k) new_feat = torch.FloatTensor(queries.shape[0], - feat.shape[1]).zero_().to(feat.device) + feat.shape[1]).zero_().to(device) for i in range(k): new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) return new_feat diff --git a/ml3d/torch/models/sparseconvnet.py b/ml3d/torch/models/sparseconvnet.py index 86c555f33..5c71d5c76 100644 --- a/ml3d/torch/models/sparseconvnet.py +++ b/ml3d/torch/models/sparseconvnet.py @@ -54,17 +54,19 @@ def __init__( augment=augment, **kwargs) cfg = self.cfg - self.device = device + self.device = torch.device(device) + print(f"this is running on device {self.device}") self.augmenter = SemsegAugmentation(cfg.augment, seed=self.rng) self.multiplier = cfg.multiplier self.input_layer = InputLayer() - self.sub_sparse_conv = SubmanifoldSparseConv(in_channels=in_channels, + self.sub_sparse_conv = SubmanifoldSparseConv(device=self.device, + in_channels=in_channels, filters=multiplier, kernel_size=[3, 3, 3]) - self.unet = UNet(conv_block_reps, [ + self.unet = UNet(self.device, conv_block_reps, [ multiplier, 2 * multiplier, 3 * multiplier, 4 * multiplier, 5 * multiplier, 6 * multiplier, 7 * multiplier - ], residual_blocks) + ], residual_blocks, self.device) self.batch_norm = BatchNormBlock(multiplier) self.relu = ReLUBlock() self.linear = LinearBlock(multiplier, num_classes) @@ -158,14 +160,14 @@ def transform(self, data, attr): return data - def update_probs(self, inputs, results, test_probs, test_labels): + def update_probs(self, inputs, results, test_probs): result = results.reshape(-1, self.cfg.num_classes) probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy() labels = np.argmax(probs, 1) self.trans_point_sampler(patchwise=False) - return probs, labels + return probs def inference_begin(self, data): data = self.preprocess(data, {'split': 'test'}) @@ -344,6 +346,7 @@ def forward(self, features_list, index_map_list): class SubmanifoldSparseConv(nn.Module): def __init__(self, + device, in_channels, filters, kernel_size, @@ -359,6 +362,7 @@ def __init__(self, offset = 0.5 offset = torch.full((3,), offset, dtype=torch.float32) + self.device = device self.net = SparseConv(in_channels=in_channels, filters=filters, kernel_size=kernel_size, @@ -377,7 +381,10 @@ def forward(self, out_feat = [] for feat, in_pos, out_pos in zip(features_list, in_positions_list, out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + out_feat.append(self.net(feat.to(self.device), + in_pos.to(self.device), + out_pos.to(self.device), + voxel_size)) return out_feat @@ -404,6 +411,7 @@ def calculate_grid(in_positions): class Convolution(nn.Module): def __init__(self, + device, in_channels, filters, kernel_size, @@ -411,6 +419,7 @@ def __init__(self, offset=None, normalize=False): super(Convolution, self).__init__() + self.device = device if offset is None: if kernel_size[0] % 2: @@ -434,7 +443,9 @@ def forward(self, features_list, in_positions_list, voxel_size=1.0): out_feat = [] for feat, in_pos, out_pos in zip(features_list, in_positions_list, out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + out_feat.append(self.net(feat.to(self.device), + in_pos.to(self.device), + out_pos.to(self.device), voxel_size)) out_positions_list = [out / 2 for out in out_positions_list] @@ -447,6 +458,7 @@ def __name__(self): class DeConvolution(nn.Module): def __init__(self, + device, in_channels, filters, kernel_size, @@ -454,6 +466,7 @@ def __init__(self, offset=None, normalize=False): super(DeConvolution, self).__init__() + self.device = device if offset is None: if kernel_size[0] % 2: @@ -477,7 +490,10 @@ def forward(self, out_feat = [] for feat, in_pos, out_pos in zip(features_list, in_positions_list, out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + out_feat.append(self.net(feat.to(self.device), + in_pos.to(self.device), + out_pos.to(self.device), + voxel_size)) return out_feat @@ -532,20 +548,22 @@ def forward(self, inputs): class ResidualBlock(nn.Module): - def __init__(self, nIn, nOut): + def __init__(self, device, nIn, nOut): super(ResidualBlock, self).__init__() - + self.device = device self.lin = NetworkInNetwork(nIn, nOut) self.batch_norm1 = BatchNormBlock(nIn) self.relu1 = ReLUBlock() - self.sub_sparse_conv1 = SubmanifoldSparseConv(in_channels=nIn, + self.sub_sparse_conv1 = SubmanifoldSparseConv(device=device, + in_channels=nIn, filters=nOut, kernel_size=[3, 3, 3]) self.batch_norm2 = BatchNormBlock(nOut) self.relu2 = ReLUBlock() - self.sub_sparse_conv2 = SubmanifoldSparseConv(in_channels=nOut, + self.sub_sparse_conv2 = SubmanifoldSparseConv(device=device, + in_channels=nOut, filters=nOut, kernel_size=[3, 3, 3]) @@ -567,20 +585,22 @@ def __name__(self): class UNet(nn.Module): def __init__(self, + device, conv_block_reps, nPlanes, residual_blocks=False, downsample=[2, 2], leakiness=0): super(UNet, self).__init__() + self.device = device self.net = nn.ModuleList( - self.get_UNet(nPlanes, residual_blocks, conv_block_reps)) + self.get_UNet(device, nPlanes, residual_blocks, conv_block_reps)) self.residual_blocks = residual_blocks @staticmethod - def block(layers, a, b, residual_blocks): + def block(device, layers, a, b, residual_blocks): if residual_blocks: - layers.append(ResidualBlock(a, b)) + layers.append(ResidualBlock(device, a, b)) else: layers.append(BatchNormBlock(a)) @@ -591,32 +611,34 @@ def block(layers, a, b, residual_blocks): kernel_size=[3, 3, 3])) @staticmethod - def get_UNet(nPlanes, residual_blocks, conv_block_reps): + def get_UNet(device, nPlanes, residual_blocks, conv_block_reps): layers = [] for i in range(conv_block_reps): - UNet.block(layers, nPlanes[0], nPlanes[0], residual_blocks) + UNet.block(device, layers, nPlanes[0], nPlanes[0], residual_blocks) if len(nPlanes) > 1: layers.append(ConcatFeat()) layers.append(BatchNormBlock(nPlanes[0])) layers.append(ReLUBlock()) layers.append( - Convolution(in_channels=nPlanes[0], + Convolution(device=device, + in_channels=nPlanes[0], filters=nPlanes[1], kernel_size=[2, 2, 2])) - layers = layers + UNet.get_UNet(nPlanes[1:], residual_blocks, + layers = layers + UNet.get_UNet(device, nPlanes[1:], residual_blocks, conv_block_reps) layers.append(BatchNormBlock(nPlanes[1])) layers.append(ReLUBlock()) layers.append( - DeConvolution(in_channels=nPlanes[1], + DeConvolution(device=device, + in_channels=nPlanes[1], filters=nPlanes[0], kernel_size=[2, 2, 2])) layers.append(JoinFeat()) for i in range(conv_block_reps): - UNet.block(layers, nPlanes[0] * (2 if i == 0 else 1), + UNet.block(device, layers, nPlanes[0] * (2 if i == 0 else 1), nPlanes[0], residual_blocks) return layers