From fdac3edf930dcd70158d7939a3e8dcf6c15df405 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 18 Aug 2024 10:04:19 +0100 Subject: [PATCH 1/4] tch version update --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b63f11ab..200e5d7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,14 +76,14 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "8.1.1" -tch = { version = "0.16.0", features = ["download-libtorch"] } +tch = { version = "0.17.0", features = ["download-libtorch"] } serde_json = "1" serde = { version = "1", features = ["derive"] } ordered-float = "4.2.0" uuid = { version = "1", features = ["v4"] } thiserror = "1" half = "2" -regex = "1.6" +regex = "1.10" cached-path = { version = "0.6", default-features = false, optional = true } dirs = { version = "5", optional = true } @@ -92,7 +92,7 @@ ort = { version = "1.16.3", optional = true, default-features = false, features "half", ] } ndarray = { version = "0.15", optional = true } -tokenizers = { version = "0.19.1", optional = true, default-features = false, features = [ +tokenizers = { version = "0.20", optional = true, default-features = false, features = [ "onig", ] } From e876d3deb4320fe603796901ac92004da8a60d31 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 18 Aug 2024 10:14:14 +0100 Subject: [PATCH 2/4] Updated readmes --- README.md | 202 ++++++++++++++++++++++++++--------------------------- src/lib.rs | 4 +- 2 files changed, 103 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index 2fd91270..a78e16ed 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ translation, summarization, text generation, conversational agents and more in just a few lines of code: ```rust - let qa_model = QuestionAnsweringModel::new(Default::default())?; - - let question = String::from("Where does Amy live ?"); - let context = String::from("Amy lives in Amsterdam"); + let qa_model = QuestionAnsweringModel::new(Default::default ()) ?; - let answers = qa_model.predict(&[QaInput { question, context }], 1, 32); +let question = String::from("Where does Amy live ?"); +let context = String::from("Amy lives in Amsterdam"); + +let answers = qa_model.predict( & [QaInput { question, context }], 1, 32); ``` Output: @@ -54,32 +54,32 @@ The tasks currently supported include: Expand to display the supported models/tasks matrix | | **Sequence classification** | **Token classification** | **Question answering** | **Text Generation** | **Summarization** | **Translation** | **Masked LM** | **Sentence Embeddings** | -| :----------: | :-------------------------: | :----------------------: | :--------------------: | :-----------------: | :---------------: | :-------------: | :-----------: | :---------------------: | -| DistilBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| MobileBERT | ✅ | ✅ | ✅ | | | | ✅ | | -| DeBERTa | ✅ | ✅ | ✅ | | | | ✅ | | -| DeBERTa (v2) | ✅ | ✅ | ✅ | | | | ✅ | | -| FNet | ✅ | ✅ | ✅ | | | | ✅ | | -| BERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| RoBERTa | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| GPT | | | | ✅ | | | | | -| GPT2 | | | | ✅ | | | | | -| GPT-Neo | | | | ✅ | | | | | -| GPT-J | | | | ✅ | | | | | -| BART | ✅ | | | ✅ | ✅ | | | | -| Marian | | | | | | ✅ | | | -| MBart | ✅ | | | ✅ | | | | | -| M2M100 | | | | ✅ | | | | | -| NLLB | | | | ✅ | | | | | -| Electra | | ✅ | | | | | ✅ | | -| ALBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| T5 | | | | ✅ | ✅ | ✅ | | ✅ | -| LongT5 | | | | ✅ | ✅ | | | | -| XLNet | ✅ | ✅ | ✅ | ✅ | | | ✅ | | -| Reformer | ✅ | | ✅ | ✅ | | | ✅ | | -| ProphetNet | | | | ✅ | ✅ | | | | -| Longformer | ✅ | ✅ | ✅ | | | | ✅ | | -| Pegasus | | | | | ✅ | | | | +|:------------:|:---------------------------:|:------------------------:|:----------------------:|:-------------------:|:-----------------:|:---------------:|:-------------:|:-----------------------:| +| DistilBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| MobileBERT | ✅ | ✅ | ✅ | | | | ✅ | | +| DeBERTa | ✅ | ✅ | ✅ | | | | ✅ | | +| DeBERTa (v2) | ✅ | ✅ | ✅ | | | | ✅ | | +| FNet | ✅ | ✅ | ✅ | | | | ✅ | | +| BERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| RoBERTa | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| GPT | | | | ✅ | | | | | +| GPT2 | | | | ✅ | | | | | +| GPT-Neo | | | | ✅ | | | | | +| GPT-J | | | | ✅ | | | | | +| BART | ✅ | | | ✅ | ✅ | | | | +| Marian | | | | | | ✅ | | | +| MBart | ✅ | | | ✅ | | | | | +| M2M100 | | | | ✅ | | | | | +| NLLB | | | | ✅ | | | | | +| Electra | | ✅ | | | | | ✅ | | +| ALBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| T5 | | | | ✅ | ✅ | ✅ | | ✅ | +| LongT5 | | | | ✅ | ✅ | | | | +| XLNet | ✅ | ✅ | ✅ | ✅ | | | ✅ | | +| Reformer | ✅ | | ✅ | ✅ | | | ✅ | | +| ProphetNet | | | | ✅ | ✅ | | | | +| Longformer | ✅ | ✅ | ✅ | | | | ✅ | | +| Pegasus | | | | | ✅ | | | | @@ -100,10 +100,10 @@ models used by this library are in the order of the 100s of MBs to GBs. ### Manual installation (recommended) 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This - package requires `v2.2`: if this version is no longer available on the "get + package requires `v2.4`: if this version is no longer available on the "get started" page, the file should be accessible by modifying the target link, for example - `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` + `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package @@ -188,7 +188,7 @@ files. These files are expected (but not all are necessary) for use in this library as per the table below: | Architecture | Encoder file | Decoder without past file | Decoder with past file | -| --------------------------- | ------------ | ------------------------- | ---------------------- | +|-----------------------------|--------------|---------------------------|------------------------| | Encoder (e.g. BERT) | required | not used | not used | | Decoder (e.g. GPT2) | not used | required | optional | | Encoder-decoder (e.g. BART) | required | required | optional | @@ -227,12 +227,12 @@ Extractive question answering from a given question and context. DistilBERT model fine-tuned on SQuAD (Stanford Question Answering Dataset) ```rust - let qa_model = QuestionAnsweringModel::new(Default::default())?; - - let question = String::from("Where does Amy live ?"); - let context = String::from("Amy lives in Amsterdam"); + let qa_model = QuestionAnsweringModel::new(Default::default ()) ?; + +let question = String::from("Where does Amy live ?"); +let context = String::from("Amy lives in Amsterdam"); - let answers = qa_model.predict(&[QaInput { question, context }], 1, 32); +let answers = qa_model.predict( & [QaInput { question, context }], 1, 32); ``` Output: @@ -281,7 +281,7 @@ is available in the ```rust use rust_bert::pipelines::translation::{Language, TranslationModelBuilder}; fn main() -> anyhow::Result<()> { -let model = TranslationModelBuilder::new() + let model = TranslationModelBuilder::new() .with_source_languages(vec![Language::English]) .with_target_languages(vec![Language::Spanish, Language::French, Language::Italian]) .create_model()?; @@ -308,9 +308,9 @@ Il s'agit d'une phrase à traduire Abstractive summarization using a pretrained BART model. ```rust - let summarization_model = SummarizationModel::new(Default::default())?; - - let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ + let summarization_model = SummarizationModel::new(Default::default ()) ?; + +let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \ from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \ a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \ @@ -332,7 +332,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \ about exoplanets like K2-18b."]; - let output = summarization_model.summarize(&input); +let output = summarization_model.summarize( & input); ``` (example from: @@ -367,11 +367,11 @@ generate responses to them. ```rust use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager}; -let conversation_model = ConversationModel::new(Default::default()); +let conversation_model = ConversationModel::new(Default::default ()); let mut conversation_manager = ConversationManager::new(); let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?"); -let output = conversation_model.generate_responses(&mut conversation_manager); +let output = conversation_model.generate_responses( & mut conversation_manager); ``` Example output: @@ -393,17 +393,17 @@ present, the unknown token otherwise. This may impact the results, it is recommended to submit prompts of similar length for best results ```rust - let model = GPT2Generator::new(Default::default())?; - - let input_context_1 = "The dog"; - let input_context_2 = "The cat was"; + let model = GPT2Generator::new(Default::default ()) ?; - let generate_options = GenerateOptions { - max_length: 30, - ..Default::default() - }; +let input_context_1 = "The dog"; +let input_context_2 = "The cat was"; - let output = model.generate(Some(&[input_context_1, input_context_2]), generate_options); +let generate_options = GenerateOptions { +max_length: 30, +..Default::default () +}; + +let output = model.generate(Some( & [input_context_1, input_context_2]), generate_options); ``` Example output: @@ -428,18 +428,18 @@ Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. ```rust - let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; + let sequence_classification_model = ZeroShotClassificationModel::new(Default::default ()) ?; - let input_sentence = "Who are you voting for in 2020?"; - let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; - let candidate_labels = &["politics", "public health", "economics", "sports"]; +let input_sentence = "Who are you voting for in 2020?"; +let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; +let candidate_labels = & ["politics", "public health", "economics", "sports"]; - let output = sequence_classification_model.predict_multilabel( - &[input_sentence, input_sequence_2], - candidate_labels, - None, - 128, - ); +let output = sequence_classification_model.predict_multilabel( +& [input_sentence, input_sequence_2], +candidate_labels, +None, +128, +); ``` Output: @@ -460,15 +460,15 @@ Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST-2. ```rust - let sentiment_classifier = SentimentModel::new(Default::default())?; - - let input = [ - "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", - "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", - "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", - ]; + let sentiment_classifier = SentimentModel::new(Default::default ()) ?; + +let input = [ +"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", +"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", +"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", +]; - let output = sentiment_classifier.predict(&input); +let output = sentiment_classifier.predict( & input); ``` (Example courtesy of [IMDb](http://www.imdb.com)) @@ -494,14 +494,14 @@ BERT cased large model fine-tuned on CoNNL03, contributed by the Models are currently available for English, German, Spanish and Dutch. ```rust - let ner_model = NERModel::new(default::default())?; + let ner_model = NERModel::new( default::default ()) ?; + +let input = [ +"My name is Amy. I live in Paris.", +"Paris is a city in France." +]; - let input = [ - "My name is Amy. I live in Paris.", - "Paris is a city in France." - ]; - - let output = ner_model.predict(&input); +let output = ner_model.predict( & input); ``` Output: @@ -529,7 +529,7 @@ Extract keywords and keyphrases extractions from input documents ```rust fn main() -> anyhow::Result<()> { let keyword_extraction_model = KeywordExtractionModel::new(Default::default())?; - + let input = "Rust is a multi-paradigm, general-purpose programming language. \ Rust emphasizes performance, type safety, and concurrency. Rust enforces memory safety—that is, \ that all references point to valid memory—without requiring the use of a garbage collector or \ @@ -560,11 +560,11 @@ Output: Extracts Part of Speech tags (Noun, Verb, Adjective...) from text. ```rust - let pos_model = POSModel::new(default::default())?; + let pos_model = POSModel::new( default::default ()) ?; + +let input = ["My name is Bob"]; - let input = ["My name is Bob"]; - - let output = pos_model.predict(&input); +let output = pos_model.predict( & input); ``` Output: @@ -588,15 +588,15 @@ applications including dense information retrieval. ```rust let model = SentenceEmbeddingsBuilder::remote( - SentenceEmbeddingsModelType::AllMiniLmL12V2 - ).create_model()?; +SentenceEmbeddingsModelType::AllMiniLmL12V2 +).create_model() ?; + +let sentences = [ +"this is an example sentence", +"each sentence is converted" +]; - let sentences = [ - "this is an example sentence", - "each sentence is converted" - ]; - - let output = model.encode(&sentences)?; +let output = model.encode( & sentences) ?; ``` Output: @@ -616,14 +616,14 @@ Output: Predict masked words in input sentences. ```rust - let model = MaskedLanguageModel::new(Default::default())?; - - let sentences = [ - "Hello I am a student", - "Paris is the of France. It is in Europe.", - ]; - - let output = model.predict(&sentences); + let model = MaskedLanguageModel::new(Default::default ()) ?; + +let sentences = [ +"Hello I am a student", +"Paris is the of France. It is in Europe.", +]; + +let output = model.predict( & sentences); ``` Output: diff --git a/src/lib.rs b/src/lib.rs index 45f7ef06..58cd01b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,8 +90,8 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from . This package requires `v2.2`: if this version is no longer available on the "get started" page, -//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12. +//! 1. Download `libtorch` from . This package requires `v2.4`: if this version is no longer available on the "get started" page, +//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` for a Linux version with CUDA12. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables //! ##### Linux: From 0c65b21250aec9de4efdcce214e3883627de87d1 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 18 Aug 2024 10:19:25 +0100 Subject: [PATCH 3/4] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a78e16ed..3df54844 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ models used by this library are in the order of the 100s of MBs to GBs. package requires `v2.4`: if this version is no longer available on the "get started" page, the file should be accessible by modifying the target link, for example - `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` + `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu121.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package From be58b9412cce76e461a19ebac15bf6b2e11471c2 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 18 Aug 2024 10:24:04 +0100 Subject: [PATCH 4/4] fix readme --- README.md | 204 ++++++++++++++++++++++++++--------------------------- src/lib.rs | 2 +- 2 files changed, 103 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index 3df54844..dcebf15c 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ translation, summarization, text generation, conversational agents and more in just a few lines of code: ```rust - let qa_model = QuestionAnsweringModel::new(Default::default ()) ?; + let qa_model = QuestionAnsweringModel::new(Default::default())?; + + let question = String::from("Where does Amy live ?"); + let context = String::from("Amy lives in Amsterdam"); -let question = String::from("Where does Amy live ?"); -let context = String::from("Amy lives in Amsterdam"); - -let answers = qa_model.predict( & [QaInput { question, context }], 1, 32); + let answers = qa_model.predict(&[QaInput { question, context }], 1, 32); ``` Output: @@ -54,32 +54,32 @@ The tasks currently supported include: Expand to display the supported models/tasks matrix | | **Sequence classification** | **Token classification** | **Question answering** | **Text Generation** | **Summarization** | **Translation** | **Masked LM** | **Sentence Embeddings** | -|:------------:|:---------------------------:|:------------------------:|:----------------------:|:-------------------:|:-----------------:|:---------------:|:-------------:|:-----------------------:| -| DistilBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| MobileBERT | ✅ | ✅ | ✅ | | | | ✅ | | -| DeBERTa | ✅ | ✅ | ✅ | | | | ✅ | | -| DeBERTa (v2) | ✅ | ✅ | ✅ | | | | ✅ | | -| FNet | ✅ | ✅ | ✅ | | | | ✅ | | -| BERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| RoBERTa | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| GPT | | | | ✅ | | | | | -| GPT2 | | | | ✅ | | | | | -| GPT-Neo | | | | ✅ | | | | | -| GPT-J | | | | ✅ | | | | | -| BART | ✅ | | | ✅ | ✅ | | | | -| Marian | | | | | | ✅ | | | -| MBart | ✅ | | | ✅ | | | | | -| M2M100 | | | | ✅ | | | | | -| NLLB | | | | ✅ | | | | | -| Electra | | ✅ | | | | | ✅ | | -| ALBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | -| T5 | | | | ✅ | ✅ | ✅ | | ✅ | -| LongT5 | | | | ✅ | ✅ | | | | -| XLNet | ✅ | ✅ | ✅ | ✅ | | | ✅ | | -| Reformer | ✅ | | ✅ | ✅ | | | ✅ | | -| ProphetNet | | | | ✅ | ✅ | | | | -| Longformer | ✅ | ✅ | ✅ | | | | ✅ | | -| Pegasus | | | | | ✅ | | | | +| :----------: | :-------------------------: | :----------------------: | :--------------------: | :-----------------: | :---------------: | :-------------: | :-----------: | :---------------------: | +| DistilBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| MobileBERT | ✅ | ✅ | ✅ | | | | ✅ | | +| DeBERTa | ✅ | ✅ | ✅ | | | | ✅ | | +| DeBERTa (v2) | ✅ | ✅ | ✅ | | | | ✅ | | +| FNet | ✅ | ✅ | ✅ | | | | ✅ | | +| BERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| RoBERTa | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| GPT | | | | ✅ | | | | | +| GPT2 | | | | ✅ | | | | | +| GPT-Neo | | | | ✅ | | | | | +| GPT-J | | | | ✅ | | | | | +| BART | ✅ | | | ✅ | ✅ | | | | +| Marian | | | | | | ✅ | | | +| MBart | ✅ | | | ✅ | | | | | +| M2M100 | | | | ✅ | | | | | +| NLLB | | | | ✅ | | | | | +| Electra | | ✅ | | | | | ✅ | | +| ALBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ | +| T5 | | | | ✅ | ✅ | ✅ | | ✅ | +| LongT5 | | | | ✅ | ✅ | | | | +| XLNet | ✅ | ✅ | ✅ | ✅ | | | ✅ | | +| Reformer | ✅ | | ✅ | ✅ | | | ✅ | | +| ProphetNet | | | | ✅ | ✅ | | | | +| Longformer | ✅ | ✅ | ✅ | | | | ✅ | | +| Pegasus | | | | | ✅ | | | | @@ -100,10 +100,10 @@ models used by this library are in the order of the 100s of MBs to GBs. ### Manual installation (recommended) 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This - package requires `v2.4`: if this version is no longer available on the "get + package requires `v2.2`: if this version is no longer available on the "get started" page, the file should be accessible by modifying the target link, for example - `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu121.zip` + `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package @@ -140,7 +140,7 @@ Alternatively, you can let the `build` script automatically download the `libtorch` library for you. The `download-libtorch` feature flag needs to be enabled. The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to -`cu118`. Note that the libtorch library is large (order of several GBs for the +`cu124`. Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. @@ -188,7 +188,7 @@ files. These files are expected (but not all are necessary) for use in this library as per the table below: | Architecture | Encoder file | Decoder without past file | Decoder with past file | -|-----------------------------|--------------|---------------------------|------------------------| +| --------------------------- | ------------ | ------------------------- | ---------------------- | | Encoder (e.g. BERT) | required | not used | not used | | Decoder (e.g. GPT2) | not used | required | optional | | Encoder-decoder (e.g. BART) | required | required | optional | @@ -227,12 +227,12 @@ Extractive question answering from a given question and context. DistilBERT model fine-tuned on SQuAD (Stanford Question Answering Dataset) ```rust - let qa_model = QuestionAnsweringModel::new(Default::default ()) ?; - -let question = String::from("Where does Amy live ?"); -let context = String::from("Amy lives in Amsterdam"); + let qa_model = QuestionAnsweringModel::new(Default::default())?; + + let question = String::from("Where does Amy live ?"); + let context = String::from("Amy lives in Amsterdam"); -let answers = qa_model.predict( & [QaInput { question, context }], 1, 32); + let answers = qa_model.predict(&[QaInput { question, context }], 1, 32); ``` Output: @@ -281,7 +281,7 @@ is available in the ```rust use rust_bert::pipelines::translation::{Language, TranslationModelBuilder}; fn main() -> anyhow::Result<()> { - let model = TranslationModelBuilder::new() +let model = TranslationModelBuilder::new() .with_source_languages(vec![Language::English]) .with_target_languages(vec![Language::Spanish, Language::French, Language::Italian]) .create_model()?; @@ -308,9 +308,9 @@ Il s'agit d'une phrase à traduire Abstractive summarization using a pretrained BART model. ```rust - let summarization_model = SummarizationModel::new(Default::default ()) ?; - -let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ + let summarization_model = SummarizationModel::new(Default::default())?; + + let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \ from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \ a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \ @@ -332,7 +332,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \ about exoplanets like K2-18b."]; -let output = summarization_model.summarize( & input); + let output = summarization_model.summarize(&input); ``` (example from: @@ -367,11 +367,11 @@ generate responses to them. ```rust use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager}; -let conversation_model = ConversationModel::new(Default::default ()); +let conversation_model = ConversationModel::new(Default::default()); let mut conversation_manager = ConversationManager::new(); let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?"); -let output = conversation_model.generate_responses( & mut conversation_manager); +let output = conversation_model.generate_responses(&mut conversation_manager); ``` Example output: @@ -393,17 +393,17 @@ present, the unknown token otherwise. This may impact the results, it is recommended to submit prompts of similar length for best results ```rust - let model = GPT2Generator::new(Default::default ()) ?; + let model = GPT2Generator::new(Default::default())?; + + let input_context_1 = "The dog"; + let input_context_2 = "The cat was"; -let input_context_1 = "The dog"; -let input_context_2 = "The cat was"; + let generate_options = GenerateOptions { + max_length: 30, + ..Default::default() + }; -let generate_options = GenerateOptions { -max_length: 30, -..Default::default () -}; - -let output = model.generate(Some( & [input_context_1, input_context_2]), generate_options); + let output = model.generate(Some(&[input_context_1, input_context_2]), generate_options); ``` Example output: @@ -428,18 +428,18 @@ Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. ```rust - let sequence_classification_model = ZeroShotClassificationModel::new(Default::default ()) ?; + let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?; -let input_sentence = "Who are you voting for in 2020?"; -let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; -let candidate_labels = & ["politics", "public health", "economics", "sports"]; + let input_sentence = "Who are you voting for in 2020?"; + let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition."; + let candidate_labels = &["politics", "public health", "economics", "sports"]; -let output = sequence_classification_model.predict_multilabel( -& [input_sentence, input_sequence_2], -candidate_labels, -None, -128, -); + let output = sequence_classification_model.predict_multilabel( + &[input_sentence, input_sequence_2], + candidate_labels, + None, + 128, + ); ``` Output: @@ -460,15 +460,15 @@ Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST-2. ```rust - let sentiment_classifier = SentimentModel::new(Default::default ()) ?; - -let input = [ -"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", -"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", -"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", -]; + let sentiment_classifier = SentimentModel::new(Default::default())?; + + let input = [ + "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", + "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", + "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", + ]; -let output = sentiment_classifier.predict( & input); + let output = sentiment_classifier.predict(&input); ``` (Example courtesy of [IMDb](http://www.imdb.com)) @@ -494,14 +494,14 @@ BERT cased large model fine-tuned on CoNNL03, contributed by the Models are currently available for English, German, Spanish and Dutch. ```rust - let ner_model = NERModel::new( default::default ()) ?; - -let input = [ -"My name is Amy. I live in Paris.", -"Paris is a city in France." -]; + let ner_model = NERModel::new(default::default())?; -let output = ner_model.predict( & input); + let input = [ + "My name is Amy. I live in Paris.", + "Paris is a city in France." + ]; + + let output = ner_model.predict(&input); ``` Output: @@ -529,7 +529,7 @@ Extract keywords and keyphrases extractions from input documents ```rust fn main() -> anyhow::Result<()> { let keyword_extraction_model = KeywordExtractionModel::new(Default::default())?; - + let input = "Rust is a multi-paradigm, general-purpose programming language. \ Rust emphasizes performance, type safety, and concurrency. Rust enforces memory safety—that is, \ that all references point to valid memory—without requiring the use of a garbage collector or \ @@ -560,11 +560,11 @@ Output: Extracts Part of Speech tags (Noun, Verb, Adjective...) from text. ```rust - let pos_model = POSModel::new( default::default ()) ?; - -let input = ["My name is Bob"]; + let pos_model = POSModel::new(default::default())?; -let output = pos_model.predict( & input); + let input = ["My name is Bob"]; + + let output = pos_model.predict(&input); ``` Output: @@ -588,15 +588,15 @@ applications including dense information retrieval. ```rust let model = SentenceEmbeddingsBuilder::remote( -SentenceEmbeddingsModelType::AllMiniLmL12V2 -).create_model() ?; - -let sentences = [ -"this is an example sentence", -"each sentence is converted" -]; + SentenceEmbeddingsModelType::AllMiniLmL12V2 + ).create_model()?; -let output = model.encode( & sentences) ?; + let sentences = [ + "this is an example sentence", + "each sentence is converted" + ]; + + let output = model.encode(&sentences)?; ``` Output: @@ -616,14 +616,14 @@ Output: Predict masked words in input sentences. ```rust - let model = MaskedLanguageModel::new(Default::default ()) ?; - -let sentences = [ -"Hello I am a student", -"Paris is the of France. It is in Europe.", -]; - -let output = model.predict( & sentences); + let model = MaskedLanguageModel::new(Default::default())?; + + let sentences = [ + "Hello I am a student", + "Paris is the of France. It is in Europe.", + ]; + + let output = model.predict(&sentences); ``` Output: diff --git a/src/lib.rs b/src/lib.rs index 2949448e..aebc99ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,7 +109,7 @@ //! ### Automatic installation //! //! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. The `download-libtorch` feature flag needs to be enabled. -//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`. +//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu124`. //! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. //! //! ## ONNX Support (Optional)