33from typing import Optional , Tuple , Union
44import gc
55import openvino as ov
6+ from openvino .runtime import opset13
7+ import nncf
8+ import numpy as np
69import torch
7- from transformers import AutoModelForCausalLM
8- from transformers import AutoProcessor
10+ from transformers import AutoModelForCausalLM , AutoProcessor , AutoConfig
911from transformers .generation import GenerationConfig , GenerationMixin
1012from transformers .modeling_outputs import CausalLMOutputWithPast , BaseModelOutputWithPast
11- from transformers import AutoConfig
12-
13- from typing import Optional , Tuple , List
14- from openvino .runtime import opset13
15- import numpy as np
16- import nncf
1713
1814
1915def model_has_state (ov_model : ov .Model ):
@@ -203,7 +199,14 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
203199 img_projection_path = output_dir / "img_projection.xml"
204200 embed_token_path = output_dir / "embed_token.xml"
205201
206- if all ([lang_model_path .exists (), image_embed_path .exists (), img_projection_path .exists (), embed_token_path .exists ()]):
202+ if all (
203+ [
204+ lang_model_path .exists (),
205+ image_embed_path .exists (),
206+ img_projection_path .exists (),
207+ embed_token_path .exists (),
208+ ]
209+ ):
207210 print (f"✅ Phi-3-vision model already converted. You can find results in { output_dir } " )
208211 return
209212 print ("⌛ Phi-3-vision conversion started. Be patient, it may takes some time." )
@@ -216,7 +219,10 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
216219
217220 if not embed_token_path .exists ():
218221 print ("⌛ Convert Input embedding model" )
219- ov_model = ov .convert_model (model .model .embed_tokens , example_input = torch .ones ([2 , 2 ], dtype = torch .int64 ))
222+ ov_model = ov .convert_model (
223+ model .model .embed_tokens ,
224+ example_input = torch .ones ([2 , 2 ], dtype = torch .int64 ),
225+ )
220226 ov .save_model (ov_model , embed_token_path )
221227 del ov_model
222228 cleanup_torchscript_cache ()
@@ -236,7 +242,10 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
236242
237243 if not img_projection_path .exists ():
238244 print ("⌛ Convert Image projection model" )
239- ov_model = ov .convert_model (vision_embed_tokens .img_projection , example_input = torch .ones ([1 , 1921 , 4096 ]))
245+ ov_model = ov .convert_model (
246+ vision_embed_tokens .img_projection ,
247+ example_input = torch .ones ([1 , 1921 , 4096 ]),
248+ )
240249 ov .save_model (ov_model , img_projection_path )
241250 del ov_model
242251 cleanup_torchscript_cache ()
@@ -246,16 +255,29 @@ def convert_phi3_model(model_id, output_dir, quantization_config):
246255 if not lang_model_path .exists ():
247256 print ("⌛ Convert Language model" )
248257
249- def forward_wrap (self , attention_mask , position_ids = None , past_key_values = None , inputs_embeds = None ):
258+ def forward_wrap (
259+ self ,
260+ attention_mask ,
261+ position_ids = None ,
262+ past_key_values = None ,
263+ inputs_embeds = None ,
264+ ):
250265 result = self ._orig_forward (
251- input_ids = None , attention_mask = attention_mask , position_ids = position_ids , past_key_values = past_key_values , inputs_embeds = inputs_embeds
266+ input_ids = None ,
267+ attention_mask = attention_mask ,
268+ position_ids = position_ids ,
269+ past_key_values = past_key_values ,
270+ inputs_embeds = inputs_embeds ,
252271 )
253272 return tuple (result .values ())
254273
255274 model ._orig_forward = model .forward
256275 model .forward = types .MethodType (forward_wrap , model )
257276 llm_input = torch .zeros ([2 , 2 , 3072 ])
258- pkv = model (inputs_embeds = llm_input , attention_mask = torch .ones ((2 , 2 ), dtype = torch .int64 ))[1 ]
277+ pkv = model (
278+ inputs_embeds = llm_input ,
279+ attention_mask = torch .ones ((2 , 2 ), dtype = torch .int64 ),
280+ )[1 ]
259281 model_inputs = ["attention_mask" , "position_ids" ]
260282 model_outputs = ["logits" ]
261283 for idx in range (len (pkv )):
@@ -393,7 +415,14 @@ def _get_past_length(self, past_key_values=None):
393415 return self ._past_length
394416
395417 def prepare_inputs_for_generation (
396- self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , pixel_values = None , image_sizes = None , ** kwargs
418+ self ,
419+ input_ids ,
420+ past_key_values = None ,
421+ attention_mask = None ,
422+ inputs_embeds = None ,
423+ pixel_values = None ,
424+ image_sizes = None ,
425+ ** kwargs ,
397426 ):
398427 if past_key_values is not None :
399428 past_length = self ._get_past_length (past_key_values )
@@ -435,7 +464,12 @@ def prepare_inputs_for_generation(
435464 )
436465 return model_inputs
437466
438- def vision_embed_tokens (self , input_ids : torch .LongTensor , pixel_values : torch .FloatTensor , image_sizes = None ) -> torch .FloatTensor :
467+ def vision_embed_tokens (
468+ self ,
469+ input_ids : torch .LongTensor ,
470+ pixel_values : torch .FloatTensor ,
471+ image_sizes = None ,
472+ ) -> torch .FloatTensor :
439473 MAX_INPUT_ID = int (1e9 )
440474 img_embeds = pixel_values
441475 img_sizes = image_sizes
0 commit comments