Skip to content

Commit 02a5274

Browse files
authored
Enable support for sparse tensors for multi_tensor_apply (#6)
1 parent 2d0f9cf commit 02a5274

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

csrc/multi_tensor_apply.cuh

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void multi_tensor_apply(
5656
for(int t = 0; t < tensor_lists[l].size(); t++)
5757
{
5858
// TODO: Print which tensor fails.
59-
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
59+
bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
6060
#ifdef VERSION_GE_1_5
6161
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
6262
#endif
@@ -78,8 +78,15 @@ void multi_tensor_apply(
7878
for(int t = 0; t < ntensors; t++)
7979
{
8080
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
81-
for(int d = 0; d < depth; d++)
82-
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
81+
for(int d = 0; d < depth; d++) {
82+
if (tensor_lists[d][t].is_sparse()) {
83+
at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
84+
dst.add_(tensor_lists[d][t]);
85+
tl.addresses[d][loc_tensor_info] = dst.data_ptr();
86+
} else {
87+
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
88+
}
89+
}
8390
loc_tensor_info++;
8491

8592
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;

0 commit comments

Comments
 (0)