Skip to content

Commit

Permalink
Early precomputation implementation
Browse files Browse the repository at this point in the history
I still have to implement embed loading, but its a very slow working POC
  • Loading branch information
korakoe committed Sep 29, 2023
1 parent bb96aad commit b0986be
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 12 deletions.
50 changes: 44 additions & 6 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
stream=False,
using_taming=False,
random_crop=False,
embeds=[],
):
super().__init__(
dataset,
Expand All @@ -132,21 +133,28 @@ def __init__(
)
self.caption_column: str = caption_column
self.tokenizer: T5Tokenizer = tokenizer
self.embeds: list = embeds

def __getitem__(self, index):
try:
image = self.dataset[index][self.image_column]
descriptions = self.dataset[index][self.caption_column]
if self.embeds:
embed = self.embeds[index]
except PIL.UnidentifiedImageError:
print("Error reading image, most likely corrupt, skipping...")
image_found = False
embed = None
current_index = 1
while not image_found:
try:
image = self.dataset[index + current_index][self.image_column]
descriptions = self.dataset[index + current_index][self.caption_column]
if self.embeds:
embed = self.embeds[index + current_index]
image_found = True
except PIL.UnidentifiedImageError:
embed = None
current_index += 1

if self.caption_column is None or descriptions is None:
Expand All @@ -171,9 +179,15 @@ def __getitem__(self, index):
attn_mask = encoded.attention_mask

if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []
else:
return self.transform(image), input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []


class URLTextDataset(ImageDataset):
Expand All @@ -187,6 +201,7 @@ def __init__(
flip=True,
center_crop=True,
using_taming=True,
embeds=[],
):
super().__init__(
dataset,
Expand All @@ -198,16 +213,21 @@ def __init__(
)
self.caption_column: str = caption_column
self.tokenizer: T5Tokenizer = tokenizer
self.embeds: list = embeds

def __getitem__(self, index):
try:
image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content))
if self.embeds:
embed = self.embeds[index]
except ConnectionError:
try:
print("Image request failure, attempting next image")
index += 1

image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content))
if self.embeds:
embed = self.embeds[index]
except ConnectionError:
raise ConnectionError("Unable to request image from the Dataset")

Expand All @@ -232,10 +252,17 @@ def __getitem__(self, index):

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask

if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []
else:
return self.transform(image), input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []


class LocalTextImageDataset(Dataset):
Expand All @@ -249,10 +276,12 @@ def __init__(
using_taming=False,
random_crop=False,
alpha_channel=False,
embeds=[],
):
super().__init__()
self.tokenizer = tokenizer
self.using_taming = using_taming
self.embeds: list = embeds

print("Building dataset...")

Expand Down Expand Up @@ -305,6 +334,9 @@ def __getitem__(self, index):
else:
text = Path(descriptions).read_text(encoding="utf-8").split("\n")

if self.embeds:
embed = self.embeds[index]

# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
Expand All @@ -317,9 +349,15 @@ def __getitem__(self, index):
input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []
else:
return self.transform(image), input_ids[0], attn_mask[0]
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
else:
return self.transform(image), input_ids[0], attn_mask[0], []


def get_directory_size(path):
Expand Down
11 changes: 6 additions & 5 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,15 @@ def train(self):

# logs
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for imgs, input_ids, attn_mask in iter(self.dl):
for imgs, input_ids, attn_mask, text_embeds in iter(self.dl):
train_loss = 0.0
steps = int(self.steps.item())

with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)
if not text_embeds:
with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
)

with self.accelerator.accumulate(self.model), self.accelerator.autocast():
loss = self.model(imgs, text_embeds=text_embeds)
Expand Down
73 changes: 72 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
import accelerate
import datasets
import diffusers
import torch
import transformers
import wandb
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from diffusers.optimization import SchedulerType, get_scheduler
from omegaconf import OmegaConf
from rich import inspect
from torch.optim import Optimizer
from tqdm import tqdm

import wandb
from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)
Expand Down Expand Up @@ -424,6 +427,12 @@
default="flash",
help="what type of attention to use [ein, flash, xformers] | Default: flash",
)
parser.add_argument(
"--precompute",
action="store_true",
default=False,
help="whether to precompute text embeds",
)


@dataclass
Expand Down Expand Up @@ -497,6 +506,7 @@ class Arguments:
debug: bool = False
config_path: Optional[str] = None
attention_type: str = "flash"
precompute: bool = False


def main():
Expand Down Expand Up @@ -852,6 +862,67 @@ def main():
},
)

embeds = []
if args.precompute:
accelerator.print("Beginning pre-computation of embeddings using T5...")
maskgit.transformer.t5.requires_grad_(False)
for imgs, input_ids, attn_mask, _ in tqdm(iter(dataloader)):
with torch.no_grad():
embedding = t5_encode_text_from_encoded(input_ids, attn_mask, maskgit.transformer.t5, "cpu")
embeds.append(embedding)

with accelerator.main_process_first():
if args.no_cache and args.train_data_dir:
dataset = LocalTextImageDataset(
args.train_data_dir,
args.image_size,
tokenizer=transformer.tokenizer,
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
embeds=embeds,
)
elif args.link:
if not args.dataset_name:
raise AssertionError("You can only use links in huggingface datasets")

dataset = URLTextDataset(
dataset,
args.image_size,
transformer.tokenizer,
image_column=args.image_column,
caption_column=args.caption_column,
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True,
embeds=embeds,
)
else:
dataset = ImageTextDataset(
dataset,
args.image_size,
transformer.tokenizer,
image_column=args.image_column,
caption_column=args.caption_column,
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
stream=args.streaming,
using_taming=False if not args.taming_model_path else True,
embeds=embeds,
)

accelerator.print("Embeddings pre-computed!")

# Create the dataloaders
dataloader, validation_dataloader = split_dataset_into_dataloaders(
dataset,
args.valid_frac if not args.streaming else 0,
args.seed,
args.batch_size,
)

# Create the trainer
accelerator.wait_for_everyone()
trainer = MaskGitTrainer(
Expand Down

0 comments on commit b0986be

Please sign in to comment.