Skip to content

Commit da51a44

Browse files
Merge pull request #6 from yhgon/patch-2
Handling device for `big_neg`
2 parents 29efe34 + 4c286eb commit da51a44

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

memory_efficient_attention/attention_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def summarize_chunk(key_idx, query, key, value, mask, bias):
2727
mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data)
2828
if mask is not None:
2929
big_neg = torch.finfo(attn_weights.dtype).min
30-
big_neg = torch.tensor(big_neg, dtype=torch.float32)
30+
big_neg = torch.tensor(big_neg, , device=mask.device, dtype=torch.float32)
3131
mask = torch.einsum('...hqk->...qhk', mask)
3232
attn_weights = torch.where(mask, attn_weights, big_neg)
3333
if weights_calc_fn is not None:

0 commit comments

Comments
 (0)