Skip to content

Commit

Permalink
Revert CudaRNGStatesTracker
Browse files Browse the repository at this point in the history
Signed-off-by: Robin Zhang <[email protected]>
  • Loading branch information
buptzyb authored and yifeis-nv committed Nov 20, 2024
1 parent a01602e commit 8017b6d
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,24 +720,18 @@ class CudaRNGStatesTracker:
"""

def __init__(self):
self.reset()

def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized

def reset(self):
"""Set to the initial state (no tracker)."""

# Track if initialized.
self._is_initialized = False

# Map from a string name to the cuda rng state.
self.states_ = {}

# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()

def reset(self):
"""
Set to the initial state (no tracker).
"""
self.states_ = {}
self.seeds_ = set()

def get_states(self) -> Dict[str, torch.Tensor]:
"""
Get rng states. Copy the dictionary so we have direct pointers
Expand All @@ -756,7 +750,6 @@ def set_states(self, states: Dict[str, torch.Tensor]) -> None:
states: Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self._is_initialized = True
self.states_ = states

def add(self, name: str, seed: int) -> None:
Expand All @@ -768,7 +761,6 @@ def add(self, name: str, seed: int) -> None:
seed: int
PyTorch seed for the RNG state.
"""
self._is_initialized = True
# Check seed is not already used.
if seed in self.seeds_:
raise RuntimeError(f"seed {seed} already exists")
Expand Down

0 comments on commit 8017b6d

Please sign in to comment.