Implement triplet-loss
The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.
The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.
python train.py
python extract_embeddings.py
python model_on_top.py
You should change --data-path with your own path
The default model is set to resnet18
Download dataset here
I use dogs vs cats here for demo
KNN: 0.9825
SVM: 0.985
Linear SVM: 0.985
RF: 0.9825
Thanks for great inspiration from https://github.com/alfonmedela/triplet-loss-pytorch/tree/master and https://github.com/chencodeX/triplet-loss-pytorch
All code within the repo is under MIT license