-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
90 lines (78 loc) · 3.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import yaml
import os
import argparse
import glob
from datasets.dataset import MyDataset,collate_fn
from model import MyModel
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import torch
from lightning.pytorch.loggers import TensorBoardLogger
assert torch.cuda.is_available(), "CPU training is not allowed."
#for high-performance gpu
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def main(args):
#load config
config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
#seed
pl.seed_everything(config['seed'],workers=True)
#load data
training_data = MyDataset()
train_loader = DataLoader(training_data, collate_fn=collate_fn, **config['dataloader'])
#define model
model = MyModel(config)
#load checkpoint
ckpt_dir = os.path.join(args.output_dir,args.expname+'_ckpt')
if args.restore_step != "":
pattern = f'{ckpt_dir}/epoch=*-step={args.restore_step}.ckpt'
else:
pattern = f'{ckpt_dir}/epoch=*-step=*.ckpt'
ckpts = glob.glob(pattern)
if len(ckpts) == 0:
ckpt = None
else:
ckpt = sorted(ckpts,reverse=True, key=lambda x:int(x.split('=')[-1].split('.')[0]))[0]
ckpt = ckpt if os.path.exists(ckpt) else None
#load ema
if ckpt is not None:
sd = torch.load(ckpt, map_location='cpu')
if 'ema' in sd:
model.ema.load_state_dict(sd['ema'])
######################training######################
lr_monitor = LearningRateMonitor(logging_interval='step')
ckpt_monitor = ModelCheckpoint(verbose=True,every_n_train_steps=config['save_every_n_steps'],save_last =True, save_top_k=-1, dirpath=ckpt_dir)
logger = TensorBoardLogger(save_dir=args.output_dir, name='lightning_logs',version=args.expname,default_hp_metric=False)
trainer = pl.Trainer(callbacks=[lr_monitor,ckpt_monitor], logger=logger, profiler="simple", **config['trainer'])
trainer.fit(model,train_dataloaders=train_loader,val_dataloaders=valid_loader,ckpt_path=ckpt)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config",
type=str,
default="configs/default.yaml",
help="config file path",
)
parser.add_argument("--expname",
type=str,
default="test",
help="config file path",
)
parser.add_argument("--restore_step",
type=str,
default="",
help="restore_step of ckpt",
)
parser.add_argument('--finetune', '-ft',
action='store_true',
default=False,
help='if finetune step, true means load base model'
)
parser.add_argument('--output_dir',
type=str,
default="lightning_logs/",
help='Directory to save checkpoints',
)
args = parser.parse_args()
main(args)