Skip to content

Commit 4ed17ef

Browse files
committed
initial release
1 parent 3214a0b commit 4ed17ef

30 files changed

+4634
-2
lines changed

Experiments_FashionMNIST.ipynb

Lines changed: 1888 additions & 0 deletions
Large diffs are not rendered by default.

Experiments_MNIST.ipynb

Lines changed: 1803 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 192 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,192 @@
1-
# siamese-triplet
2-
Siamese and triplet networks with online pair/triplet mining in PyTorch
1+
# Siamese and triplet learning with online pair/triplet mining
2+
3+
PyTorch implementation of siamese and triplet networks for learning embeddings.
4+
5+
Siamese and triplet networks are useful to learn mappings from image to a compact Euclidean space where distances correspond to a measure of similarity [2]. Embeddings trained in such way can be used as features vectors for classification or few-shot learning tasks.
6+
7+
# Installation
8+
9+
Requires [pytorch](http://pytorch.org/) 0.3.1 with torchvision 0.2.0
10+
11+
# Code structure
12+
13+
- **datasets.py**
14+
- *SiameseMNIST* class - wrapper for a MNIST-like dataset, returning random positive and negative pairs
15+
- *TripletMNIST* class - wrapper for a MNIST-like dataset, returning random triplets (anchor, positive and negative)
16+
- *BalancedBatchSampler* class - BatchSampler for data loader, randomly chooses *n_classes* and *n_samples* from each class of a MNIST-like dataset
17+
- **networks.py**
18+
- *EmbeddingNet* - base network for encoding images into embedding vector
19+
- *ClassificationNet* - wrapper for an embedding network, adds a fully connected layer and log softmax for classification
20+
- *SiameseNet* - wrapper for an embedding network, processes pairs of inputs
21+
- *TripletNet* - wrapper for an embedding network, processes triplets of inputs
22+
- **losses.py**
23+
- *ContrastiveLoss* - contrastive loss for pairs of embeddings and pair target (same/different)
24+
- *TripletLoss* - triplet loss for triplets of embeddings
25+
- *OnlineContrastiveLoss* - contrastive loss for a mini-batch of embeddings. Uses a *PairSelector* object to find positive and negative pairs within a mini-batch using ground truth class labels and computes contrastive loss for these pairs
26+
- *OnlineTripletLoss* - triplet loss for a mini-batch of embeddings. Uses a *TripletSelector* object to find triplets within a mini-batch using ground truth class labels and computes triplet loss
27+
- **trainer.py**
28+
- *fit* - unified function for training a network with different number of inputs and different types of loss functions
29+
- **metrics.py**
30+
- Sample metrics that can be used with *fit* function from *trainer.py*
31+
- **utils.py**
32+
- *PairSelector* - abstract class defining objects generating pairs based on embeddings and ground truth class labels. Can be used with *OnlineContrastiveLoss*.
33+
- *AllPositivePairSelector, HardNegativePairSelector* - PairSelector implementations
34+
- *TripletSelector* - abstract class defining objects generating triplets based on embeddings and ground truth class labels. Can be used with *OnlineTripletLoss*.
35+
- *AllTripletSelector*, *HardestNegativeTripletSelector*, *RandomNegativeTripletSelector*, *SemihardNegativeTripletSelector* - TripletSelector implementations
36+
37+
# Examples
38+
39+
We'll train embeddings on MNIST dataset. Experiments were run in [jupyter notebook](Experiments_MNIST.ipynb).
40+
41+
We'll go through learning supervised feature embeddings using different loss functions on MNIST dataset. This is just for visualization purposes, thus we'll be using 2-dimensional embeddings which isn't the best choice in practice.
42+
43+
For every experiment the same embedding network is used (32 conv 5x5 -> PReLU -> MaxPool 2x2 -> 64 conv 5x5 -> PReLU -> MaxPool 2x2 -> Dense 256 -> PReLU -> Dense 256 -> PReLU -> Dense 2) and we don't perform any hyperparameter search.
44+
45+
## Baseline - classification with softmax
46+
47+
We add a fully-connected layer with the number of classes and train the network for classification with softmax and cross-entropy. The network trains to ~99% accuracy. We extract 2 dimensional embeddings from penultimate layer:
48+
49+
Train set:
50+
51+
![](images/mnist_softmax_train.png)
52+
53+
Test set:
54+
55+
![](images/mnist_softmax_test.png)
56+
57+
While the embeddings look separable (which is what we trained them for), they don't have good metric properties. They might not be the best choice as a descriptor for new classes.
58+
59+
## Siamese network
60+
61+
Now we'll train a siamese network that takes a pair of images and trains the embeddings so that the distance between them is minimized if they're from the same class and is greater than some margin value if they represent different classes.
62+
We'll minimize a contrastive loss function [1]:
63+
$$L_{contrastive}(x_0, x_1, y) = \frac{1}{2} y \lVert f(x_0)-f(x_1)\rVert_2^2 + \frac{1}{2}(1-y)\{max(0, m-\lVert f(x_0)-f(x_1)\rVert_2\}^2$$
64+
65+
*SiameseMNIST* class samples random positive and negative pairs that are then fed to Siamese Network.
66+
67+
After 20 epochs of training here are the embeddings we get for training set:
68+
69+
![](images/mnist_siamese_train.png)
70+
71+
Test set:
72+
73+
![](images/mnist_siamese_test.png)
74+
75+
The learned embeddings are clustered much better within class.
76+
77+
## Triplet network
78+
79+
We'll train a triplet network, that takes an anchor, a positive (of same class as an anchor) and negative (of different class than an anchor) examples. The objective is to learn embeddings such that the anchor is closer to the positive example than it is to the negative example by some margin value.
80+
81+
![alt text](images/anchor_negative_positive.png "Source: FaceNet")
82+
Source: *Schroff, Florian, Dmitry Kalenichenko, and James Philbin. [Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) CVPR 2015.*
83+
84+
**Triplet loss**: $L_{triplet}(x_a, x_p, x_n) = m + \lVert f(x_a)-f(x_p)\rVert_2^2 - \lVert f(x_a)-f(x_n)\rVert_2^2$
85+
86+
*TripletMNIST* class samples a positive and negative example for every possible anchor.
87+
88+
After 20 epochs of training here are the embeddings we get for training set:
89+
90+
![](images/mnist_triplet_train.png)
91+
92+
Test set:
93+
94+
![](images/mnist_triplet_test.png)
95+
96+
The learned embeddings are not as close to each other within class as in case of siamese network, but that's not what we optimized them for. We wanted the embeddings to be closer to other embeddings from the same class than from the other classes and we can see that's where the training is going to.
97+
98+
## Online pair/triplet selection - negative mining
99+
100+
There are couple of problems with siamese and triplet networks:
101+
1. The **number of possible pairs/triplets** grows **quadratically/cubically** with the number of examples. It's infeasible to process them all and the training converges slowly.
102+
2. We generate pairs/triplets *randomly*. As the training continues, more and more pairs/triplets are **easy** to deal with (their loss value is very small or even 0), *preventing the network from training*. We need to provide the network with **hard examples**.
103+
3. Each image that is fed to the network is used only for computation of contrastive/triplet loss for only one pair/triplet. The computation is somewhat wasted; once the embedding is computed, it could be reused for many pairs/triplets.
104+
105+
To deal with these issues efficiently, we'll feed a network with standard mini-batches as we did for classification. The loss function will be responsible for selection of hard pairs and triplets within mini-batch. If we feed the network with 16 images per 10 classes, we can process up to $159*160/2 = 12720$ pairs and $10*16*15/2*(9*16) = 172800$ triplets, compared to 80 pairs and 53 triplets in previous implementation.
106+
107+
Usually it's not the best idea to process all possible pairs or triplets within a mini-batch. We can find some strategies on how to select triplets in [2] and [3].
108+
109+
### Online pair selection
110+
111+
We'll feed a network with mini-batches, as we did for classification network. This time we'll use a special BatchSampler that will sample *n_classes* and *n_samples* within each class, resulting in mini batches of size *n_classes\*n_samples*.
112+
113+
For each mini batch positive and negative pairs will be selected using provided labels.
114+
115+
MNIST is a rather easy dataset and the embeddings from the randomly selected pairs were quite good already, we don't see much improvement here.
116+
117+
**Train embeddings:**
118+
119+
![](images/mnist_ocl_train.png)
120+
121+
**Test embeddings:**
122+
123+
![](images/mnist_ocl_test.png)
124+
125+
### Online triplet selection
126+
127+
We'll feed a network with mini-batches just like with online pair selection. There are couple of strategies we can use for triplet selection given labels and predicted embeddings:
128+
129+
- All possible triplets (might be too many)
130+
- Hardest negative for each positive pair (will result in the same negative for each anchor)
131+
- Random hard negative for each positive pair (consider only triplets with positive triplet loss value)
132+
- Semi-hard negative for each positive pair (similar to [2])
133+
134+
The strategy for triplet selection must be chosen carefully. A bad strategy might lead to inefficient training or, even worse, to model collapsing (all embeddings ending up having the same values).
135+
136+
Here's what we got with random hard negatives for each positive pair.
137+
138+
**Training set:**
139+
140+
![](images/mnist_otl_train.png)
141+
142+
**Test set:**
143+
144+
![](images/mnist_otl_test.png)
145+
146+
# FashionMNIST
147+
148+
Similar experiments were conducted for FashionMNIST dataset where advantages of online negative mining are more visible. The exact same network architecture with only 2-dimensional embeddings was used, which is probably not complex enough for learning good embeddings.
149+
150+
## Baseline - classification
151+
152+
![](images/fmnist_softmax_test.png)
153+
154+
## Siamese vs online contrastive loss with negative mining
155+
156+
Siamese network with randomly selected pairs
157+
158+
![](images/fmnist_siamese_test.png)
159+
160+
Online contrastive loss with negative mining
161+
162+
![](images/fmnist_ocl_test.png)
163+
164+
## Triplet vs online triplet loss with negative mining
165+
166+
Triplet network with random triplets
167+
168+
![](images/fmnist_triplet_test.png)
169+
170+
Online triplet loss with negative mining
171+
172+
![](images/fmnist_otl_test.png)
173+
174+
# TODO
175+
176+
- [ ] Optimize triplet selection
177+
- [ ] Evaluate with a metric that is comparable between approaches
178+
- [ ] Evaluate in one-shot setting when classes from test set are not in train set
179+
- [ ] Show online triplet selection example on more difficult datasets
180+
181+
# References
182+
183+
[1] Raia Hadsell, Sumit Chopra, Yann LeCun, [Dimensionality reduction by learning an invariant mapping](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), CVPR 2006
184+
185+
[2] Schroff, Florian, Dmitry Kalenichenko, and James Philbin. [Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) CVPR 2015
186+
187+
[3] Alexander Hermans, Lucas Beyer, Bastian Leibe, [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737), 2017
188+
189+
[4] Brandon Amos, Bartosz Ludwiczuk, Mahadev Satyanarayanan, [OpenFace: A general-purpose face recognition library with mobile applications](http://reports-archive.adm.cs.cmu.edu/anon/2016/CMU-CS-16-118.pdf), 2016
190+
191+
[5] Yi Sun, Xiaogang Wang, Xiaoou Tang, [Deep Learning Face Representation by Joint Identification-Verification](http://papers.nips.cc/paper/5416-deep-learning-face-representation-by-joint-identification-verification), NIPS 2014
192+

datasets.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import numpy as np
2+
from PIL import Image
3+
4+
from torch.utils.data import Dataset
5+
from torch.utils.data.sampler import BatchSampler
6+
7+
8+
class SiameseMNIST(Dataset):
9+
"""
10+
Train: For each sample creates randomly a positive or a negative pair
11+
Test: Creates fixed pairs for testing
12+
"""
13+
14+
def __init__(self, mnist_dataset):
15+
self.mnist_dataset = mnist_dataset
16+
17+
self.train = self.mnist_dataset.train
18+
self.transform = self.mnist_dataset.transform
19+
20+
if self.train:
21+
self.train_labels = self.mnist_dataset.train_labels
22+
self.train_data = self.mnist_dataset.train_data
23+
self.labels_set = set(self.train_labels.numpy())
24+
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
25+
for label in self.labels_set}
26+
else:
27+
# generate fixed pairs for testing
28+
self.test_labels = self.mnist_dataset.test_labels
29+
self.test_data = self.mnist_dataset.test_data
30+
self.labels_set = set(self.test_labels.numpy())
31+
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
32+
for label in self.labels_set}
33+
34+
random_state = np.random.RandomState(29)
35+
36+
positive_pairs = [[i,
37+
random_state.choice(self.label_to_indices[self.test_labels[i]]),
38+
1]
39+
for i in range(0, len(self.test_data), 2)]
40+
41+
negative_pairs = [[i,
42+
random_state.choice(self.label_to_indices[
43+
np.random.choice(
44+
list(self.labels_set - set([self.test_labels[i]]))
45+
)
46+
]),
47+
0]
48+
for i in range(1, len(self.test_data), 2)]
49+
self.test_pairs = positive_pairs + negative_pairs
50+
51+
def __getitem__(self, index):
52+
if self.train:
53+
target = np.random.randint(0, 2)
54+
img1, label1 = self.train_data[index], self.train_labels[index]
55+
if target == 1:
56+
siamese_index = index
57+
while siamese_index == index:
58+
siamese_index = np.random.choice(self.label_to_indices[label1])
59+
else:
60+
siamese_label = np.random.choice(list(self.labels_set - set([label1])))
61+
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
62+
img2 = self.train_data[siamese_index]
63+
else:
64+
img1 = self.test_data[self.test_pairs[index][0]]
65+
img2 = self.test_data[self.test_pairs[index][1]]
66+
target = self.test_pairs[index][2]
67+
68+
img1 = Image.fromarray(img1.numpy(), mode='L')
69+
img2 = Image.fromarray(img2.numpy(), mode='L')
70+
if self.transform is not None:
71+
img1 = self.transform(img1)
72+
img2 = self.transform(img2)
73+
return (img1, img2), target
74+
75+
def __len__(self):
76+
return len(self.mnist_dataset)
77+
78+
79+
class TripletMNIST(Dataset):
80+
"""
81+
Train: For each sample (anchor) randomly chooses a positive and negative samples
82+
Test: Creates fixed triplets for testing
83+
"""
84+
85+
def __init__(self, mnist_dataset):
86+
self.mnist_dataset = mnist_dataset
87+
self.train = self.mnist_dataset.train
88+
self.transform = self.mnist_dataset.transform
89+
90+
if self.train:
91+
self.train_labels = self.mnist_dataset.train_labels
92+
self.train_data = self.mnist_dataset.train_data
93+
self.labels_set = set(self.train_labels.numpy())
94+
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
95+
for label in self.labels_set}
96+
97+
else:
98+
self.test_labels = self.mnist_dataset.test_labels
99+
self.test_data = self.mnist_dataset.test_data
100+
# generate fixed triplets for testing
101+
self.labels_set = set(self.test_labels.numpy())
102+
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
103+
for label in self.labels_set}
104+
105+
random_state = np.random.RandomState(29)
106+
107+
triplets = [[i,
108+
random_state.choice(self.label_to_indices[self.test_labels[i]]),
109+
random_state.choice(self.label_to_indices[
110+
np.random.choice(
111+
list(self.labels_set - set([self.test_labels[i]]))
112+
)
113+
])
114+
]
115+
for i in range(len(self.test_data))]
116+
self.test_triplets = triplets
117+
118+
def __getitem__(self, index):
119+
if self.train:
120+
img1, label1 = self.train_data[index], self.train_labels[index]
121+
positive_index = index
122+
while positive_index == index:
123+
positive_index = np.random.choice(self.label_to_indices[label1])
124+
negative_label = np.random.choice(list(self.labels_set - set([label1])))
125+
negative_index = np.random.choice(self.label_to_indices[negative_label])
126+
img2 = self.train_data[positive_index]
127+
img3 = self.train_data[negative_index]
128+
else:
129+
img1 = self.test_data[self.test_triplets[index][0]]
130+
img2 = self.test_data[self.test_triplets[index][1]]
131+
img3 = self.test_data[self.test_triplets[index][2]]
132+
133+
img1 = Image.fromarray(img1.numpy(), mode='L')
134+
img2 = Image.fromarray(img2.numpy(), mode='L')
135+
img3 = Image.fromarray(img3.numpy(), mode='L')
136+
if self.transform is not None:
137+
img1 = self.transform(img1)
138+
img2 = self.transform(img2)
139+
img3 = self.transform(img3)
140+
return (img1, img2, img3), []
141+
142+
def __len__(self):
143+
return len(self.mnist_dataset)
144+
145+
146+
147+
148+
class BalancedBatchSampler(BatchSampler):
149+
"""
150+
BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
151+
Returns batches of size n_classes * n_samples
152+
"""
153+
154+
def __init__(self, dataset, n_classes, n_samples):
155+
if dataset.train:
156+
self.labels = dataset.train_labels
157+
else:
158+
self.labels = dataset.test_labels
159+
self.labels_set = list(set(self.labels.numpy()))
160+
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
161+
for label in self.labels_set}
162+
for l in self.labels_set:
163+
np.random.shuffle(self.label_to_indices[l])
164+
self.used_label_indices_count = {label: 0 for label in self.labels_set}
165+
self.count = 0
166+
self.n_classes = n_classes
167+
self.n_samples = n_samples
168+
self.dataset = dataset
169+
self.batch_size = self.n_samples * self.n_classes
170+
171+
def __iter__(self):
172+
self.count = 0
173+
while self.count + self.batch_size < len(self.dataset):
174+
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
175+
indices = []
176+
for class_ in classes:
177+
indices.extend(self.label_to_indices[class_][
178+
self.used_label_indices_count[class_]:self.used_label_indices_count[
179+
class_] + self.n_samples])
180+
self.used_label_indices_count[class_] += self.n_samples
181+
if self.used_label_indices_count[class_] + self.n_samples < len(self.label_to_indices[class_]):
182+
np.random.shuffle(self.label_to_indices[class_])
183+
self.used_label_indices_count[class_] = 0
184+
yield indices
185+
self.count += self.n_classes * self.n_samples
186+
187+
def __len__(self):
188+
return len(self.dataset) // self.batch_size

images/anchor_negative_positive.png

25.9 KB
Loading

images/fmnist_ocl_test.png

140 KB
Loading

images/fmnist_ocl_train.png

169 KB
Loading

images/fmnist_otl_test.png

134 KB
Loading

images/fmnist_otl_train.png

131 KB
Loading

images/fmnist_siamese_test.png

106 KB
Loading

0 commit comments

Comments
 (0)