Skip to content

Commit

Permalink
chore: change ratio for hybrid dbfs
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Dec 13, 2024
1 parent 630223c commit e3ad344
Showing 1 changed file with 78 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@


def pad_colbert_vecs(colbert_vecs_list, device):
"""
Since ColBERT embeddings are computed on a token-level basis, each document (or query)
may produce a different number of token embeddings. This function aligns all embeddings
to the same length by padding shorter sequences with zeros, ensuring that every input
ends up with a uniform shape.
Steps:
1. Determine the maximum sequence length (i.e., the largest number of tokens in any
query or passage within the batch).
2. For each set of token embeddings, pad it with zeros until it matches the max
sequence length. Zeros here act as placeholders and do not affect the similarity
computations since they represent "no token."
3. Convert all padded embeddings into a single, consistent tensor and move it to the
specified device (e.g., GPU) for efficient batch computation.
By performing this padding operation, subsequent tensor operations (like the einsum
computations for ColBERT scoring) become simpler and more efficient, as all sequences
share a common shape.
"""

lengths = [vec.shape[0] for vec in colbert_vecs_list]
max_len = max(lengths)
dim = colbert_vecs_list[0].shape[1]
Expand All @@ -18,18 +38,57 @@ def pad_colbert_vecs(colbert_vecs_list, device):


def compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs):
# query_colbert_vecs: (Q, Tq, D)
# passage_colbert_vecs: (P, Tp, D)
# einsum 식에서 q:queries, p:passages, r:query tokens dim, c:passage tokens dim, d:embedding dim
"""
Compute ColBERT scores:
ColBERT (Contextualized Late Interaction over BERT) evaluates the similarity
between a query and a passage at the token level. Instead of producing a single
dense vector for each query or passage, ColBERT maintains embeddings for every
token. This allows for finer-grained matching, capturing more subtle similarities.
Definitions of variables:
- q: Number of queries (Q)
- p: Number of passages (P)
- r: Number of tokens in each query (Tq)
- c: Number of tokens in each passage (Tp)
- d: Embedding dimension (D)
I used the operation `einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs)`:
- einsum (Einstein summation) is a powerful notation and function for
expressing and computing multi-dimensional tensor contractions. It allows you
to specify how dimensions in input tensors correspond to each other and how
they should be combined (multiplied and summed) to produce the output.
In this particular case:
- "qrd" corresponds to (Q, Tq, D) for query token embeddings.
- "pcd" corresponds to (P, Tp, D) for passage token embeddings.
- "qrd,pcd->qprc" means:
1. For each query q and passage p, compute the dot product between every query token
embedding (r) and every passage token embedding (c) across the embedding dimension d.
2. This results in a (Q, P, Tq, Tp) tensor (qprc), where each element is the similarity
score between a single query token and a single passage token.
After computing this full matrix of token-to-token scores:
- We take the maximum over the passage token dimension (c) for each query token (r).
This step identifies, for each query token, which passage token is the "best match."
- Then we sum over all query tokens (r) to aggregate their best matches into a single
score per query-passage pair.
In summary:
1. einsum to get all pairwise token similarities.
2. max over passage tokens to find the best matching passage token for each query token.
3. sum over query tokens to combine all the best matches into a final ColBERT score
for each query-passage pair.
"""

dot_products = torch.einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs) # Q,P,Tq,Tp
max_per_query_token, _ = dot_products.max(dim=3) # max over c (Tp)
colbert_scores = max_per_query_token.sum(dim=2) # sum over r (Tq)
max_per_query_token, _ = dot_products.max(dim=3)
colbert_scores = max_per_query_token.sum(dim=2)
return colbert_scores


def hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores, weights=(0.33, 0.33, 0.34)):
def hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores, weights=(0.45, 0.45, 0.1)):
w_dense, w_sparse, w_colbert = weights
# 모든 입력이 torch.Tensor일 경우 아래 연산 정상 작동
return w_dense * dense_scores + w_sparse * sparse_scores + w_colbert * colbert_scores


Expand All @@ -42,12 +101,12 @@ def test_m3_single_device():
)

queries = [
"What is BGE M3?",
"Defination of BM25"
"What is Sionic AI?",
"Try https://sionicstorm.ai today!"
] * 100
passages = [
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
"Sionic AI delivers more accessible and cost-effective AI technology addressing the various needs to boost productivity and drive innovation.",
"The Large Language Model (LLM) is not for research and experimentation. We offer solutions that leverage LLM to add value to your business. Anyone can easily train and control AI."
] * 100

queries_embeddings = model.encode_queries(
Expand All @@ -56,36 +115,32 @@ def test_m3_single_device():
return_sparse=True,
return_colbert_vecs=True,
)

passages_embeddings = model.encode_corpus(
passages,
return_dense=True,
return_sparse=True,
return_colbert_vecs=True,
)

# device 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# dense_vecs, lexical_weights 등이 numpy array 형태일 수 있으므로 텐서로 변환
q_dense = torch.tensor(queries_embeddings["dense_vecs"], dtype=torch.float, device=device)
p_dense = torch.tensor(passages_embeddings["dense_vecs"], dtype=torch.float, device=device)
dense_scores = q_dense @ p_dense.T

# sparse_scores도 numpy array를 텐서로 변환
sparse_scores_np = model.compute_lexical_matching_score(
queries_embeddings["lexical_weights"],
passages_embeddings["lexical_weights"]
)

sparse_scores = torch.tensor(sparse_scores_np, dtype=torch.float, device=device)

# colbert_vecs 패딩 후 텐서 변환
query_colbert_vecs = pad_colbert_vecs(queries_embeddings["colbert_vecs"], device)
passage_colbert_vecs = pad_colbert_vecs(passages_embeddings["colbert_vecs"], device)

colbert_scores = compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs)

# 모든 스코어가 torch.Tensor이므로 오류 없이 연산 가능
hybrid_scores = hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores)
hybrid_scores = hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores)

print("Dense score:\n", dense_scores[:2, :2])
print("Sparse score:\n", sparse_scores[:2, :2])
Expand All @@ -95,11 +150,14 @@ def test_m3_single_device():

if __name__ == '__main__':
test_m3_single_device()
print("Expected Vector Scores")
print("--------------------------------")
print("Expected Output for Dense & Sparse (original):")
print("Dense score:")
print(" [[0.626 0.3477]\n [0.3496 0.678 ]]")
print("Sparse score:")
print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]")
print("ColBERT score:")
print("[[5.8061, 3.1195] \n [5.6822, 4.6513]]")
print("Hybrid DBSF Ensemble score:")
print("[[0.9822, 0.5125] \n [0.8127, 0.6958]]")
print("--------------------------------")
print("ColBERT and Hybrid DBSF scores will vary depending on the actual embeddings.")

0 comments on commit e3ad344

Please sign in to comment.