Skip to content

Commit

Permalink
idk if this works, but here's a small clip implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Oct 9, 2023
1 parent a1ca68d commit 1c129e7
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 28 deletions.
24 changes: 12 additions & 12 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def __getitem__(self, index):

if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


class URLTextDataset(ImageDataset):
Expand Down Expand Up @@ -255,14 +255,14 @@ def __getitem__(self, index):

if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


class LocalTextImageDataset(Dataset):
Expand Down Expand Up @@ -350,14 +350,14 @@ def __getitem__(self, index):
attn_mask = encoded.attention_mask
if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


def get_directory_size(path):
Expand Down
27 changes: 17 additions & 10 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
self_cond: bool = False,
add_mask_id: bool = False,
cache_path: PathLike = None,
use_clip=False,
**kwargs,
):
super().__init__()
Expand All @@ -183,29 +184,35 @@ def __init__(
self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
self.norm = LayerNorm(dim)

self.use_clip = use_clip

self.dim_out = default(dim_out, num_tokens)
self.to_logits = nn.Linear(dim, self.dim_out, bias=False)

# text conditioning
t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path)
self.t5: T5EncoderModel = t5
self.tokenizer: T5Tokenizer = tokenizer
if not use_clip:
t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path)
self.t5: T5EncoderModel = t5
self.tokenizer: T5Tokenizer = tokenizer

self.t5.eval()
self.t5.eval()

text_embed_dim = get_encoded_dim(t5_name)
text_embed_dim = get_encoded_dim(t5_name)

self.text_embed_proj = (
nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity()
)
self.text_embed_proj = (
nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity()
)

# optional self conditioning
self.self_cond = self_cond
self.self_cond_to_init_embed = FeedForward(dim)

def encode_text(self, *args, **kwargs):
kwargs.update(tokenizer=self.tokenizer, t5=self.t5)
return t5_encode_text(*args, **kwargs)
if not self.use_clip:
kwargs.update(tokenizer=self.tokenizer, t5=self.t5)
return t5_encode_text(*args, **kwargs)
else:
print("Using clip instead, this function shouldn't be accessed")

def forward_with_cond_scale(self, *args, cond_scale=3.0, return_embed=False, **kwargs):
if cond_scale == 1:
Expand Down
27 changes: 21 additions & 6 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from diffusers.optimization import SchedulerType
from ema_pytorch import EMA
from omegaconf import OmegaConf
from PIL import Image
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision.utils import save_image
Expand All @@ -24,6 +23,9 @@
xm = None
met = None

import open_clip
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm


Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
validation_image_scale: float = 1.0,
only_save_last_checkpoint=False,
args=None,
clip=None,
):
super().__init__(
dataloader=dataloader,
Expand Down Expand Up @@ -96,6 +99,9 @@ def __init__(
self.optim: Optimizer = optimizer
self.lr_scheduler: SchedulerType = scheduler

self.use_clip = True if clip is not None else False
self.clip_model = clip

self.use_ema = use_ema
self.validation_prompts: List[str] = validation_prompts
if use_ema:
Expand Down Expand Up @@ -154,15 +160,24 @@ def train(self):

# logs
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for imgs, input_ids, attn_mask, text_embeds in iter(self.dl):
for imgs, input_ids, attn_mask, text_embeds, text in iter(self.dl):
train_loss = 0.0
steps = int(self.steps.item())

if not text_embeds:
if not self.use_clip:
if not text_embeds:
with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)
else:
img_for_embed = transforms.ToPILImage(imgs)

model, _, preprocess = self.clip_model
text = open_clip.tokenize(text)

with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)
text_embeds = model.encode_text(text)

with self.accelerator.accumulate(self.model), self.accelerator.autocast():
loss = self.model(imgs, text_embeds=text_embeds)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"xformers>=0.0.20",
"wandb",
"bz2file",
"open_clip_torch",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
26 changes: 26 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import bz2file as bz2
import datasets
import diffusers
import open_clip
import torch
import transformers
from accelerate.utils import ProjectConfiguration
Expand Down Expand Up @@ -124,6 +125,12 @@ def decompress_pickle(file):
parser.add_argument("--heads", type=int, default=8, help="Attention heads")
parser.add_argument("--ff_mult", type=int, default=4, help="Feed forward expansion factor")
parser.add_argument("--t5_name", type=str, default="t5-small", help="Name of your t5 model")
parser.add_argument(
"--use_metaclip",
action="store_true",
default=False,
help="whether to use MetaClip instead of a T5",
)
parser.add_argument("--cond_image_size", type=int, default=None, help="Conditional image size.")
parser.add_argument(
"--validation_prompt",
Expand Down Expand Up @@ -480,6 +487,7 @@ class Arguments:
heads: int = 8
ff_mult: int = 4
t5_name: str = "t5-small"
use_metaclip: bool = False
mixed_precision: str = "no"
cond_image_size: Optional[int] = None
validation_prompt: str = "A photo of a dog"
Expand Down Expand Up @@ -750,6 +758,7 @@ def main():
cache_path=args.cache_path,
flash=flash,
xformers=xformers,
use_clip=args.use_metaclip,
)

# (2) pass your trained VAE and the base transformer to MaskGit
Expand Down Expand Up @@ -987,6 +996,22 @@ def main():
args.batch_size,
)

if args.use_metaclip:
if args.mixed_precision == "no":
clip_precision = "fp32"
else:
clip_precision = args.mixed_precision

clip = open_clip.create_model_and_transforms(
"ViT-B-32-quickgelu",
pretrained="metaclip/b32_400m.pt",
cache_dir=args.cache_path,
precision=clip_precision,
device=accelerator.device,
)
else:
clip = None

# Create the trainer
accelerator.wait_for_everyone()
trainer = MaskGitTrainer(
Expand Down Expand Up @@ -1017,6 +1042,7 @@ def main():
only_save_last_checkpoint=args.only_save_last_checkpoint,
num_epochs=args.num_epochs,
args=args,
clip=clip,
)

# Prepare the trainer for distributed training
Expand Down

0 comments on commit 1c129e7

Please sign in to comment.