Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Faster clone #1040

Open
wants to merge 5 commits into
base: gh/vmoens/29/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None and self.device.type == "cuda":
return self._clone_recurse()

result = TensorDict._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
Expand Down
17 changes: 17 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8164,6 +8164,23 @@ def cosh_(self) -> T:
torch._foreach_cosh_(self._values_list(True, True))
return self

def _clone_recurse(self) -> TensorDictBase: # noqa: D417
keys, vals = self._items_list(True, True)
vals = torch._foreach_add(vals, 0)
items = dict(zip(keys, vals))
result = self._fast_apply(
lambda name, val: items.pop(name, None),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down
20 changes: 9 additions & 11 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def _call(
"The output of the function must be a tensordict, a tensorclass or None. Got "
f"type(out)={type(out)}."
)
if is_tensor_collection(out):
out.lock_()
self._out = out
self.counter += 1
if self._out_matches_in:
Expand Down Expand Up @@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
if self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
result = tree_unflatten(
[out.clone() for out in self._out], self._out_struct
)
return result

Expand Down Expand Up @@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self._out = out
self._out, self._out_struct = tree_flatten(out)
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
# user.
Expand All @@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
f"and the identity between input and output will not match anymore. "
f"Make sure you don't rely on input-output identity further in the code."
)
if isinstance(self._out, torch.Tensor) or self._out is None:
self._return_unchanged = (
"clone" if self._out is not None else True
)
if not self._out:
self._return_unchanged = True
else:
self._out = [out.lock_() if is_tensor_collection(out) else out for out in self._out]
self._return_unchanged = False
return this_out

Expand Down
Loading