Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update instructor for sentence transformers #129

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 67 additions & 146 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# This script is based on the modifications from https://github.com/UKPLab/sentence-transformers
import importlib
import json
import os
Expand All @@ -24,24 +23,6 @@ def batch_to_device(batch, target_device: str):


class INSTRUCTORPooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.

Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
You can concatenate multiple poolings together.

:param word_embedding_dimension: Dimensions for the word embeddings
:param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings
:param pooling_mode_cls_token: Use the first token (CLS token) as text representations
:param pooling_mode_max_tokens: Use max in each dimension over all tokens.
:param pooling_mode_mean_tokens: Perform mean-pooling
:param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but divide by sqrt(input_length).
:param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling,
see https://arxiv.org/abs/2202.08904
:param pooling_mode_lasttoken: Perform last token pooling,
see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005
"""

def __init__(
self,
word_embedding_dimension: int,
Expand All @@ -65,7 +46,7 @@ def __init__(
"pooling_mode_lasttoken",
]

if pooling_mode is not None: # Set pooling mode by string
if pooling_mode is not None:
pooling_mode = pooling_mode.lower()
assert pooling_mode in ["mean", "max", "cls", "weightedmean", "lasttoken"]
pooling_mode_cls_token = pooling_mode == "cls"
Expand Down Expand Up @@ -100,9 +81,6 @@ def __repr__(self):
return f"Pooling({self.get_config_dict()})"

def get_pooling_mode_str(self) -> str:
"""
Returns the pooling mode as string
"""
modes = []
if self.pooling_mode_cls_token:
modes.append("cls")
Expand All @@ -120,24 +98,22 @@ def get_pooling_mode_str(self) -> str:
return "+".join(modes)

def forward(self, features):
# print(features.keys())
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"]

## Pooling strategy
output_vectors = []
if self.pooling_mode_cls_token:
cls_token = features.get(
"cls_token_embeddings", token_embeddings[:, 0]
) # Take first token by default
)
output_vectors.append(cls_token)
if self.pooling_mode_max_tokens:
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
token_embeddings[
input_mask_expanded == 0
] = -1e9 # Set padding tokens to large negative value
] = -1e9
max_over_time = torch.max(token_embeddings, 1)[0]
output_vectors.append(max_over_time)
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
Expand All @@ -146,7 +122,6 @@ def forward(self, features):
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

# If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = (
features["token_weights_sum"]
Expand All @@ -166,7 +141,6 @@ def forward(self, features):
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
# token_embeddings shape: bs, seq, hidden_dim
weights = (
torch.arange(start=1, end=token_embeddings.shape[1] + 1)
.unsqueeze(0)
Expand All @@ -180,7 +154,6 @@ def forward(self, features):

sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

# If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = (
features["token_weights_sum"]
Expand All @@ -194,26 +167,16 @@ def forward(self, features):
output_vectors.append(sum_embeddings / sum_mask)
if self.pooling_mode_lasttoken:
batch_size, _, hidden_dim = token_embeddings.shape
# attention_mask shape: (bs, seq_len)
# Get shape [bs] indices of the last token (i.e. the last token for each batch item)
# argmin gives us the index of the first 0 in the attention mask;
# We get the last 1 index by subtracting 1
gather_indices = (
torch.argmin(attention_mask, 1, keepdim=False) - 1
) # Shape [bs]
)

# There are empty sequences, where the index would become -1 which will crash
gather_indices = torch.clamp(gather_indices, min=0)

# Turn indices from shape [bs] --> [bs, 1, hidden_dim]
gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
gather_indices = gather_indices.unsqueeze(1)
assert gather_indices.shape == (batch_size, 1, hidden_dim)

# Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
# Actually no need for the attention mask as we gather the last token where attn_mask = 1
# but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
# use the attention mask to ignore them again
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
Expand Down Expand Up @@ -249,10 +212,6 @@ def load(input_path):


def import_from_string(dotted_path):
"""
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError:
Expand Down Expand Up @@ -307,6 +266,9 @@ def __init__(
)

if load_model:
import inspect
if 'backend' in inspect.signature(self._load_model).parameters:
model_args['backend'] = 'torch'
self._load_model(self.model_name_or_path, config, cache_dir, **model_args)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path
Expand Down Expand Up @@ -353,7 +315,7 @@ def forward(self, features):
all_layer_idx = 2
if (
len(output_states) < 3
): # Some models only output last_hidden_states and all_hidden_states
):
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
features.update({"all_layer_embeddings": hidden_states})
Expand All @@ -362,7 +324,6 @@ def forward(self, features):

@staticmethod
def load(input_path: str):
# Old classes used other config names than 'sentence_bert_config.json'
for config_name in [
"sentence_bert_config.json",
"sentence_roberta_config.json",
Expand All @@ -381,15 +342,11 @@ def load(input_path: str):
return INSTRUCTORTransformer(model_name_or_path=input_path, **config)

def tokenize(self, texts):
"""
Tokenizes a text and maps tokens to token-ids
"""
output = {}
if isinstance(texts[0], str):
to_tokenize = [texts]
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]

# Lowercase
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]

Expand Down Expand Up @@ -446,30 +403,15 @@ def prepare_input_features(
input_attention_mask_shape = input_features["attention_mask"].shape
instruction_attention_mask = instruction_features["attention_mask"]

# reducing the attention length by 1 in order to omit the attention corresponding to the end_token
instruction_attention_mask = instruction_attention_mask[:, 1:]

# creating instruction attention matrix equivalent to the size of the input attention matrix
expanded_instruction_attention_mask = torch.zeros(
input_attention_mask_shape, dtype=torch.int64
)
# assigning the the actual instruction attention matrix to the expanded_instruction_attention_mask
# eg:
# instruction_attention_mask: 3x3
# [[1,1,1],
# [1,1,0],
# [1,0,0]]
# expanded_instruction_attention_mask: 3x4
# [[1,1,1,0],
# [1,1,0,0],
# [1,0,0,0]]
expanded_instruction_attention_mask[
: instruction_attention_mask.size(0), : instruction_attention_mask.size(1)
] = instruction_attention_mask

# In the pooling layer we want to consider only the tokens corresponding to the input text
# and not the instruction. This is achieved by inverting the
# attention_mask corresponding to the instruction.
expanded_instruction_attention_mask = 1 - expanded_instruction_attention_mask
input_features["instruction_mask"] = expanded_instruction_attention_mask
if return_data_type == "np":
Expand Down Expand Up @@ -517,62 +459,62 @@ def smart_batching_collate(self, batch):

return batched_input_features, labels

def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False):
"""
Loads a full sentence-transformers model
"""
# copied from https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L559
# because we need to get files outside of the allow_patterns too
# If file is local
if os.path.isdir(model_path):
model_path = str(model_path)
else:
# If model_path is a Hugging Face repository ID, download the model
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "InstructorEmbedding",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}

# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(
model_path, "config_sentence_transformers.json"
)
if os.path.exists(config_sentence_transformers_json_path):
with open(
config_sentence_transformers_json_path, encoding="UTF-8"
) as config_file:
self._model_config = json.load(config_file)

# Check if a readme exists
model_card_path = os.path.join(model_path, "README.md")
if os.path.exists(model_card_path):
try:
with open(model_card_path, encoding="utf8") as config_file:
self._model_card_text = config_file.read()
except:
pass

# Load the modules of sentence transformer
modules_json_path = os.path.join(model_path, "modules.json")
with open(modules_json_path, encoding="UTF-8") as config_file:
modules_config = json.load(config_file)

modules = OrderedDict()
for module_config in modules_config:
if module_config["idx"] == 0:
module_class = INSTRUCTORTransformer
elif module_config["idx"] == 1:
module_class = INSTRUCTORPooling
else:
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
modules[module_config["name"]] = module

return modules
def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False, local_files_only=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None):
import inspect
base_signature = inspect.signature(SentenceTransformer.__init__)

if os.path.isdir(model_path):
model_path = str(model_path)
else:
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
"local_files_only": local_files_only,
}
model_path = snapshot_download(**download_kwargs)

config_sentence_transformers_json_path = os.path.join(
model_path, "config_sentence_transformers.json"
)
if os.path.exists(config_sentence_transformers_json_path):
with open(
config_sentence_transformers_json_path, encoding="UTF-8"
) as config_file:
self._model_config = json.load(config_file)

model_card_path = os.path.join(model_path, "README.md")
if os.path.exists(model_card_path):
try:
with open(model_card_path, encoding="utf8") as config_file:
self._model_card_text = config_file.read()
except:
pass

modules_json_path = os.path.join(model_path, "modules.json")
with open(modules_json_path, encoding="UTF-8") as config_file:
modules_config = json.load(config_file)

modules = OrderedDict()
if 'backend' in base_signature.parameters:
module_kwargs = {}

for module_config in modules_config:
if module_config["idx"] == 0:
module_class = INSTRUCTORTransformer
elif module_config["idx"] == 1:
module_class = INSTRUCTORPooling
else:
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
modules[module_config["name"]] = module

if 'backend' in base_signature.parameters:
return modules, module_kwargs
return modules

def encode(
self,
Expand All @@ -585,26 +527,6 @@ def encode(
device: Union[str, None] = None,
normalize_embeddings: bool = False,
):
"""
Computes sentence embeddings

:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings.
Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors.
Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return.
Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1.
In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

:return:
By default, a list of tensors is returned. If convert_to_tensor,
a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
"""
self.eval()
if show_progress_bar is None:
show_progress_bar = False
Expand All @@ -619,7 +541,7 @@ def encode(
input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
):
sentences = [sentences]
input_was_string = True

Expand Down Expand Up @@ -660,22 +582,21 @@ def encode(
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
elif output_value is None:
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {
name: out_features[name][sent_idx] for name in out_features
}
embeddings.append(row)
else: # Sentence embeddings
else:
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(
embeddings, p=2, dim=1
)

# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.cpu()

Expand Down
Loading