Extend device data node binding API to not clone specified input tensors #9054
+117
−37
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In this PR, we extend the
_get_tensors_xla_device_data_node
binding API to return the same tensor values for a given set of specified unmutated tensor inputs. It currently returns a list of tensors that capture the XLATensor values of the graph inputs. However, these tensors end up creating new ATen tensors, which are effectively clones of the original tensors.xla/torch_xla/csrc/init_python_bindings.cpp
Line 2919 in c4b45a9
This makes it so that these are not eligible to be aliased at the step barrier. Instead, we extend the API to allow users to specify the list of known input tensors, such that if graph input matches one of those inputs, then the same ATen tensor is returned back to the user as instead of returning a clone.
The input tensors are kept as optional, ensuring that the API change is backwards compatible.
cc: @mcuiaws