Skip to content

Latest commit

 

History

History
82 lines (60 loc) · 4.9 KB

File metadata and controls

82 lines (60 loc) · 4.9 KB

20 Newsgroups sentence-Bert classifier

API for sklearn 20 Newsgroups classifier with Faiss (Approximate Nearest Neighors library)
Author: Arnauld Adjovi

Demo:

Introduction

This repository fine-tunes DistilBERT models on scikit-learn 20 News groups dataset with a triplet network structure to produce semantically meaningful sentence embeddings that can be used in supervised scenarios: Semantic textual similarity via Facebook Faiss (Approximate Nearest-Neighors library) and label predictions.

We fine-tuned a pretrained distilbert-base-nli-mean-tokens with a TripletLoss function for 20 Newsgroups labels prediction. We choosed facebook Faiss-IVFlibrary for semantic search because of his avantages, mainly:

  • fast search time (good Recall-Queries per second)
  • good accuracy
  • low memory footprint per index vector

The final fine-tuned model is available on Google Drive

Results

  • pretrained model: 44.03 accuracy on 20 Newsgroup test set
  • our fine-tuning: 60.45 accuracy on 20 Newsgroup test set

pipeline benchmark on test set for Faiss and pretrained Sbert

How it works

Installation tutorial : INSTALL

The following file are the main components of the application:

Using the REST API

After starting the docker image, the api will be available on localhost:5000/predict. To run model prediction on a sentence, you just need to do a GET request with the sentence in the query parameter.

https://localhost:5000/predict?query=It was a 2-door sports car, looked to be from the late 60s early 70s. It was called a Bricklin. The doors were really small.

You should get a json as output

{
  "label": "rec.autos",
  "query": "It was a 2-door sports car, looked to be from the late 60s early 70s. It was called a Bricklin. The doors were really small."
}

Improvements

The following improvements can help achieve a better performance:

Sentence-Bert embedding:

  • Use Albert instead of distilBert to encode the corpus (Albert Sentencepiece yield on better embedding than Bert Wordpiece)
  • Fine-tune on more labeled data to improve models accuracy and use BatchSemiHardTripletLoss function for better training
  • Leverage quantization and computational graph optimization with OnnxRuntime to improve Albert Inference time

Approximate Nearest-Neighbors library:

  • Faiss-HNSW is the best option if you have a lot of RAM

Reference

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "http://arxiv.org/abs/1908.10084",
}