-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed switched token_type_ids and attention_mask #412
base: main
Are you sure you want to change the base?
Conversation
Just noticed that the exported onnx model does only work if we switch the attention_mask and token_type_ids of the generated dcitionary after tokenization, which is probably caused by my change. I will investigate further and report back soon. |
I reverted my previous changes and implemented the fix directly in function |
I forgot to include the import Edit: okay, there are still errors, I'll analyse and address them soon! |
Tested with Code snippet for reproductibility: from setfit import SetFitModel
from sentence_transformers import SentenceTransformer
from setfit import SetFitHead, SetFitHead, SetFitModel
from setfit.exporters.onnx import export_onnx
model_id = "sentence-transformers/distiluse-base-multilingual-cased-v2"
model_body = SentenceTransformer(model_id)
model_head = SetFitHead(in_features = model_body.get_sentence_embedding_dimension(), out_features = 4)
model = SetFitModel(model_body = model_body, model_head = model_head)
export_onnx(model.model_body,
model.model_head,
opset=12,
output_path="dummy_path")
|
Perhaps we can adopt the approach from #435 for ONNX, rather than sticking with the current |
I was having the same error as mentioned in #338 where I could not export my model with model_base
stsb-xlm-roberta-base
. After some debugging, I noticed that the attention_mask and token_type_ids were switched in the functionforward
(line 50) insetfit/exporters/onnx.py
. The error then occurs because we are trying to look up both the token_type_id embedding with index 0 and the one with index 1, but there is only one embedding in the matrix. I believe that this did not happen with other model bases because they have more than two token_type embeddings.However, I must confess that I was not yet able to test this fix with other models that previously worked. We should definitely do this before we merge this code. To make the code safer, I also made us of kwargs when calling
self.model_body
instead of positional arguments. In my case, I was able to export the model after this small fix.