Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#61 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Fixed the dataset not converting properly the images to RGBA when using 4 channels for training on the maskgit.
  • Loading branch information
ZeroCool940711 committed Aug 1, 2023
2 parents 2f99bfe + 9477603 commit 1382c71
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
11 changes: 4 additions & 7 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,6 @@ def __init__(
self.caption_pair.append(captions)

transform_list = [
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]
if flip:
Expand All @@ -258,6 +251,10 @@ def __init__(
transform_list.append(T.CenterCrop(image_size))
if random_crop:
transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
if alpha_channel:
transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
else:
transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
transform_list.append(T.ToTensor())
self.transform = T.Compose(transform_list)

Expand Down
4 changes: 2 additions & 2 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@

@dataclass
class Arguments:
total_params: Optional[int] = None
only_save_last_checkpoint: bool = False
validation_image_scale: float = 1.0
no_center_crop: bool = False
Expand Down Expand Up @@ -493,7 +494,6 @@ class Arguments:
debug: bool = False
config_path: Optional[str] = None
attention_type: str = "flash"
total_params: Optional[int] = None


def main():
Expand Down Expand Up @@ -714,7 +714,7 @@ def main():

# load the maskgit transformer from disk if we have previously trained one
with accelerator.main_process_first():
if args.resume_path:
if args.resume_path is not None and len(args.resume_path) > 1:
load = True

accelerator.print("Loading Muse MaskGit...")
Expand Down

0 comments on commit 1382c71

Please sign in to comment.