Skip to content

Commit

Permalink
[Performance] Faster clone
Browse files Browse the repository at this point in the history
ghstack-source-id: 6eecbac5e946a4d93d3d6e148e8c18aaa2501b00
Pull Request resolved: #1040
  • Loading branch information
vmoens committed Oct 14, 2024
1 parent 088d953 commit 4786d54
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
3 changes: 3 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None and self.device.type == "cuda":
return self._clone_recurse()

result = TensorDict._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
Expand Down
17 changes: 17 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8156,6 +8156,23 @@ def cosh_(self) -> T:
torch._foreach_cosh_(self._values_list(True, True))
return self

def _clone_recurse(self) -> TensorDictBase: # noqa: D417
keys, vals = self._items_list(True, True)
vals = torch._foreach_add(vals, 0)
items = dict(zip(keys, vals))
result = self._fast_apply(
lambda name, val: items.pop(name, None),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down
17 changes: 6 additions & 11 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
if self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
result = tree_unflatten(
torch._foreach_add(self._out, 0.0), self._out_struct
)
return result

Expand Down Expand Up @@ -340,7 +337,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self._out = out
self._out, self._out_struct = tree_flatten(out)
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
# user.
Expand All @@ -356,10 +353,8 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
f"and the identity between input and output will not match anymore. "
f"Make sure you don't rely on input-output identity further in the code."
)
if isinstance(self._out, torch.Tensor) or self._out is None:
self._return_unchanged = (
"clone" if self._out is not None else True
)
if not self._out:
self._return_unchanged = True
else:
self._return_unchanged = False
return this_out
Expand Down

0 comments on commit 4786d54

Please sign in to comment.