diff --git a/.env.template b/.env.template index 48fe88a..8177e9f 100644 --- a/.env.template +++ b/.env.template @@ -1,4 +1,4 @@ -GOOGLE_API_KEY= +GOOGLE_API_KEY=your_api_key_here # To use vertexai keep it true and false to use gemini GEMINI_USE_VERTEX=true diff --git a/backend/agents.py b/backend/agents.py index 155ce0a..fc812ed 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -13,6 +13,7 @@ from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search from retrieval import get_retriever +from rrf import reciprocal_rank_fusion # LLM (Gemini) client setup @@ -436,24 +437,16 @@ async def execute_search(state: AgentState) -> Dict[str, Any]: def fuse_results(state: AgentState) -> AgentState: - logger.info("Node: Result Fusion") + logger.info("Node: Result Fusion (RRF)") ks_results = state.get("ks_results", []) vector_results = state.get("vector_results", []) - combined: Dict[str, dict] = {} - for res in vector_results: - if isinstance(res, dict): - doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}" - combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6} - for res in ks_results: - if isinstance(res, dict): - doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}" - if doc_id in combined: - combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4 - else: - combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4} - all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True) + + # We pass both lists to RRF. RRF handles deduplication and ranking. + # It takes care of ranking documents that appear in either or both lists. + all_sorted = reciprocal_rank_fusion([vector_results, ks_results], k=60, top_k=60) + logger.info( - "Results summary: KS=%d, Vector=%d, Combined=%d", + "RRF fusion: KS=%d, Vector=%d → Combined=%d unique results", len(ks_results), len(vector_results), len(all_sorted), diff --git a/backend/rrf.py b/backend/rrf.py new file mode 100644 index 0000000..8bfb147 --- /dev/null +++ b/backend/rrf.py @@ -0,0 +1,82 @@ +import logging +from typing import List, Dict, Any, Set + +logger = logging.getLogger("rrf") + +def extract_doc_id(result: Dict[str, Any]) -> str: + """ + Safely extract a unique document ID from a search result dictionary. + Handles differences between Keyword Search (KS) and Vector Search formats. + """ + return str(result.get("id") or result.get("_id") or "") + +def reciprocal_rank_fusion( + ranked_lists: List[List[Dict[str, Any]]], + k: int = 60, + top_k: int = 15 +) -> List[Dict[str, Any]]: + """ + Combines multiple ranked lists of documents into a single ranked list using + Reciprocal Rank Fusion (RRF). + + Formula: RRF_score(d) = sum(1 / (k + rank_i(d))) + where `rank_i(d)` is the 1-based index (rank) of document `d` in list `i`. + + Args: + ranked_lists: A list of lists, where each inner list contains document dicts + ordered by their original search score (highest first). + k: The smoothing constant (default: 60, standard from literature). + top_k: The number of top fused results to return. + + Returns: + A single fused list of document dictionaries, ordered by RRF score descending. + Each dictionary will have an added 'rrf_score' field and an updated 'final_score' + field for compatibility with the rest of the application. + """ + # 1. Initialize RRF scores for all unique document IDs + rrf_scores: Dict[str, float] = {} + + # We also keep a mapping of ID -> original document dict + # so we can reconstruct the final list (we use the first occurrence we find) + doc_map: Dict[str, Dict[str, Any]] = {} + + for ranked_list in ranked_lists: + for idx, doc in enumerate(ranked_list): + doc_id = extract_doc_id(doc) + + # Skip if we couldn't resolve an ID (should theoretically not happen, but safe) + if not doc_id: + # Generate a weak fallback ID based on content hash or title context if needed, + # but for KnowledgeSpace, id or _id should always exist. + doc_id = str(hash(doc.get("title_guess", "unknown"))) + + rank = idx + 1 # RRF uses 1-based ranks + + # Add the reciprocal rank score for this document + rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (k + rank)) + + # Store the underlying doc if we haven't seen it yet + if doc_id not in doc_map: + # Make a shallow copy to avoid mutating the original deeply + doc_map[doc_id] = dict(doc) + + # 2. Sort documents by their accumulated RRF score descending + sorted_keys = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True) + sorted_doc_ids: List[str] = list(sorted_keys) + + # 3. Construct the final fused list + fused_results: List[Dict[str, Any]] = [] + + for doc_id in sorted_doc_ids[:top_k]: + doc = doc_map[doc_id] + score = rrf_scores[doc_id] + + # Add tracking fields to the document + doc["rrf_score"] = score + # Maintain backward compatibility with agents.py expectations + doc["final_score"] = score + + fused_results.append(doc) + + logger.debug(f"Combined {len(ranked_lists)} lists into {len(fused_results)} results.") + return fused_results diff --git a/backend/tests/test_rrf.py b/backend/tests/test_rrf.py new file mode 100644 index 0000000..4e86f20 --- /dev/null +++ b/backend/tests/test_rrf.py @@ -0,0 +1,75 @@ +import pytest +from rrf import reciprocal_rank_fusion, extract_doc_id + +def test_extract_doc_id(): + assert extract_doc_id({"id": "123"}) == "123" + assert extract_doc_id({"_id": "456"}) == "456" + assert extract_doc_id({"id": "123", "_id": "456"}) == "123" # Prefers 'id' + assert extract_doc_id({}) == "" + +def test_rrf_single_list(): + list1 = [{"id": "A"}, {"id": "B"}, {"id": "C"}] + fused = reciprocal_rank_fusion([list1], k=60, top_k=10) + + assert len(fused) == 3 + assert fused[0]["id"] == "A" + assert fused[1]["id"] == "B" + assert fused[2]["id"] == "C" + + # Check score math: A=1/61, B=1/62, C=1/63 + assert fused[0]["rrf_score"] == 1 / 61 + assert fused[1]["rrf_score"] == 1 / 62 + assert fused[2]["rrf_score"] == 1 / 63 + +def test_rrf_two_lists_same_order(): + list1 = [{"id": "A"}, {"id": "B"}] + list2 = [{"_id": "A"}, {"_id": "B"}] # Note list2 uses _id + fused = reciprocal_rank_fusion([list1, list2], k=60, top_k=10) + + assert len(fused) == 2 + assert fused[0]["id"] == "A" # Source dict comes from list1 first + assert fused[1]["id"] == "B" + + # A is rank 1 in both: 1/61 + 1/61 + assert fused[0]["rrf_score"] == (1/61) + (1/61) + +def test_rrf_boosts_overlap(): + # A is in both lists but ranked lower. B is rank 1 in list1 only. C is rank 1 in list2 only. + list1 = [{"id": "B"}, {"id": "A"}, {"id": "X"}] + list2 = [{"id": "C"}, {"id": "A"}, {"id": "Y"}] + + fused = reciprocal_rank_fusion([list1, list2], k=60, top_k=10) + + weights = {doc["id"]: doc["rrf_score"] for doc in fused} + + # A: rank 2 + rank 2 = 1/62 + 1/62 = 0.032258 + # B: rank 1 + none = 1/61 + 0 = 0.016393 + # C: rank 1 + none = 1/61 + 0 = 0.016393 + + assert fused[0]["id"] == "A" + assert weights["A"] > weights["B"] + assert weights["A"] > weights["C"] + +def test_rrf_empty_lists(): + assert reciprocal_rank_fusion([], k=60) == [] + assert reciprocal_rank_fusion([[], []], k=60) == [] + + list1 = [{"id": "A"}] + # Fuses one empty list and one populated list + fused = reciprocal_rank_fusion([list1, []], k=60) + assert len(fused) == 1 + assert fused[0]["id"] == "A" + +def test_rrf_top_k_truncates(): + list1 = [{"id": str(i)} for i in range(100)] + fused = reciprocal_rank_fusion([list1], k=60, top_k=5) + assert len(fused) == 5 + assert fused[-1]["id"] == "4" # Indices 0, 1, 2, 3, 4 + +def test_rrf_id_fallback(): + # If a document doesn't have id or _id, the function uses a hash fallback. + # While relying on title_guess is weak, this ensures no crash. + list1 = [{"title_guess": "Unique Title"}, {"title_guess": "Another Title"}] + fused = reciprocal_rank_fusion([list1]) + assert len(fused) == 2 + assert fused[0].get("rrf_score") is not None