Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

query embeddings using glove #273

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions python/ml4ir/applications/ranking/features/feature_fns/string.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import string
import re
import tensorflow as tf
import numpy as np
from nltk.corpus import stopwords

from ml4ir.base.features.feature_fns.base import BaseFeatureLayerOp
from ml4ir.base.io.file_io import FileIO
Expand All @@ -15,6 +17,8 @@ class QueryLength(BaseFeatureLayerOp):

TOKENIZE = "tokenize"
SEPARATOR = "sep"
ONE_HOT_VECTOR = "one_hot_vector"
MAX_LENGTH = "max_length" # define max length for one hot encoding.

def __init__(self, feature_info: dict, file_io: FileIO, **kwargs):
"""
Expand Down Expand Up @@ -42,6 +46,8 @@ def __init__(self, feature_info: dict, file_io: FileIO, **kwargs):

self.tokenize = self.feature_layer_args.get(self.TOKENIZE, True)
self.sep = self.feature_layer_args.get(self.SEPARATOR, " ")
self.one_hot = self.feature_layer_args.get(self.ONE_HOT_VECTOR, False)
self.max_length = self.feature_layer_args.get(self.MAX_LENGTH, 10)

def call(self, inputs, training=None):
"""
Expand All @@ -64,8 +70,15 @@ def call(self, inputs, training=None):
else:
query_len = tf.strings.length(inputs)

query_len = tf.expand_dims(tf.cast(query_len, tf.float32), axis=-1)
return query_len
if self.one_hot:
# Clip the query lengths to the max_length
query_len = tf.clip_by_value(query_len, 0, self.max_length)
# Convert to one-hot encoding
query_len_one_hot = tf.one_hot(query_len, depth=self.max_length + 1)
return query_len_one_hot
else:
query_len = tf.expand_dims(tf.cast(query_len, tf.float32), axis=-1)
return query_len


class QueryTypeVector(BaseFeatureLayerOp):
Expand Down Expand Up @@ -158,3 +171,172 @@ def call(self, inputs, training=None):
query_type_vector = self.categorical_vector_op(query_type, training=training)

return query_type_vector


class QueryEmbeddingVector(BaseFeatureLayerOp):
mohazahran marked this conversation as resolved.
Show resolved Hide resolved
"""
A feature layer operation to define a query embedding vectorizer using pre-trained word embeddings.

Attributes
----------
LAYER_NAME : str
Name of the layer, set to "query_embedding_vector".
feature_info : dict
Configuration parameters for the specific feature from the FeatureConfig.
file_io : FileIO
FileIO handler object for reading and writing.
embedding_size : int
Dimension size of categorical embedding.
glove_path : str
Path to the pre-trained GloVe embeddings file.
max_entries : int
Maximum number of entries to load from the GloVe embeddings file.
stop_words : set
Set of English stopwords.
word_vectors : dict
Dictionary to store word embeddings.
embedding_dim : int
Dimension of the embeddings.
"""

LAYER_NAME = "query_embedding_vector"

def __init__(self, feature_info: dict, file_io: FileIO, **kwargs):
"""
Initialize layer to define a query embedding vectorizer using pre-trained word embeddings.

Parameters
----------
feature_info : dict
Dictionary representing the configuration parameters for the specific feature from the FeatureConfig.
file_io : FileIO
FileIO handler object for reading and writing.

Notes
-----
Args under feature_layer_info:
remove_quotes : string
Whether to remove quotes from the string tensors. Defaults to true.
output_mode : str
The type of vector representation to compute. Currently supports either embedding or one_hot.
embedding_size : int
Dimension size of categorical embedding.
glove_path : str
Path to the pre-trained GloVe embeddings file.
"""
super().__init__(feature_info=feature_info, file_io=file_io, **kwargs)
self.feature_info = feature_info
self.file_io = file_io
self.embedding_size = feature_info["feature_layer_info"]["args"]["embedding_size"]
self.glove_path = feature_info["feature_layer_info"]["args"]["glove_path"]
self.max_entries = feature_info["feature_layer_info"]["args"]["max_entries"]
self.stop_words = {"you're", 'itself', 'but', 'against', 'until', 'where', 'as', 'from', 'own', 'again', 's', "wasn't", 'about', 'out', 'his', 'an', 'those', 've', 'should', 'doing', 'ourselves', 'or', 'down', 'such', "she's", 't', 're', 'me', 'what', 'to', 'didn', "wouldn't", 'hers', 'been', 'which', 'further', 'there', "shouldn't", 'them', "couldn't", 'is', 'wouldn', 'he', 'over', "hasn't", 'their', 'after', 'during', 'few', 'up', 'ma', 'yourselves', 'i', 'themselves', "won't", 'having', "you'll", 'these', 'were', 'most', "isn't", 'how', 'ours', 'y', 'and', 'if', 'not', 'between', 'its', "that'll", 'then', 'that', 'above', 'hadn', 'can', 'each', 'aren', 'whom', 'don', 'we', 'won', 'who', 'be', 'here', 'in', 'our', 'any', 'your', 'shan', 'all', 'd', 'same', 'you', 'nor', 'theirs', 'am', 'isn', 'below', 'o', 'couldn', 'into', "hadn't", 'shouldn', 'very', 'haven', 'it', 'wasn', 'other', 'they', 'are', 'both', 'no', 'through', 'at', 'now', 'himself', 'was', 'off', 'herself', 'doesn', 'mightn', "weren't", "you've", 'too', "mustn't", 'when', 'only', 'on', 'him', 'by', 'hasn', 'once', "haven't", 'yourself', 'have', "you'd", 'a', "doesn't", 'll', 'so', "should've", 'does', 'had', 'my', 'yours', 'she', 'than', 'some', 'why', 'with', 'the', 'will', 'needn', 'did', 'mustn', "needn't", 'more', 'her', 'before', 'for', 'has', 'because', 'of', 'do', "didn't", 'myself', "mightn't", 'just', 'weren', "aren't", 'this', 'ain', "don't", 'while', 'under', 'm', 'being', "it's", "shan't"}

self.word_vectors = {}

# Load GloVe embeddings with a limit on the number of entries
with open(self.glove_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if self.max_entries is not None and idx >= self.max_entries:
break
values = line.split()
word = values[0]
word = tf.constant(word, dtype=tf.string)
vector = np.array(values[1:], dtype='float32')
self.word_vectors[str(word)] = vector

# Determine the dimension of the embeddings
self.embedding_dim = len(vector)

def preprocess_text(self, text):
"""
Preprocess the input text by converting to lowercase, removing punctuation, tokenizing,
and filtering out stopwords.

Parameters
----------
text : tf.Tensor
Input text tensor.

Returns
-------
tf.RaggedTensor
Preprocessed and tokenized text.
"""
# Convert to lowercase
text = tf.strings.lower(text)

# Remove punctuation
text = tf.strings.regex_replace(text, f"[{string.punctuation}]", " ")

# Tokenize the text
tokens = tf.strings.split(text)

# Filter out stopwords
def filter_stopwords(tokens):
return tf.ragged.boolean_mask(tokens,
~tf.reduce_any(tf.strings.regex_full_match(tokens, '|'.join(self.stop_words)),
axis=-1))

tokens = filter_stopwords(tokens)
return tokens

def word_lookup(self, word):
"""
Look up the word embedding for a given word.

Parameters
----------
word : str
The word to look up.

Returns
-------
np.ndarray
The embedding vector for the word, or a zero vector if the word is not found.
"""
return self.word_vectors.get(str(word), np.zeros((self.embedding_dim), dtype=np.float32))

def build_embeddings(self, query):
"""
Build the embedding for a given query by summing the embeddings of its words.

Parameters
----------
query : tf.Tensor
Tensor containing the words of the query.

Returns
-------
tf.Tensor
Tensor of shape (embedding_dim,) containing the summed word embeddings.
"""
if query.shape[0] == 1:
word_embeddings = tf.map_fn(lambda word: self.word_lookup(word), query[0], dtype=tf.float32)
query_embedding = tf.reduce_sum(word_embeddings, axis=0)
return query_embedding
else:
return np.zeros((self.embedding_dim), dtype=np.float32)

def call(self, queries, training=None):
"""
Defines the forward pass for the layer on the input queries tensor.

Parameters
----------
queries : tf.Tensor
Input tensor containing the queries.
training : bool, optional
Boolean flag indicating if the layer is being used in training mode or not.

Returns
-------
tf.Tensor
Resulting tensor after the forward pass through the feature transform layer.
"""
inputs = self.preprocess_text(queries)
query_embeddings = tf.map_fn(lambda query: self.build_embeddings(query), inputs.to_tensor(),
dtype=tf.float32)
query_embeddings = tf.expand_dims(query_embeddings, axis=1)
return query_embeddings

2 changes: 2 additions & 0 deletions python/ml4ir/applications/ranking/tests/test_feature_fns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import tensorflow as tf
import numpy as np
import copy
import unittest
from unittest.mock import MagicMock

from ml4ir.applications.ranking.features.feature_fns import categorical
from ml4ir.applications.ranking.features.feature_fns import normalization
Expand Down
56 changes: 56 additions & 0 deletions python/ml4ir/applications/ranking/tests/test_query_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import tensorflow as tf
import numpy as np
import unittest
from unittest.mock import MagicMock
from ml4ir.applications.ranking.features.feature_fns import string as string_transforms
from ml4ir.base.tests.test_base import RelevanceTestBase

class TestQueryEmbeddingVectorUsingGlove(RelevanceTestBase):
def setUp(self):
self.feature_info = {
"name": "default_feature",
"feature_layer_info": {
"args": {
"embedding_size": 3,
"glove_path": "mock_glove_path.txt",
"max_entries": 2
}
}
}
self.file_io = MagicMock()
glove_content = "word1 0.1 0.2 0.3\nword2 0.4 0.5 0.6\nword 0.9 0.9 0.9"
self.open_mock = unittest.mock.mock_open(read_data=glove_content)
with unittest.mock.patch('builtins.open', self.open_mock):
self.query_embedding_vector = string_transforms.QueryEmbeddingVector(self.feature_info, self.file_io)

def test_glove_file_loading(self):
word1 = tf.constant("word1", dtype=tf.string)
word2 = tf.constant("word2", dtype=tf.string)

# Assert dimensions
self.assertIn(str(word1), self.query_embedding_vector.word_vectors)
self.assertIn(str(word2), self.query_embedding_vector.word_vectors)
self.assertEqual(len(self.query_embedding_vector.word_vectors), 2)
self.assertEqual(self.query_embedding_vector.embedding_dim, 3)

def test_preprocess_text(self):
text = tf.constant(["Hello, world!"])
processed_text = self.query_embedding_vector.preprocess_text(text)
self.assertEqual(processed_text.numpy().tolist(), [[b'hello', b'world']])

def test_word_lookup_existing_word(self):
word1 = tf.constant("word1", dtype=tf.string)
embedding = self.query_embedding_vector.word_lookup(word1)
expected_embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32)
np.testing.assert_array_equal(embedding, expected_embedding)

def test_word_lookup_unknown_word(self):
embedding = self.query_embedding_vector.word_lookup("unknown_word")
expected_embedding = np.zeros((3,), dtype=np.float32)
np.testing.assert_array_equal(embedding, expected_embedding)

def test_build_embeddings(self):
query = tf.constant([["word1", "word2"]])
query_embedding = self.query_embedding_vector.build_embeddings(query)
expected_embedding = np.array([0.5, 0.7, 0.9], dtype=np.float32)
np.isclose(query_embedding.numpy(), expected_embedding).all()
5 changes: 3 additions & 2 deletions python/ml4ir/base/features/feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ml4ir.applications.ranking.features.feature_fns.categorical import CategoricalVector
from ml4ir.applications.ranking.features.feature_fns.normalization import TheoreticalMinMaxNormalization
from ml4ir.applications.ranking.features.feature_fns.rank_transform import ReciprocalRank
from ml4ir.applications.ranking.features.feature_fns.string import QueryLength, QueryTypeVector
from ml4ir.applications.ranking.features.feature_fns.string import QueryLength, QueryTypeVector, QueryEmbeddingVector


class FeatureLayerMap:
Expand Down Expand Up @@ -42,7 +42,8 @@ def __init__(self):
TheoreticalMinMaxNormalization.LAYER_NAME: TheoreticalMinMaxNormalization,
ReciprocalRank.LAYER_NAME: ReciprocalRank,
QueryLength.LAYER_NAME: QueryLength,
QueryTypeVector.LAYER_NAME: QueryTypeVector
QueryTypeVector.LAYER_NAME: QueryTypeVector,
QueryEmbeddingVector.LAYER_NAME: QueryEmbeddingVector
}

def add_fn(self, key, fn):
Expand Down
2 changes: 1 addition & 1 deletion python/ml4ir/base/model/scoring/monte_carlo_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def call(self, inputs: Dict[str, tf.Tensor], training=None):
scores = super().call(inputs, training=False)[self.output_name]
for _ in range(monte_carlo_trials):
scores += super().call(inputs, training=True)[self.output_name]
scores = tf.divide(scores,monte_carlo_trials_tf)
scores = tf.divide(scores, monte_carlo_trials_tf)
return {self.output_name: scores}
3 changes: 2 additions & 1 deletion python/ml4ir/base/model/scoring/scorer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_scorer(model_config: dict,
if (MonteCarloInferenceKey.MONTE_CARLO_TRIALS in model_config and
(model_config[MonteCarloInferenceKey.MONTE_CARLO_TRIALS].get(MonteCarloInferenceKey.NUM_TEST_TRIALS, 0) or
model_config[MonteCarloInferenceKey.MONTE_CARLO_TRIALS].get(MonteCarloInferenceKey.NUM_TRAINING_TRIALS, 0))):
logger.info("Using Monte Carlo scorer.")
scorer = MonteCarloScorer(
feature_config=feature_config,
model_config=model_config,
Expand All @@ -84,6 +84,7 @@ def get_scorer(model_config: dict,
logs_dir=logs_dir
)
else:
logger.info("Using default scorer.")
scorer = RelevanceScorer(
feature_config=feature_config,
model_config=model_config,
Expand Down