diff --git a/.gitignore b/.gitignore index 243a7b4..c997595 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,5 @@ logs/ # virtualenv venv/ ENV/ + +.idea/ diff --git a/train.py b/train.py index 0e733ec..5cb87f8 100644 --- a/train.py +++ b/train.py @@ -113,6 +113,7 @@ if args.gpu_ids[0] >= 0 else torch.device("cpu") ) +torch.cuda.set_device(device) # Print config and args. print(yaml.dump(config, default_flow_style=False)) @@ -279,7 +280,7 @@ def lr_lambda_fun(current_iteration: int) -> float: scheduler.step(global_iteration_step) global_iteration_step += 1 - torch.cuda.empty_cache() + torch.cuda.empty_cache() # ------------------------------------------------------------------------- # ON EPOCH END (checkpointing and validation)