Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#62 from Sygil-Dev/dev
Browse files Browse the repository at this point in the history
Merge dev to main.
  • Loading branch information
ZeroCool940711 committed Aug 1, 2023
2 parents 0872074 + 1382c71 commit 7d7c00c
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 49 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/sem-version-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Bump version
on:
push:
branches:
- master
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Bump version and push tag
uses: hennejg/[email protected]
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ repos:
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.272"
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.278"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ Training should be done in 4 stages.
accelerate launch train_muse_maskgit.py --dataset_name="cifar10" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="0" --image_column="img" --caption_column="label"
```

## Checkpoints and Pretrained Models
We currently do not have any usable pretrained model for Muse but we are trying to train it with whatever resources we have available, for more information check the [Sygil Muse](https://huggingface.co/Sygil/Sygil-Muse) repository on HuggingFace where we are uploading the checkpoints for different tests we have performed and where we will be uploading the final weights once we have something everyone can use.

## Appreciation
- [Lucidrains](https://github.com/lucidrains/muse-maskgit-pytorch) for the original Muse-Maskgit-Pytorch implementation.
- The [ShoukanLabs](https://github.com/ShoukanLabs) team for contributing so much to improving the code and adding new features.
Expand Down
33 changes: 26 additions & 7 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
parser.add_argument(
"--logging_dir",
type=str,
default="results/logs",
default=None,
help="Path to log the losses and LR",
)

Expand Down Expand Up @@ -112,6 +112,11 @@
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
"--image_size",
type=int,
Expand Down Expand Up @@ -224,7 +229,7 @@ class Arguments:
mixed_precision: str = "no"
use_8bit_adam: bool = False
results_dir: str = "results"
logging_dir: str = "results/logs"
logging_dir: Optional[str] = None
resume_path: Optional[str] = None
dataset_name: Optional[str] = None
streaming: bool = False
Expand Down Expand Up @@ -288,7 +293,7 @@ def main():
args = parser.parse_args(namespace=Arguments())

project_config = ProjectConfiguration(
project_dir=args.logging_dir,
project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
automatic_checkpoint_naming=True,
)

Expand Down Expand Up @@ -345,7 +350,12 @@ def main():
if args.vae_path and not args.use_paintmind:
accelerator.print("Loading Muse VQGanVAE")
vae = VQGanVAE(
dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim
dim=args.dim,
vq_codebook_size=args.vq_codebook_size,
vq_codebook_dim=args.vq_codebook_dim,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
).to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

if args.latest_checkpoint:
Expand Down Expand Up @@ -413,6 +423,11 @@ def main():
# move vae to device
vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

# Use the parameters() method to get an iterator over all the learnable parameters of the model
total_params = sum(p.numel() for p in vae.parameters())

print(f"Total number of parameters: {format(total_params, ',d')}")

# then you plug the vae and transformer into your MaskGit as so

dataset = ImageDataset(
Expand All @@ -422,6 +437,7 @@ def main():
center_crop=True if not args.no_center_crop and not args.random_crop else False,
flip=not args.no_flip,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
)

if args.input_image and not args.input_folder:
Expand All @@ -430,7 +446,9 @@ def main():
os.makedirs(f"{args.results_dir}/outputs", exist_ok=True)

save_image(
dataset[image_id], f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}"
dataset[image_id],
f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}",
format="PNG",
)

_, ids, _ = vae.encode(
Expand Down Expand Up @@ -489,7 +507,8 @@ def main():

# Create horizontal grid with input and output images
grid_image = PIL.Image.new(
"RGB", (input_image.width + output_image.width, input_image.height)
"RGB" if args.channels == 3 else "RGBA",
(input_image.width + output_image.width, input_image.height),
)
grid_image.paste(input_image, (0, 0))
grid_image.paste(output_image, (input_image.width, 0))
Expand All @@ -499,7 +518,7 @@ def main():
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
grid_image.save(f"{output_dir}/{filename}")
grid_image.save(f"{output_dir}/{filename}", format="PNG")

# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
Expand Down
23 changes: 20 additions & 3 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,27 @@ def __init__(
stream=False,
using_taming=False,
random_crop=False,
alpha_channel=True,
):
super().__init__()
self.dataset = dataset
self.image_column = image_column
self.stream = stream
transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize(image_size),
]

if flip:
transform_list.append(T.RandomHorizontalFlip())
if center_crop and not random_crop:
transform_list.append(T.CenterCrop(image_size))
if random_crop:
transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
if alpha_channel:
transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
else:
transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))

transform_list.append(T.ToTensor())
self.transform = T.Compose(transform_list)
self.using_taming = using_taming
Expand Down Expand Up @@ -199,7 +205,15 @@ def __getitem__(self, index):

class LocalTextImageDataset(Dataset):
def __init__(
self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False
self,
path,
image_size,
tokenizer,
flip=True,
center_crop=True,
using_taming=False,
random_crop=False,
alpha_channel=False,
):
super().__init__()
self.tokenizer = tokenizer
Expand Down Expand Up @@ -229,7 +243,6 @@ def __init__(
self.caption_pair.append(captions)

transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize(image_size),
]
if flip:
Expand All @@ -238,6 +251,10 @@ def __init__(
transform_list.append(T.CenterCrop(image_size))
if random_crop:
transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
if alpha_channel:
transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
else:
transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
transform_list.append(T.ToTensor())
self.transform = T.Compose(transform_list)

Expand Down
6 changes: 5 additions & 1 deletion muse_maskgit_pytorch/t5.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from dataclasses import dataclass, field
from functools import cached_property
from os import PathLike
Expand All @@ -8,6 +9,9 @@
from torch import Tensor
from transformers import T5Config, T5EncoderModel, T5Tokenizer

# disable t5 warnings and a few others to keep the console clean and nice.
warnings.filterwarnings("ignore")


# dataclass for T5 model info
@dataclass
Expand Down Expand Up @@ -41,7 +45,7 @@ def tokenizer(self) -> T5Tokenizer:


# config
MAX_LENGTH = 256
MAX_LENGTH = 512
DEFAULT_T5_NAME = "google/t5-v1_1-base"
T5_OBJECTS: Dict[str, T5ModelInfo] = {}

Expand Down
6 changes: 3 additions & 3 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,17 @@ def load(self, path: Union[str, PathLike]):
def log_validation_images(self, images, step, prompts=None):
if self.validation_image_scale != 1:
# Calculate the new height based on the scale factor
new_height = int(images[0].shape[0] * self.validation_image_scale)
new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)

# Calculate the aspect ratio of the original image
aspect_ratio = images[0].shape[1] / images[0].shape[0]
aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]

# Calculate the new width based on the new height and aspect ratio
new_width = int(new_height * aspect_ratio)

# Resize the images using the new width and height
output_size = (new_width, new_height)
images_pil = [Image.fromarray(image) for image in images]
images_pil = [Image.fromarray(np.array(image)) for image in images]
images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
images = [np.array(image_pil) for image_pil in images_pil_resized]

Expand Down
13 changes: 10 additions & 3 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
ema_update_after_step=0,
ema_update_every=1,
validation_prompts=["a photo of a dog"],
timesteps=18,
clear_previous_experiments=False,
validation_image_scale: float = 1.0,
only_save_last_checkpoint=False,
Expand All @@ -79,6 +80,8 @@ def __init__(
self.save_results_every = save_results_every
self.log_metrics_every = log_metrics_every
self.batch_size = batch_size
self.current_step = current_step
self.timesteps = timesteps

# arguments used for the training script,
# we are going to use them later to save them to a config file.
Expand Down Expand Up @@ -113,7 +116,7 @@ def __init__(
self.info_bar = tqdm(total=0, bar_format="{desc}")

def save_validation_images(
self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1
self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1, timesteps=18
):
# moved the print to the top of the function so it shows before the progress bar for reability.
if validation_prompts:
Expand All @@ -126,6 +129,7 @@ def save_validation_images(
cond_images=cond_image,
cond_scale=cond_scale,
temperature=temperature,
timesteps=timesteps,
).to(self.accelerator.device)

save_dir = self.results_dir.joinpath("MaskGit")
Expand All @@ -147,7 +151,7 @@ def train(self):
proc_label = f"[P{self.accelerator.process_index}][Worker]"

# logs
for epoch in range(self.num_epochs):
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for imgs, input_ids, attn_mask in iter(self.dl):
train_loss = 0.0
steps = int(self.steps.item())
Expand Down Expand Up @@ -244,7 +248,10 @@ def train(self):
)

saved_image = self.save_validation_images(
self.validation_prompts, steps, cond_image=cond_image
self.validation_prompts,
steps,
cond_image=cond_image,
timesteps=self.timesteps,
)
if self.on_tpu:
self.accelerator.print(
Expand Down
4 changes: 3 additions & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.current_step = current_step

# vae
self.model = vae

Expand Down Expand Up @@ -220,7 +222,7 @@ def train(self):
else:
proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"

for epoch in range(self.num_epochs):
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for img in self.dl:
loss = 0.0
steps = int(self.steps.item())
Expand Down
9 changes: 5 additions & 4 deletions muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, x):

# discriminator
class Discriminator(nn.Module):
def __init__(self, dims, channels=3, groups=16, init_kernel_size=5):
def __init__(self, dims, channels=4, groups=16, init_kernel_size=5):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])

Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self,
dim: int,
*,
channels=3,
channels=4,
layers=4,
layer_mults=None,
num_resnet_blocks=1,
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
*,
dim: int,
accelerator: Accelerator = None,
channels=3,
channels=4,
layers=4,
l2_recon_loss=False,
use_hinge_loss=True,
Expand Down Expand Up @@ -407,7 +407,8 @@ def vgg(self):
if exists(self._vgg):
return self._vgg

vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg = torchvision.models.vgg16(pretrained=True)
vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1)
vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
self._vgg = vgg.to(self.device)
return self._vgg
Expand Down
Loading

0 comments on commit 7d7c00c

Please sign in to comment.