We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 29efe34 + 4c286eb commit da51a44Copy full SHA for da51a44
memory_efficient_attention/attention_torch.py
@@ -27,7 +27,7 @@ def summarize_chunk(key_idx, query, key, value, mask, bias):
27
mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data)
28
if mask is not None:
29
big_neg = torch.finfo(attn_weights.dtype).min
30
- big_neg = torch.tensor(big_neg, dtype=torch.float32)
+ big_neg = torch.tensor(big_neg, , device=mask.device, dtype=torch.float32)
31
mask = torch.einsum('...hqk->...qhk', mask)
32
attn_weights = torch.where(mask, attn_weights, big_neg)
33
if weights_calc_fn is not None:
0 commit comments