@@ -60,8 +60,11 @@ def get_shared_lib_info(lib_base_name: str):
60
60
_lib .reset .argtypes = [ctypes .c_void_p ]
61
61
_lib .reset .restype = None
62
62
63
- _lib .get_logits .argtypes = [ctypes .c_void_p , ctypes .POINTER (ctypes .c_float )]
64
- _lib .reset .restype = None
63
+ _lib .run_prefill_with_logits .argtypes = [ctypes .c_void_p , ctypes .POINTER (ctypes .c_int ), ctypes .c_int , ctypes .POINTER (ctypes .c_float ), ctypes .c_int ]
64
+ _lib .run_prefill_with_logits .restype = None
65
+
66
+ _lib .run_decode_with_logits .argtypes = [ctypes .c_void_p , ctypes .c_int , ctypes .POINTER (ctypes .c_float ), ctypes .c_int ]
67
+ _lib .run_decode_with_logits .restype = None
65
68
66
69
67
70
def load_model_from_file (model_dir : str ):
@@ -82,12 +85,21 @@ def run_decode(model_ptr, input_id, vocab_size):
82
85
return new_token
83
86
84
87
85
- def reset (model_ptr ):
86
- _lib .reset (model_ptr )
88
+ def run_prefill_with_logits (model_ptr , input_ids , logits , vocab_size ):
89
+ input_ptr = (ctypes .c_int32 * len (input_ids ))(* input_ids )
90
+ input_len = len (input_ids )
91
+ logits_ptr = logits .data .data_ptr ()
92
+ logits_ptr = ctypes .cast (logits_ptr , ctypes .POINTER (ctypes .c_float ))
93
+ _lib .run_prefill_with_logits (model_ptr , input_ptr , input_len , logits_ptr , vocab_size )
94
+ return logits
87
95
88
96
89
- def get_logits (model_ptr , logits ):
90
- src = logits .data .data_ptr ()
91
- src = ctypes .cast (src , ctypes .POINTER (ctypes .c_float ))
92
- _lib .get_logits (model_ptr , src )
97
+ def run_decode_with_logits (model_ptr , input_id , logits , vocab_size ):
98
+ logits_ptr = logits .data .data_ptr ()
99
+ logits_ptr = ctypes .cast (logits_ptr , ctypes .POINTER (ctypes .c_float ))
100
+ _lib .run_decode_with_logits (model_ptr , input_id , logits_ptr , vocab_size )
93
101
return logits
102
+
103
+
104
+ def reset (model_ptr ):
105
+ _lib .reset (model_ptr )
0 commit comments