Skip to content

Commit

Permalink
forgot to normalise
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Oct 9, 2023
1 parent 1c129e7 commit 1574763
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,12 @@ def train(self):
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)
else:
img_for_embed = transforms.ToPILImage(imgs)

model, _, preprocess = self.clip_model
text = open_clip.tokenize(text)

with torch.no_grad():
text_embeds = model.encode_text(text)
text_embeds /= text_embeds.norm(dim=-1, keepdim=True)

with self.accelerator.accumulate(self.model), self.accelerator.autocast():
loss = self.model(imgs, text_embeds=text_embeds)
Expand Down

0 comments on commit 1574763

Please sign in to comment.