From 6c7721c8bbb43e62933df63eb73aa02616b1c61e Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 23 Dec 2024 12:12:15 +0800 Subject: [PATCH] Fetch chunk by batches. (#4177) ### What problem does this PR solve? #4173 ### Type of change - [x] Performance Improvement --- rag/nlp/search.py | 15 +++++++++++---- rag/utils/es_conn.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 837969761a4..39fbc015080 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -70,7 +70,7 @@ def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=Non pg = int(req.get("page", 1)) - 1 topk = int(req.get("topk", 1024)) ps = int(req.get("size", topk)) - offset, limit = pg * ps, (pg + 1) * ps + offset, limit = pg * ps, ps src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int", "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks", @@ -380,6 +380,13 @@ def sql_retrieval(self, sql, fetch_size=128, format="json"): def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]): condition = {"doc_id": doc_id} - res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids) - dict_chunks = self.dataStore.getFields(res, fields) - return dict_chunks.values() + res = [] + bs = 128 + for p in range(0, max_count, bs): + res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids) + dict_chunks = self.dataStore.getFields(res, fields) + if dict_chunks: + res.extend(dict_chunks.values()) + if len(dict_chunks.values()) < bs: + break + return res diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 35d64286959..f4de03aba7d 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -196,7 +196,7 @@ def search(self, selectFields: list[str], highlightFields: list[str], condition: s = s.sort(*orders) if limit > 0: - s = s[offset:limit] + s = s[offset:offset+limit] q = s.to_dict() logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))