@@ -267,6 +267,8 @@ def _call(
267
267
"The output of the function must be a tensordict, a tensorclass or None. Got "
268
268
f"type(out)={ type (out )} ."
269
269
)
270
+ if is_tensor_collection (out ):
271
+ out .lock_ ()
270
272
self ._out = out
271
273
self .counter += 1
272
274
if self ._out_matches_in :
@@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302
304
torch ._foreach_copy_ (dests , srcs )
303
305
torch .cuda .synchronize ()
304
306
self .graph .replay ()
305
- if self ._return_unchanged == "clone" :
306
- result = self ._out .clone ()
307
- elif self ._return_unchanged :
307
+ if self ._return_unchanged :
308
308
result = self ._out
309
309
else :
310
- result = tree_map (
311
- lambda x : x .detach ().clone () if x is not None else x ,
312
- self ._out ,
310
+ result = tree_unflatten (
311
+ [out .clone () for out in self ._out ], self ._out_struct
313
312
)
314
313
return result
315
314
@@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340
339
self .graph = torch .cuda .CUDAGraph ()
341
340
with torch .cuda .graph (self .graph ):
342
341
out = self .module (* self ._args , ** self ._kwargs )
343
- self ._out = out
342
+ self ._out , self . _out_struct = tree_flatten ( out )
344
343
self .counter += 1
345
344
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346
345
# user.
@@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356
355
f"and the identity between input and output will not match anymore. "
357
356
f"Make sure you don't rely on input-output identity further in the code."
358
357
)
359
- if isinstance (self ._out , torch .Tensor ) or self ._out is None :
360
- self ._return_unchanged = (
361
- "clone" if self ._out is not None else True
362
- )
358
+ if not self ._out :
359
+ self ._return_unchanged = True
363
360
else :
361
+ self ._out = [out .lock_ () if is_tensor_collection (out ) else out for out in self ._out ]
364
362
self ._return_unchanged = False
365
363
return this_out
366
364
0 commit comments