-
Notifications
You must be signed in to change notification settings - Fork 2
/
sub_quadratic_attention.py
214 lines (186 loc) · 7.14 KB
/
sub_quadratic_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT License (see LICENSE-sub-quadratic-attention.txt)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, List
def narrow_trunc(
input: Tensor,
dim: int,
start: int,
length: int
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_tokens, k_channels_per_head = key.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = narrow_trunc(
key,
1,
chunk_idx,
kv_chunk_size
)
value_chunk = narrow_trunc(
value,
1,
chunk_idx,
kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key: keys for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
return narrow_trunc(
query,
1,
chunk_idx,
min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)
if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key=key,
value=value,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res