-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_T5.py
203 lines (174 loc) · 10.7 KB
/
train_T5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
from custom_datasets import OnlineFontSquare, HFDataCollector, TextSampler, FixedTextSampler, dataset_factory, GibberishSampler
from pathlib import Path
from torch.utils.data import DataLoader
import argparse
from tqdm import tqdm
from utils import MetricCollector
from torchvision.utils import make_grid, save_image
from emuru import Emuru
import pickle
import wandb
def train(args):
if args.device == 'cpu':
print('WARNING: Using CPU')
model = Emuru(args.t5_checkpoint, args.vae_checkpoint, args.ocr_checkpoint, args.slices_per_query).to(args.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
if args.resume:
try:
checkpoint_path = sorted(Path(args.resume_dir).rglob('*.pth'))[-1]
checkpoint = torch.load(checkpoint_path, map_location=args.device)
model.load_state_dict(checkpoint['model'], strict=False)
# optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1 if args.resume_wandb else args.start_epoch
args.wandb_id = checkpoint['wandb_id'] if args.resume_wandb else args.wandb_id
print(f'Resumed training from {checkpoint_path}')
except KeyError:
model.load_pretrained(args.resume_dir)
print(f'Resumed with the old checkpoint system: {checkpoint_path}')
if args.to_width == 768:
sampler = TextSampler(32, 64, (4, 7))
else:
sampler = TextSampler(4, 128, (1, 32))
if args.renderers:
with open(args.renderers, 'rb') as f:
renderers = pickle.load(f)
else:
renderers = None
dataset = OnlineFontSquare(args.fonts, args.backgrounds, sampler, renderers=renderers,
_to_width=args.to_width, max_fonts=args.max_fonts)
if args.to_width != 768:
# Remove the random warping
dataset.transform.transforms.pop(2)
# Remove the random rotation
dataset.transform.transforms.pop(1)
# dataset[0]
dataset.length *= args.db_multiplier
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=model.data_collator,
num_workers=args.dataloader_num_workers)
eval_dataset = dataset_factory('test', ['iam_lines'], root_path=args.datasets)
eval_dataset.batch_keys('style')
eval_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=eval_dataset.collate_fn,
num_workers=args.dataloader_num_workers_eval, persistent_workers=False)
# eval_fonts = sorted(Path('files/font_square/fonts').rglob('*.ttf'))[:100]
# dataset_eval = OnlineFontSquare(eval_fonts, [], FixedTextSampler('this is a test'))
# loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, collate_fn=model.data_collator, num_workers=args.dataloader_num_workers)
if args.wandb:
import wandb
args.wandb_id = wandb.util.generate_id() if not hasattr(args, 'wandb_id') else args.wandb_id
# resume = 'must' if args.resume_wandb else 'allow'
resume = 'allow'
wandb.init(project='Emuru', name=Path(args.output_dir).name, config=args, id=args.wandb_id, resume=resume)
collector = MetricCollector()
loader_iter = iter(loader)
model.alpha = args.start_alpha
for epoch in range(args.start_epoch, args.num_train_epochs):
model.train()
for i in tqdm(range(args.dataloader_chunk), desc=f'Epoch {epoch}'):
try:
# Get the next batch
batch = next(loader_iter)
except StopIteration:
# If the iterator is exhausted, reinitialize it
loader_iter = iter(loader)
batch = next(loader_iter)
batch = {k: v.to(args.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
batch['noise'] = args.teacher_noise
losses, pred, gt = model(**batch)
losses['loss'].backward()
if (epoch * args.dataloader_chunk + i) % args.gradient_acc == 0:
optimizer.step()
optimizer.zero_grad()
losses = {f'train/{k}': v for k, v in losses.items()}
collector.update(losses)
# print('Warning')
# if i > 2:
# break
# imgs = model.custom_generate(text='this is a sample text', img=None, max_new_tokens=96)
# imgs = model.custom_generate(input_ids=batch['input_ids'], img=batch['img'], max_new_tokens=96 - 16, decoder_truncate=16)
# print()
with torch.no_grad():
model.eval()
wandb_data = {}
wandb_data['train/alpha'] = model.alpha
if args.wandb:
pred, gt, synth_gen_test = model.continue_gen_test(gt, batch, pred)
# alpha = torch.ones_like(batch['img'][:, :1])
# img_rgba = torch.cat([batch['img'], alpha], dim=1)
gt = gt.repeat(1, 3, 1, 1)
pred = pred.repeat(1, 3, 1, 1)
synth_img = torch.cat([batch['img'], gt, pred], dim=-1)[:16]
wandb_data['synth_img'] = wandb.Image(make_grid(synth_img, nrow=1, normalize=True))
wandb_data['synth_gen_test'] = wandb.Image(synth_gen_test)
for i, batch in tqdm(enumerate(eval_loader), total=len(eval_loader), desc=f'Eval'):
batch = {k: v.to(args.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
res = model.tokenizer(batch['style_text'], return_tensors='pt', padding=True, return_attention_mask=True, return_length=True)
res = {k: v.to(args.device) if isinstance(v, torch.Tensor) else v for k, v in res.items()}
losses, pred, gt = model(img=batch['style_img'], **res)
losses = {f'eval/{k}': v for k, v in losses.items()}
collector.update(losses)
batch['input_ids'] = model.tokenizer(batch['style_text'], return_tensors='pt', padding=True).input_ids.to(args.device)
batch['img'] = batch['style_img']
if args.wandb:
pred, gt, real_gen_test = model.continue_gen_test(gt, batch, pred)
# alpha = torch.ones_like(batch['img'][:, :1])
# img_rgba = torch.cat([batch['img'], alpha], dim=1)
gt = gt.repeat(1, 3, 1, 1)
pred = pred.repeat(1, 3, 1, 1)
real_img = torch.cat([batch['img'], gt, pred], dim=-1)[:16]
wandb_data['real_img'] = wandb.Image(make_grid(real_img, nrow=1, normalize=True))
wandb_data['real_gen_test'] = wandb.Image(real_gen_test)
wandb.log(wandb_data | collector.dict())
if epoch % 5 == 0 and epoch > 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'wandb_id': args.wandb_id if args.wandb else None
}
checkpoint_path = Path(args.output_dir) / f'{epoch // 100 * 100:05d}.pth'
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, checkpoint_path)
print(f'Saved model at epoch {epoch} in {checkpoint_path}')
collector.reset()
model.alpha -= args.decrement_alpha
model.alpha = max(args.end_alpha, model.alpha)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train a T5 model with a VAE')
parser.add_argument('--device', type=str, default='cuda', help='Device')
parser.add_argument('--t5_checkpoint', type=str, default='google-t5/t5-small', help='T5 checkpoint')
parser.add_argument('--vae_checkpoint', type=str, default='results_vae/a912/model_0205', help='VAE checkpoint')
parser.add_argument('--ocr_checkpoint', type=str, default='files/checkpoints/Origami_bw_img/origami.pth', help='OCR checkpoint')
parser.add_argument('--resume_dir', type=str, default=None, help='Resume directory')
parser.add_argument('--output_dir', type=str, default='files/checkpoints/Emuru_100k', help='Output directory')
parser.add_argument('--fonts', type=str, default='files/font_square/clean_fonts', help='Fonts path')
parser.add_argument('--backgrounds', type=str, default='files/font_square/backgrounds', help='Backgrounds path')
parser.add_argument('--datasets', type=str, default='/home/vpippi/Teddy/files/datasets/', help='Root datasets path')
parser.add_argument('--renderers', type=str, help='Renderers path')
parser.add_argument('--checkpoint_tag', type=str, default='', help='Checkpoint tag')
parser.add_argument('--db_multiplier', type=int, default=1, help='Dataset multiplier')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
parser.add_argument('--num_train_epochs', type=int, default=10 ** 10, help='Number of train epochs')
parser.add_argument('--report_to', type=str, default='none', help='Report to')
parser.add_argument('--dataloader_chunk', type=int, default=2000, help='Dataloader chunk size')
parser.add_argument('--dataloader_num_workers', type=int, default=15, help='Dataloader num workers')
parser.add_argument('--dataloader_num_workers_eval', type=int, default=4, help='Dataloader num workers')
parser.add_argument('--slices_per_query', type=int, default=1, help='Number of slices to predict in each query')
parser.add_argument('--wandb', action='store_true', help='Use wandb')
parser.add_argument('--wandb_id', type=str, default=wandb.util.generate_id(), help='Wandb id')
parser.add_argument('--resume', action='store_true', help='Resume training')
parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch')
parser.add_argument('--teacher_noise', type=float, default=0.1, help='How much noise add during training')
parser.add_argument('--start_alpha', type=float, default=1.0, help='Alpha between the mse_loss (alpha=1) and the ocr_loss (alpha=0)')
parser.add_argument('--end_alpha', type=float, default=1.0, help='Variable alpha')
parser.add_argument('--decrement_alpha', type=float, default=0., help='Variable alpha')
parser.add_argument('--gradient_acc', type=int, default=1)
parser.add_argument('--to_width', type=int, default=768)
parser.add_argument('--max_fonts', type=int)
args = parser.parse_args()
if args.resume_dir is None:
args.resume_dir = args.output_dir
args.resume_wandb = args.resume_dir == args.output_dir
train(args)