Skip to content

Commit

Permalink
maybe fix it?
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Oct 9, 2023
1 parent 883c3ed commit de06cf0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
67 changes: 40 additions & 27 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,20 @@ def __getitem__(self, index):
else:
text = descriptions
# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = []
attn_mask = []

if self.using_taming:
if self.embeds:
Expand Down Expand Up @@ -242,16 +246,20 @@ def __getitem__(self, index):
else:
text = descriptions
# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = []
attn_mask = []

if self.using_taming:
if self.embeds:
Expand Down Expand Up @@ -338,16 +346,21 @@ def __getitem__(self, index):
embed = self.embeds[index]

# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = []
attn_mask = []

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
Expand Down
1 change: 1 addition & 0 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
self.norm = LayerNorm(dim)

self.use_clip = use_clip
self.tokenizer = None

self.dim_out = default(dim_out, num_tokens)
self.to_logits = nn.Linear(dim, self.dim_out, bias=False)
Expand Down
7 changes: 5 additions & 2 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,9 @@ def main():
else:
embeds = []

if args.use_metaclip:
transformer.tokenizer = None

# Create the dataset objects
with accelerator.main_process_first():
if args.no_cache and args.train_data_dir:
Expand Down Expand Up @@ -1003,8 +1006,8 @@ def main():
clip_precision = args.mixed_precision

clip = open_clip.create_model_and_transforms(
"ViT-L-14",
pretrained="metaclip/l14_400m.pt",
"convnext_base_w",
pretrained="laion2b_s13b_b82k_augreg",
cache_dir=args.cache_path,
precision=clip_precision,
device=accelerator.device,
Expand Down

0 comments on commit de06cf0

Please sign in to comment.