Skip to content

Commit 29efe34

Browse files
Merge pull request #5 from yhgon/patch-1
Handle different device in `torch.index_select`
2 parents f6f65ad + 8de8825 commit 29efe34

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

memory_efficient_attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def dynamic_slice(x, starts, sizes):
66
# start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
77
starts = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))]
88
for i, (start, size) in enumerate(zip(starts, sizes)):
9-
x = torch.index_select(x, i, torch.tensor(range(start, start + size)))
9+
x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device))
1010
return x
1111

1212

0 commit comments

Comments
 (0)