forked from ggerganov/whisper.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert-whisper-to-coreml.py
331 lines (253 loc) · 12.6 KB
/
convert-whisper-to-coreml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import argparse
import torch
import torch.nn.functional as F
import coremltools as ct
from torch import Tensor
from torch import nn
from typing import Dict
from typing import Optional
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
from coremltools.models.neural_network.quantization_utils import quantize_weights
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model
# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_attention = all(substr in k for substr in ['attn', '.weight'])
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
return state_dict
class LayerNormANE(LayerNormANEBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
class MultiHeadAttentionANE(MultiHeadAttention):
def __init__(self, n_state: int, n_head: int):
super().__init__(n_state, n_head)
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
def forward(self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention_ane(q, k, v, mask)
return self.out(wv), qk
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
_, dim, _, seqlen = q.size()
dim_per_head = dim // self.n_head
scale = float(dim_per_head)**-0.5
q = q * scale
mh_q = q.split(dim_per_head, dim=1)
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
mh_v = v.split(dim_per_head, dim=1)
mh_qk = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
for qi, ki in zip(mh_q, mh_k)
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_head):
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
return attn, torch.cat(mh_qk, dim=1).float().detach()
class ResidualAttentionBlockANE(ResidualAttentionBlock):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__(n_state, n_head, cross_attention)
self.attn = MultiHeadAttentionANE(n_state, n_head)
self.attn_ln = LayerNormANE(n_state)
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Conv2d(n_state, n_mlp, kernel_size=1),
nn.GELU(),
nn.Conv2d(n_mlp, n_state, kernel_size=1)
)
self.mlp_ln = LayerNormANE(n_state)
class AudioEncoderANE(AudioEncoder):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
self.blocks = nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNormANE(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
# Add positional embedding and add dummy dim for ANE
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
# """
# TODO:
# I think we need to transpose the result here to make it fit whisper.cpp memory order.
# However, even doing this, the results are still wrong. Kind of less wrong compared to
# not transposing, but still wrong.
# Also, I don't know why the original OpenAI implementation does not need to transpose
# transpose to (batch_size, n_ctx, n_state)
# x : torch.Tensor, shape = (batch_size, n_state, 1, n_ctx)
# """
# x = x.transpose(1,3)
return x
class TextDecoderANE(TextDecoder):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
self.blocks= nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln= LayerNormANE(n_state)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
# Reformat for ANE
mask = self.mask[None, None, :, :].permute(0,3,1,2)
x = x.transpose(1,2).unsqueeze(2)
for block in self.blocks:
x = block(x, xa, mask=mask, kv_cache=kv_cache)
x = self.ln(x)
# Reformat back from ANE
x = x.permute(0,2,3,1).squeeze(0)
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
if self.token_embedding.weight.shape[0] >= 51865:
# split in 11 chunks - 4715 each
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
else:
# split in 12 chunks - 4322 each
assert(self.token_embedding.weight.shape[0] == 51864)
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
return logits
class WhisperANE(Whisper):
def __init__(self, dims: ModelDimensions):
super().__init__(dims)
self.encoder = AudioEncoderANE(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoderANE(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=3).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttentionANE):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
def convert_encoder(hparams, model, quantize=False):
model.eval()
input_shape = (1, 80, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
model = ct.convert(
traced_model,
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
outputs=[ct.TensorType(name="output")],
compute_units=ct.ComputeUnit.ALL
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
def convert_decoder(hparams, model, quantize=False):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_state, 1, 1500)
audio_data = torch.randn(audio_shape)
token_data = torch.randint(50257, tokens_shape).long()
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(
traced_model,
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
inputs=[
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
ct.TensorType(name="audio_data", shape=audio_shape)
]
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()
hparams = whisper.dims
print(hparams)
if args.optimize_ane:
whisperANE = WhisperANE(hparams).eval()
whisperANE.load_state_dict(whisper.state_dict())
encoder = whisperANE.encoder
decoder = whisperANE.decoder
else:
encoder = whisper.encoder
decoder = whisper.decoder
# Convert encoder
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
if args.encoder_only is False:
# Convert decoder
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
print("done converting")