Skip to content

Commit

Permalink
Added GODEL support
Browse files Browse the repository at this point in the history
  • Loading branch information
Emulator000 committed May 11, 2023
1 parent 9fd7983 commit 7e00a22
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// limitations under the License.

//! # Multi-turn dialogue
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or
//! [GODEL](https://github.com/microsoft/GODEL).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
Expand Down Expand Up @@ -55,6 +56,7 @@
//! from the 3rd party utilization of the pretrained system.
use crate::common::error::RustBertError;
use crate::gpt2::GPT2Generator;
use crate::t5::T5Generator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
Expand Down Expand Up @@ -695,12 +697,14 @@ impl Default for ConversationManager {
pub enum ConversationOption {
/// Conversation based on GPT2 model
GPT2(GPT2Generator),
T5(T5Generator),
}

impl ConversationOption {
pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
.to_string(),
Expand All @@ -717,6 +721,10 @@ impl ConversationOption {
config.into(),
tokenizer,
)?)),
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
.to_string(),
Expand All @@ -729,27 +737,33 @@ impl ConversationOption {
Self::GPT2(model_ref) => {
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
}
Self::T5(model_ref) => {
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
}
}
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
}
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
}
}

/// Returns the `ModelType` for this ConversationOption
pub fn model_type(&self) -> ModelType {
match *self {
Self::GPT2(_) => ModelType::GPT2,
Self::T5(_) => ModelType::T5,
}
}

Expand All @@ -765,6 +779,11 @@ impl ConversationOption {
.into_iter()
.map(|output| output.indices)
.collect(),
Self::T5(ref model) => model
.generate_from_ids_and_past(input_ids, attention_mask, None)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/t5/t5_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ impl T5ModelResources {
"sentence-t5-base/model",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/model",
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot",
);
}

impl T5ConfigResources {
Expand All @@ -79,6 +84,11 @@ impl T5ConfigResources {
"sentence-t5-base/config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/config",
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json",
);
}

impl T5VocabResources {
Expand All @@ -97,6 +107,11 @@ impl T5VocabResources {
"sentence-t5-base/spiece",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
);
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
"godel-v1-1-base/spiece",
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/spiece.model",
);
}

const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
Expand Down

0 comments on commit 7e00a22

Please sign in to comment.