diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index fba9ceb..698fe0f 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -17,6 +17,7 @@ class DPSentenceTransformer(SentenceTransformer): + """SentenceBERT with pytorch torch.nn.DataParallel""" def __init__(self, sbert_model: SentenceTransformer): super(DPSentenceTransformer, self).__init__() @@ -209,6 +210,7 @@ def _encode_with_auto_batch_size(batch_size, self, text, prefix): batch_size=batch_size, normalize_embeddings=self.normalize_embeddings, ) + self.batch_size = batch_size return out