-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
35 lines (24 loc) · 955 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import ignite.distributed as idist
import hydra
from omegaconf import DictConfig, OmegaConf
import os
import torch
from models.kyn.trainer import training as kyn
torch.autograd.set_detect_anomaly(True)
@hydra.main(version_base=None, config_path="configs")
def main(config: DictConfig):
OmegaConf.set_struct(config, False)
os.environ["NCCL_DEBUG"] = "INFO"
# torch.autograd.set_detect_anomaly(True)
backend = config.get("backend", None)
nproc_per_node = config.get("nproc_per_node", None)
with_amp = config.get("with_amp", False)
spawn_kwargs = {}
spawn_kwargs["nproc_per_node"] = nproc_per_node
if backend == "xla-tpu" and with_amp:
raise RuntimeError("The value of with_amp should be False if backend is xla")
training = globals()[config["model"]]
with idist.Parallel(backend=backend, **spawn_kwargs) as parallel:
parallel.run(training, config)
if __name__ == "__main__":
main()