diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 0d44594..6566370 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -13,7 +13,7 @@ from einops import rearrange, repeat from torch import einsum, isnan, nn from tqdm.auto import tqdm -from transformers import T5EncoderModel, T5Tokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer from .attn import ein_attn, sdp_attn from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text @@ -417,6 +417,8 @@ def __init__( self, image_size, transformer: MaskGitTransformer, + clip: CLIPTextModel = None, + clip_tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast] = None, accelerator: Optional[Accelerator] = None, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, @@ -446,6 +448,9 @@ def __init__( self.resize_image_for_cond_image = exists(cond_image_size) self.cond_drop_prob = cond_drop_prob + self.clip = clip + self.clip_tokenizer = clip_tokenizer + self.transformer = transformer self.self_cond = transformer.self_cond if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens: @@ -524,22 +529,54 @@ def generate( cond_ids = None - text_embeds = self.transformer.encode_text(texts) - - demask_fn = self.transformer.forward_with_cond_scale - - # whether to use token critic for scores - use_token_critic = exists(self.token_critic) and not force_not_use_token_critic - + # whether to use token critic for scores if use_token_critic: token_critic_fn = self.token_critic.forward_with_cond_scale - # negative prompting, as in paper + if self.clip is not None and self.clip_tokenizer is not None: + clip_model = self.clip + clip_tokenizer = self.clip_tokenizer + print(texts) + inputs = [token[1:-1] for token in clip_tokenizer(texts, truncation=True).input_ids] + + inputs = torch.tensor(inputs, device=self.accelerator.device) + max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = inputs[0, 0] + text_input_chunk[:, -1] = inputs[0, -1] + text_embedding = clip_model(text_input_chunk)[0] + + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device) + else: + text_embeds = clip_model(inputs)[0].to(self.accelerator.device) + else: + text_embeds = self.transformer.encode_text(texts) - neg_text_embeds = None - if exists(negative_texts): - assert len(texts) == len(negative_texts) + demask_fn = self.transformer.forward_with_cond_scale + + # negative prompting, as in paper + + neg_text_embeds = None + if exists(negative_texts): + assert len(texts) == len(negative_texts) neg_text_embeds = self.transformer.encode_text(negative_texts) demask_fn = partial( @@ -547,11 +584,11 @@ def generate( neg_text_embeds=neg_text_embeds, ) - if use_token_critic: - token_critic_fn = partial( - self.token_critic.forward_with_neg_prompt, - neg_text_embeds=neg_text_embeds, - ) + if use_token_critic: + token_critic_fn = partial( + self.token_critic.forward_with_neg_prompt, + neg_text_embeds=neg_text_embeds, + ) if self.resize_image_for_cond_image: if cond_images is None: @@ -576,14 +613,18 @@ def generate( ids = ids.scatter(1, masked_indices, self.mask_id) - logits, embed = demask_fn( - ids, - text_embeds=text_embeds, - self_cond_embed=self_cond_embed, - conditioning_token_ids=cond_ids, - cond_scale=cond_scale, - return_embed=True, - ) + if self.clip is None: + logits, embed = demask_fn( + ids, + text_embeds=text_embeds, + self_cond_embed=self_cond_embed, + conditioning_token_ids=cond_ids, + cond_scale=cond_scale, + return_embed=True, + ) + else: + embed = text_embeds + logits = text_embeds self_cond_embed = embed if self.self_cond else None diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 0389952..fb29613 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -149,14 +149,17 @@ def save_validation_images( self.accelerator.print( f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}" ) - - images = self.model.generate( - validation_prompts, - cond_images=cond_image, - cond_scale=cond_scale, - temperature=temperature, - timesteps=timesteps, - ).to(self.accelerator.device) + images = [] + for text in validation_prompts: + images.append( + self.model.generate( + (text,), + cond_images=cond_image, + cond_scale=cond_scale, + temperature=temperature, + timesteps=timesteps, + ).to(self.accelerator.device) + ) save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) @@ -189,6 +192,7 @@ def train(self): input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device ) else: + print(text) clip_model, clip_tokenizer = self.clip_model inputs = [token[1:-1] for token in clip_tokenizer(text, truncation=True).input_ids] diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 337b258..d0ed167 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -761,10 +761,24 @@ def main(): use_clip=args.use_clip, ) + if args.use_clip: + model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( + accelerator.device + ) + tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) + + clip = (model, tokenizer) + else: + model = None + tokenizer = None + clip = None + # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae transformer=transformer, # transformer + clip=model, + clip_tokenizer=tokenizer, accelerator=accelerator, # accelerator image_size=args.image_size, # image size cond_drop_prob=args.cond_drop_prob, # conditional dropout, for classifier free guidance @@ -999,16 +1013,6 @@ def main(): args.batch_size, ) - if args.use_clip: - model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( - accelerator.device - ) - tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) - - clip = (model, tokenizer) - else: - clip = None - # Create the trainer accelerator.wait_for_everyone() trainer = MaskGitTrainer(