From 69f13c78b8d078b9e0beadcda5e3583a8db65fed Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:25:19 +0800 Subject: [PATCH] [NPU] Update layernorm node on MTL/ARL (#12738) * Update layernorm node on MTL/ARL * Fix on style --- .../src/ipex_llm/transformers/npu_models/mp_models_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 27e4469fd99..ad687741c34 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -472,7 +472,9 @@ def layer_norm(self, hidden_states, layernorm_weight): ) eps = self.constant(self.rms_norm_eps) hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"]: + if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"] or \ + os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or \ + os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1": # to support special drivers hidden_states = self.convert_to_fp16(hidden_states) else: