Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _load_sbert_model Parameters and Fix Tokenize Padding #122

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str):
return batch


class InstructorPooling(nn.Module):
class INSTRUCTOR_Pooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.

Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
Expand Down Expand Up @@ -245,7 +245,7 @@ def load(input_path):
) as config_file:
config = json.load(config_file)

return InstructorPooling(**config)
return INSTRUCTOR_Pooling(**config)


def import_from_string(dotted_path):
Expand All @@ -271,7 +271,7 @@ def import_from_string(dotted_path):
raise ImportError(msg)


class InstructorTransformer(Transformer):
class INSTRUCTORTransformer(Transformer):
def __init__(
self,
model_name_or_path: str,
Expand Down Expand Up @@ -378,7 +378,7 @@ def load(input_path: str):

with open(sbert_config_path, encoding="UTF-8") as config_file:
config = json.load(config_file)
return InstructorTransformer(model_name_or_path=input_path, **config)
return INSTRUCTORTransformer(model_name_or_path=input_path, **config)

def tokenize(self, texts):
"""
Expand All @@ -395,7 +395,7 @@ def tokenize(self, texts):

input_features = self.tokenizer(
*to_tokenize,
padding="max_length",
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_seq_length,
Expand All @@ -420,7 +420,7 @@ def tokenize(self, texts):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
else:
Expand All @@ -430,7 +430,7 @@ def tokenize(self, texts):
return output


class Instructor(SentenceTransformer):
class INSTRUCTOR(SentenceTransformer):
@staticmethod
def prepare_input_features(
input_features, instruction_features, return_data_type: str = "pt"
Expand Down Expand Up @@ -510,27 +510,39 @@ def smart_batching_collate(self, batch):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
batched_input_features.append(input_features)

return batched_input_features, labels

def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False):
def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False, local_files_only=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None):
"""
Loads a full sentence-transformers model
"""
# Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
model_path = snapshot_download(**download_kwargs)
# copied from https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L559
# because we need to get files outside of the allow_patterns too
# If file is local
if os.path.isdir(model_path):
model_path = str(model_path)
else:
# If model_path is a Hugging Face repository ID, download the model
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "InstructorEmbedding",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
# Try to download from the remote
try:
model_path = snapshot_download(**download_kwargs)
except Exception:
# Otherwise, try local (i.e. cache) only
download_kwargs["local_files_only"] = True
model_path = snapshot_download(**download_kwargs)

# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(
Expand Down Expand Up @@ -559,9 +571,9 @@ def _load_sbert_model(self, model_path, token = None, cache_folder = None, revis
modules = OrderedDict()
for module_config in modules_config:
if module_config["idx"] == 0:
module_class = InstructorTransformer
module_class = INSTRUCTORTransformer
elif module_config["idx"] == 1:
module_class = InstructorPooling
module_class = INSTRUCTOR_Pooling
else:
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
Expand Down Expand Up @@ -619,7 +631,7 @@ def encode(
input_was_string = True

if device is None:
device = self._target_device
device = self.device

self.to(device)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy
requests>=2.26.0
scikit_learn>=1.0.2
scipy
sentence_transformers>=2.2.0
sentence_transformers>=2.3.0
torch
tqdm
rich
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name='InstructorEmbedding',
packages=['InstructorEmbedding'],
version='1.0.1',
version='1.0.2',
license='Apache License 2.0',
description='Text embedding tool',
long_description=readme,
Expand Down