Skip to content

Commit

Permalink
Merge pull request #33171 from vespa-engine/arnej/detect-no-token-typ…
Browse files Browse the repository at this point in the history
…e-ids

detect if model does not use token_type_ids
  • Loading branch information
arnej27959 authored Jan 25, 2025
2 parents effc2d1 + 2e88fc4 commit 1fc09f8
Showing 1 changed file with 16 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
this.runtime = runtime;
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
prependQuery = config.prependQuery();
Expand All @@ -75,15 +74,29 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
tokenTypeIdsName = detectTokenTypeIds(config, evaluator);
validateModel();
}

private static String detectTokenTypeIds(HuggingFaceEmbedderConfig config, OnnxEvaluator evaluator) {
String configured = config.transformerTokenTypeIds();
Map<String, TensorType> inputs = evaluator.getInputInfo();
if (inputs.size() < 3) {
// newer models have only 2 inputs (they do not use token type IDs)
return "";
} else {
// could detect fallback from inputs here, currently set as default in .def file
return configured;
}
}

private void validateModel() {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
validateName(inputs, attentionMaskName, "input");
if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input");

if (!tokenTypeIdsName.isEmpty()) {
validateName(inputs, tokenTypeIdsName, "input");
}
Map<String, TensorType> outputs = evaluator.getOutputInfo();
validateName(outputs, outputName, "output");
}
Expand Down Expand Up @@ -250,4 +263,3 @@ protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, S
protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { }

}

0 comments on commit 1fc09f8

Please sign in to comment.