From 43cca3be279f457959cc237999e04386713635c3 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Wed, 14 Aug 2024 05:43:33 +0300 Subject: [PATCH] fix gemma2 runtime error caused by sliding window (#11788) * fix runtime error * revert workflow --- .github/workflows/llm_performance_tests.yml | 55 +++++++++---------- .../ipex_llm/transformers/models/gemma2.py | 14 +++-- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/.github/workflows/llm_performance_tests.yml b/.github/workflows/llm_performance_tests.yml index 53cb9c642c6..f48fb742afe 100644 --- a/.github/workflows/llm_performance_tests.yml +++ b/.github/workflows/llm_performance_tests.yml @@ -1207,35 +1207,32 @@ jobs: call conda deactivate - # NOTE: Gemma2 not working for 4096-512. - # When it works, uncomment this section and remember to change "'s/{today}_test3/{today}_test1/g'" in next section. + - name: Prepare igpu perf test for transformers 4.43 (4096-512 int4+fp16) + shell: bash + run: | + sed -i 's/{today}_test3/{today}_test4/g' python/llm/dev/benchmark/all-in-one/run.py + sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/4096-512_int4_fp16_443.yaml + + - name: Test on igpu for transformers 4.43 (4096-512 int4+fp16) + shell: cmd + run: | + call conda activate igpu-perf + pip install transformers==4.43.1 + pip install trl - #- name: Prepare igpu perf test for transformers 4.43 (4096-512 int4+fp16) - # shell: bash - # run: | - # sed -i 's/{today}_test3/{today}_test4/g' python/llm/dev/benchmark/all-in-one/run.py - # sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/4096-512_int4_fp16_443.yaml - - #- name: Test on igpu for transformers 4.43 (4096-512 int4+fp16) - # shell: cmd - # run: | - # call conda activate igpu-perf - # pip install transformers==4.43.1 - # pip install trl - # - # set SYCL_CACHE_PERSISTENT=1 - # set BIGDL_LLM_XMX_DISABLED=1 - # - # cd python\llm\dev\benchmark\all-in-one - # move ..\..\..\test\benchmark\igpu-perf\4096-512_int4_fp16_443.yaml config.yaml - # set PYTHONIOENCODING=utf-8 - # python run.py >> %CSV_SAVE_PATH%\4096-512_int4_fp16\log\%LOG_FILE% 2>&1 - # if %ERRORLEVEL% neq 0 (exit /b 1) - # python ..\..\..\test\benchmark\igpu-perf\check_csv_results.py --yaml-file config.yaml --suffix test4 - # if %ERRORLEVEL% neq 0 (exit /b 1) - # - # pip uninstall trl -y - # call conda deactivate + set SYCL_CACHE_PERSISTENT=1 + set BIGDL_LLM_XMX_DISABLED=1 + + cd python\llm\dev\benchmark\all-in-one + move ..\..\..\test\benchmark\igpu-perf\4096-512_int4_fp16_443.yaml config.yaml + set PYTHONIOENCODING=utf-8 + python run.py >> %CSV_SAVE_PATH%\4096-512_int4_fp16\log\%LOG_FILE% 2>&1 + if %ERRORLEVEL% neq 0 (exit /b 1) + python ..\..\..\test\benchmark\igpu-perf\check_csv_results.py --yaml-file config.yaml --suffix test4 + if %ERRORLEVEL% neq 0 (exit /b 1) + + pip uninstall trl -y + call conda deactivate - name: Concat csv and generate html (4096-512 int4+fp16) shell: cmd @@ -1259,7 +1256,7 @@ jobs: shell: bash run: | sed -i 's/4096-512/1024-128/g' python/llm/dev/benchmark/all-in-one/run.py - sed -i 's/{today}_test3/{today}_test1/g' python/llm/dev/benchmark/all-in-one/run.py + sed -i 's/{today}_test4/{today}_test1/g' python/llm/dev/benchmark/all-in-one/run.py sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/1024-128_int4_fp16_loadlowbit.yaml - name: Test on igpu (load_low_bit 1024-128 int4+fp16) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py index 33201864223..07f8314a021 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma2.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -129,7 +129,8 @@ def gemma2_attention_forward( # IPEX_LLM OPT: sdp kv_seq_len = q_len if past_key_value is None else past_key_value.kv_seq_len if (use_sdp_causal(q_len, kv_seq_len, -1, query_states, self.training) - and kv_seq_len <= key_states.size(2)): + and kv_seq_len <= key_states.size(2) and + (self.sliding_window is None or kv_seq_len < self.sliding_window)): import xe_addons attn_weights = None attn_output = xe_addons.gemma2_sdp_causal(query_states, @@ -141,10 +142,15 @@ def gemma2_attention_forward( elif use_sdp(q_len, kv_seq_len, -1, query_states): import xe_addons attn_weights = None + if self.sliding_window is not None: + attn_mask = attention_mask[:, :, :q_len, : key_states.shape[-2]] + else: + attn_mask = attention_mask + attn_output = xe_addons.gemma2_sdp(query_states, - key_states[:, :, :kv_seq_len, :], - value_states[:, :, :kv_seq_len, :], - attention_mask[:, :, :q_len, :kv_seq_len], + key_states, + value_states, + attn_mask, self.config.attn_logit_softcapping, self.scaling) else: