From e31e0c8546c399828031244e973e9f542cde2e38 Mon Sep 17 00:00:00 2001 From: Royal Cities Date: Tue, 4 Feb 2025 14:11:50 -0500 Subject: [PATCH] Update conditioners.py Fix T5-base tokenizer config loading - Use T5Tokenizer directly instead of AutoTokenizer - Add explicit tokenizer parameters - Disable auth checks for public models - Enable legacy mode for older T5 models --- stable_audio_tools/models/conditioners.py | 54 ++++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index 6f2fe67d..a530c45e 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -275,10 +275,9 @@ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] class T5Conditioner(Conditioner): - T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", - "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] T5_MODEL_DIMS = { "t5-small": 512, @@ -301,14 +300,14 @@ def __init__( self, output_dim: int, t5_model_name: str = "t5-base", - max_length: str = 128, + max_length: int = 128, # Changed from str to int enable_grad: bool = False, project_out: bool = False ): assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) - from transformers import T5EncoderModel, AutoTokenizer + from transformers import T5EncoderModel, T5Tokenizer # Changed to T5Tokenizer self.max_length = max_length self.enable_grad = enable_grad @@ -319,10 +318,26 @@ def __init__( with warnings.catch_warnings(): warnings.simplefilter("ignore") try: - # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) - # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) - self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) - model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + # Explicit tokenizer configuration with legacy support + self.tokenizer = T5Tokenizer.from_pretrained( + t5_model_name, + model_max_length=max_length, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + use_auth_token=False, # Disable auth checks + legacy=True # Handle older T5 models + ) + + # Model initialization with conditional precision + model = T5EncoderModel.from_pretrained(t5_model_name) + model = model.train(enable_grad).requires_grad_(enable_grad) + if enable_grad: + model = model.to(torch.float16) + else: + model = model.to(torch.float32) + finally: logging.disable(previous_level) @@ -331,9 +346,7 @@ def __init__( else: self.__dict__["model"] = model - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - self.model.to(device) self.proj_out.to(device) @@ -346,25 +359,22 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t ) input_ids = encoded["input_ids"].to(device) - attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + attention_mask = encoded["attention_mask"].to(device).bool() self.model.eval() - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + with torch.cuda.amp.autocast(enabled=self.enable_grad, dtype=torch.float16 if self.enable_grad else torch.float32): embeddings = self.model( - input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] + input_ids=input_ids, + attention_mask=attention_mask + ).last_hidden_state - # Cast embeddings to same type as proj_out, unless proj_out is Identity if not isinstance(self.proj_out, nn.Identity): - proj_out_dtype = next(self.proj_out.parameters()).dtype - embeddings = embeddings.to(proj_out_dtype) - - embeddings = self.proj_out(embeddings) + embeddings = embeddings.to(next(self.proj_out.parameters()).dtype) + embeddings = self.proj_out(embeddings) - embeddings = embeddings * attention_mask.unsqueeze(-1).float() + return embeddings * attention_mask.unsqueeze(-1).float(), attention_mask - return embeddings, attention_mask class PhonemeConditioner(Conditioner): """