Skip to content

Commit

Permalink
someone a lot smarter than me will have to figure out negative prompting
Browse files Browse the repository at this point in the history
also now doing image generation sequentially because adding clip to inference is harder than adding it to training
  • Loading branch information
korakoe committed Oct 10, 2023
1 parent 2070a69 commit 701fd85
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 43 deletions.
91 changes: 66 additions & 25 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from einops import rearrange, repeat
from torch import einsum, isnan, nn
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer

from .attn import ein_attn, sdp_attn
from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text
Expand Down Expand Up @@ -417,6 +417,8 @@ def __init__(
self,
image_size,
transformer: MaskGitTransformer,
clip: CLIPTextModel = None,
clip_tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast] = None,
accelerator: Optional[Accelerator] = None,
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
Expand Down Expand Up @@ -446,6 +448,9 @@ def __init__(
self.resize_image_for_cond_image = exists(cond_image_size)
self.cond_drop_prob = cond_drop_prob

self.clip = clip
self.clip_tokenizer = clip_tokenizer

self.transformer = transformer
self.self_cond = transformer.self_cond
if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens:
Expand Down Expand Up @@ -524,34 +529,66 @@ def generate(

cond_ids = None

text_embeds = self.transformer.encode_text(texts)

demask_fn = self.transformer.forward_with_cond_scale

# whether to use token critic for scores

use_token_critic = exists(self.token_critic) and not force_not_use_token_critic

# whether to use token critic for scores
if use_token_critic:
token_critic_fn = self.token_critic.forward_with_cond_scale

# negative prompting, as in paper
if self.clip is not None and self.clip_tokenizer is not None:
clip_model = self.clip
clip_tokenizer = self.clip_tokenizer
print(texts)
inputs = [token[1:-1] for token in clip_tokenizer(texts, truncation=True).input_ids]

inputs = torch.tensor(inputs, device=self.accelerator.device)
max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone()

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = inputs[0, 0]
text_input_chunk[:, -1] = inputs[0, -1]
text_embedding = clip_model(text_input_chunk)[0]

if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]

text_embeddings.append(text_embedding)
text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device)
else:
text_embeds = clip_model(inputs)[0].to(self.accelerator.device)
else:
text_embeds = self.transformer.encode_text(texts)

neg_text_embeds = None
if exists(negative_texts):
assert len(texts) == len(negative_texts)
demask_fn = self.transformer.forward_with_cond_scale

# negative prompting, as in paper

neg_text_embeds = None
if exists(negative_texts):
assert len(texts) == len(negative_texts)

neg_text_embeds = self.transformer.encode_text(negative_texts)
demask_fn = partial(
self.transformer.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)

if use_token_critic:
token_critic_fn = partial(
self.token_critic.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)
if use_token_critic:
token_critic_fn = partial(
self.token_critic.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)

if self.resize_image_for_cond_image:
if cond_images is None:
Expand All @@ -576,14 +613,18 @@ def generate(

ids = ids.scatter(1, masked_indices, self.mask_id)

logits, embed = demask_fn(
ids,
text_embeds=text_embeds,
self_cond_embed=self_cond_embed,
conditioning_token_ids=cond_ids,
cond_scale=cond_scale,
return_embed=True,
)
if self.clip is None:
logits, embed = demask_fn(
ids,
text_embeds=text_embeds,
self_cond_embed=self_cond_embed,
conditioning_token_ids=cond_ids,
cond_scale=cond_scale,
return_embed=True,
)
else:
embed = text_embeds
logits = text_embeds

self_cond_embed = embed if self.self_cond else None

Expand Down
20 changes: 12 additions & 8 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,17 @@ def save_validation_images(
self.accelerator.print(
f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}"
)

images = self.model.generate(
validation_prompts,
cond_images=cond_image,
cond_scale=cond_scale,
temperature=temperature,
timesteps=timesteps,
).to(self.accelerator.device)
images = []
for text in validation_prompts:
images.append(
self.model.generate(
(text,),
cond_images=cond_image,
cond_scale=cond_scale,
temperature=temperature,
timesteps=timesteps,
).to(self.accelerator.device)
)

save_dir = self.results_dir.joinpath("MaskGit")
save_dir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -189,6 +192,7 @@ def train(self):
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)
else:
print(text)
clip_model, clip_tokenizer = self.clip_model
inputs = [token[1:-1] for token in clip_tokenizer(text, truncation=True).input_ids]

Expand Down
24 changes: 14 additions & 10 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,24 @@ def main():
use_clip=args.use_clip,
)

if args.use_clip:
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to(
accelerator.device
)
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path)

clip = (model, tokenizer)
else:
model = None
tokenizer = None
clip = None

# (2) pass your trained VAE and the base transformer to MaskGit
maskgit = MaskGit(
vae=vae, # vqgan vae
transformer=transformer, # transformer
clip=model,
clip_tokenizer=tokenizer,
accelerator=accelerator, # accelerator
image_size=args.image_size, # image size
cond_drop_prob=args.cond_drop_prob, # conditional dropout, for classifier free guidance
Expand Down Expand Up @@ -999,16 +1013,6 @@ def main():
args.batch_size,
)

if args.use_clip:
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to(
accelerator.device
)
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path)

clip = (model, tokenizer)
else:
clip = None

# Create the trainer
accelerator.wait_for_everyone()
trainer = MaskGitTrainer(
Expand Down

0 comments on commit 701fd85

Please sign in to comment.