Skip to content

Commit 97f7a22

Browse files
authored
add trust_remote_code in tokenizer to fix baichuan issue (#2725)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent b1f9044 commit 97f7a22

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

router/src/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub enum Tokenizer {
2727
Python {
2828
tokenizer_name: String,
2929
revision: Option<String>,
30+
trust_remote_code: bool,
3031
},
3132
Rust(tokenizers::Tokenizer),
3233
}
@@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> {
3839
py: Python<'a>,
3940
tokenizer_name: String,
4041
revision: Option<String>,
42+
trust_remote_code: bool,
4143
) -> PyResult<PyTokenizer<'a>> {
4244
let transformers = py.import_bound("transformers")?;
4345
let auto = transformers.getattr("AutoTokenizer")?;
4446
let from_pretrained = auto.getattr("from_pretrained")?;
4547
let args = (tokenizer_name,);
4648
let kwargs = if let Some(rev) = &revision {
47-
[("revision", rev.to_string())].into_py_dict_bound(py)
49+
[
50+
("revision", rev.to_string().into_py(py)),
51+
("trust_remote_code", trust_remote_code.into_py(py)),
52+
]
53+
.into_py_dict_bound(py)
4854
} else {
49-
pyo3::types::PyDict::new_bound(py)
55+
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
5056
};
5157
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
5258
tracing::info!("Loaded a python tokenizer");

router/src/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,7 @@ pub async fn run(
18291829
Tokenizer::Python {
18301830
tokenizer_name: tokenizer_name.clone(),
18311831
revision: revision.clone(),
1832+
trust_remote_code,
18321833
}
18331834
}
18341835
};

router/src/validation.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,11 @@ fn tokenizer_worker(
439439
Tokenizer::Python {
440440
tokenizer_name,
441441
revision,
442+
trust_remote_code,
442443
} => {
443444
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
444-
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
445+
let tokenizer =
446+
PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;
445447
// Loop over requests
446448
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
447449
receiver.blocking_recv()

0 commit comments

Comments
 (0)