114114]
115115
116116
117+ def fit_bf16_to_uint16_np (tensor ):
118+ if "xpu" in paddle .device .get_device () and isinstance (tensor , np .ndarray ) and str (tensor .dtype ) == "bfloat16" :
119+ return tensor .astype ("uint16" )
120+ return tensor
121+
122+
117123def dy2st_nocheck_guard_context ():
118124 try :
119125 context = paddle .framework ._no_check_dy2st_diff ()
@@ -433,7 +439,7 @@ def _transpose_hf_weight(key, weight):
433439 and not key .endswith ("_scale" )
434440 ):
435441 # numpy.array -> paddle.tensor
436- weight = paddle .Tensor .__call__ (py_safe_slice_ [:], zero_copy = True )
442+ weight = paddle .Tensor .__call__ (fit_bf16_to_uint16_np ( py_safe_slice_ [:]) , zero_copy = True )
437443 weight = _transpose_hf_weight (key , weight )
438444 key_name = key .split (".weight" )[0 ]
439445 quant_key_name = key_name + ".quant_weight"
@@ -478,7 +484,7 @@ def _transpose_hf_weight(key, weight):
478484 weight = py_safe_slice_ [:]
479485 if not return_numpy and device == "expected" :
480486 with device_guard ():
481- weight = paddle .Tensor .__call__ (weight , zero_copy = True )
487+ weight = paddle .Tensor .__call__ (fit_bf16_to_uint16_np ( weight ) , zero_copy = True )
482488 weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
483489 weight = _transpose_hf_weight (key , weight )
484490 part_state_dict [key ] = weight
@@ -492,7 +498,7 @@ def _transpose_hf_weight(key, weight):
492498 scale = f .get_tensor (key )
493499 if not return_numpy and device == "expected" :
494500 with device_guard ():
495- scale = paddle .Tensor .__call__ (scale , zero_copy = True )
501+ scale = paddle .Tensor .__call__ (fit_bf16_to_uint16_np ( scale ) , zero_copy = True )
496502 scale = scale ._copy_to (paddle .framework ._current_expected_place (), False )
497503 scale_dict [key ] = scale
498504 return part_state_dict , scale_dict
@@ -583,10 +589,14 @@ def load_state_dict(
583589 if device == "cpu" :
584590 with device_guard ():
585591 for k in list (state_dict .keys ()):
586- state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
592+ state_dict [k ] = paddle .Tensor .__call__ (
593+ fit_bf16_to_uint16_np (state_dict .pop (k )), zero_copy = True
594+ )
587595 elif device == "pin_memory" :
588596 for k in list (state_dict .keys ()):
589- state_dict [k ] = paddle .to_tensor (state_dict .pop (k ), place = paddle .CUDAPinnedPlace ())
597+ state_dict [k ] = paddle .to_tensor (
598+ fit_bf16_to_uint16_np (state_dict .pop (k )), place = paddle .CUDAPinnedPlace ()
599+ )
590600
591601 if len (scale_dict ) != 0 :
592602 if ckpt_quant_stage == "O0" :
@@ -2784,7 +2794,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
27842794 for k in list (state_dict .keys ()):
27852795 if not isinstance (state_dict [k ], paddle .Tensor ):
27862796 with device_guard ():
2787- state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
2797+ state_dict [k ] = paddle .Tensor .__call__ (
2798+ fit_bf16_to_uint16_np (state_dict .pop (k )), zero_copy = True
2799+ )
27882800 else :
27892801 if is_sharded :
27902802 loaded_state_dict_keys = sharded_metadata ["all_checkpoint_keys" ]
@@ -2799,7 +2811,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
27992811 for k in list (state_dict .keys ()):
28002812 if not isinstance (state_dict [k ], paddle .Tensor ):
28012813 with device_guard ():
2802- state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
2814+ state_dict [k ] = paddle .Tensor .__call__ (
2815+ fit_bf16_to_uint16_np (state_dict .pop (k )), zero_copy = True
2816+ )
28032817 # 3. init the model
28042818 init_args = config ["init_args" ] or ()
28052819 with ContextManagers (init_contexts ):
@@ -3380,7 +3394,9 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
33803394 if not return_numpy :
33813395 for key in list (state_dict .keys ()):
33823396 if isinstance (state_dict [key ], np .ndarray ):
3383- state_dict [key ] = paddle .Tensor .__call__ (state_dict .pop (key ), zero_copy = True )
3397+ state_dict [key ] = paddle .Tensor .__call__ (
3398+ fit_bf16_to_uint16_np (state_dict .pop (key )), zero_copy = True
3399+ )
33843400 return state_dict
33853401
33863402 index_file = os .path .join (folder , _add_variant (PADDLE_WEIGHTS_INDEX_NAME , variant ))
@@ -3423,7 +3439,7 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
34233439 if not return_numpy :
34243440 for key in list (ret .keys ()):
34253441 if isinstance (ret [key ], np .ndarray ):
3426- ret [key ] = paddle .Tensor .__call__ (ret .pop (key ), zero_copy = True )
3442+ ret [key ] = paddle .Tensor .__call__ (fit_bf16_to_uint16_np ( ret .pop (key ) ), zero_copy = True )
34273443
34283444 return ret
34293445
0 commit comments