Skip to content

Commit

Permalink
[Doc] Fix tutorials (#2768)
Browse files Browse the repository at this point in the history
(cherry picked from commit 75f113f)
  • Loading branch information
vmoens committed Feb 7, 2025
1 parent 3ff13ff commit ef4d7a1
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
build-docs:
strategy:
matrix:
python_version: ["3.10"]
python_version: ["3.9"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -60,7 +60,7 @@ jobs:
bash ./miniconda.sh -b -f -p "${conda_dir}"
eval "$(${conda_dir}/bin/conda shell.bash hook)"
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python=3.10
conda create --prefix "${env_dir}" -y python=3.9
printf "* Activating\n"
conda activate "${env_dir}"
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@
"filename_pattern": "reference/generated/tutorials/", # files to parse
"notebook_images": "reference/generated/tutorials/media/", # images to parse
"download_all_examples": True,
"abort_on_example_error": False,
"only_warn_on_example_error": True,
"abort_on_example_error": True,
"show_memory": True,
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
"write_computation_times": True,
}

napoleon_use_ivar = True
Expand Down
10 changes: 3 additions & 7 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ class ReplayBufferTrainer(TrainerHookBase):
memmap (bool, optional): if ``True``, a memmap tensordict is created.
Default is ``False``.
device (device, optional): device where the samples must be placed.
Default is ``cpu``.
Default to ``None``.
flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be
flattened (or equivalently masked with the valid mask obtained from
the collector) before being passed to the replay buffer. Otherwise,
Expand All @@ -666,7 +666,7 @@ def __init__(
replay_buffer: TensorDictReplayBuffer,
batch_size: Optional[int] = None,
memmap: bool = False,
device: DEVICE_TYPING = "cpu",
device: DEVICE_TYPING | None = None,
flatten_tensordicts: bool = False,
max_dims: Optional[Sequence[int]] = None,
) -> None:
Expand Down Expand Up @@ -695,15 +695,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase:
pads += [0, pad_value]
batch = pad(batch, pads)
batch = batch.cpu()
if self.memmap:
# We can already place the tensords on the device if they're memmap,
# as this is a lazy op
batch = batch.memmap_().to(self.device)
self.replay_buffer.extend(batch)

def sample(self, batch: TensorDictBase) -> TensorDictBase:
sample = self.replay_buffer.sample(batch_size=self.batch_size)
return sample.to(self.device, non_blocking=True)
return sample.to(self.device) if self.device is not None else sample

def update_priority(self, batch: TensorDictBase) -> None:
self.replay_buffer.update_tensordict_priority(batch)
Expand Down
6 changes: 6 additions & 0 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,12 @@ def ceil_div(x, y):
collector.shutdown()
del collector

try:
parallel_env.close()
del parallel_env
except Exception:
pass

###############################################################################
# Experiment results
# ------------------
Expand Down
8 changes: 6 additions & 2 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,12 @@ def make_model(dummy_env):
# time must always have the same shape.


def get_replay_buffer(buffer_size, n_optim, batch_size):
def get_replay_buffer(buffer_size, n_optim, batch_size, device):
replay_buffer = TensorDictReplayBuffer(
batch_size=batch_size,
storage=LazyMemmapStorage(buffer_size),
prefetch=n_optim,
transform=lambda td: td.to(device),
)
return replay_buffer

Expand Down Expand Up @@ -660,7 +661,7 @@ def get_loss_module(actor, gamma):
# requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which
# can be cumbersome to implement.
buffer_hook = ReplayBufferTrainer(
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size, device=device),
flatten_tensordicts=True,
)
buffer_hook.register(trainer)
Expand Down Expand Up @@ -750,6 +751,9 @@ def print_csv_files_in_folder(folder_path):

print_csv_files_in_folder(logger.experiment.log_dir)

trainer.shutdown()
del trainer

###############################################################################
# Conclusion and possible improvements
# ------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import R3MTransform, TransformedEnv
from torchrl.envs import Compose, R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor

Expand Down Expand Up @@ -115,7 +115,7 @@
from torchrl.data import LazyMemmapStorage, ReplayBuffer

storage = LazyMemmapStorage(1000)
rb = ReplayBuffer(storage=storage, transform=r3m)
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))

##############################################################################
# We can now collect the data (random rollouts for our purpose) and fill the replay
Expand Down

0 comments on commit ef4d7a1

Please sign in to comment.