diff --git a/nmma/utils/gitlab.py b/nmma/utils/gitlab.py index 997068b5..96689280 100644 --- a/nmma/utils/gitlab.py +++ b/nmma/utils/gitlab.py @@ -3,7 +3,6 @@ from os.path import exists from pathlib import Path from os import makedirs - import requests from requests.exceptions import ConnectionError from yaml import load @@ -14,11 +13,19 @@ MODELS = {} +try: + from mpi4py import MPI + mpi_enabled = True +except ImportError: + mpi_enabled = False def download_and_decompress(file_info): download(file_info) decompress(file_info[1]) +def mpi_barrier(comm): + if mpi_enabled: + comm.Barrier() def download_models_list(models_home=None): # first we load the models list from gitlab @@ -29,7 +36,6 @@ def download_models_list(models_home=None): with open(Path(models_home, "models.yaml"), "wb") as f: f.write(r.content) - def load_models_list(models_home=None): models_home = get_models_home(models_home) @@ -96,7 +102,6 @@ def load_models_list(models_home=None): return models, downloaded_if_missing is False - def refresh_models_list(models_home=None): global MODELS models_home = get_models_home(models_home) @@ -110,7 +115,6 @@ def refresh_models_list(models_home=None): raise ValueError(f"Could not load models list: {str(e)}") return models - def get_model( models_home=None, model_name=None, @@ -184,15 +188,32 @@ def get_model( [f"{base_url}/{core_model_name}.{core_format}"] if not filters_only else [] ) + [f"{base_url}/{model_name}/{f}.{filter_format}" for f in filters] + comm = None + if mpi_enabled: + try: + comm = MPI.COMM_WORLD + except Exception as e: + print("MPI could not be initialized:", e) + comm = None + + rank = 0 + if comm: + try: + rank = comm.Get_rank() + except Exception as e: + print("Error getting MPI rank:", e) + missing = [(f"{u}", f"{f}") for u, f in zip(urls, filepaths) if not f.exists()] if len(missing) > 0: if not download_if_missing: raise OSError("Data not found and `download_if_missing` is False") - print(f"downloading {len(missing)} files for model {model_name}:") - with ThreadPoolExecutor( - max_workers=min(len(missing), max(cpu_count(), 8)) - ) as executor: - executor.map(download_and_decompress, missing) + if rank == 0 or not comm: + print(f"downloading {len(missing)} files for model {model_name}:") + with ThreadPoolExecutor( + max_workers=min(len(missing), max(cpu_count(), 8)) + ) as executor: + executor.map(download_and_decompress, missing) + mpi_barrier(comm) return [str(f) for f in filepaths], filters + skipped_filters diff --git a/nmma/utils/models.py b/nmma/utils/models.py index e6afe3d7..2aa73efa 100644 --- a/nmma/utils/models.py +++ b/nmma/utils/models.py @@ -1,5 +1,14 @@ from .models_tools import SOURCES, get_models_home, get_parser # noqa +try: + from mpi4py import MPI + mpi_enabled = True +except ImportError: + mpi_enabled = False + +def mpi_barrier(comm): + if mpi_enabled: + comm.Barrier() def refresh_models_list(models_home=None, source=None): @@ -37,6 +46,21 @@ def get_model( source=None, ): + comm = None + if mpi_enabled: + try: + comm = MPI.COMM_WORLD + except Exception as e: + print("MPI could not be initialized:", e) + comm = None + + rank = 0 + if comm: + try: + rank = comm.Get_rank() + except Exception as e: + print("Error getting MPI rank:", e) + if source is None: source = SOURCES[0] if source not in ["gitlab"]: @@ -48,13 +72,15 @@ def get_model( if source == "gitlab": from .gitlab import get_model - files, filters = get_model( - models_home=models_home, - model_name=model_name, - filters=filters, - download_if_missing=download_if_missing, - filters_only=filters_only, - ) + if rank == 0 or not MPI.Is_initialized(): + files, filters = get_model( + models_home=models_home, + model_name=model_name, + filters=filters, + download_if_missing=download_if_missing, + filters_only=filters_only, + ) + mpi_barrier(comm) break except Exception as e: print(f"Error while getting model from {source}: {str(e)}") @@ -111,5 +137,3 @@ def main(args=None): if __name__ == "__main__": main() - -# python nmma/utils/models.py --model="Bu2019lm" --filters=ztfr,ztfg,ztfi --svd-path='./svdmodels' --source='gitlab'