Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#72 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 authored Sep 9, 2023
2 parents 09e5f92 + 5b2099a commit 70c7bde
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 57 deletions.
15 changes: 11 additions & 4 deletions infer_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,11 @@ def main():
print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema, model_type="vae")
args.vae_path, ema_model_path = get_latest_checkpoints(
args.vae_path, use_ema=args.use_ema, model_type="vae"
)
print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
print("Resuming VAE from: ", args.vae_path)
Expand Down Expand Up @@ -309,9 +311,14 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
args.resume_path, ema_model_path = get_latest_checkpoints(
args.resume_path,
use_ema=args.use_ema,
model_type="maskgit",
cond_image_size=args.cond_image_size,
)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)
Expand Down
7 changes: 4 additions & 3 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
Expand Down Expand Up @@ -380,8 +379,10 @@ def main():

if args.latest_checkpoint:
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}")
#if args.use_ema:
print(
f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}"
)
# if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)
Expand Down
5 changes: 4 additions & 1 deletion muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,10 @@ def load(self, path):
if not path.exists() and path.is_file():
raise ValueError(f"cannot find file {path} (does not exist or is not a file)")
state_dict = torch.load(str(path), map_location="cpu")
self.load_state_dict(state_dict)
try:
self.load_state_dict(state_dict)
except RuntimeError:
self.load_state_dict(state_dict, strict=False)

def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs) if self.accelerator else print(*args, **kwargs)
Expand Down
25 changes: 12 additions & 13 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from beartype import beartype
from datasets import Dataset
from lion_pytorch import Lion
from PIL import Image
from torch import nn
from torch.optim import Adam, AdamW, Optimizer
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -289,21 +288,21 @@ def load(self, path: Union[str, PathLike]):
return pkg

def log_validation_images(self, images, step, prompts=None):
if self.validation_image_scale != 1:
# Calculate the new height based on the scale factor
new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)
# if self.validation_image_scale > 1:
## Calculate the new height based on the scale factor
# new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)

# Calculate the aspect ratio of the original image
aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]
## Calculate the aspect ratio of the original image
# aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]

# Calculate the new width based on the new height and aspect ratio
new_width = int(new_height * aspect_ratio)
## Calculate the new width based on the new height and aspect ratio
# new_width = int(new_height * aspect_ratio)

# Resize the images using the new width and height
output_size = (new_width, new_height)
images_pil = [Image.fromarray(np.array(image)) for image in images]
images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
images = [np.array(image_pil) for image_pil in images_pil_resized]
## Resize the images using the new width and height
# output_size = (new_width, new_height)
# images_pil = [Image.fromarray(np.array(image)) for image in images]
# images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
# images = [np.array(image_pil) for image_pil in images_pil_resized]

for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down
2 changes: 2 additions & 0 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
logging_dir="./results/logs",
apply_grad_penalty_every=4,
use_ema=True,
ema_vae=None,
ema_update_after_step=0,
ema_update_every=1,
validation_prompts=["a photo of a dog"],
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
if use_ema:
ema_model = EMA(
self.model,
ema_model=ema_vae,
update_after_step=ema_update_after_step,
update_every=ema_update_every,
)
Expand Down
19 changes: 18 additions & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from accelerate import Accelerator
from diffusers.optimization import get_scheduler
from einops import rearrange
from ema_pytorch import EMA
from omegaconf import OmegaConf
from PIL import Image
Expand Down Expand Up @@ -55,6 +54,7 @@ def __init__(
lr_warmup_steps=500,
discr_max_grad_norm=None,
use_ema=True,
ema_vae=None,
ema_beta=0.995,
ema_update_after_step=0,
ema_update_every=1,
Expand Down Expand Up @@ -155,6 +155,7 @@ def __init__(
if use_ema:
self.ema_model = EMA(
vae,
ema_model=ema_vae,
update_after_step=ema_update_after_step,
update_every=ema_update_every,
)
Expand Down Expand Up @@ -217,6 +218,22 @@ def log_validation_images(self, logs, steps):
if self.use_ema:
ema_grid = make_grid([sample, ema_recon], nrow=2)

# Scale the images
if self.validation_image_scale > 1:
grid = torch.nn.functional.interpolate(
grid.unsqueeze(0),
scale_factor=self.validation_image_scale,
mode="bicubic",
align_corners=False,
)
if self.use_ema:
ema_grid = torch.nn.functional.interpolate(
ema_grid.unsqueeze(0),
scale_factor=self.validation_image_scale,
mode="bicubic",
align_corners=False,
)

# Save grid
grid_file = f"{steps}_{i}.png"
if self.use_ema:
Expand Down
50 changes: 36 additions & 14 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import print_function
import re, glob, os, torch

import glob
import os
import re

import torch


def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Expand All @@ -15,8 +21,10 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
if cond_image_size:
checkpoint_files = glob.glob(os.path.join(vae_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt"))
#print(checkpoint_files)
checkpoint_files = glob.glob(
os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt")
)
# print(checkpoint_files)

print(f"Finding latest {'VAE' if model_type == 'vae' else 'MaskGit'} checkpoint...")

Expand All @@ -29,7 +37,9 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
key=lambda x: int(
re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1)
)
if not x.endswith("ema.pt")
else -1,
)
Expand All @@ -38,9 +48,7 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
if os.path.getsize(latest_non_ema_checkpoint_file) == 0 or not os.access(
latest_non_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable."
)
print(f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
Expand All @@ -51,7 +59,11 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x
).group(1)
)
if not x.endswith("ema.pt")
else -1,
)
Expand All @@ -71,17 +83,19 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x
).group(1)
)
if x.endswith("ema.pt")
else -1,
)

if os.path.getsize(latest_ema_checkpoint_file) == 0 or not os.access(
latest_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable."
)
print(f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
Expand All @@ -94,7 +108,14 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.ema\.pt$"
if model_type == "vae"
else r"maskgit\.(\d+)\.ema\.pt$",
x,
).group(1)
)
if x.endswith("ema.pt")
else -1,
)
Expand All @@ -104,6 +125,7 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im

return latest_non_ema_checkpoint_file, latest_ema_checkpoint_file


def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
"""Removes duplicate weights from the ema state dictionary.
Expand All @@ -119,4 +141,4 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
for key, value in ema_state_dict.items():
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy
return ema_state_dict_copy
1 change: 0 additions & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def load_state_dict(self, *args, **kwargs):
except RuntimeError:
return super().load_state_dict(*args, **kwargs, strict=False)


def save(self, path):
if self.accelerator is not None:
self.accelerator.save(self.state_dict(), path)
Expand Down
47 changes: 37 additions & 10 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from omegaconf import OmegaConf
from rich import inspect
from torch.optim import Optimizer

from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)
Expand Down Expand Up @@ -594,13 +595,28 @@ def main():
accelerator.print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
args.vae_path, ema_vae_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_vae_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_vae_path}")
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
if ema_model_path:
ema_vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
vq_codebook_size=args.vq_codebook_size,
l2_recon_loss=args.use_l2_recon_loss,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
accelerator=accelerator,
)
print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")

ema_vae.load(ema_model_path, map="cpu")
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)
ema_vae_path = None
ema_vae = None

# use config next to checkpoint if there is one and merge the cli arguments to it
# the cli arguments will take priority so we can use it to override any value we want.
Expand All @@ -621,7 +637,7 @@ def main():
discr_layers=args.discr_layers,
).to(accelerator.device)

vae.load(args.vae_path if not args.use_ema or not ema_vae_path else ema_vae_path, map="cpu")
vae.load(args.vae_path, map="cpu")

elif args.taming_model_path is not None and args.taming_config_path is not None:
accelerator.print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
Expand Down Expand Up @@ -691,10 +707,20 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
try:
args.resume_path, ema_model_path = get_latest_checkpoints(
args.resume_path,
use_ema=args.use_ema,
model_type="maskgit",
cond_image_size=args.cond_image_size,
)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
# if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")

except ValueError:
load = False

else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

Expand Down Expand Up @@ -844,6 +870,7 @@ def main():
results_dir=args.results_dir,
logging_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
use_ema=args.use_ema,
ema_vae=ema_vae,
ema_update_after_step=args.ema_update_after_step,
ema_update_every=args.ema_update_every,
apply_grad_penalty_every=args.apply_grad_penalty_every,
Expand Down
Loading

0 comments on commit 70c7bde

Please sign in to comment.