Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Dec 3, 2024
1 parent b62ae49 commit 03bf892
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,11 @@ def causal_lm_forward(
input_list = input_ids[0]
input_length = len(input_list)
if input_length > 1:
logits = run_prefill_with_logits(self.model_ptr, input_list, self.logits_buffer, self.vocab_size)
logits = run_prefill_with_logits(self.model_ptr, input_list,
self.logits_buffer, self.vocab_size)
else:
logits = run_decode_with_logits(self.model_ptr, input_list[0], self.logits_buffer, self.vocab_size)
logits = run_decode_with_logits(self.model_ptr, input_list[0],
self.logits_buffer, self.vocab_size)

return CausalLMOutputWithPast(
loss=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def get_shared_lib_info(lib_base_name: str):
_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None

_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_prefill_with_logits.restype = None

_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_decode_with_logits.restype = None


Expand Down

0 comments on commit 03bf892

Please sign in to comment.