Skip to content

Commit

Permalink
Update IPEX_LLM_PERFORMANCE_MODE with input length threshold (#11908)
Browse files Browse the repository at this point in the history
* Update IPEX_LLM_PERFORMANCE_MODE with input length threshold

* Update based on comments. And and judgement for inputs_embeds

* Fix for benchmarking purposes

* Update based on comments

* Small fix
  • Loading branch information
Oscilloscope98 authored Aug 23, 2024
1 parent 303a090 commit 24c279e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
original_generate = GenerationMixin.generate
query_group_size = 16

# may tune it with more tested data
PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100


@torch.no_grad()
def generate(
Expand All @@ -57,7 +60,14 @@ def generate(
lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1" and lookahead is None:
lookahead = 2 # default to 2 now
if inputs is not None:
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
else:
inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None:
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
if lookahead:
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
Expand Down

0 comments on commit 24c279e

Please sign in to comment.