-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder.py
29 lines (23 loc) · 814 Bytes
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torchvision
from torch import nn as nn
class Encoder(nn.Module):
def __init__(self, linear_layer_size=1000):
"""
CNN for the images
:param linear_layer_size:
"""
super(Encoder, self).__init__()
self.mod = torchvision.models.resnet50(pretrained=True)
for param in self.mod.parameters():
param.requires_grad = False
self.mod.fc = nn.Linear(2048, linear_layer_size)
self.batch_norm = nn.BatchNorm1d(linear_layer_size, momentum=0.01)
def forward(self, images):
"""
Forward pass into the encoder
:param images:
:return: feature vector
"""
features = self.mod(images)
features = self.batch_norm(features.reshape(features.size(0), -1))
return features