diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index f6cfa1f79d6..8a107bff62c 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -204,12 +204,15 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int): - - invalidInputError(tensor.dtype == torch.uint8, - "Input tensor must be uint8") + if qtype == NF4: + invalidInputError(tensor.dtype == torch.bfloat16, + "NF4 Input tensor must be bfloat16") + else: + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor must be uint8") invalidInputError(tensor.device == torch.device('cpu'), - "Input tensor must be uint8") + "Input tensor must be on cpu") src = ctypes.c_void_p(tensor.data.data_ptr())