Skip to content

Commit

Permalink
generate
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Dec 2, 2024
1 parent 26adb82 commit 18cabaf
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
77 changes: 71 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# limitations under the License.


from ipex_llm.utils.common.log4Error import invalidInputError
import os
import time
import torch
import importlib
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
import time
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import CausalLMOutputWithPast
from ipex_llm.transformers.utils import module_name_process
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
from ipex_llm.utils.common.log4Error import invalidInputError


def module_optimization(func) -> torch.nn.Module:
Expand Down Expand Up @@ -133,6 +134,14 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)


def general_convert(m, target_m, new_func, func_name="forward"):
if isinstance(m, target_m):
bound_method = new_func.__get__(m, m.__class__)
setattr(m, func_name, bound_method)
for _, sub_m in m.named_children():
general_convert(sub_m, target_m, new_func, func_name)


def optimize_llm(model: torch.nn.Module):
if model.config.model_type == "llama":
from ipex_llm.transformers.npu_models.llama import merge_qkv
Expand Down Expand Up @@ -390,10 +399,66 @@ def optimize_llm_single_process(
model.kv_len = kv_len
model.model_ptr = model_ptr
model.vocab_size = model.config.vocab_size
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
except:
invalidInputError(False,
"False to InitLLMPipeline.")
# patch generate function
import types
model.generate = types.MethodType(generate, model)
# import types
# model.generate = types.MethodType(generate, model)
from transformers.modeling_utils import PreTrainedModel
general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
general_convert(model, PreTrainedModel, causal_lm_forward)
return model


def prepare_input_ids(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None: # kvcache
input_ids = input_ids[:, -1]
else: # prefill, reset the model here
from .npu_llm_cpp import reset
reset(self.model_ptr)
model_inputs = {
"input_ids": input_ids
}
return model_inputs


def causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
start = time.perf_counter()
from .npu_llm_cpp import run_decode, run_prefill, get_logits
if isinstance(input_ids[0], torch.Tensor):
input_list = input_ids[0].flatten().tolist()
else:
input_list = input_ids[0]
input_length = len(input_list)
if input_length > 1:
run_prefill(self.model_ptr, input_list, self.vocab_size)
else:
run_decode(self.model_ptr, input_list[0], self.vocab_size)
logits = get_logits(self.model_ptr, self.logits_buffer)
end = time.perf_counter()
overall = (end - start) * 1000
print("Overall time: ", overall)

return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=1, # just an indicator
hidden_states=None,
attentions=None,
)
10 changes: 10 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def get_shared_lib_info(lib_base_name: str):
_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None

_lib.get_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_float)]
_lib.reset.restype = None


def load_model_from_file(model_dir: str):
return _lib.load_model_from_file(model_dir.encode('utf-8'))
Expand All @@ -81,3 +84,10 @@ def run_decode(model_ptr, input_id, vocab_size):

def reset(model_ptr):
_lib.reset(model_ptr)


def get_logits(model_ptr, logits):
src = logits.data.data_ptr()
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
_lib.get_logits(model_ptr, src)
return logits

0 comments on commit 18cabaf

Please sign in to comment.