Skip to content

Commit

Permalink
unroller_only option for DistributedUnroller
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Mar 5, 2025
1 parent 455bea1 commit dc18a00
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,16 @@ def _train_iter_off_policy(self):
return steps


@alf.configurable(whitelist=['episode_length', 'name', 'optimizer'])
@alf.configurable(
whitelist=['episode_length', 'name', 'optimizer', 'unroller_only'])
class DistributedUnroller(DistributedOffPolicyAlgorithm):
def __init__(self,
core_alg_ctor: Callable,
*args,
episode_length: int = 200,
env: AlfEnvironment = None,
config: TrainerConfig = None,
unroller_only: bool = False,
debug_summaries: bool = False,
name: str = "DistributedUnroller",
**kwargs):
Expand All @@ -585,6 +587,11 @@ def __init__(self,
step type to an artificial ``StepType.LAST``, and switching.
For traing safety, it is recommended to always set this value to a
positive number.
unroller_only: if True, will skip the training related steps (including
registering to all trainer workers, _create_pull_params_subprocess,
_check_paramss_update observe_for_replay) and do unroll only. Therefore,
it won't be blocked by the trainer and is useful for visualizing the
unroll behaviors without training.
*args: additional args to pass to ``core_alg_ctor``.
**kwargs: additional kwargs to pass to ``core_alg_ctor``.
"""
Expand All @@ -605,6 +612,7 @@ def __init__(self,
self._episode_length = episode_length
self._num_exps = 0
self._is_first_step = True
self._unroller_only = unroller_only

ip = get_local_ip()
self._id = f"unroller-{ip}-{self._port}"
Expand Down Expand Up @@ -686,6 +694,9 @@ def observe_for_replay(self, exp: Experience):
Every time we make sure a full episode is sent to the same DDP rank, if
multi-gpu training is enabled on the trainer.
"""
if self._unroller_only:
return

# Get the current worker id to send the exp to
worker_id = f'worker-{self._current_worker}'
self._num_exps += 1
Expand Down Expand Up @@ -765,7 +776,7 @@ def train_iter(self):
There is actually no training happening in this function. But the unroller
will check if there are updated params available.
"""
if not self._registered:
if not self._registered and not self._unroller_only:
# We need lazy registration so that trainer's params has a higher
# priority than the unroller's loaded params (if enabled).
self._register_to_trainer()
Expand All @@ -784,5 +795,6 @@ def train_iter(self):
torch.cuda.empty_cache()
# Experience will be sent to the trainer in this function
self._unroll_iter_off_policy()
self._check_paramss_update()
if not self._unroller_only:
self._check_paramss_update()
return 0

0 comments on commit dc18a00

Please sign in to comment.