Skip to content

Godthumb/triplet-loss

Repository files navigation

Triplet-loss-pytorch

Implement triplet-loss

Introduction

The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.

The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.

Usage

python train.py
python extract_embeddings.py
python model_on_top.py

You should change --data-path with your own path
The default model is set to resnet18

Data

Download dataset here
I use dogs vs cats here for demo

Results


KNN: 0.9825
SVM: 0.985
Linear SVM: 0.985
RF: 0.9825

Acknowledgement

Thanks for great inspiration from https://github.com/alfonmedela/triplet-loss-pytorch/tree/master and https://github.com/chencodeX/triplet-loss-pytorch

License

All code within the repo is under MIT license

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages