Skip to content

Commit cb8a852

Browse files
committed
update
1 parent 8740621 commit cb8a852

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

python/llm/src/ipex_llm/transformers/npu_models/convert.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ def optimize_llm_single_process(
404404
invalidInputError(False,
405405
"False to InitLLMPipeline.")
406406
# patch generate function
407-
# import types
408-
# model.generate = types.MethodType(generate, model)
407+
import types
408+
model.simple_generate = types.MethodType(generate, model)
409409
from transformers.modeling_utils import PreTrainedModel
410410
general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
411411
general_convert(model, PreTrainedModel, causal_lm_forward)
@@ -439,17 +439,16 @@ def causal_lm_forward(
439439
return_dict: Optional[bool] = None,
440440
) -> Union[Tuple, CausalLMOutputWithPast]:
441441
start = time.perf_counter()
442-
from .npu_llm_cpp import run_decode, run_prefill, get_logits
442+
from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
443443
if isinstance(input_ids[0], torch.Tensor):
444444
input_list = input_ids[0].flatten().tolist()
445445
else:
446446
input_list = input_ids[0]
447447
input_length = len(input_list)
448448
if input_length > 1:
449-
run_prefill(self.model_ptr, input_list, self.vocab_size)
449+
logits = run_prefill_with_logits(self.model_ptr, input_list, self.logits_buffer, self.vocab_size)
450450
else:
451-
run_decode(self.model_ptr, input_list[0], self.vocab_size)
452-
logits = get_logits(self.model_ptr, self.logits_buffer)
451+
logits = run_decode_with_logits(self.model_ptr, input_list[0], self.logits_buffer, self.vocab_size)
453452
end = time.perf_counter()
454453
overall = (end - start) * 1000
455454
print("Overall time: ", overall)

python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ def get_shared_lib_info(lib_base_name: str):
6060
_lib.reset.argtypes = [ctypes.c_void_p]
6161
_lib.reset.restype = None
6262

63-
_lib.get_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_float)]
64-
_lib.reset.restype = None
63+
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
64+
_lib.run_prefill_with_logits.restype = None
65+
66+
_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
67+
_lib.run_decode_with_logits.restype = None
6568

6669

6770
def load_model_from_file(model_dir: str):
@@ -82,12 +85,21 @@ def run_decode(model_ptr, input_id, vocab_size):
8285
return new_token
8386

8487

85-
def reset(model_ptr):
86-
_lib.reset(model_ptr)
88+
def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
89+
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
90+
input_len = len(input_ids)
91+
logits_ptr = logits.data.data_ptr()
92+
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
93+
_lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
94+
return logits
8795

8896

89-
def get_logits(model_ptr, logits):
90-
src = logits.data.data_ptr()
91-
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
92-
_lib.get_logits(model_ptr, src)
97+
def run_decode_with_logits(model_ptr, input_id, logits, vocab_size):
98+
logits_ptr = logits.data.data_ptr()
99+
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
100+
_lib.run_decode_with_logits(model_ptr, input_id, logits_ptr, vocab_size)
93101
return logits
102+
103+
104+
def reset(model_ptr):
105+
_lib.reset(model_ptr)

0 commit comments

Comments
 (0)