Skip to content

Commit

Permalink
cleanup + embed saving + loading
Browse files Browse the repository at this point in the history
will easily produce a massive file
  • Loading branch information
korakoe committed Sep 29, 2023
1 parent b0986be commit 3e94841
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"omegaconf",
"xformers>=0.0.20",
"wandb",
"bz2file",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
34 changes: 33 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import logging
import os
import pickle
from dataclasses import dataclass
from typing import Optional, Union

import accelerate
import bz2file as bz2
import datasets
import diffusers
import torch
Expand Down Expand Up @@ -48,6 +50,18 @@
)
from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_optimizer


def compressed_pickle(title, data):
with bz2.BZ2File(title, "w") as f:
pickle.dump(data, f)


def decompress_pickle(file):
data = bz2.BZ2File(file, "rb")
data = pickle.load(data)
return data


# remove some unnecessary errors from transformer shown on the console.
transformers.logging.set_verbosity_error()

Expand Down Expand Up @@ -431,7 +445,13 @@
"--precompute",
action="store_true",
default=False,
help="whether to precompute text embeds",
help="whether to precompute text embeds (only use if we wan to compute, not load)",
)
parser.add_argument(
"--precompute_path",
type=str,
default="",
help="The path to save or load embeds",
)


Expand Down Expand Up @@ -507,6 +527,7 @@ class Arguments:
config_path: Optional[str] = None
attention_type: str = "flash"
precompute: bool = False
precompute_path: str = ""


def main():
Expand Down Expand Up @@ -757,6 +778,11 @@ def main():

accelerator.print(f"Total number of parameters: {format(total_params, ',d')}")

if args.precompute_path and not args.precompute:
embeds = decompress_pickle(args.precompute_path)
else:
embeds = []

# Create the dataset objects
with accelerator.main_process_first():
if args.no_cache and args.train_data_dir:
Expand All @@ -769,6 +795,7 @@ def main():
using_taming=False if not args.taming_model_path else True,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
embeds=embeds,
)
elif args.link:
if not args.dataset_name:
Expand All @@ -783,6 +810,7 @@ def main():
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True,
embeds=embeds,
)
else:
dataset = ImageTextDataset(
Expand All @@ -795,6 +823,7 @@ def main():
flip=False if args.no_flip else True,
stream=args.streaming,
using_taming=False if not args.taming_model_path else True,
embeds=embeds,
)

# Create the dataloaders
Expand Down Expand Up @@ -871,6 +900,9 @@ def main():
embedding = t5_encode_text_from_encoded(input_ids, attn_mask, maskgit.transformer.t5, "cpu")
embeds.append(embedding)

if args.precompute_path:
compressed_pickle(args.precompute_path, embeds)

with accelerator.main_process_first():
if args.no_cache and args.train_data_dir:
dataset = LocalTextImageDataset(
Expand Down

0 comments on commit 3e94841

Please sign in to comment.