12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import six
16
- import os
17
15
import math
18
- import numpy as np
19
- import onnxruntime as ort
16
+ import os
17
+ import re
20
18
21
- import paddle
19
+ import onnxruntime as ort
22
20
import paddle2onnx
21
+ import six
23
22
24
23
from paddlenlp .transformers import AutoTokenizer
25
24
from paddlenlp .utils .tools import get_bool_ids_greater_than , get_span
@@ -45,8 +44,8 @@ def __init__(self, model_path_prefix, device="cpu", use_fp16=False, device_id=0)
45
44
print (">>> [InferBackend] Use GPU to inference ..." )
46
45
if use_fp16 :
47
46
print (">>> [InferBackend] Use FP16 to inference ..." )
48
- from onnxconverter_common import float16
49
47
import onnx
48
+ from onnxconverter_common import float16
50
49
51
50
fp16_model_file = os .path .join (infer_model_dir , "fp16_model.onnx" )
52
51
onnx_model = onnx .load_model (float_onnx_file )
@@ -62,7 +61,7 @@ def __init__(self, model_path_prefix, device="cpu", use_fp16=False, device_id=0)
62
61
self .predictor = ort .InferenceSession (onnx_model , sess_options = sess_options , providers = providers )
63
62
if device == "gpu" :
64
63
assert "CUDAExecutionProvider" in self .predictor .get_providers (), (
65
- f "The environment for GPU inference is not set properly. "
64
+ "The environment for GPU inference is not set properly. "
66
65
"A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. "
67
66
"Please run the following commands to reinstall: \n "
68
67
"1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu"
@@ -87,6 +86,7 @@ def __init__(self, args):
87
86
self ._position_prob = args .position_prob
88
87
self ._max_seq_len = args .max_seq_len
89
88
self ._batch_size = args .batch_size
89
+ self ._multilingual = args .multilingual
90
90
self ._schema_tree = None
91
91
self .set_schema (args .schema )
92
92
if args .device == "cpu" :
@@ -167,12 +167,18 @@ def _single_stage_predict(self, inputs):
167
167
end_probs = []
168
168
for idx in range (0 , len (texts ), self ._batch_size ):
169
169
l , r = idx , idx + self ._batch_size
170
- input_dict = {
171
- "input_ids" : encoded_inputs ["input_ids" ][l :r ].astype ("int64" ),
172
- "token_type_ids" : encoded_inputs ["token_type_ids" ][l :r ].astype ("int64" ),
173
- "pos_ids" : encoded_inputs ["position_ids" ][l :r ].astype ("int64" ),
174
- "att_mask" : encoded_inputs ["attention_mask" ][l :r ].astype ("int64" ),
175
- }
170
+ if self ._multilingual :
171
+ input_dict = {
172
+ "input_ids" : encoded_inputs ["input_ids" ][l :r ].astype ("int64" ),
173
+ "position_ids" : encoded_inputs ["position_ids" ][l :r ].astype ("int64" ),
174
+ }
175
+ else :
176
+ input_dict = {
177
+ "input_ids" : encoded_inputs ["input_ids" ][l :r ].astype ("int64" ),
178
+ "token_type_ids" : encoded_inputs ["token_type_ids" ][l :r ].astype ("int64" ),
179
+ "position_ids" : encoded_inputs ["position_ids" ][l :r ].astype ("int64" ),
180
+ "attention_mask" : encoded_inputs ["attention_mask" ][l :r ].astype ("int64" ),
181
+ }
176
182
start_prob , end_prob = self ._infer (input_dict )
177
183
start_prob = start_prob .tolist ()
178
184
end_prob = end_prob .tolist ()
0 commit comments