Skip to content

Extend device data node binding API to not clone specified input tensors #9054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,69 @@ def test_no_op_mark_step_keep_buffer_donation(self):
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))


def test_device_data_node_tracing_aliasing(self):
"""
Test that _get_tensors_xla_device_data_node does not return new XLA tensors
for a given set of unmutated input tensor during its tracing. This helps ensure that
aliasings can be retained if using the binding for tracing purposes.
"""
xla_device = xm.xla_device()
t0 = torch.tensor(10).to(xla_device)

t1 = t0 + 5
t0_input_tensor_id = torch_xla._XLAC._xla_get_tensor_id(t0)
t1_output_tensor_id = torch_xla._XLAC._xla_get_tensor_id(t1)

# We feed t0 as an input to the API that computes the tensor values of all
# the specified nodes, ensuring that it does not return a new XLA tensor
# for the same backend data, if it is not mutated. Note that t0 is captured
# when doing a post order traversal of t1.
results_with_inputs = torch_xla._XLAC._get_tensors_xla_device_data_node([t1],
[t0])
self.assertEqual(len(results_with_inputs), 2)
try:
input_index = results_with_inputs[0].index(t0_input_tensor_id)
non_input_index = 0 if input_index == 0 else 1
except ValueError:
self.fail(
f"Input tensor ID {t0_input_tensor_id} is not present in the results: {results_with_inputs[0]}"
)

# Since t0 is an input tensor and not mutated, we expect the resulting
# tensor ID and the ID associated with the XLA Tensor to match the original
# value.
self.assertEqual(results_with_inputs[0][input_index], t0_input_tensor_id)
self.assertEqual(
torch_xla._XLAC._xla_get_tensor_id(results_with_inputs[1][input_index]),
t0_input_tensor_id)

# Since t1 is not an input to the API, we expect a new XLA tensor to be
# generated for the resulting values that map to t1.
self.assertNotEqual(results_with_inputs[0][non_input_index],
t1_output_tensor_id)
self.assertNotEqual(
torch_xla._XLAC._xla_get_tensor_id(
results_with_inputs[1][non_input_index]), t1_output_tensor_id)

torch_xla._XLAC._xla_sync_multi([t0, t1], [str(xla_device)], True, False)
self.assertTrue(t1.item(), 16)

# In case we do have a mutation of the input, then we should expect that a
# different tensor ID is returned.
t0 += 10
t1 = t0 + 5

t0_input_tensor_id = torch_xla._XLAC._xla_get_tensor_id(t0)
results_with_inputs = torch_xla._XLAC._get_tensors_xla_device_data_node([t1],
[t0])

self.assertFalse(t0_input_tensor_id in results_with_inputs[0])
self.assertFalse(t0_input_tensor_id in [
torch_xla._XLAC._xla_get_tensor_id(tensor)
for tensor in results_with_inputs[1]
])


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
93 changes: 56 additions & 37 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2877,50 +2877,69 @@ void InitXlaModuleBindings(py::module m) {
// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
* to compute the value of tensors.
* to compute the value of tensors. Note that all DeviceData nodes return a
* clone of the tensor. In case the user provides a list of uncloned tensors,
* then the API will ensure that any captured tensor that is included in this
* list does not clone the XLA tensor. This ensures that uses of the resulting
* tensor values can be correctly aliased, as the tensor retains the original
* tensor IDs.
*/
m.def("_get_tensors_xla_device_data_node",
[](const std::vector<at::Tensor>& tensors)
-> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
std::vector<int64_t> tensor_ids;
std::vector<at::IValue> ivalues;
std::vector<const torch::lazy::Node*> roots;
for (const at::Tensor& tensor : tensors) {
auto xtensor = bridge::TryGetXlaTensor(tensor);
if (xtensor) {
roots.push_back(xtensor->GetIrValue().node.get());
}
m.def(
"_get_tensors_xla_device_data_node",
[](const std::vector<at::Tensor>& output_tensors,
const std::vector<at::Tensor>& uncloned_tensors)
-> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
std::vector<const torch::lazy::Node*> roots;
for (const at::Tensor& tensor : output_tensors) {
auto xtensor = bridge::TryGetXlaTensor(tensor);
if (xtensor) {
roots.push_back(xtensor->GetIrValue().node.get());
}
auto post_order = torch::lazy::Util::ComputePostOrder(roots);
std::unordered_set<torch::lazy::BackendData::Handle> data_handles;

for (const torch::lazy::Node* nodeptr : post_order) {
const auto backend_data =
torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr);
if (!backend_data) {
continue;
}
}

// Dedup by handle
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
if (!data_handles.insert(handle).second) {
continue;
}
auto* infoptr =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
backend_data->info());
if (infoptr) {
tensor_ids.push_back(infoptr->tensor_id);
} else {
// TODO(JackCaoG): Make sure this device data is actually seed.
tensor_ids.push_back(seed_info_id);
}
std::unordered_map<int64_t, at::Tensor> uncloned_tensor_map;
uncloned_tensor_map.reserve(uncloned_tensors.size());
for (const at::Tensor& tensor : uncloned_tensors) {
int64_t tensor_id = GetTensorId(tensor);
uncloned_tensor_map[tensor_id] = tensor;
}

auto post_order = torch::lazy::Util::ComputePostOrder(roots);
std::unordered_set<torch::lazy::BackendData::Handle> data_handles;

std::vector<int64_t> tensor_ids;
std::vector<at::IValue> ivalues;
for (const torch::lazy::Node* nodeptr : post_order) {
const auto backend_data =
torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr);
if (!backend_data) {
continue;
}

// Dedup by handle
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
if (!data_handles.insert(handle).second) {
continue;
}
auto* infoptr =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
backend_data->info());

// TODO(JackCaoG): Make sure this device data is actually seed.
int64_t tensor_id = infoptr ? infoptr->tensor_id : seed_info_id;
tensor_ids.push_back(tensor_id);
if (uncloned_tensor_map.find(tensor_id) !=
uncloned_tensor_map.end()) {
ivalues.emplace_back(uncloned_tensor_map[tensor_id]);
} else {
at::Tensor tensor = bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(backend_data));
ivalues.emplace_back(tensor);
}
return std::make_pair(tensor_ids, ivalues);
});
}
return std::make_pair(tensor_ids, ivalues);
},
py::arg("output_tensors"), py::arg("uncloned_tensors") = py::list());

m.def("_get_seed_info_id", []() -> int64_t { return seed_info_id; });

Expand Down
Loading