Skip to content

Commit 8eb7778

Browse files
author
Vincent Moens
committed
[Performance] Faster clone
ghstack-source-id: e710b72 Pull Request resolved: #1040
1 parent 8d845f7 commit 8eb7778

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

tensordict/_td.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
30093009
return all([value.is_contiguous() for _, value in self.items()])
30103010

30113011
def _clone(self, recurse: bool = True) -> T:
3012+
if recurse and self.device is not None and self.device.type == "cuda":
3013+
return self._clone_recurse()
3014+
30123015
result = TensorDict._new_unsafe(
30133016
source={key: _clone_value(value, recurse) for key, value in self.items()},
30143017
batch_size=self.batch_size,

tensordict/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8164,6 +8164,23 @@ def cosh_(self) -> T:
81648164
torch._foreach_cosh_(self._values_list(True, True))
81658165
return self
81668166

8167+
def _clone_recurse(self) -> TensorDictBase: # noqa: D417
8168+
keys, vals = self._items_list(True, True)
8169+
vals = torch._foreach_add(vals, 0)
8170+
items = dict(zip(keys, vals))
8171+
result = self._fast_apply(
8172+
lambda name, val: items.pop(name, None),
8173+
named=True,
8174+
nested_keys=True,
8175+
is_leaf=_NESTED_TENSORS_AS_LISTS,
8176+
propagate_lock=True,
8177+
filter_empty=True,
8178+
default=None,
8179+
)
8180+
if items:
8181+
result.update(items)
8182+
return result
8183+
81678184
def add(
81688185
self,
81698186
other: TensorDictBase | torch.Tensor,

tensordict/nn/cudagraphs.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def _call(
267267
"The output of the function must be a tensordict, a tensorclass or None. Got "
268268
f"type(out)={type(out)}."
269269
)
270+
if is_tensor_collection(out):
271+
out.lock_()
270272
self._out = out
271273
self.counter += 1
272274
if self._out_matches_in:
@@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302304
torch._foreach_copy_(dests, srcs)
303305
torch.cuda.synchronize()
304306
self.graph.replay()
305-
if self._return_unchanged == "clone":
306-
result = self._out.clone()
307-
elif self._return_unchanged:
307+
if self._return_unchanged:
308308
result = self._out
309309
else:
310-
result = tree_map(
311-
lambda x: x.detach().clone() if x is not None else x,
312-
self._out,
310+
result = tree_unflatten(
311+
[out.clone() for out in self._out], self._out_struct
313312
)
314313
return result
315314

@@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340339
self.graph = torch.cuda.CUDAGraph()
341340
with torch.cuda.graph(self.graph):
342341
out = self.module(*self._args, **self._kwargs)
343-
self._out = out
342+
self._out, self._out_struct = tree_flatten(out)
344343
self.counter += 1
345344
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346345
# user.
@@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356355
f"and the identity between input and output will not match anymore. "
357356
f"Make sure you don't rely on input-output identity further in the code."
358357
)
359-
if isinstance(self._out, torch.Tensor) or self._out is None:
360-
self._return_unchanged = (
361-
"clone" if self._out is not None else True
362-
)
358+
if not self._out:
359+
self._return_unchanged = True
363360
else:
361+
self._out = [out.lock_() if is_tensor_collection(out) else out for out in self._out]
364362
self._return_unchanged = False
365363
return this_out
366364

0 commit comments

Comments
 (0)