Skip to content

Latest commit

 

History

History
41 lines (31 loc) · 1.2 KB

README.md

File metadata and controls

41 lines (31 loc) · 1.2 KB

Triplet-loss-pytorch

Implement triplet-loss

Introduction

The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.

The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.

Usage

python train.py
python extract_embeddings.py
python model_on_top.py

You should change --data-path with your own path
The default model is set to resnet18

Data

Download dataset here
I use dogs vs cats here for demo

Results


KNN: 0.9825
SVM: 0.985
Linear SVM: 0.985
RF: 0.9825

Acknowledgement

Thanks for great inspiration from https://github.com/alfonmedela/triplet-loss-pytorch/tree/master and https://github.com/chencodeX/triplet-loss-pytorch

License

All code within the repo is under MIT license