diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7acafaa6..4ed6e733 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -50,7 +50,10 @@ print_git_commit, setup_wandb, ) -from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.slurm_distributed import ( + DistributedEnvironmentOpenmpi, + DistributedEnvironmentSlurm, +) from mace.tools.utils import AtomicNumberTable @@ -78,7 +81,10 @@ def run(args: argparse.Namespace) -> None: ) from e if args.distributed: try: - distr_env = DistributedEnvironment() + if args.distributed_env == "slurm": + distr_env = DistributedEnvironmentSlurm() + elif args.distributed_env == "openmpi": + distr_env = DistributedEnvironmentOpenmpi() except Exception as e: # pylint: disable=W0703 logging.error(f"Failed to initialize distributed environment: {e}") return @@ -86,8 +92,8 @@ def run(args: argparse.Namespace) -> None: local_rank = distr_env.local_rank rank = distr_env.rank if rank == 0: - print(distr_env) - torch.distributed.init_process_group(backend="nccl") + print("Using Distributed Environment: ", distr_env) + torch.distributed.init_process_group(backend=args.distributed_backend) else: rank = int(0) @@ -99,7 +105,8 @@ def run(args: argparse.Namespace) -> None: logging.log(level=loglevel, msg=message) if args.distributed: - torch.cuda.set_device(local_rank) + if args.device == "cuda": + torch.cuda.set_device(local_rank) logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") logging.info(f"Processes: {world_size}") @@ -568,7 +575,10 @@ def run(args: argparse.Namespace) -> None: setup_wandb(args) if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) + if args.device == "cuda": + distributed_model = DDP(model, device_ids=[local_rank]) + elif args.device == "cpu": + distributed_model = DDP(model) else: distributed_model = None @@ -664,7 +674,7 @@ def run(args: argparse.Namespace) -> None: ) try: drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 + except AttributeError as e: # pylint: disable=W0612 # noqa: F841 drop_last = False test_loader = torch_geometric.dataloader.DataLoader( test_set, @@ -686,7 +696,10 @@ def run(args: argparse.Namespace) -> None: ) model.to(device) if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) + if args.device == "cuda": + distributed_model = DDP(model, device_ids=[local_rank]) + elif args.device == "cpu": + distributed_model = DDP(model) model_to_evaluate = model if not args.distributed else distributed_model if swa_eval: logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") @@ -751,7 +764,7 @@ def run(args: argparse.Namespace) -> None: path_complied, _extra_files=extra_files, ) - except Exception as e: # pylint: disable=W0703 + except Exception as e: # pylint: disable=W0703 # noqa: F841 pass else: torch.save(model, Path(args.model_dir) / (args.name + ".model")) @@ -766,7 +779,7 @@ def run(args: argparse.Namespace) -> None: path_complied, _extra_files=extra_files, ) - except Exception as e: # pylint: disable=W0703 + except Exception as e: # pylint: disable=W0703 # noqa: F841 pass if args.distributed: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 046f04d6..511bee70 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -76,6 +76,21 @@ def build_default_arg_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) + parser.add_argument( + "--distributed_backend", + help="PyTorch distributed backend", + type=str, + choices=["nccl", "gloo", "mpi"], + default="nccl", + ) + parser.add_argument( + "--distributed_env", + help="The parallel environment to use for distributed training", + type=str, + choices=["slurm", "openmpi"], + default="slurm", + ) + parser.add_argument("--log_level", help="log level", type=str, default="INFO") parser.add_argument( diff --git a/mace/tools/slurm_distributed.py b/mace/tools/slurm_distributed.py index 78de52a1..866cbab3 100644 --- a/mace/tools/slurm_distributed.py +++ b/mace/tools/slurm_distributed.py @@ -10,7 +10,7 @@ import hostlist -class DistributedEnvironment: +class DistributedEnvironmentSlurm: def __init__(self): self._setup_distr_env() self.master_addr = os.environ["MASTER_ADDR"] @@ -32,3 +32,10 @@ def _setup_distr_env(self): ) os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] os.environ["RANK"] = os.environ["SLURM_PROCID"] + + +class DistributedEnvironmentOpenmpi: + def __init__(self): + self.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + self.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + self.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])