Skip to content

Commit

Permalink
Merge pull request nuclear-multimessenger-astronomy#361 from sahiljha…
Browse files Browse the repository at this point in the history
…war/utils-mpi

Prevent parallel model download when using MPI
  • Loading branch information
sahiljhawar authored May 2, 2024
2 parents 1f6b7b7 + 147a663 commit ffa5f1a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 18 deletions.
39 changes: 30 additions & 9 deletions nmma/utils/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from os.path import exists
from pathlib import Path
from os import makedirs

import requests
from requests.exceptions import ConnectionError
from yaml import load
Expand All @@ -14,11 +13,19 @@

MODELS = {}

try:
from mpi4py import MPI
mpi_enabled = True
except ImportError:
mpi_enabled = False

def download_and_decompress(file_info):
download(file_info)
decompress(file_info[1])

def mpi_barrier(comm):
if mpi_enabled:
comm.Barrier()

def download_models_list(models_home=None):
# first we load the models list from gitlab
Expand All @@ -29,7 +36,6 @@ def download_models_list(models_home=None):
with open(Path(models_home, "models.yaml"), "wb") as f:
f.write(r.content)


def load_models_list(models_home=None):

models_home = get_models_home(models_home)
Expand Down Expand Up @@ -96,7 +102,6 @@ def load_models_list(models_home=None):

return models, downloaded_if_missing is False


def refresh_models_list(models_home=None):
global MODELS
models_home = get_models_home(models_home)
Expand All @@ -110,7 +115,6 @@ def refresh_models_list(models_home=None):
raise ValueError(f"Could not load models list: {str(e)}")
return models


def get_model(
models_home=None,
model_name=None,
Expand Down Expand Up @@ -184,15 +188,32 @@ def get_model(
[f"{base_url}/{core_model_name}.{core_format}"] if not filters_only else []
) + [f"{base_url}/{model_name}/{f}.{filter_format}" for f in filters]

comm = None
if mpi_enabled:
try:
comm = MPI.COMM_WORLD
except Exception as e:
print("MPI could not be initialized:", e)
comm = None

rank = 0
if comm:
try:
rank = comm.Get_rank()
except Exception as e:
print("Error getting MPI rank:", e)

missing = [(f"{u}", f"{f}") for u, f in zip(urls, filepaths) if not f.exists()]
if len(missing) > 0:
if not download_if_missing:
raise OSError("Data not found and `download_if_missing` is False")

print(f"downloading {len(missing)} files for model {model_name}:")
with ThreadPoolExecutor(
max_workers=min(len(missing), max(cpu_count(), 8))
) as executor:
executor.map(download_and_decompress, missing)
if rank == 0 or not comm:
print(f"downloading {len(missing)} files for model {model_name}:")
with ThreadPoolExecutor(
max_workers=min(len(missing), max(cpu_count(), 8))
) as executor:
executor.map(download_and_decompress, missing)
mpi_barrier(comm)

return [str(f) for f in filepaths], filters + skipped_filters
42 changes: 33 additions & 9 deletions nmma/utils/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .models_tools import SOURCES, get_models_home, get_parser # noqa

try:
from mpi4py import MPI
mpi_enabled = True
except ImportError:
mpi_enabled = False

def mpi_barrier(comm):
if mpi_enabled:
comm.Barrier()

def refresh_models_list(models_home=None, source=None):

Expand Down Expand Up @@ -37,6 +46,21 @@ def get_model(
source=None,
):

comm = None
if mpi_enabled:
try:
comm = MPI.COMM_WORLD
except Exception as e:
print("MPI could not be initialized:", e)
comm = None

rank = 0
if comm:
try:
rank = comm.Get_rank()
except Exception as e:
print("Error getting MPI rank:", e)

if source is None:
source = SOURCES[0]
if source not in ["gitlab"]:
Expand All @@ -48,13 +72,15 @@ def get_model(
if source == "gitlab":
from .gitlab import get_model

files, filters = get_model(
models_home=models_home,
model_name=model_name,
filters=filters,
download_if_missing=download_if_missing,
filters_only=filters_only,
)
if rank == 0 or not MPI.Is_initialized():
files, filters = get_model(
models_home=models_home,
model_name=model_name,
filters=filters,
download_if_missing=download_if_missing,
filters_only=filters_only,
)
mpi_barrier(comm)
break
except Exception as e:
print(f"Error while getting model from {source}: {str(e)}")
Expand Down Expand Up @@ -111,5 +137,3 @@ def main(args=None):

if __name__ == "__main__":
main()

# python nmma/utils/models.py --model="Bu2019lm" --filters=ztfr,ztfg,ztfi --svd-path='./svdmodels' --source='gitlab'

0 comments on commit ffa5f1a

Please sign in to comment.