forked from LTH14/mar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_cache.py
123 lines (97 loc) · 4.16 KB
/
main_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import datetime
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import util.misc as misc
from util.loader import ImageFolderWithFilename
from models.vae import AutoencoderKL
from engine_mar import cache_latents
from util.crop import center_crop_arr
def get_args_parser():
parser = argparse.ArgumentParser('Cache VAE latents', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * # gpus')
# VAE parameters
parser.add_argument('--img_size', default=256, type=int,
help='images input size')
parser.add_argument('--vae_path', default="pretrained_models/vae/kl16.ckpt", type=str,
help='images input size')
parser.add_argument('--vae_embed_dim', default=16, type=int,
help='vae output embedding dimension')
# Dataset parameters
parser.add_argument('--data_path', default='./data/imagenet', type=str,
help='dataset path')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# caching latents
parser.add_argument('--cached_path', default='', help='path to cached latents')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
# augmentation following DiT and ADM
transform_train = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = ImageFolderWithFilename(os.path.join(args.data_path, 'train'), transform=transform_train)
print(dataset_train)
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=False,
)
print("Sampler_train = %s" % str(sampler_train))
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False, # Don't drop in cache
)
# define the vae
vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
# training
print(f"Start caching VAE latents")
start_time = time.time()
cache_latents(
vae,
data_loader_train,
device,
args=args
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Caching time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
main(args)