-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding sentence transformer layer (#254)
- Loading branch information
1 parent
815a62c
commit 1c20605
Showing
7 changed files
with
308 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
python/ml4ir/base/features/feature_fns/sentence_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import tensorflow as tf | ||
|
||
from ml4ir.base.features.feature_fns.base import BaseFeatureLayerOp | ||
from ml4ir.base.io.file_io import FileIO | ||
from ml4ir.base.model.layers.sentence_transformers import SentenceTransformerWithTokenizerLayer | ||
|
||
|
||
class SentenceTransformerWithTokenizer(BaseFeatureLayerOp): | ||
""" | ||
Converts a string tensor into embeddings using sentence transformers | ||
by first tokenizing the string tensor and then passing through the transformer model | ||
This is a wrapper around the keras model layer so that it can be used in the feature transform layer | ||
""" | ||
LAYER_NAME = "sentence_transformer_with_tokenizer" | ||
|
||
MODEL_NAME_OR_PATH = "model_name_or_path" | ||
TRAINABLE = "trainable" | ||
|
||
def __init__(self, feature_info: dict, file_io: FileIO, **kwargs): | ||
""" | ||
Initialize layer to define a query length feature transform | ||
Parameters | ||
---------- | ||
feature_info : dict | ||
Dictionary representing the configuration parameters for the specific feature from the FeatureConfig | ||
file_io : FileIO object | ||
FileIO handler object for reading and writing | ||
Notes | ||
----- | ||
Args under feature_layer_info: | ||
model_name_or_path: str | ||
Name or path to the sentence transformer embedding model | ||
finetune_model: bool | ||
Finetune the pretrained embedding model | ||
""" | ||
super().__init__(feature_info=feature_info, file_io=file_io, **kwargs) | ||
|
||
self.sentence_transformer_with_tokenizer_op = SentenceTransformerWithTokenizerLayer( | ||
model_name_or_path=self.feature_layer_args.get(self.MODEL_NAME_OR_PATH, "intfloat/e5-base"), | ||
trainable=self.feature_layer_args.get(self.TRAINABLE, False) | ||
) | ||
|
||
def call(self, inputs, training=None): | ||
""" | ||
Defines the forward pass for the layer on the inputs tensor | ||
Parameters | ||
---------- | ||
inputs: tensor | ||
Input tensor on which the feature transforms are applied | ||
training: boolean | ||
Boolean flag indicating if the layer is being used in training mode or not | ||
Returns | ||
------- | ||
tf.Tensor | ||
Resulting tensor after the forward pass through the feature transform layer | ||
""" | ||
return self.sentence_transformer_with_tokenizer_op(inputs, training=training) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
python/ml4ir/base/model/layers/sentence_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import json | ||
from pathlib import Path | ||
|
||
import tensorflow as tf | ||
import torch | ||
from sentence_transformers.models import Dense as SentenceTransformersDense | ||
from tensorflow.keras.layers import Layer, Dense | ||
from transformers import TFAutoModel, TFBertTokenizer | ||
from sentence_transformers import SentenceTransformer | ||
|
||
# NOTE: We set device CPU for the torch backend so that the sentence-transformers model does not use GPU resources | ||
torch.device("cpu") | ||
|
||
|
||
class SentenceTransformerLayerKey: | ||
"""Stores the names of the sentence-transformer model layers""" | ||
TRANSFORMER = "sentence_transformers.models.Transformer" | ||
POOLING = "sentence_transformers.models.Pooling" | ||
DENSE = "sentence_transformers.models.Dense" | ||
NORMALIZE = "sentence_transformers.models.Normalize" | ||
|
||
|
||
class SentenceTransformerWithTokenizerLayer(Layer): | ||
""" | ||
Converts a string tensor into embeddings using sentence transformers | ||
by first tokenizing the string tensor and then passing through the transformer model | ||
Some of this code is inspired from -> https://www.philschmid.de/tensorflow-sentence-transformers | ||
""" | ||
|
||
def __init__(self, | ||
name="sentence_transformer", | ||
model_name_or_path: str = "intfloat/e5-base", | ||
trainable: bool = False, | ||
**kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
name: str | ||
Layer name | ||
model_name_or_path: str | ||
Name or path to the sentence transformer embedding model | ||
trainable: bool | ||
Finetune the pretrained embedding model | ||
kwargs: | ||
Additional key-value args that will be used for configuring the layer | ||
""" | ||
super().__init__(name=name, **kwargs) | ||
|
||
self.model_name_or_path = Path(model_name_or_path) | ||
if not Path(model_name_or_path).exists(): # The user provided a model name | ||
# If the sentence_transformer model files are not present, we initialize it to trigger a download | ||
st_model = SentenceTransformer(model_name_or_path) | ||
|
||
self.model_name_or_path = Path( | ||
torch.hub._get_torch_home()) / "sentence_transformers" / model_name_or_path.replace("/", "_") | ||
if not self.model_name_or_path.exists(): | ||
raise FileNotFoundError( | ||
f"{self.model_name_or_path} does not exist. Verify the `model_name_or_path` argument") | ||
|
||
del st_model | ||
|
||
# Load the modules.json config to add custom model layers | ||
self.modules = json.load(open(self.model_name_or_path / "modules.json")) | ||
self.module_names = [module["name"] for module in self.modules] | ||
|
||
self.trainable = trainable | ||
|
||
# Define tokenizer as part of the tensorflow graph | ||
self.tokenizer = TFBertTokenizer.from_pretrained(self.model_name_or_path, **kwargs) | ||
|
||
self.transformer_model = None | ||
self.dense = None | ||
self.apply_dense = False | ||
self.normalize_embeddings = False | ||
self.pool_embeddings = False | ||
for module in self.modules: | ||
# Define the transformer model and initialize pretrained weights | ||
if module["type"] == SentenceTransformerLayerKey.TRANSFORMER: | ||
self.transformer_model = TFAutoModel.from_pretrained(self.model_name_or_path, | ||
from_pt=True, | ||
trainable=self.trainable, | ||
**kwargs) | ||
|
||
# Define mean pooling op | ||
if module["type"] == SentenceTransformerLayerKey.POOLING: | ||
self.pool_embeddings = True | ||
|
||
# Define normalization op | ||
if module["type"] == SentenceTransformerLayerKey.NORMALIZE: | ||
self.normalize_embeddings = True | ||
|
||
# Define dense layer if present in the model and initialize pretrained weights | ||
if module["type"] == SentenceTransformerLayerKey.DENSE: | ||
self.apply_dense = True | ||
st_dense = SentenceTransformersDense.load(self.model_name_or_path / module["path"]) | ||
self.dense = Dense(units=st_dense.get_config_dict()["out_features"], | ||
kernel_initializer=tf.keras.initializers.Constant( | ||
st_dense.state_dict()["linear.weight"].T), | ||
bias_initializer=tf.keras.initializers.Constant( | ||
st_dense.state_dict()["linear.bias"]), | ||
activation=st_dense.get_config_dict()["activation_function"].split(".")[-1].lower(), | ||
trainable=self.trainable) | ||
del st_dense | ||
|
||
@classmethod | ||
def mean_pooling(cls, token_embeddings, attention_mask): | ||
"""Mean pool the token embeddings with the attention mask to generate the embeddings""" | ||
input_mask_expanded = tf.cast( | ||
tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)), | ||
tf.float32 | ||
) | ||
|
||
embeddings_sum = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) | ||
embeddings_count = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max) | ||
|
||
return tf.math.divide_no_nan(embeddings_sum, embeddings_count) | ||
|
||
@classmethod | ||
def normalize(cls, embeddings): | ||
"""Normalize sentence embeddings""" | ||
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1) | ||
return embeddings | ||
|
||
def call(self, inputs, training=None): | ||
""" | ||
Defines the forward pass for the layer on the inputs tensor | ||
Parameters | ||
---------- | ||
inputs: tensor | ||
Input tensor on which the feature transforms are applied | ||
training: boolean | ||
Boolean flag indicating if the layer is being used in training mode or not | ||
Returns | ||
------- | ||
tf.Tensor | ||
Resulting tensor after the forward pass through the feature transform layer | ||
""" | ||
# Tokenize string tensors | ||
tokens = self.tokenizer(inputs) | ||
# NOTE: TFDistilBertModel does not expect the token_type_ids key | ||
tokens.pop("token_type_ids", None) | ||
|
||
# Apply the modules as configured | ||
embeddings = self.transformer_model(tokens, training=training) | ||
|
||
if self.pool_embeddings: | ||
embeddings = self.mean_pooling(embeddings[0], tokens["attention_mask"]) | ||
|
||
if self.apply_dense: | ||
embeddings = self.dense(embeddings, training=training) | ||
|
||
if self.normalize_embeddings: | ||
embeddings = self.normalize(embeddings) | ||
|
||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from ml4ir.base.model.layers.sentence_transformers import SentenceTransformerWithTokenizerLayer | ||
|
||
|
||
class TestSentenceTransformerWithTokenizerLayer(unittest.TestCase): | ||
TEST_PHRASES = ["test query to test the embedding layer", | ||
"Another test query which does more testing!"] | ||
|
||
def test_e5_base(self): | ||
model = SentenceTransformerWithTokenizerLayer(model_name_or_path="intfloat/e5-base") | ||
embeddings = model(self.TEST_PHRASES).numpy() | ||
|
||
self.assertEqual(embeddings.shape, (2, 768)) | ||
self.assertTrue(np.allclose(embeddings[0, :5], [-0.01958332, 0.02002536, 0.00893079, -0.02941261, 0.06580967])) | ||
self.assertTrue(np.allclose(embeddings[1, :5], [-0.0034735, 0.04219092, -0.00087385, -0.0156969, 0.06526384])) | ||
|
||
def test_distiluse(self): | ||
model = SentenceTransformerWithTokenizerLayer( | ||
model_name_or_path="sentence-transformers/distiluse-base-multilingual-cased-v1") | ||
embeddings = model(self.TEST_PHRASES).numpy() | ||
|
||
self.assertEqual(embeddings.shape, (2, 512)) | ||
self.assertTrue(np.allclose(embeddings[0, :5], [0.00174321, 0.01326918, -0.01836516, 0.05429131, 0.06062959])) | ||
self.assertTrue( | ||
np.allclose(embeddings[1, :5], [0.03018673, -0.00636012, -0.00495617, -0.04597681, -0.05619023])) | ||
|
||
def test_e5_base_with_sentence_transformers(self): | ||
model = SentenceTransformerWithTokenizerLayer(model_name_or_path="intfloat/e5-base") | ||
embeddings = model(self.TEST_PHRASES).numpy() | ||
|
||
st_model = SentenceTransformer("intfloat/e5-base") | ||
st_embeddings = st_model.encode(self.TEST_PHRASES) | ||
|
||
self.assertTrue(np.allclose(embeddings, st_embeddings, atol=1e-5)) | ||
|
||
def test_distiluse_with_sentence_transformers(self): | ||
model = SentenceTransformerWithTokenizerLayer( | ||
model_name_or_path="sentence-transformers/distiluse-base-multilingual-cased-v1") | ||
embeddings = model(self.TEST_PHRASES).numpy() | ||
|
||
st_model = SentenceTransformer("distiluse-base-multilingual-cased-v1") | ||
st_embeddings = st_model.encode(self.TEST_PHRASES) | ||
|
||
self.assertTrue(np.allclose(embeddings, st_embeddings, atol=1e-5)) | ||
|
||
def test_trainable(self): | ||
model = SentenceTransformerWithTokenizerLayer( | ||
model_name_or_path="sentence-transformers/distiluse-base-multilingual-cased-v1", | ||
trainable=True) | ||
model(self.TEST_PHRASES) | ||
self.assertTrue(model.trainable) | ||
self.assertTrue(model.transformer_model.trainable) | ||
self.assertTrue(model.dense.trainable) | ||
self.assertTrue(len(model.transformer_model.trainable_weights) > 0) | ||
self.assertTrue(len(model.dense.trainable_weights) > 0) | ||
|
||
model = SentenceTransformerWithTokenizerLayer( | ||
model_name_or_path="sentence-transformers/distiluse-base-multilingual-cased-v1", | ||
trainable=False) | ||
model(self.TEST_PHRASES) | ||
self.assertFalse(model.trainable) | ||
self.assertFalse(model.transformer_model.trainable) | ||
self.assertFalse(model.dense.trainable) | ||
self.assertTrue(len(model.transformer_model.trainable_weights) == 0) | ||
self.assertTrue(len(model.dense.trainable_weights) == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters