Skip to content

Commit

Permalink
Fix API retrieval error. (#4408)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#4403

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Jan 8, 2025
1 parent b7ce4e7 commit 3d66d78
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
import xxhash
import re
Expand Down Expand Up @@ -1331,18 +1331,14 @@ def retrieval_test(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_error_data_result(message="Dataset not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id
)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]
)
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])

if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
Expand Down

0 comments on commit 3d66d78

Please sign in to comment.