Skip to content

Commit d9e0dbc

Browse files
authored
fit xpu not support bf16 (#2809)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent 569640b commit d9e0dbc

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

paddleformers/transformers/model_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@
114114
]
115115

116116

117+
def fit_bf16_to_uint16_np(tensor):
118+
if "xpu" in paddle.device.get_device() and isinstance(tensor, np.ndarray) and str(tensor.dtype) == "bfloat16":
119+
return tensor.astype("uint16")
120+
return tensor
121+
122+
117123
def dy2st_nocheck_guard_context():
118124
try:
119125
context = paddle.framework._no_check_dy2st_diff()
@@ -433,7 +439,7 @@ def _transpose_hf_weight(key, weight):
433439
and not key.endswith("_scale")
434440
):
435441
# numpy.array -> paddle.tensor
436-
weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True)
442+
weight = paddle.Tensor.__call__(fit_bf16_to_uint16_np(py_safe_slice_[:]), zero_copy=True)
437443
weight = _transpose_hf_weight(key, weight)
438444
key_name = key.split(".weight")[0]
439445
quant_key_name = key_name + ".quant_weight"
@@ -478,7 +484,7 @@ def _transpose_hf_weight(key, weight):
478484
weight = py_safe_slice_[:]
479485
if not return_numpy and device == "expected":
480486
with device_guard():
481-
weight = paddle.Tensor.__call__(weight, zero_copy=True)
487+
weight = paddle.Tensor.__call__(fit_bf16_to_uint16_np(weight), zero_copy=True)
482488
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
483489
weight = _transpose_hf_weight(key, weight)
484490
part_state_dict[key] = weight
@@ -492,7 +498,7 @@ def _transpose_hf_weight(key, weight):
492498
scale = f.get_tensor(key)
493499
if not return_numpy and device == "expected":
494500
with device_guard():
495-
scale = paddle.Tensor.__call__(scale, zero_copy=True)
501+
scale = paddle.Tensor.__call__(fit_bf16_to_uint16_np(scale), zero_copy=True)
496502
scale = scale._copy_to(paddle.framework._current_expected_place(), False)
497503
scale_dict[key] = scale
498504
return part_state_dict, scale_dict
@@ -583,10 +589,14 @@ def load_state_dict(
583589
if device == "cpu":
584590
with device_guard():
585591
for k in list(state_dict.keys()):
586-
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
592+
state_dict[k] = paddle.Tensor.__call__(
593+
fit_bf16_to_uint16_np(state_dict.pop(k)), zero_copy=True
594+
)
587595
elif device == "pin_memory":
588596
for k in list(state_dict.keys()):
589-
state_dict[k] = paddle.to_tensor(state_dict.pop(k), place=paddle.CUDAPinnedPlace())
597+
state_dict[k] = paddle.to_tensor(
598+
fit_bf16_to_uint16_np(state_dict.pop(k)), place=paddle.CUDAPinnedPlace()
599+
)
590600

591601
if len(scale_dict) != 0:
592602
if ckpt_quant_stage == "O0":
@@ -2784,7 +2794,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
27842794
for k in list(state_dict.keys()):
27852795
if not isinstance(state_dict[k], paddle.Tensor):
27862796
with device_guard():
2787-
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
2797+
state_dict[k] = paddle.Tensor.__call__(
2798+
fit_bf16_to_uint16_np(state_dict.pop(k)), zero_copy=True
2799+
)
27882800
else:
27892801
if is_sharded:
27902802
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
@@ -2799,7 +2811,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
27992811
for k in list(state_dict.keys()):
28002812
if not isinstance(state_dict[k], paddle.Tensor):
28012813
with device_guard():
2802-
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
2814+
state_dict[k] = paddle.Tensor.__call__(
2815+
fit_bf16_to_uint16_np(state_dict.pop(k)), zero_copy=True
2816+
)
28032817
# 3. init the model
28042818
init_args = config["init_args"] or ()
28052819
with ContextManagers(init_contexts):
@@ -3380,7 +3394,9 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
33803394
if not return_numpy:
33813395
for key in list(state_dict.keys()):
33823396
if isinstance(state_dict[key], np.ndarray):
3383-
state_dict[key] = paddle.Tensor.__call__(state_dict.pop(key), zero_copy=True)
3397+
state_dict[key] = paddle.Tensor.__call__(
3398+
fit_bf16_to_uint16_np(state_dict.pop(key)), zero_copy=True
3399+
)
33843400
return state_dict
33853401

33863402
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
@@ -3423,7 +3439,7 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
34233439
if not return_numpy:
34243440
for key in list(ret.keys()):
34253441
if isinstance(ret[key], np.ndarray):
3426-
ret[key] = paddle.Tensor.__call__(ret.pop(key), zero_copy=True)
3442+
ret[key] = paddle.Tensor.__call__(fit_bf16_to_uint16_np(ret.pop(key)), zero_copy=True)
34273443

34283444
return ret
34293445

0 commit comments

Comments
 (0)