diff --git a/pnet2_layers/layers.py b/pnet2_layers/layers.py index 00c0c61..6e214f3 100644 --- a/pnet2_layers/layers.py +++ b/pnet2_layers/layers.py @@ -33,10 +33,6 @@ def build(self, input_shape): def call(self, xyz, points, training=True): - if points is not None: - if len(points.shape) < 3: - points = tf.expand_dims(points, axis=0) - if self.group_all: nsample = xyz.get_shape()[1] new_xyz, new_points, idx, grouped_xyz = utils.sample_and_group_all(xyz, points, self.use_xyz) @@ -56,7 +52,7 @@ def call(self, xyz, points, training=True): new_points = tf.math.reduce_max(new_points, axis=2, keepdims=True) - return new_xyz, tf.squeeze(new_points) + return new_xyz, tf.squeeze(new_points, [2]) class Pointnet_SA_MSG(Layer): @@ -89,10 +85,6 @@ def build(self, input_shape): def call(self, xyz, points, training=True): - if points is not None: - if len(points.shape) < 3: - points = tf.expand_dims(points, axis=0) - new_xyz = utils.gather_point(xyz, utils.farthest_point_sample(self.npoint, xyz)) new_points_list = [] @@ -145,13 +137,6 @@ def build(self, input_shape): def call(self, xyz1, xyz2, points1, points2, training=True): - if points1 is not None: - if len(points1.shape) < 3: - points1 = tf.expand_dims(points1, axis=0) - if points2 is not None: - if len(points2.shape) < 3: - points2 = tf.expand_dims(points2, axis=0) - dist, idx = utils.three_nn(xyz1, xyz2) dist = tf.maximum(dist, 1e-10) norm = tf.reduce_sum((1.0/dist),axis=2, keepdims=True) @@ -168,8 +153,6 @@ def call(self, xyz1, xyz2, points1, points2, training=True): for i, mlp_layer in enumerate(self.mlp_list): new_points1 = mlp_layer(new_points1, training=training) - new_points1 = tf.squeeze(new_points1) - if len(new_points1.shape) < 3: - new_points1 = tf.expand_dims(new_points1, axis=0) + new_points1 = tf.squeeze(new_points1, [2]) return new_points1