Skip to content

Commit

Permalink
Specify different seeds in distributed runs for torch and jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 11, 2024
1 parent 2db38e9 commit dc466d4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions skrl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
seed = int(seed)

# set different seeds in distributed runs
if config.torch.is_distributed or config.jax.is_distributed:
seed += max(config.torch.rank, config.jax.rank)
if config.torch.is_distributed:
seed += config.torch.rank
if config.jax.is_distributed:
seed += config.jax.rank

logger.info(f"Seed: {seed}")

Expand Down

0 comments on commit dc466d4

Please sign in to comment.