diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index a9dad42ed33..3aa99e2eac3 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -47,7 +47,7 @@ import torch import torch.distributed import torch.nn.functional as F -from torch import Tensor, device, dtype, nn +from torch import Tensor, dtype, nn from operator import mul from functools import reduce from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd @@ -294,10 +294,10 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): if hard_condition: return ( batch_size > 1 - or (device in ["arc"] and qtype in [SYM_INT8, FP4]) - or (device in ["arc", "mtl"] and qtype in [FP8E4]) - or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0) - or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5]) + or (device_name in ["arc"] and qtype in [SYM_INT8, FP4]) + or (device_name in ["arc", "mtl"] and qtype in [FP8E4]) + or (device_name in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0) + or (device_name in ["bmg"] and qtype in [SYM_INT4, FP8E5]) ) return False