3
3
from typing import Optional , Tuple , Union
4
4
import gc
5
5
import openvino as ov
6
+ from openvino .runtime import opset13
7
+ import nncf
8
+ import numpy as np
6
9
import torch
7
- from transformers import AutoModelForCausalLM
8
- from transformers import AutoProcessor
10
+ from transformers import AutoModelForCausalLM , AutoProcessor , AutoConfig
9
11
from transformers .generation import GenerationConfig , GenerationMixin
10
12
from transformers .modeling_outputs import CausalLMOutputWithPast , BaseModelOutputWithPast
11
- from transformers import AutoConfig
12
-
13
- from typing import Optional , Tuple , List
14
- from openvino .runtime import opset13
15
- import numpy as np
16
- import nncf
17
13
18
14
19
15
def model_has_state (ov_model : ov .Model ):
@@ -203,7 +199,14 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
203
199
img_projection_path = output_dir / "img_projection.xml"
204
200
embed_token_path = output_dir / "embed_token.xml"
205
201
206
- if all ([lang_model_path .exists (), image_embed_path .exists (), img_projection_path .exists (), embed_token_path .exists ()]):
202
+ if all (
203
+ [
204
+ lang_model_path .exists (),
205
+ image_embed_path .exists (),
206
+ img_projection_path .exists (),
207
+ embed_token_path .exists (),
208
+ ]
209
+ ):
207
210
print (f"✅ Phi-3-vision model already converted. You can find results in { output_dir } " )
208
211
return
209
212
print ("⌛ Phi-3-vision conversion started. Be patient, it may takes some time." )
@@ -216,7 +219,10 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
216
219
217
220
if not embed_token_path .exists ():
218
221
print ("⌛ Convert Input embedding model" )
219
- ov_model = ov .convert_model (model .model .embed_tokens , example_input = torch .ones ([2 , 2 ], dtype = torch .int64 ))
222
+ ov_model = ov .convert_model (
223
+ model .model .embed_tokens ,
224
+ example_input = torch .ones ([2 , 2 ], dtype = torch .int64 ),
225
+ )
220
226
ov .save_model (ov_model , embed_token_path )
221
227
del ov_model
222
228
cleanup_torchscript_cache ()
@@ -236,7 +242,10 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
236
242
237
243
if not img_projection_path .exists ():
238
244
print ("⌛ Convert Image projection model" )
239
- ov_model = ov .convert_model (vision_embed_tokens .img_projection , example_input = torch .ones ([1 , 1921 , 4096 ]))
245
+ ov_model = ov .convert_model (
246
+ vision_embed_tokens .img_projection ,
247
+ example_input = torch .ones ([1 , 1921 , 4096 ]),
248
+ )
240
249
ov .save_model (ov_model , img_projection_path )
241
250
del ov_model
242
251
cleanup_torchscript_cache ()
@@ -246,16 +255,29 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
246
255
if not lang_model_path .exists ():
247
256
print ("⌛ Convert Language model" )
248
257
249
- def forward_wrap (self , attention_mask , position_ids = None , past_key_values = None , inputs_embeds = None ):
258
+ def forward_wrap (
259
+ self ,
260
+ attention_mask ,
261
+ position_ids = None ,
262
+ past_key_values = None ,
263
+ inputs_embeds = None ,
264
+ ):
250
265
result = self ._orig_forward (
251
- input_ids = None , attention_mask = attention_mask , position_ids = position_ids , past_key_values = past_key_values , inputs_embeds = inputs_embeds
266
+ input_ids = None ,
267
+ attention_mask = attention_mask ,
268
+ position_ids = position_ids ,
269
+ past_key_values = past_key_values ,
270
+ inputs_embeds = inputs_embeds ,
252
271
)
253
272
return tuple (result .values ())
254
273
255
274
model ._orig_forward = model .forward
256
275
model .forward = types .MethodType (forward_wrap , model )
257
276
llm_input = torch .zeros ([2 , 2 , 3072 ])
258
- pkv = model (inputs_embeds = llm_input , attention_mask = torch .ones ((2 , 2 ), dtype = torch .int64 ))[1 ]
277
+ pkv = model (
278
+ inputs_embeds = llm_input ,
279
+ attention_mask = torch .ones ((2 , 2 ), dtype = torch .int64 ),
280
+ )[1 ]
259
281
model_inputs = ["attention_mask" , "position_ids" ]
260
282
model_outputs = ["logits" ]
261
283
for idx in range (len (pkv )):
@@ -393,7 +415,14 @@ def _get_past_length(self, past_key_values=None):
393
415
return self ._past_length
394
416
395
417
def prepare_inputs_for_generation (
396
- self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , pixel_values = None , image_sizes = None , ** kwargs
418
+ self ,
419
+ input_ids ,
420
+ past_key_values = None ,
421
+ attention_mask = None ,
422
+ inputs_embeds = None ,
423
+ pixel_values = None ,
424
+ image_sizes = None ,
425
+ ** kwargs ,
397
426
):
398
427
if past_key_values is not None :
399
428
past_length = self ._get_past_length (past_key_values )
@@ -435,7 +464,12 @@ def prepare_inputs_for_generation(
435
464
)
436
465
return model_inputs
437
466
438
- def vision_embed_tokens (self , input_ids : torch .LongTensor , pixel_values : torch .FloatTensor , image_sizes = None ) -> torch .FloatTensor :
467
+ def vision_embed_tokens (
468
+ self ,
469
+ input_ids : torch .LongTensor ,
470
+ pixel_values : torch .FloatTensor ,
471
+ image_sizes = None ,
472
+ ) -> torch .FloatTensor :
439
473
MAX_INPUT_ID = int (1e9 )
440
474
img_embeds = pixel_values
441
475
img_sizes = image_sizes
0 commit comments