File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -56,7 +56,7 @@ void multi_tensor_apply(
56
56
for (int t = 0 ; t < tensor_lists[l].size (); t++)
57
57
{
58
58
// TODO: Print which tensor fails.
59
- bool contiguous_memory = tensor_lists[l][t].is_contiguous ();
59
+ bool contiguous_memory = (tensor_lists[l][t]. is_sparse ()) ? tensor_lists[l][t]. _values (). is_contiguous () : tensor_lists[l][t].is_contiguous ();
60
60
#ifdef VERSION_GE_1_5
61
61
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous (at::MemoryFormat::ChannelsLast));
62
62
#endif
@@ -78,8 +78,15 @@ void multi_tensor_apply(
78
78
for (int t = 0 ; t < ntensors; t++)
79
79
{
80
80
tl.sizes [loc_tensor_info] = tensor_lists[0 ][t].numel ();
81
- for (int d = 0 ; d < depth; d++)
82
- tl.addresses [d][loc_tensor_info] = tensor_lists[d][t].data_ptr ();
81
+ for (int d = 0 ; d < depth; d++) {
82
+ if (tensor_lists[d][t].is_sparse ()) {
83
+ at::Tensor dst = at::zeros (tensor_lists[d][t].sizes (), tensor_lists[d][t].options ().layout (at::kStrided ));
84
+ dst.add_ (tensor_lists[d][t]);
85
+ tl.addresses [d][loc_tensor_info] = dst.data_ptr ();
86
+ } else {
87
+ tl.addresses [d][loc_tensor_info] = tensor_lists[d][t].data_ptr ();
88
+ }
89
+ }
83
90
loc_tensor_info++;
84
91
85
92
int chunks_this_tensor = (tensor_lists[0 ][t].numel () + chunk_size - 1 )/chunk_size;
You can’t perform that action at this time.
0 commit comments