From 2db38e9c2b34538630e42a43e9faa5685ec60bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Wed, 10 Jul 2024 22:38:37 -0400 Subject: [PATCH] Initialize distributed runs if JAX_WORLD_SIZE > 1 --- skrl/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skrl/__init__.py b/skrl/__init__.py index 8b2e30eb..a363b974 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -124,9 +124,8 @@ def __init__(self) -> None: # device self._device = f"cuda:{self._local_rank}" - # TODO: find a better place for it # set up distributed runs - if self._is_distributed and "jax" in sys.modules: + if self._is_distributed: import jax logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})") jax.distributed.initialize(coordinator_address=self._coordinator_address,