Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to implement Autoencoder in Pytorch #15

Open
dhirajsuvarna opened this issue May 20, 2020 · 18 comments
Open

Trying to implement Autoencoder in Pytorch #15

dhirajsuvarna opened this issue May 20, 2020 · 18 comments

Comments

@dhirajsuvarna
Copy link

Hi,

I am trying to implement autendoer in pytorch and I did write the model which I suppose is excatly what is present in this repo.

Model in pytorch

class PCAutoEncoder(nn.Module):
    def __init__(self, point_dim, num_points):
        super(PCAutoEncoder, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=point_dim, out_channels=64, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
        self.conv5 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1)
        self.fc1 = nn.Linear(in_features=1024, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=1024)
        self.fc3 = nn.Linear(in_features=1024, out_features=num_points*3)
        #batch norm
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
    
    def forward(self, x):
        batch_size = x.shape[0]
        point_dim = x.shape[1]
        num_points = x.shape[2]
        #encoder
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn1(self.conv2(x)))
        x = F.relu(self.bn1(self.conv3(x)))
        x = F.relu(self.bn2(self.conv4(x)))
        x = F.relu(self.bn3(self.conv5(x)))
        # do max pooling 
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        # get the global embedding
        global_feat = x
        #decoder
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn3(self.fc2(x)))
        reconstructed_points = self.fc3(x)
        #do reshaping
        reconstructed_points = reconstructed_points.reshape(batch_size, point_dim, num_points)
        return reconstructed_points, global_feat

However, after training this model for 200 ephocs, when I try to generate the output point cloud all i can generate if scatterd points as shown below -
autoencoder_wrong_output

Any direction to figure out the problem would be helpful.

@dhirajsuvarna
Copy link
Author

+1 @charlesq34

@skyir0n
Copy link

skyir0n commented Jun 8, 2020

It is helpful to use only one category in datatset (suppose ShapeNet) so as to train this autoencoder model.

@dhirajsuvarna
Copy link
Author

@skyir0n - the above result is from training only one category from the dataset.

@skyir0n
Copy link

skyir0n commented Jun 8, 2020

In your code, batch normalization is performed by using the same BN module. Due to learnable paramerters in a BN module (alpha and gamma for scaling and shifting), they should be defined, respectively, at each layer.

@siddharthKatageri
Copy link

@dhirajsuvarna hi, did you solve the problem? what was going wrong?

@dvirginz
Copy link

@dhirajsuvarna
Copy link
Author

@dvirginz , @siddharthKatageri,
I do have some idea of whats going wrong, but haven't got the time to try it out. I will check this again and update you guys withing a week (fingers crossed)

@siddharthKatageri
Copy link

@dhirajsuvarna @dvirginz
here's a simple implementation of pointcloud autoencoder in pytorch

  def __init__(self, num_points):
    super().__init__()
    self.conv1 = nn.Conv1d(3,64,1)
    self.conv2 = nn.Conv1d(64,128,1)
    self.conv3 = nn.Conv1d(128,256,1)
    self.conv4 = nn.Conv1d(256,1024,1)
    
    self.bn1 = nn.BatchNorm1d(64)   
    self.bn2 = nn.BatchNorm1d(128)
    self.bn3 = nn.BatchNorm1d(256)
    self.bn4 = nn.BatchNorm1d(1024)

    ###############

    self.fc1 = nn.Linear(1024, 1024) 
    self.fc2 = nn.Linear(1024, 1024)   
    self.fc3 = nn.Linear(1024, num_points*3)

    self.bn5 = nn.BatchNorm1d(1024)
    self.bn6 = nn.BatchNorm1d(1024)
    


  def forward(self, input):
    batchsize, dim, npoints = input.shape
    xb = F.relu(self.bn1(self.conv1(input)))
    xb = F.relu(self.bn2(self.conv2(xb)))
    xb = F.relu(self.bn3(self.conv3(xb)))
    xb = self.bn4(self.conv4(xb))
    xb = nn.MaxPool1d(xb.size(-1))(xb)  

    ######################
    embedding = nn.Flatten(1)(xb)  
    #can also be written as (xb.view(-1, 1024))
    ######################
      
    xb = F.relu(self.bn5(self.fc1(embedding)))
    xb = F.relu(self.bn6(self.fc2(xb)))
    output = self.fc3(xb)
    output = output.view(batchsize, dim, npoints)
    return  output, embedding

I have used chamfer distance as a loss function provided by pytorch3d.
Tried on bed class in ModelNet dataset and the model is able to reconstruct from the embedding computed by the encoder.

@saltoricristiano
Copy link

Hi @siddharthKatageri,
thanks for sharing your autoencoder! I'm trying to run your model on ModelNet dataset with pytorch3d chamfer loss, but actually I'm having some problem to get reasonable results.
ORIGINAL
original

RECONSTRUCTED
recon

I'm using all a smaller version of ModelNet with 10 classes. Did you use only one of them? Any additional hint?
If you can help, thank you very much in advance!

@Xin546946
Copy link

I have the same issue as well. Is there some tricks on training the pointnet ae? I think pointnet ae is not the best choice for reconstructing the point cloud since MLP is not able to control the permutation invariance but it should work rather than a cluster cloud.

@L1nn97
Copy link

L1nn97 commented May 19, 2021

same issue with u guys, i tried to use emd loss, the reconstructed point cloud always looks like a strange cube.

@saltoricristiano
Copy link

Hi all, are you training the auto encoder on all the classes? I was able to obtain good reconstructions only with per-class trainings...

@L1nn97
Copy link

L1nn97 commented May 19, 2021

Hi all, are you training the auto encoder on all the classes? I was able to obtain good reconstructions only with per-class trainings...

thank!wondering how many points u sampled during training for each batch?

@piseabhijeet
Copy link

piseabhijeet commented Jun 17, 2021

Hi all, are you training the auto encoder on all the classes? I was able to obtain good reconstructions only with per-class trainings...

Hi @saltoricristiano

Thank you for reporting the work. But an AE should scale up for multiple classes too. Isn't it?

@piseabhijeet
Copy link

Hi,

I am trying to implement autendoer in pytorch and I did write the model which I suppose is excatly what is present in this repo.

Model in pytorch

class PCAutoEncoder(nn.Module):
    def __init__(self, point_dim, num_points):
        super(PCAutoEncoder, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=point_dim, out_channels=64, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
        self.conv5 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1)
        self.fc1 = nn.Linear(in_features=1024, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=1024)
        self.fc3 = nn.Linear(in_features=1024, out_features=num_points*3)
        #batch norm
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
    
    def forward(self, x):
        batch_size = x.shape[0]
        point_dim = x.shape[1]
        num_points = x.shape[2]
        #encoder
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn1(self.conv2(x)))
        x = F.relu(self.bn1(self.conv3(x)))
        x = F.relu(self.bn2(self.conv4(x)))
        x = F.relu(self.bn3(self.conv5(x)))
        # do max pooling 
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        # get the global embedding
        global_feat = x
        #decoder
        x = F.relu(self.bn3(self.fc1(x)))
        x = F.relu(self.bn3(self.fc2(x)))
        reconstructed_points = self.fc3(x)
        #do reshaping
        reconstructed_points = reconstructed_points.reshape(batch_size, point_dim, num_points)
        return reconstructed_points, global_feat

However, after training this model for 200 ephocs, when I try to generate the output point cloud all i can generate if scatterd points as shown below -
autoencoder_wrong_output

Any direction to figure out the problem would be helpful.

Hi @dhirajsuvarna

Did you get it to work finally?

@piseabhijeet
Copy link

Also @saltoricristiano

check this out:
#1 (comment)

I hope you are doing the same

@CUN-bjy
Copy link

CUN-bjy commented Nov 10, 2021

Hi,

I am trying to implement autendoer in pytorch and I did write the model which I suppose is excatly what is present in this repo.

Model in pytorch

after training this model for 200 ephocs, when I try to generate the output point cloud all i can generate if scatterd points as shown below - autoencoder_wrong_output

Any direction to figure out the problem would be helpful.

I tested with @dhirajsuvarna 's code(https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch).
And I also have the same problem.

In my case, the shape of tensor inputs was wrong(specifically for the chamfer loss module)(https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch/blob/3bb4a90a8bc016c1d3ab3ab7433f039fb3759196/train_shapenet.py#L64-L83)

for chamfer loss, the tensor shape has to be (batch_size, num_points, num_dim).
but the above code's shape was (batch_size, num_dim, num_points).

            points = data
            points = points.transpose(2, 1)
            points = points.to(device)

            optimizer.zero_grad()
            reconstructed_points, latent_vector = autoencoder(points)

            points = points.transpose(1, 2)
            reconstructed_points = reconstructed_points.transpose(1, 2)
            dist1, dist2 = chamfer_dist(points, reconstructed_points)
            train_loss = (torch.mean(dist1)) + (torch.mean(dist2))

plus, there was just 3 batch norm layer while applied layers are 7.
https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch/blob/3bb4a90a8bc016c1d3ab3ab7433f039fb3759196/model/model.py#L31-L33

With 3 batchnorm layer, it works well in train mode but it doesn't work well in eval mode of autoencoder.
By this, pytorch/pytorch#5406

after changing this, work well.

@PoopBear1
Copy link

Hi,
I am trying to implement autendoer in pytorch and I did write the model which I suppose is excatly what is present in this repo.
Model in pytorch
after training this model for 200 ephocs, when I try to generate the output point cloud all i can generate if scatterd points as shown below - autoencoder_wrong_output
Any direction to figure out the problem would be helpful.

I tested with @dhirajsuvarna 's code(https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch). And I also have the same problem.

In my case, the shape of tensor inputs was wrong(specifically for the chamfer loss module)(https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch/blob/3bb4a90a8bc016c1d3ab3ab7433f039fb3759196/train_shapenet.py#L64-L83)

for chamfer loss, the tensor shape has to be (batch_size, num_points, num_dim). but the above code's shape was (batch_size, num_dim, num_points).

            points = data
            points = points.transpose(2, 1)
            points = points.to(device)

            optimizer.zero_grad()
            reconstructed_points, latent_vector = autoencoder(points)

            points = points.transpose(1, 2)
            reconstructed_points = reconstructed_points.transpose(1, 2)
            dist1, dist2 = chamfer_dist(points, reconstructed_points)
            train_loss = (torch.mean(dist1)) + (torch.mean(dist2))

plus, there was just 3 batch norm layer while applied layers are 7. https://github.com/dhirajsuvarna/pointnet-autoencoder-pytorch/blob/3bb4a90a8bc016c1d3ab3ab7433f039fb3759196/model/model.py#L31-L33

With 3 batchnorm layer, it works well in train mode but it doesn't work well in eval mode of autoencoder. By this, pytorch/pytorch#5406

after changing this, work well.

Thanks for sharing this correct form that Chamfer Distance requires.
Do you happen to know what is going on after changing the code to (batch_size, num_points, num_dim)?
My predicted output becomes a weird 3D cube from the sparse points (as the output 2).

Appreciated your help in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants