Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure in implementation using PyTorch (sorry to open an issue here...) #18

Open
JingyunLiang opened this issue Dec 30, 2017 · 1 comment

Comments

@JingyunLiang
Copy link

JingyunLiang commented Dec 30, 2017

When I am trying to implement it using PyTorch, the accuracy rises to 35% in first FC layer pretraining stage ( around epoch 10). In the second stage, however, the accuracy decreases to 20%. The key operations are outer production, average pooling, signed sqrt and L2 normalization. Codes are as follows:

# The definition of Bilinear CNN
# input: [batch, channel, height, width]
class VggBasedNet_bilinear(nn.Module):
    def __init__(self, originalModel):
        super(VggBasedNet_bilinear, self).__init__()
        # feature extraction from Conv5_3 with relu
        self.features = nn.Sequential(*list(original_vgg16.features)[:-1]) 

        self.classifier = nn.Linear(512 * 512, args.numClasses)

    def forward(self, x):
        # feature extraction from Conv5_3 with relu
        x = self.features(x).view(-1,512,784)
        
        #  outer production of features on each position over height*width; average pooling
        x = torch.matmul(x, x.permute(0,2,1)).view(-1,512*512)/784.0

        # signed sqrt
        x = torch.mul(torch.sign(x),torch.sqrt(torch.abs(x)+1e-12)) 

        # L2 normalization
        x = F.normalize(x, p=2, dim=1)

        # final FC layer
        x = self.classifier(x)

        return x

I am sure that there is no wrong in rest codes because I only changed the network structure based on a VGG16 fine-tuning script.
Anyone who knows PyTorch? Is there any problem with above codes? Can they achieve their corresponding function?

@theFool32
Copy link

I run into similary result as you.
Any solution have you found?
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants