We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents f6f65ad + 8de8825 commit 29efe34Copy full SHA for 29efe34
memory_efficient_attention/utils.py
@@ -6,7 +6,7 @@ def dynamic_slice(x, starts, sizes):
6
# start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
7
starts = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))]
8
for i, (start, size) in enumerate(zip(starts, sizes)):
9
- x = torch.index_select(x, i, torch.tensor(range(start, start + size)))
+ x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device))
10
return x
11
12
0 commit comments