diff --git a/backend/app/api/admin_routes/chat_engine.py b/backend/app/api/admin_routes/chat_engine.py index 57d48f479..a7a31bf6b 100644 --- a/backend/app/api/admin_routes/chat_engine.py +++ b/backend/app/api/admin_routes/chat_engine.py @@ -2,7 +2,7 @@ from fastapi_pagination import Params, Page from app.api.deps import SessionDep, CurrentSuperuserDep -from app.rag.chat_config import ChatEngineConfig +from app.rag.chat.config import ChatEngineConfig from app.repositories import chat_engine_repo from app.models import ChatEngine, ChatEngineUpdate diff --git a/backend/app/api/admin_routes/evaluation/evaluation_task.py b/backend/app/api/admin_routes/evaluation/evaluation_task.py index 05271f1c0..7a1f53088 100644 --- a/backend/app/api/admin_routes/evaluation/evaluation_task.py +++ b/backend/app/api/admin_routes/evaluation/evaluation_task.py @@ -119,7 +119,7 @@ def get_evaluation_task_summary( evaluation_task_id: int, session: SessionDep, user: CurrentSuperuserDep ) -> EvaluationTaskSummary: task = must_get(session, EvaluationTask, evaluation_task_id) - return get_evaluation_task_summary(task, session) + return get_summary_for_evaluation_task(task, session) @router.get("/admin/evaluation/tasks") @@ -135,7 +135,7 @@ def list_evaluation_task( task_page: Page[EvaluationTask] = paginate(session, stmt, params) summaries: List[EvaluationTaskSummary] = [] for task in task_page.items: - summaries.append(get_evaluation_task_summary(task, session)) + summaries.append(get_summary_for_evaluation_task(task, session)) return Page[EvaluationTaskSummary]( items=summaries, @@ -169,7 +169,7 @@ def list_evaluation_task_items( return paginate(session, stmt, params) -def get_evaluation_task_summary( +def get_summary_for_evaluation_task( evaluation_task: EvaluationTask, session: Session ) -> EvaluationTaskSummary: status_counts = ( diff --git a/backend/app/api/admin_routes/evaluation/tools.py b/backend/app/api/admin_routes/evaluation/tools.py index 81e1f3843..7db94db4a 100644 --- a/backend/app/api/admin_routes/evaluation/tools.py +++ b/backend/app/api/admin_routes/evaluation/tools.py @@ -1,6 +1,4 @@ -from http.client import HTTPException from typing import TypeVar, Type - from fastapi import status, HTTPException from sqlmodel import SQLModel, Session diff --git a/backend/app/api/admin_routes/knowledge_base/graph/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/routes.py index 09b7eb8b8..6f6c51d1e 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/routes.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/routes.py @@ -226,7 +226,7 @@ def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchReq try: kb = knowledge_base_repo.must_get(session, kb_id) graph_store = get_kb_tidb_graph_store(session, kb) - entities, relations = graph_store.retrieve_with_weight( + entities, relationships = graph_store.retrieve_with_weight( request.query, [], request.depth, @@ -236,7 +236,7 @@ def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchReq ) return { "entities": entities, - "relationships": relations, + "relationships": relationships, } except KBNotFound as e: raise e diff --git a/backend/app/api/admin_routes/legacy_retrieve.py b/backend/app/api/admin_routes/legacy_retrieve.py new file mode 100644 index 000000000..1f3d425c1 --- /dev/null +++ b/backend/app/api/admin_routes/legacy_retrieve.py @@ -0,0 +1,133 @@ +import logging +from typing import Optional, List + +from fastapi import APIRouter +from sqlmodel import Session +from app.models import Document +from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest +from app.api.deps import SessionDep, CurrentSuperuserDep +from llama_index.core.schema import NodeWithScore + +from app.exceptions import InternalServerError, KBNotFound +from app.rag.chat.config import ChatEngineConfig +from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def get_override_engine_config( + db_session: Session, + engine_name: str, + # Override chat engine config. + top_k: Optional[int] = None, + similarity_top_k: Optional[int] = None, + oversampling_factor: Optional[int] = None, + refine_question_with_kg: Optional[bool] = None, +) -> ChatEngineConfig: + engine_config = ChatEngineConfig.load_from_db(db_session, engine_name) + if similarity_top_k is not None: + engine_config.vector_search.similarity_top_k = similarity_top_k + if oversampling_factor is not None: + engine_config.vector_search.oversampling_factor = oversampling_factor + if top_k is not None: + engine_config.vector_search.top_k = top_k + if refine_question_with_kg is not None: + engine_config.refine_question_with_kg = refine_question_with_kg + return engine_config + + +@router.get("/admin/retrieve/documents", deprecated=True) +def legacy_retrieve_documents( + session: SessionDep, + user: CurrentSuperuserDep, + question: str, + chat_engine: str = "default", + # Override chat engine config. + top_k: Optional[int] = 5, + similarity_top_k: Optional[int] = None, + oversampling_factor: Optional[int] = 5, + refine_question_with_kg: Optional[bool] = True, +) -> List[Document]: + try: + engine_config = get_override_engine_config( + db_session=session, + engine_name=chat_engine, + top_k=top_k, + similarity_top_k=similarity_top_k, + oversampling_factor=oversampling_factor, + refine_question_with_kg=refine_question_with_kg, + ) + retriever = RetrieveFlow( + db_session=session, + engine_name=chat_engine, + engine_config=engine_config, + ) + return retriever.retrieve_documents(question) + except KBNotFound as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.get("/admin/embedding_retrieve", deprecated=True) +def legacy_retrieve_chunks( + session: SessionDep, + user: CurrentSuperuserDep, + question: str, + chat_engine: str = "default", + # Override chat engine config. + top_k: Optional[int] = 5, + similarity_top_k: Optional[int] = None, + oversampling_factor: Optional[int] = 5, + refine_question_with_kg=False, +) -> List[NodeWithScore]: + try: + engine_config = get_override_engine_config( + db_session=session, + engine_name=chat_engine, + top_k=top_k, + similarity_top_k=similarity_top_k, + oversampling_factor=oversampling_factor, + refine_question_with_kg=refine_question_with_kg, + ) + retriever = RetrieveFlow( + db_session=session, + engine_name=chat_engine, + engine_config=engine_config, + ) + return retriever.retrieve(question) + except KBNotFound as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.post("/admin/embedding_retrieve", deprecated=True) +def legacy_retrieve_chunks_2( + session: SessionDep, + user: CurrentSuperuserDep, + request: ChatEngineBasedRetrieveRequest, +) -> List[NodeWithScore]: + try: + engine_config = get_override_engine_config( + db_session=session, + engine_name=request.chat_engine, + top_k=request.top_k, + similarity_top_k=request.similarity_top_k, + oversampling_factor=request.oversampling_factor, + refine_question_with_kg=request.refine_question_with_kg, + ) + retriever = RetrieveFlow( + db_session=session, + engine_name=request.chat_engine, + engine_config=engine_config, + ) + return retriever.retrieve(request.query) + except KBNotFound as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() diff --git a/backend/app/api/admin_routes/models.py b/backend/app/api/admin_routes/models.py index 26724363b..1a644ee12 100644 --- a/backend/app/api/admin_routes/models.py +++ b/backend/app/api/admin_routes/models.py @@ -26,6 +26,9 @@ class KnowledgeBaseDescriptor(BaseModel): id: int name: str + def __hash__(self): + return hash(self.id) + class DataSourceDescriptor(BaseModel): id: int @@ -44,4 +47,4 @@ class ChatEngineBasedRetrieveRequest(BaseModel): top_k: Optional[int] = 5 similarity_top_k: Optional[int] = None oversampling_factor: Optional[int] = 5 - enable_kg_enhance_query_refine: Optional[bool] = False + refine_question_with_kg: Optional[bool] = False diff --git a/backend/app/api/admin_routes/retrieve_old.py b/backend/app/api/admin_routes/retrieve_old.py deleted file mode 100644 index 330de259c..000000000 --- a/backend/app/api/admin_routes/retrieve_old.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from typing import Optional, List - -from fastapi import APIRouter -from app.models import Document -from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest -from app.api.deps import SessionDep, CurrentSuperuserDep -from llama_index.core.schema import NodeWithScore -from app.rag.retrieve import retrieve_service - -from app.exceptions import InternalServerError, KBNotFound - -router = APIRouter() -logger = logging.getLogger(__name__) - - -@router.get("/admin/retrieve/documents", deprecated=True) -def retrieve_documents( - session: SessionDep, - user: CurrentSuperuserDep, - question: str, - chat_engine: str = "default", - top_k: Optional[int] = 5, - similarity_top_k: Optional[int] = None, - oversampling_factor: Optional[int] = 5, - enable_kg_enhance_query_refine: Optional[bool] = True, -) -> List[Document]: - try: - return retrieve_service.chat_engine_retrieve_documents( - session, - question=question, - top_k=top_k, - chat_engine_name=chat_engine, - similarity_top_k=similarity_top_k, - oversampling_factor=oversampling_factor, - enable_kg_enhance_query_refine=enable_kg_enhance_query_refine, - ) - except KBNotFound as e: - raise e - except Exception as e: - logger.exception(e) - raise InternalServerError() - - -@router.get("/admin/embedding_retrieve", deprecated=True) -def embedding_retrieve( - session: SessionDep, - user: CurrentSuperuserDep, - question: str, - chat_engine: str = "default", - top_k: Optional[int] = 5, - similarity_top_k: Optional[int] = None, - oversampling_factor: Optional[int] = 5, - enable_kg_enhance_query_refine=False, -) -> List[NodeWithScore]: - try: - nodes = retrieve_service.chat_engine_retrieve_chunks( - session, - question=question, - top_k=top_k, - chat_engine_name=chat_engine, - similarity_top_k=similarity_top_k, - oversampling_factor=oversampling_factor, - enable_kg_enhance_query_refine=enable_kg_enhance_query_refine, - ) - return nodes - except KBNotFound as e: - raise e - except Exception as e: - logger.exception(e) - raise InternalServerError() - - -@router.post("/admin/embedding_retrieve", deprecated=True) -def embedding_search( - session: SessionDep, - user: CurrentSuperuserDep, - request: ChatEngineBasedRetrieveRequest, -) -> List[NodeWithScore]: - try: - return retrieve_service.chat_engine_retrieve_chunks( - session, - request.query, - top_k=request.top_k, - similarity_top_k=request.similarity_top_k, - oversampling_factor=request.oversampling_factor, - enable_kg_enhance_query_refine=request.enable_kg_enhance_query_refine, - ) - except KBNotFound as e: - raise e - except Exception as e: - logger.exception(e) - raise InternalServerError() diff --git a/backend/app/api/admin_routes/semantic_cache.py b/backend/app/api/admin_routes/semantic_cache.py index 73fd0a6bf..0ffea78af 100644 --- a/backend/app/api/admin_routes/semantic_cache.py +++ b/backend/app/api/admin_routes/semantic_cache.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Body from app.api.deps import SessionDep, CurrentSuperuserDep -from app.rag.chat_config import ChatEngineConfig +from app.rag.chat.config import ChatEngineConfig from app.rag.semantic_cache import SemanticCacheManager, SemanticItem router = APIRouter() diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ea7b80fcc..59dc4af5e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -38,9 +38,9 @@ from app.api.admin_routes import ( chat_engine as admin_chat_engine, feedback as admin_feedback, + legacy_retrieve as admin_legacy_retrieve, site_setting as admin_site_settings, upload as admin_upload, - retrieve_old as admin_retrieve_old, stats as admin_stats, semantic_cache as admin_semantic_cache, langfuse as admin_langfuse, @@ -84,7 +84,7 @@ api_router.include_router(admin_embedding_model_router, tags=["admin/embedding_model"]) api_router.include_router(admin_reranker_model_router, tags=["admin/reranker_model"]) api_router.include_router(admin_langfuse.router, tags=["admin/langfuse"]) -api_router.include_router(admin_retrieve_old.router, tags=["admin/retrieve_old"]) +api_router.include_router(admin_legacy_retrieve.router, tags=["admin/retrieve_old"]) api_router.include_router(admin_stats.router, tags=["admin/stats"]) api_router.include_router(admin_semantic_cache.router, tags=["admin/semantic_cache"]) api_router.include_router(admin_evaluation_task.router, tags=["admin/evaluation/task"]) @@ -96,19 +96,3 @@ api_router.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth", tags=["auth"] ) - -# api_router.include_router( -# fastapi_users.get_register_router(UserRead, UserCreate), -# prefix="/auth", -# tags=["auth"], -# ) -# api_router.include_router( -# fastapi_users.get_reset_password_router(), -# prefix="/auth", -# tags=["auth"], -# ) -# api_router.include_router( -# fastapi_users.get_verify_router(UserRead), -# prefix="/auth", -# tags=["auth"], -# ) diff --git a/backend/app/api/routes/chat.py b/backend/app/api/routes/chat.py index 6a09496ae..81a0ae0e7 100644 --- a/backend/app/api/routes/chat.py +++ b/backend/app/api/routes/chat.py @@ -12,31 +12,21 @@ from fastapi_pagination import Params, Page from app.api.deps import SessionDep, OptionalUserDep, CurrentUserDep +from app.rag.chat.chat_flow import ChatFlow +from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrievalResult from app.repositories import chat_repo from app.models import Chat, ChatUpdate -from app.rag.chat import ( - ChatFlow, - ChatEvent, + +from app.rag.chat.chat_service import ( + get_final_chat_result, user_can_view_chat, user_can_edit_chat, get_chat_message_subgraph, get_chat_message_recommend_questions, remove_chat_message_recommend_questions, ) -from app.rag.types import ( - MessageRole, - ChatMessage, - ChatEventType, - ChatMessageSate, -) -from app.exceptions import ( - ChatNotFound, - KBNotFound, - LLMException, - EmbeddingModelException, - RerankerModelException, - InternalServerError, -) +from app.rag.types import MessageRole, ChatMessage +from app.exceptions import InternalServerError logger = logging.getLogger(__name__) @@ -86,60 +76,23 @@ def chats( chat_messages=chat_request.messages, engine_name=chat_request.chat_engine, ) - except ChatNotFound as e: - raise e - except KBNotFound as e: - raise e - except LLMException as e: - raise e - except EmbeddingModelException as e: - raise e - except RerankerModelException as e: + + if chat_request.stream: + return StreamingResponse( + chat_flow.chat(), + media_type="text/event-stream", + headers={ + "X-Content-Type-Options": "nosniff", + }, + ) + else: + return get_final_chat_result(chat_flow.chat()) + except HTTPException as e: raise e except Exception as e: logger.exception(e) raise InternalServerError() - if chat_request.stream: - return StreamingResponse( - chat_flow.chat(), - media_type="text/event-stream", - headers={ - "X-Content-Type-Options": "nosniff", - }, - ) - else: - trace, sources, content = None, [], "" - chat_id, message_id = None, None - # TODO: maybe we can wrap following code in the chat() method, and using "yield from" to return the values. - for m in chat_flow.chat(): - if not isinstance(m, ChatEvent): - continue - if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART: - if m.payload.state == ChatMessageSate.SOURCE_NODES: - sources = m.payload.context - elif m.payload.state == ChatMessageSate.TRACE: - trace = m.payload.context - elif m.event_type == ChatEventType.TEXT_PART: - content += m.payload - elif m.event_type == ChatEventType.DATA_PART: - chat_id = m.payload.chat.id - message_id = m.payload.assistant_message.id - elif m.event_type == ChatEventType.ERROR_PART: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=m.payload, - ) - else: - pass - return { - "chat_id": chat_id, - "message_id": message_id, - "trace": trace, - "sources": sources, - "content": content, - } - @router.get("/chats") def list_chats( @@ -154,9 +107,7 @@ def list_chats( @router.get("/chats/{chat_id}") def get_chat(session: SessionDep, user: OptionalUserDep, chat_id: UUID): - chat = chat_repo.get(session, chat_id) - if not chat: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Chat not found") + chat = chat_repo.must_get(session, chat_id) if not user_can_view_chat(chat, user): raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied") @@ -171,71 +122,99 @@ def get_chat(session: SessionDep, user: OptionalUserDep, chat_id: UUID): def update_chat( session: SessionDep, user: CurrentUserDep, chat_id: UUID, chat_update: ChatUpdate ): - chat = chat_repo.get(session, chat_id) - if not chat: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Chat not found") + try: + chat = chat_repo.must_get(session, chat_id) - if not user_can_edit_chat(chat, user): - raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied") - return chat_repo.update(session, chat, chat_update) + if not user_can_edit_chat(chat, user): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="Access denied" + ) + + return chat_repo.update(session, chat, chat_update) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e, exc_info=True) + raise InternalServerError() @router.delete("/chats/{chat_id}") def delete_chat(session: SessionDep, user: CurrentUserDep, chat_id: UUID): - chat = chat_repo.get(session, chat_id) - if not chat: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Chat not found") - - if not user_can_edit_chat(chat, user): - raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied") - return chat_repo.delete(session, chat) + try: + chat = chat_repo.must_get(session, chat_id) + if not user_can_edit_chat(chat, user): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="Access denied" + ) -class SubgraphResponse(BaseModel): - entities: List[dict] - relationships: List[dict] + return chat_repo.delete(session, chat) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e, exc_info=True) + raise InternalServerError() @router.get( - "/chat-messages/{chat_message_id}/subgraph", response_model=SubgraphResponse + "/chat-messages/{chat_message_id}/subgraph", + response_model=KnowledgeGraphRetrievalResult, ) def get_chat_subgraph(session: SessionDep, user: OptionalUserDep, chat_message_id: int): - chat_message = chat_repo.get_message(session, chat_message_id) - if not chat_message: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Chat message not found" - ) + try: + chat_message = chat_repo.must_get_message(session, chat_message_id) - if not user_can_view_chat(chat_message.chat, user): - raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied") + if not user_can_view_chat(chat_message.chat, user): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="Access denied" + ) - entities, relations = get_chat_message_subgraph(session, chat_message) + result = get_chat_message_subgraph(session, chat_message) + return result.model_dump(exclude_none=True) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e, exc_info=True) + raise InternalServerError() - return SubgraphResponse(entities=entities, relationships=relations) +@router.get("/chat-messages/{chat_message_id}/recommended-questions") +def get_recommended_questions( + session: SessionDep, user: OptionalUserDep, chat_message_id: int +) -> List[str]: + try: + chat_message = chat_repo.must_get_message(session, chat_message_id) -@router.get( - "/chat-messages/{chat_message_id}/recommended-questions", response_model=List[str] -) -def get_recommended_questions(session: SessionDep, chat_message_id: int): - chat_message = chat_repo.get_message(session, chat_message_id) - if not chat_message or len(chat_message.content) == 0: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Chat message not found" - ) + if not user_can_view_chat(chat_message.chat, user): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="Access denied" + ) - return get_chat_message_recommend_questions(session, chat_message) + return get_chat_message_recommend_questions(session, chat_message) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e, exc_info=True) + raise InternalServerError() -@router.post( - "/chat-messages/{chat_message_id}/recommended-questions", response_model=List[str] -) -def refresh_recommended_questions(session: SessionDep, chat_message_id: int): - chat_message = chat_repo.get_message(session, chat_message_id) - if not chat_message or len(chat_message.content) == 0: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Chat message not found" - ) +@router.post("/chat-messages/{chat_message_id}/recommended-questions") +def refresh_recommended_questions( + session: SessionDep, user: OptionalUserDep, chat_message_id: int +) -> List[str]: + try: + chat_message = chat_repo.must_get_message(session, chat_message_id) - remove_chat_message_recommend_questions(session, chat_message_id) - return get_chat_message_recommend_questions(session, chat_message) + if not user_can_view_chat(chat_message.chat, user): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="Access denied" + ) + + remove_chat_message_recommend_questions(session, chat_message_id) + + return get_chat_message_recommend_questions(session, chat_message) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e, exc_info=True) + raise InternalServerError() diff --git a/backend/app/api/routes/index.py b/backend/app/api/routes/index.py index e5c20acbf..d32c9edda 100644 --- a/backend/app/api/routes/index.py +++ b/backend/app/api/routes/index.py @@ -4,7 +4,7 @@ from app.api.deps import SessionDep from app.api.routes.models import SystemConfigStatusResponse from app.site_settings import SiteSetting -from app.rag.chat import ( +from app.rag.chat.chat_service import ( check_rag_required_config, check_rag_optional_config, check_rag_config_need_migration, diff --git a/backend/app/api_server.py b/backend/app/api_server.py new file mode 100644 index 000000000..4e4eaa3d9 --- /dev/null +++ b/backend/app/api_server.py @@ -0,0 +1,116 @@ +import logging +from logging.config import dictConfig +import sentry_sdk + +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request, Response +from fastapi.routing import APIRoute +from starlette.middleware.cors import CORSMiddleware +from dotenv import load_dotenv +from app.api.main import api_router +from app.core.config import settings, Environment +from app.site_settings import SiteSetting +from app.utils.uuid6 import uuid7 + + +dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "root": { + "level": logging.INFO + if settings.ENVIRONMENT != Environment.LOCAL + else logging.DEBUG, + "handlers": ["console"], + }, + "loggers": { + "uvicorn.error": { + "level": "ERROR", + "handlers": ["console"], + "propagate": False, + }, + "uvicorn.access": { + "level": "INFO", + "handlers": ["console"], + "propagate": False, + }, + }, + } +) + + +logger = logging.getLogger(__name__) + + +load_dotenv() + + +def custom_generate_unique_id(route: APIRoute) -> str: + return f"{route.tags[0]}-{route.name}" + + +if settings.SENTRY_DSN and settings.ENVIRONMENT != "local": + sentry_sdk.init( + dsn=str(settings.SENTRY_DSN), + enable_tracing=True, + traces_sample_rate=settings.SENTRY_TRACES_SAMPLE_RATE, + profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE, + ) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + SiteSetting.update_db_cache() + yield + + +app = FastAPI( + title=settings.PROJECT_NAME, + openapi_url=f"{settings.API_V1_STR}/openapi.json", + generate_unique_id_function=custom_generate_unique_id, + lifespan=lifespan, +) + + +# Set all CORS enabled origins +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[ + str(origin).strip("/") for origin in settings.BACKEND_CORS_ORIGINS + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + +@app.middleware("http") +async def identify_browser(request: Request, call_next): + browser_id = request.cookies.get(settings.BROWSER_ID_COOKIE_NAME) + has_browser_id = bool(browser_id) + if not browser_id: + browser_id = uuid7() + request.state.browser_id = browser_id + response: Response = await call_next(request) + if not has_browser_id: + response.set_cookie( + settings.BROWSER_ID_COOKIE_NAME, + browser_id, + max_age=settings.BROWSER_ID_COOKIE_MAX_AGE, + ) + return response + + +app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index 9f6e2c7e1..1a3f3b172 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -25,6 +25,13 @@ def __init__(self, chat_id: UUID): self.detail = f"chat #{chat_id} is not found" +class ChatMessageNotFound(ChatException): + status_code = 404 + + def __init__(self, message_id: int): + self.detail = f"chat message #{message_id} is not found" + + # LLM @@ -149,7 +156,11 @@ def __init__(self, kb_id, chat_engines_num: int): # Document -class DocumentNotFound(KBException): +class DocumentException(HTTPException): + pass + + +class DocumentNotFound(DocumentException): status_code = 404 def __init__(self, document_id: int): diff --git a/backend/app/models/knowledge_base.py b/backend/app/models/knowledge_base.py index 1fd255be9..17e2ba18a 100644 --- a/backend/app/models/knowledge_base.py +++ b/backend/app/models/knowledge_base.py @@ -13,6 +13,7 @@ SQLModel, ) +from app.api.admin_routes.models import KnowledgeBaseDescriptor from app.exceptions import KBDataSourceNotFound from app.models.auth import User from app.models.data_source import DataSource @@ -109,3 +110,9 @@ def must_get_data_source_by_id(self, data_source_id: int) -> DataSource: if data_source is None: raise KBDataSourceNotFound(self.id, data_source_id) return data_source + + def to_descriptor(self) -> KnowledgeBaseDescriptor: + return KnowledgeBaseDescriptor( + id=self.id, + name=self.name, + ) diff --git a/backend/app/api/admin_routes/knowledge_base/document/__init__.py b/backend/app/rag/chat/__init__.py similarity index 100% rename from backend/app/api/admin_routes/knowledge_base/document/__init__.py rename to backend/app/rag/chat/__init__.py diff --git a/backend/app/rag/chat.py b/backend/app/rag/chat/chat.py similarity index 99% rename from backend/app/rag/chat.py rename to backend/app/rag/chat/chat.py index 028478269..a45e10628 100644 --- a/backend/app/rag/chat.py +++ b/backend/app/rag/chat/chat.py @@ -39,7 +39,7 @@ from app.core.config import settings from app.models.entity import get_kb_entity_model from app.models.recommend_question import RecommendQuestion -from app.rag.chat_stream_protocol import ( +from app.rag.chat.stream_protocol import ( ChatStreamMessagePayload, ChatStreamDataPayload, ChatEvent, @@ -68,7 +68,7 @@ ) from app.rag.knowledge_base.selector import KBSelectMode from app.rag.utils import parse_goal_response_format -from app.rag.chat_config import ( +from app.rag.chat.config import ( ChatEngineConfig, KnowledgeGraphOption, KnowledgeBaseOption, diff --git a/backend/app/rag/chat/chat_flow.py b/backend/app/rag/chat/chat_flow.py new file mode 100644 index 000000000..12cfe8d9e --- /dev/null +++ b/backend/app/rag/chat/chat_flow.py @@ -0,0 +1,790 @@ +import json +import logging +from datetime import datetime, UTC +from typing import List, Optional, Generator, Tuple, Any +from urllib.parse import urljoin +from uuid import UUID + +import requests +from langfuse.llama_index import LlamaIndexInstrumentor +from langfuse.llama_index._context import langfuse_instrumentor_context +from llama_index.core import get_response_synthesizer +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.schema import NodeWithScore +from sqlmodel import Session +from app.core.config import settings +from app.exceptions import ChatNotFound +from app.models import ( + User, + Chat as DBChat, + ChatVisibility, + ChatMessage as DBChatMessage, +) +from app.rag.chat.config import ChatEngineConfig +from app.rag.chat.retrieve.retrieve_flow import SourceDocument, RetrieveFlow +from app.rag.chat.stream_protocol import ( + ChatEvent, + ChatStreamDataPayload, + ChatStreamMessagePayload, +) +from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrievalResult +from app.rag.types import ChatEventType, MessageRole, ChatMessageSate +from app.rag.utils import parse_goal_response_format +from app.repositories import chat_repo +from app.site_settings import SiteSetting +from app.utils.jinja2 import get_prompt_by_jinja2_template +from app.utils.tracing import LangfuseContextManager + +logger = logging.getLogger(__name__) + + +def parse_chat_messages( + chat_messages: List[ChatMessage], +) -> tuple[str, List[ChatMessage]]: + user_question = chat_messages[-1].content + chat_history = chat_messages[:-1] + return user_question, chat_history + + +class ChatFlow: + _trace_manager: LangfuseContextManager + + def __init__( + self, + *, + db_session: Session, + user: User, + browser_id: str, + origin: str, + chat_messages: List[ChatMessage], + engine_name: str = "default", + chat_id: Optional[UUID] = None, + ) -> None: + self.chat_id = chat_id + self.db_session = db_session + self.user = user + self.browser_id = browser_id + self.engine_name = engine_name + + # Load chat engine and chat session. + self.user_question, self.chat_history = parse_chat_messages(chat_messages) + if chat_id: + # FIXME: + # only chat owner or superuser can access the chat, + # anonymous user can only access anonymous chat by track_id + self.db_chat_obj = chat_repo.get(self.db_session, chat_id) + if not self.db_chat_obj: + raise ChatNotFound(chat_id) + try: + self.engine_config = ChatEngineConfig.load_from_db( + db_session, self.db_chat_obj.engine.name + ) + self.db_chat_engine = self.engine_config.get_db_chat_engine() + except Exception as e: + logger.error(f"Failed to load chat engine config: {e}") + self.engine_config = ChatEngineConfig.load_from_db( + db_session, engine_name + ) + self.db_chat_engine = self.engine_config.get_db_chat_engine() + logger.info( + f"ChatService - chat_id: {chat_id}, chat_engine: {self.db_chat_obj.engine.name}" + ) + self.chat_history = [ + ChatMessage(role=m.role, content=m.content, additional_kwargs={}) + for m in chat_repo.get_messages(self.db_session, self.db_chat_obj) + ] + else: + self.engine_config = ChatEngineConfig.load_from_db(db_session, engine_name) + self.db_chat_engine = self.engine_config.get_db_chat_engine() + self.db_chat_obj = chat_repo.create( + self.db_session, + DBChat( + title=self.user_question[:100], + engine_id=self.db_chat_engine.id, + engine_options=self.engine_config.screenshot(), + user_id=self.user.id if self.user else None, + browser_id=self.browser_id, + origin=origin, + visibility=ChatVisibility.PUBLIC + if not self.user + else ChatVisibility.PRIVATE, + ), + ) + chat_id = self.db_chat_obj.id + # slack/discord may create a new chat with history messages + now = datetime.now(UTC) + for i, m in enumerate(self.chat_history): + chat_repo.create_message( + session=self.db_session, + chat=self.db_chat_obj, + chat_message=DBChatMessage( + role=m.role, + content=m.content, + ordinal=i + 1, + created_at=now, + updated_at=now, + finished_at=now, + ), + ) + + # Init Langfuse for tracing. + enable_langfuse = ( + SiteSetting.langfuse_secret_key and SiteSetting.langfuse_public_key + ) + instrumentor = LlamaIndexInstrumentor( + host=SiteSetting.langfuse_host, + secret_key=SiteSetting.langfuse_secret_key, + public_key=SiteSetting.langfuse_public_key, + enabled=enable_langfuse, + ) + self._trace_manager = LangfuseContextManager(instrumentor) + + # Init LLM. + self._llm = self.engine_config.get_llama_llm(self.db_session) + self._fast_llm = self.engine_config.get_fast_llama_llm(self.db_session) + self._fast_dspy_lm = self.engine_config.get_fast_dspy_lm(self.db_session) + + # Load knowledge bases. + self.knowledge_bases = self.engine_config.get_knowledge_bases(self.db_session) + self.knowledge_base_ids = [kb.id for kb in self.knowledge_bases] + + # Init retrieve flow. + self.retrieve_flow = RetrieveFlow( + db_session=self.db_session, + engine_name=self.engine_name, + engine_config=self.engine_config, + llm=self._llm, + fast_llm=self._fast_llm, + knowledge_bases=self.knowledge_bases, + ) + + def chat(self) -> Generator[ChatEvent | str, None, None]: + try: + with self._trace_manager.observe( + trace_name="ChatFlow", + user_id=self.user.email + if self.user + else f"anonymous-{self.browser_id}", + metadata={ + "is_external_engine": self.engine_config.is_external_engine, + "chat_engine_config": self.engine_config.screenshot(), + }, + tags=[f"chat_engine:{self.engine_name}"], + release=settings.ENVIRONMENT, + ) as trace: + trace.update( + input={ + "user_question": self.user_question, + "chat_history": self.chat_history, + } + ) + + if self.engine_config.is_external_engine: + yield from self._external_chat() + else: + response_text, source_documents = yield from self._builtin_chat() + trace.update(output=response_text) + except Exception as e: + logger.exception(e) + yield ChatEvent( + event_type=ChatEventType.ERROR_PART, + payload="Encountered an error while processing the chat. Please try again later.", + ) + + def _builtin_chat( + self, + ) -> Generator[ChatEvent | str, None, Tuple[Optional[str], List[Any]]]: + ctx = langfuse_instrumentor_context.get().copy() + db_user_message, db_assistant_message = yield from self._chat_start() + langfuse_instrumentor_context.get().update(ctx) + + # 1. Retrieve Knowledge graph related to the user question. + ( + knowledge_graph, + knowledge_graph_context, + ) = yield from self._search_knowledge_graph(user_question=self.user_question) + + # 2. Refine the user question using knowledge graph and chat history. + refined_question = yield from self._refine_user_question( + user_question=self.user_question, + chat_history=self.chat_history, + knowledge_graph_context=knowledge_graph_context, + refined_question_prompt=self.engine_config.llm.condense_question_prompt, + ) + + # 3. Check if the question provided enough context information or need to clarify. + if self.engine_config.clarify_question: + need_clarify, need_clarify_response = yield from self._clarify_question( + user_question=refined_question, + chat_history=self.chat_history, + knowledge_graph_context=knowledge_graph_context, + ) + if need_clarify: + yield from self._chat_finish( + db_assistant_message=db_assistant_message, + db_user_message=db_user_message, + response_text=need_clarify_response, + knowledge_graph=knowledge_graph, + ) + return None, [] + + # 4. Use refined question to search for relevant chunks. + relevant_chunks = yield from self._search_relevance_chunks( + user_question=refined_question + ) + + # 5. Generate a response using the refined question and related chunks + response_text, source_documents = yield from self._generate_answer( + user_question=refined_question, + knowledge_graph_context=knowledge_graph_context, + relevant_chunks=relevant_chunks, + ) + + yield from self._chat_finish( + db_assistant_message=db_assistant_message, + db_user_message=db_user_message, + response_text=response_text, + knowledge_graph=knowledge_graph, + source_documents=source_documents, + ) + + return response_text, source_documents + + def _chat_start( + self, + ) -> Generator[ChatEvent, None, Tuple[DBChatMessage, DBChatMessage]]: + db_user_message = chat_repo.create_message( + session=self.db_session, + chat=self.db_chat_obj, + chat_message=DBChatMessage( + role=MessageRole.USER.value, + trace_url=self._trace_manager.trace_url, + content=self.user_question, + ), + ) + db_assistant_message = chat_repo.create_message( + session=self.db_session, + chat=self.db_chat_obj, + chat_message=DBChatMessage( + role=MessageRole.ASSISTANT.value, + trace_url=self._trace_manager.trace_url, + content="", + ), + ) + yield ChatEvent( + event_type=ChatEventType.DATA_PART, + payload=ChatStreamDataPayload( + chat=self.db_chat_obj, + user_message=db_user_message, + assistant_message=db_assistant_message, + ), + ) + return db_user_message, db_assistant_message + + def _search_knowledge_graph( + self, + user_question: str, + annotation_silent: bool = False, + ) -> Generator[ChatEvent, None, Tuple[KnowledgeGraphRetrievalResult, str]]: + kg_config = self.engine_config.knowledge_graph + if kg_config is None or kg_config.enabled is False: + return KnowledgeGraphRetrievalResult(), "" + + with self._trace_manager.span( + name="search_knowledge_graph", input=user_question + ) as span: + if not annotation_silent: + if kg_config.using_intent_search: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.KG_RETRIEVAL, + display="Identifying The Question's Intents and Perform Knowledge Graph Search", + ), + ) + else: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.KG_RETRIEVAL, + display="Searching the Knowledge Graph for Relevant Context", + ), + ) + + knowledge_graph, knowledge_graph_context = ( + self.retrieve_flow.search_knowledge_graph(user_question) + ) + + span.end( + output={ + "knowledge_graph": knowledge_graph, + "knowledge_graph_context": knowledge_graph_context, + } + ) + + return knowledge_graph, knowledge_graph_context + + def _refine_user_question( + self, + user_question: str, + chat_history: Optional[List[ChatMessage]] = list, + refined_question_prompt: Optional[str] = None, + knowledge_graph_context: str = "", + annotation_silent: bool = False, + ) -> Generator[ChatEvent, None, str]: + with self._trace_manager.span( + name="refine_user_question", + input={ + "user_question": user_question, + "chat_history": chat_history, + "knowledge_graph_context": knowledge_graph_context, + }, + ) as span: + if not annotation_silent: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.REFINE_QUESTION, + display="Query Rewriting for Enhanced Information Retrieval", + ), + ) + + refined_question = self._fast_llm.predict( + get_prompt_by_jinja2_template( + refined_question_prompt, + graph_knowledges=knowledge_graph_context, + chat_history=chat_history, + question=user_question, + current_date=datetime.now().strftime("%Y-%m-%d"), + ), + ) + + if not annotation_silent: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.REFINE_QUESTION, + message=refined_question, + ), + ) + + span.end(output=refined_question) + + return refined_question + + def _clarify_question( + self, + user_question: str, + chat_history: Optional[List[ChatMessage]] = list, + knowledge_graph_context: str = "", + ) -> Generator[ChatEvent, None, Tuple[bool, str]]: + """ + Check if the question clear and provided enough context information, otherwise, it is necessary to + stop the conversation early and ask the user for the further clarification. + + Args: + user_question: str + knowledge_graph_context: str + + Returns: + bool: Determine whether further clarification of the issue is needed from the user. + str: The content of the questions that require clarification from the user. + """ + with self._trace_manager.span( + name="clarify_question", + input={ + "user_question": user_question, + "knowledge_graph_context": knowledge_graph_context, + }, + ) as span: + clarity_result = ( + self._fast_llm.predict( + prompt=get_prompt_by_jinja2_template( + self.engine_config.llm.clarifying_question_prompt, + graph_knowledges=knowledge_graph_context, + chat_history=chat_history, + question=user_question, + ), + ) + .strip() + .strip(".\"'!") + ) + + need_clarify = clarity_result.lower() != "false" + need_clarify_response = clarity_result if need_clarify else "" + + if need_clarify: + yield ChatEvent( + event_type=ChatEventType.TEXT_PART, + payload=need_clarify_response, + ) + + span.end( + output={ + "need_clarify": need_clarify, + "need_clarify_response": need_clarify_response, + } + ) + + return need_clarify, need_clarify_response + + def _search_relevance_chunks( + self, user_question: str + ) -> Generator[ChatEvent, None, List[NodeWithScore]]: + with self._trace_manager.span( + name="search_relevance_chunks", input=user_question + ) as span: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.SEARCH_RELATED_DOCUMENTS, + display="Retrieving the Most Relevant Documents", + ), + ) + + relevance_chunks = self.retrieve_flow.search_relevant_chunks(user_question) + + span.end( + output={ + "relevance_chunks": relevance_chunks, + } + ) + + return relevance_chunks + + def _generate_answer( + self, + user_question: str, + knowledge_graph_context: str, + relevant_chunks: List[NodeWithScore], + ) -> Generator[ChatEvent, None, Tuple[str, List[SourceDocument]]]: + with self._trace_manager.span( + name="generate_answer", input=user_question + ) as span: + # Initialize response synthesizer. + text_qa_template = get_prompt_by_jinja2_template( + self.engine_config.llm.text_qa_prompt, + current_date=datetime.now().strftime("%Y-%m-%d"), + graph_knowledges=knowledge_graph_context, + original_question=self.user_question, + ) + response_synthesizer = get_response_synthesizer( + llm=self._llm, text_qa_template=text_qa_template, streaming=True + ) + + # Initialize response. + response = response_synthesizer.synthesize( + query=user_question, + nodes=relevant_chunks, + ) + source_documents = self.retrieve_flow.get_source_documents_from_nodes( + response.source_nodes + ) + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.SOURCE_NODES, + context=source_documents, + ), + ) + + # Generate response. + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.GENERATE_ANSWER, + display="Generating a Precise Answer with AI", + ), + ) + response_text = "" + for word in response.response_gen: + response_text += word + yield ChatEvent( + event_type=ChatEventType.TEXT_PART, + payload=word, + ) + + if not response_text: + raise Exception("Got empty response from LLM") + + span.end( + output=response_text, + metadata={ + "source_documents": source_documents, + }, + ) + + return response_text, source_documents + + def _post_verification( + self, user_question: str, response_text: str, chat_id: UUID, message_id: int + ) -> Optional[str]: + # post verification to external service, will return the post verification result url + post_verification_url = self.engine_config.post_verification_url + post_verification_token = self.engine_config.post_verification_token + + if not post_verification_url: + return None + + external_request_id = f"{chat_id}_{message_id}" + qa_content = f"User question: {user_question}\n\nAnswer:\n{response_text}" + + with self._trace_manager.span( + name="post_verification", + input={ + "external_request_id": external_request_id, + "qa_content": qa_content, + }, + ) as span: + try: + resp = requests.post( + post_verification_url, + json={ + "external_request_id": external_request_id, + "qa_content": qa_content, + }, + headers={ + "Authorization": f"Bearer {post_verification_token}", + } + if post_verification_token + else {}, + timeout=10, + ) + resp.raise_for_status() + job_id = resp.json()["job_id"] + post_verification_link = urljoin( + f"{post_verification_url}/", str(job_id) + ) + + span.end( + output={ + "post_verification_link": post_verification_link, + } + ) + + return post_verification_link + except Exception as e: + logger.exception("Failed to post verification: %s", e.message) + return None + + def _chat_finish( + self, + db_assistant_message: ChatMessage, + db_user_message: ChatMessage, + response_text: str, + knowledge_graph: KnowledgeGraphRetrievalResult = KnowledgeGraphRetrievalResult(), + source_documents: Optional[List[SourceDocument]] = list, + annotation_silent: bool = False, + ): + if not annotation_silent: + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.FINISHED, + ), + ) + + post_verification_result_url = self._post_verification( + self.user_question, + response_text, + self.db_chat_obj.id, + db_assistant_message.id, + ) + + db_assistant_message.sources = [s.model_dump() for s in source_documents] + db_assistant_message.graph_data = knowledge_graph.to_stored_graph_dict() + db_assistant_message.content = response_text + db_assistant_message.post_verification_result_url = post_verification_result_url + db_assistant_message.updated_at = datetime.now(UTC) + db_assistant_message.finished_at = datetime.now(UTC) + self.db_session.add(db_assistant_message) + + db_user_message.graph_data = knowledge_graph.to_stored_graph_dict() + db_user_message.updated_at = datetime.now(UTC) + db_user_message.finished_at = datetime.now(UTC) + self.db_session.add(db_user_message) + self.db_session.commit() + + yield ChatEvent( + event_type=ChatEventType.DATA_PART, + payload=ChatStreamDataPayload( + chat=self.db_chat_obj, + user_message=db_user_message, + assistant_message=db_assistant_message, + ), + ) + + # TODO: Separate _external_chat() method into another ExternalChatFlow class, but at the same time, we need to + # share some common methods through ChatMixin or BaseChatFlow. + def _external_chat(self) -> Generator[ChatEvent | str, None, None]: + db_user_message, db_assistant_message = yield from self._chat_start() + + goal, response_format = self.user_question, {} + try: + # 1. Generate the goal with the user question, knowledge graph and chat history. + goal, response_format = yield from self._generate_goal() + + # 2. Check if the goal provided enough context information or need to clarify. + if self.engine_config.clarify_question: + need_clarify, need_clarify_response = yield from self._clarify_question( + user_question=goal, chat_history=self.chat_history + ) + if need_clarify: + yield from self._chat_finish( + db_assistant_message=db_assistant_message, + db_user_message=db_user_message, + response_text=need_clarify_response, + annotation_silent=True, + ) + return + except Exception as e: + goal = self.user_question + logger.warning( + f"Failed to generate refined goal, fallback to use user question as goal directly: {e}", + exc_info=True, + extra={}, + ) + + cache_messages = None + if settings.ENABLE_QUESTION_CACHE: + try: + logger.info( + f"start to find_recent_assistant_messages_by_goal with goal: {goal}, response_format: {response_format}" + ) + cache_messages = chat_repo.find_recent_assistant_messages_by_goal( + self.db_session, + {"goal": goal, "Lang": response_format.get("Lang", "English")}, + 90, + ) + logger.info( + f"find_recent_assistant_messages_by_goal result {len(cache_messages)} for goal {goal}" + ) + except Exception as e: + logger.error(f"Failed to find recent assistant messages by goal: {e}") + + stream_chat_api_url = ( + self.engine_config.external_engine_config.stream_chat_api_url + ) + if cache_messages and len(cache_messages) > 0: + stackvm_response_text = cache_messages[0].content + task_id = cache_messages[0].meta.get("task_id") + for chunk in stackvm_response_text.split(". "): + if chunk: + if not chunk.endswith("."): + chunk += ". " + yield ChatEvent( + event_type=ChatEventType.TEXT_PART, + payload=chunk, + ) + else: + logger.debug( + f"Chatting with external chat engine (api_url: {stream_chat_api_url}) to answer for user question: {self.user_question}" + ) + chat_params = { + "goal": goal, + "response_format": response_format, + "namespace_name": "Default", + } + res = requests.post(stream_chat_api_url, json=chat_params, stream=True) + + # Notice: External type chat engine doesn't support non-streaming mode for now. + stackvm_response_text = "" + task_id = None + for line in res.iter_lines(): + if not line: + continue + + # Append to final response text. + chunk = line.decode("utf-8") + if chunk.startswith("0:"): + word = json.loads(chunk[2:]) + stackvm_response_text += word + yield ChatEvent( + event_type=ChatEventType.TEXT_PART, + payload=word, + ) + else: + yield line + b"\n" + + try: + if chunk.startswith("8:") and task_id is None: + states = json.loads(chunk[2:]) + if len(states) > 0: + # accesss task by http://endpoint/?task_id=$task_id + task_id = states[0].get("task_id") + except Exception as e: + logger.error(f"Failed to get task_id from chunk: {e}") + + response_text = stackvm_response_text + base_url = stream_chat_api_url.replace("/api/stream_execute_vm", "") + try: + post_verification_result_url = self._post_verification( + goal, + response_text, + self.db_chat_obj.id, + db_assistant_message.id, + ) + db_assistant_message.post_verification_result_url = ( + post_verification_result_url + ) + except Exception: + logger.error( + "Specific error occurred during post verification job.", exc_info=True + ) + + trace_url = f"{base_url}?task_id={task_id}" if task_id else "" + message_meta = { + "task_id": task_id, + "goal": goal, + **response_format, + } + + db_assistant_message.content = response_text + db_assistant_message.trace_url = trace_url + db_assistant_message.meta = message_meta + db_assistant_message.updated_at = datetime.now(UTC) + db_assistant_message.finished_at = datetime.now(UTC) + self.db_session.add(db_assistant_message) + + db_user_message.trace_url = trace_url + db_user_message.meta = message_meta + db_user_message.updated_at = datetime.now(UTC) + db_user_message.finished_at = datetime.now(UTC) + self.db_session.add(db_user_message) + self.db_session.commit() + + yield ChatEvent( + event_type=ChatEventType.DATA_PART, + payload=ChatStreamDataPayload( + chat=self.db_chat_obj, + user_message=db_user_message, + assistant_message=db_assistant_message, + ), + ) + + def _generate_goal(self) -> Generator[ChatEvent, None, Tuple[str, dict]]: + try: + refined_question = yield from self._refine_user_question( + user_question=self.user_question, + chat_history=self.chat_history, + refined_question_prompt=self.engine_config.llm.generate_goal_prompt, + annotation_silent=True, + ) + + goal = refined_question.strip() + if goal.startswith("Goal: "): + goal = goal[len("Goal: ") :].strip() + except Exception as e: + logger.error(f"Failed to refine question with related knowledge graph: {e}") + goal = self.user_question + + response_format = {} + try: + clean_goal, response_format = parse_goal_response_format(goal) + logger.info(f"clean goal: {clean_goal}, response_format: {response_format}") + if clean_goal: + goal = clean_goal + except Exception as e: + logger.error(f"Failed to parse goal and response format: {e}") + + return goal, response_format diff --git a/backend/app/rag/chat/chat_service.py b/backend/app/rag/chat/chat_service.py new file mode 100644 index 000000000..e2a1b045d --- /dev/null +++ b/backend/app/rag/chat/chat_service.py @@ -0,0 +1,327 @@ +from http import HTTPStatus +import logging + +from typing import Generator, List, Optional +from fastapi import HTTPException +from pydantic import BaseModel +from sqlalchemy import text, delete +from sqlmodel import Session, select, func + +from app.api.routes.models import ( + RequiredConfigStatus, + OptionalConfigStatus, + NeedMigrationStatus, +) +from app.models import ( + User, + ChatVisibility, + Chat as DBChat, + ChatMessage as DBChatMessage, + KnowledgeBase as DBKnowledgeBase, + RerankerModel as DBRerankerModel, + ChatEngine, +) +from app.models.recommend_question import RecommendQuestion +from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow +from app.rag.chat.stream_protocol import ChatEvent +from app.rag.retrievers.knowledge_graph.schema import ( + KnowledgeGraphRetrievalResult, + RetrievedEntity, + StoredKnowledgeGraph, + RetrievedSubGraph, +) +from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store +from app.repositories import chat_engine_repo, knowledge_base_repo + +from app.rag.chat.config import ( + ChatEngineConfig, +) +from app.rag.types import ( + ChatEventType, + ChatMessageSate, +) +from app.repositories import chat_engine_repo +from app.repositories.embedding_model import embed_model_repo +from app.repositories.llm import llm_repo +from app.site_settings import SiteSetting +from app.utils.jinja2 import get_prompt_by_jinja2_template + +logger = logging.getLogger(__name__) + + +class ChatResult(BaseModel): + chat_id: int + message_id: int + trace: str + sources: List[RetrievedEntity] + content: str + + +def get_final_chat_result( + generator: Generator[ChatEvent | str, None, None], +) -> ChatResult: + trace, sources, content = None, [], "" + chat_id, message_id = None, None + for m in generator: + if not isinstance(m, ChatEvent): + continue + if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART: + if m.payload.state == ChatMessageSate.SOURCE_NODES: + sources = m.payload.context + elif m.payload.state == ChatMessageSate.TRACE: + trace = m.payload.context + elif m.event_type == ChatEventType.TEXT_PART: + content += m.payload + elif m.event_type == ChatEventType.DATA_PART: + chat_id = m.payload.chat.id + message_id = m.payload.assistant_message.id + elif m.event_type == ChatEventType.ERROR_PART: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=m.payload, + ) + else: + pass + return ChatResult( + chat_id=chat_id, + message_id=message_id, + trace=trace, + sources=sources, + content=content, + ) + + +def user_can_view_chat(chat: DBChat, user: Optional[User]) -> bool: + # Anonymous or public chat can be accessed by anyone + # Non-anonymous chat can be accessed by owner or superuser + if not chat.user_id or chat.visibility == ChatVisibility.PUBLIC: + return True + return user is not None and (user.is_superuser or chat.user_id == user.id) + + +def user_can_edit_chat(chat: DBChat, user: Optional[User]) -> bool: + if user is None: + return False + if user.is_superuser: + return True + return chat.user_id == user.id + + +def get_graph_data_from_chat_message( + db_session: Session, + chat_message: DBChatMessage, + engine_config: ChatEngineConfig, +) -> Optional[KnowledgeGraphRetrievalResult]: + if not chat_message.graph_data: + return None + + graph_data = chat_message.graph_data + + # For forward compatibility. + if "version" not in graph_data: + kb = engine_config.get_knowledge_bases(db_session)[0] + graph_store = get_kb_tidb_graph_store(db_session, kb) + return graph_store.get_subgraph_by_relationship_ids(graph_data["relationships"]) + + # Stored Knowledge Graph -> Retrieved Knowledge Graph + stored_kg = StoredKnowledgeGraph.model_validate(graph_data) + if stored_kg.knowledge_base_id is not None: + kb = knowledge_base_repo.must_get(db_session, stored_kg.knowledge_base_id) + graph_store = get_kb_tidb_graph_store(db_session, kb) + retrieved_kg = graph_store.get_subgraph_by_relationship_ids( + ids=stored_kg.relationships, query=stored_kg.query + ) + return retrieved_kg + elif stored_kg.knowledge_base_ids is not None: + kg_store_map = {} + knowledge_base_set = set() + relationship_set = set() + entity_set = set() + subgraphs = [] + + for kb_id in stored_kg.knowledge_base_ids: + kb = knowledge_base_repo.must_get(db_session, kb_id) + knowledge_base_set.add(kb.to_descriptor()) + kg_store = get_kb_tidb_graph_store(db_session, kb) + kg_store_map[kb_id] = kg_store + + for stored_subgraph in stored_kg.subgraphs: + kg_store = kg_store_map.get(stored_subgraph.knowledge_base_id) + if kg_store is None: + continue + relationship_ids = stored_subgraph.relationships + subgraph = kg_store.get_subgraph_by_relationship_ids( + ids=relationship_ids, + query=stored_kg.query, + ) + relationship_set.update(subgraph.relationships) + entity_set.update(subgraph.entities) + subgraphs.append( + RetrievedSubGraph( + **subgraph.model_dump(), + ) + ) + + return KnowledgeGraphRetrievalResult( + query=stored_kg.query, + knowledge_bases=list(knowledge_base_set), + relationships=list(relationship_set), + entities=list(entity_set), + subgraphs=subgraphs, + ) + else: + return None + + +def get_chat_message_subgraph( + db_session: Session, chat_message: DBChatMessage +) -> KnowledgeGraphRetrievalResult: + chat_engine: ChatEngine = chat_message.chat.engine + engine_name = chat_engine.name + engine_config = ChatEngineConfig.load_from_db(db_session, chat_engine.name) + + # Try to get subgraph from `chat_message.graph_data`. + try: + knowledge_graph = get_graph_data_from_chat_message( + db_session, chat_message, engine_config + ) + if knowledge_graph is not None: + return knowledge_graph + except Exception as e: + logger.error( + f"Failed to get subgraph from chat_message.graph_data: {e}", exc_info=True + ) + + # Try to get subgraph based on the chat message content. + # Notice: it use current chat engine config, not the snapshot stored in chat_message. + retriever = RetrieveFlow( + db_session=db_session, + engine_name=engine_name, + engine_config=engine_config, + ) + knowledge_graph, _ = retriever.search_knowledge_graph(chat_message.content) + return knowledge_graph + + +def check_rag_required_config(session: Session) -> RequiredConfigStatus: + """ + Check if the required configuration items have been configured, it any of them is + missing, the RAG application can not complete its work. + """ + has_default_llm = llm_repo.has_default(session) + has_default_embedding_model = embed_model_repo.has_default(session) + has_default_chat_engine = chat_engine_repo.has_default(session) + has_knowledge_base = session.scalar(select(func.count(DBKnowledgeBase.id))) > 0 + + return RequiredConfigStatus( + default_llm=has_default_llm, + default_embedding_model=has_default_embedding_model, + default_chat_engine=has_default_chat_engine, + knowledge_base=has_knowledge_base, + ) + + +def check_rag_optional_config(session: Session) -> OptionalConfigStatus: + langfuse = bool( + SiteSetting.langfuse_host + and SiteSetting.langfuse_secret_key + and SiteSetting.langfuse_public_key + ) + default_reranker = session.scalar(select(func.count(DBRerankerModel.id))) > 0 + return OptionalConfigStatus( + langfuse=langfuse, + default_reranker=default_reranker, + ) + + +def check_rag_config_need_migration(session: Session) -> NeedMigrationStatus: + """ + Check if any configuration needs to be migrated. + """ + chat_engines_without_kb_configured = session.exec( + select(ChatEngine.id) + .where(ChatEngine.deleted_at == None) + .where( + text( + "JSON_EXTRACT(engine_options, '$.knowledge_base.linked_knowledge_base') IS NULL" + ) + ) + ) + + return NeedMigrationStatus( + chat_engines_without_kb_configured=chat_engines_without_kb_configured, + ) + + +def remove_chat_message_recommend_questions( + db_session: Session, + chat_message_id: int, +) -> None: + delete_stmt = delete(RecommendQuestion).where( + RecommendQuestion.chat_message_id == chat_message_id + ) + db_session.exec(delete_stmt) + db_session.commit() + + +def get_chat_message_recommend_questions( + db_session: Session, + chat_message: DBChatMessage, + engine_name: str = "default", +) -> List[str]: + chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name) + llm = chat_engine_config.get_llama_llm(db_session) + + statement = ( + select(RecommendQuestion.questions) + .where(RecommendQuestion.chat_message_id == chat_message.id) + .with_for_update() # using write lock in case the same chat message trigger multiple requests + ) + + questions = db_session.exec(statement).first() + if questions is not None: + return questions + + recommend_questions = llm.predict( + prompt=get_prompt_by_jinja2_template( + chat_engine_config.llm.further_questions_prompt, + chat_message_content=chat_message.content, + ), + ) + recommend_question_list = recommend_questions.splitlines() + recommend_question_list = [ + question.strip() for question in recommend_question_list if question.strip() + ] + + longest_question = 0 + for question in recommend_question_list: + longest_question = max(longest_question, len(question)) + + # check the output by if the output with format and the length + if ( + "##" in recommend_questions + or "**" in recommend_questions + or longest_question > 500 + ): + regenerate_content = f""" + Please note that you are generating a question list. You previously generated it incorrectly; try again. + ---------------------------------------- + {chat_message.content} + """ + # with format or too long for per question, it's not a question list, generate again + recommend_questions = llm.predict( + prompt=get_prompt_by_jinja2_template( + chat_engine_config.llm.further_questions_prompt, + chat_message_content=regenerate_content, + ), + ) + + db_session.add( + RecommendQuestion( + chat_message_id=chat_message.id, + questions=recommend_question_list, + ) + ) + db_session.commit() + + return recommend_question_list diff --git a/backend/app/rag/chat_config.py b/backend/app/rag/chat/config.py similarity index 79% rename from backend/app/rag/chat_config.py rename to backend/app/rag/chat/config.py index ebf2f8a81..3d7ab451c 100644 --- a/backend/app/rag/chat_config.py +++ b/backend/app/rag/chat/config.py @@ -1,17 +1,19 @@ import logging import dspy -from typing import Optional, List, Mapping, Any +from typing import Optional, List from pydantic import BaseModel, Field from sqlmodel import Session from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.llms.llm import LLM +from app.rag.postprocessors.metadata_post_filter import MetadataPostFilter +from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig +from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrieverConfig from app.utils.dspy import get_dspy_lm_by_llama_llm from app.rag.llms.resolver import get_default_llm, resolve_llm from app.rag.rerankers.resolver import get_default_reranker_model, resolve_reranker -from app.rag.postprocessors import get_metadata_post_filter, MetadataFilters from app.models import ( LLM as DBLLM, @@ -47,19 +49,13 @@ class LLMOption(BaseModel): generate_goal_prompt: str = DEFAULT_GENERATE_GOAL_PROMPT analyze_question_type_prompt: str = DEFAULT_ANALYZE_QUESTION_TYPE_PROMPT -class VectorSearchOption(BaseModel): - metadata_post_filters: Optional[MetadataFilters] = None +class VectorSearchOption(VectorSearchRetrieverConfig): + pass -class KnowledgeGraphOption(BaseModel): +class KnowledgeGraphOption(KnowledgeGraphRetrieverConfig): enabled: bool = True - depth: int = 2 - include_meta: bool = True - with_degree: bool = False using_intent_search: bool = True - enable_metadata_filter: bool = False - metadata_filters: Optional[Mapping[str, Any]] = None - relationship_meta_filters: Optional[dict] = None # To be deprecated. class ExternalChatEngine(BaseModel): @@ -72,31 +68,40 @@ class LinkedKnowledgeBase(BaseModel): class KnowledgeBaseOption(BaseModel): - linked_knowledge_base: LinkedKnowledgeBase + linked_knowledge_base: LinkedKnowledgeBase = None linked_knowledge_bases: Optional[List[LinkedKnowledgeBase]] = Field( default_factory=list ) class ChatEngineConfig(BaseModel): + external_engine_config: Optional[ExternalChatEngine] = None + llm: LLMOption = LLMOption() - # Notice: Currently knowledge base option is optional, if it is not configured, it will use - # the deprecated chunks / relationships / entities table as the data source. - knowledge_base: Optional[KnowledgeBaseOption] = None + + knowledge_base: KnowledgeBaseOption = KnowledgeBaseOption() knowledge_graph: KnowledgeGraphOption = KnowledgeGraphOption() vector_search: VectorSearchOption = VectorSearchOption() + refine_question_with_kg: bool = True + clarify_question: bool = False + post_verification_url: Optional[str] = None post_verification_token: Optional[str] = None - external_engine_config: Optional[ExternalChatEngine] = None hide_sources: bool = False - clarify_question: bool = False _db_chat_engine: Optional[DBChatEngine] = None _db_llm: Optional[DBLLM] = None _db_fast_llm: Optional[DBLLM] = None _db_reranker: Optional[DBRerankerModel] = None + @property + def is_external_engine(self) -> bool: + return ( + self.external_engine_config is not None + and self.external_engine_config.stream_chat_api_url + ) + def get_db_chat_engine(self) -> Optional[DBChatEngine]: return self._db_chat_engine @@ -172,7 +177,23 @@ def get_reranker( ) def get_metadata_filter(self) -> BaseNodePostprocessor: - return get_metadata_post_filter(self.vector_search.metadata_post_filters) + return MetadataPostFilter(self.vector_search.metadata_filters) + + def get_knowledge_bases(self, db_session: Session) -> List[KnowledgeBase]: + if not self.knowledge_base: + return [] + kb_config: KnowledgeBaseOption = self.knowledge_base + linked_knowledge_base_ids = [] + if len(kb_config.linked_knowledge_bases) == 0: + linked_knowledge_base_ids.append(kb_config.linked_knowledge_base.id) + else: + linked_knowledge_base_ids.extend( + [kb.id for kb in kb_config.linked_knowledge_bases] + ) + knowledge_bases = knowledge_base_repo.get_by_ids( + db_session, knowledge_base_ids=linked_knowledge_base_ids + ) + return knowledge_bases def screenshot(self) -> dict: return self.model_dump( diff --git a/backend/app/rag/chat/retrieve/retrieve_flow.py b/backend/app/rag/chat/retrieve/retrieve_flow.py new file mode 100644 index 000000000..1e9b929c9 --- /dev/null +++ b/backend/app/rag/chat/retrieve/retrieve_flow.py @@ -0,0 +1,162 @@ +import logging +from datetime import datetime +from typing import List, Optional, Tuple + +from llama_index.core.instrumentation import get_dispatcher +from llama_index.core.llms import LLM +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import BaseModel +from sqlmodel import Session + +from app.models import ( + Document as DBDocument, + KnowledgeBase, +) +from app.utils.jinja2 import get_prompt_by_jinja2_template +from app.rag.chat.config import ChatEngineConfig +from app.rag.retrievers.knowledge_graph.fusion_retriever import ( + KnowledgeGraphFusionRetriever, +) +from app.rag.retrievers.knowledge_graph.schema import ( + KnowledgeGraphRetrievalResult, + KnowledgeGraphRetrieverConfig, +) +from app.rag.retrievers.chunk.fusion_retriever import ChunkFusionRetriever +from app.repositories import document_repo + +dispatcher = get_dispatcher(__name__) +logger = logging.getLogger(__name__) + + +class SourceDocument(BaseModel): + id: int + name: str + source_uri: Optional[str] = None + + +class RetrieveFlow: + def __init__( + self, + db_session: Session, + engine_name: str = "default", + engine_config: Optional[ChatEngineConfig] = None, + llm: Optional[LLM] = None, + fast_llm: Optional[LLM] = None, + knowledge_bases: Optional[List[KnowledgeBase]] = None, + ): + self.db_session = db_session + self.engine_name = engine_name + self.engine_config = engine_config or ChatEngineConfig.load_from_db( + db_session, engine_name + ) + self.db_chat_engine = self.engine_config.get_db_chat_engine() + + # Init LLM. + self._llm = llm or self.engine_config.get_llama_llm(self.db_session) + self._fast_llm = fast_llm or self.engine_config.get_fast_llama_llm( + self.db_session + ) + + # Load knowledge bases. + self.knowledge_bases = ( + knowledge_bases or self.engine_config.get_knowledge_bases(self.db_session) + ) + self.knowledge_base_ids = [kb.id for kb in self.knowledge_bases] + + def retrieve(self, user_question: str) -> List[NodeWithScore]: + if self.engine_config.refine_question_with_kg: + # 1. Retrieve Knowledge graph related to the user question. + _, knowledge_graph_context = self.search_knowledge_graph(user_question) + + # 2. Refine the user question using knowledge graph and chat history. + self._refine_user_question(user_question, knowledge_graph_context) + + # 3. Search relevant chunks based on the user question. + return self.search_relevant_chunks(user_question=user_question) + + def retrieve_documents(self, user_question: str) -> List[DBDocument]: + nodes = self.retrieve(user_question) + return self.get_documents_from_nodes(nodes) + + def search_knowledge_graph( + self, user_question: str + ) -> Tuple[KnowledgeGraphRetrievalResult, str]: + kg_config = self.engine_config.knowledge_graph + knowledge_graph = KnowledgeGraphRetrievalResult() + knowledge_graph_context = "" + if kg_config is not None and kg_config.enabled: + kg_retriever = KnowledgeGraphFusionRetriever( + db_session=self.db_session, + knowledge_base_ids=[kb.id for kb in self.knowledge_bases], + llm=self._llm, + use_query_decompose=kg_config.using_intent_search, + use_async=True, + config=KnowledgeGraphRetrieverConfig.model_validate( + kg_config.model_dump(exclude={"enabled", "using_intent_search"}) + ), + ) + knowledge_graph = kg_retriever.retrieve_knowledge_graph(user_question) + knowledge_graph_context = self._get_knowledge_graph_context(knowledge_graph) + return knowledge_graph, knowledge_graph_context + + def _get_knowledge_graph_context( + self, knowledge_graph: KnowledgeGraphRetrievalResult + ) -> str: + if self.engine_config.knowledge_graph.using_intent_search: + kg_context_template = get_prompt_by_jinja2_template( + self.engine_config.llm.intent_graph_knowledge, + # For forward compatibility considerations. + sub_queries=knowledge_graph.to_subqueries_dict(), + ) + return kg_context_template.template + else: + kg_context_template = get_prompt_by_jinja2_template( + self.engine_config.llm.normal_graph_knowledge, + entities=knowledge_graph.entities, + relationships=knowledge_graph.relationships, + ) + return kg_context_template.template + + def _refine_user_question( + self, user_question: str, knowledge_graph_context: str + ) -> str: + return self._fast_llm.predict( + get_prompt_by_jinja2_template( + self.engine_config.llm.condense_question_prompt, + graph_knowledges=knowledge_graph_context, + question=user_question, + current_date=datetime.now().strftime("%Y-%m-%d"), + ), + ) + + def search_relevant_chunks(self, user_question: str) -> List[NodeWithScore]: + retriever = ChunkFusionRetriever( + db_session=self.db_session, + knowledge_base_ids=self.knowledge_base_ids, + llm=self._llm, + config=self.engine_config.vector_search, + use_query_decompose=False, + use_async=True, + ) + return retriever.retrieve(QueryBundle(user_question)) + + def get_documents_from_nodes(self, nodes: List[NodeWithScore]) -> List[DBDocument]: + document_ids = [n.node.metadata["document_id"] for n in nodes] + documents = document_repo.list_full_documents_by_ids( + self.db_session, document_ids + ) + # Keep the original order of document ids, which is sorted by similarity. + return sorted(documents, key=lambda x: document_ids.index(x.id)) + + def get_source_documents_from_nodes( + self, nodes: List[NodeWithScore] + ) -> List[SourceDocument]: + documents = self.get_documents_from_nodes(nodes) + return [ + SourceDocument( + id=doc.id, + name=doc.name, + source_uri=doc.source_uri, + ) + for doc in documents + ] diff --git a/backend/app/rag/chat_stream_protocol.py b/backend/app/rag/chat/stream_protocol.py similarity index 78% rename from backend/app/rag/chat_stream_protocol.py rename to backend/app/rag/chat/stream_protocol.py index 47431718e..0ac982baf 100644 --- a/backend/app/rag/chat_stream_protocol.py +++ b/backend/app/rag/chat/stream_protocol.py @@ -1,6 +1,8 @@ import json from dataclasses import dataclass +from pydantic import BaseModel + from app.models import ChatMessage, Chat from app.rag.types import ChatEventType, ChatMessageSate @@ -30,15 +32,22 @@ def dump(self): class ChatStreamMessagePayload(ChatStreamPayload): state: ChatMessageSate = ChatMessageSate.TRACE display: str = "" - context: dict | list | str = "" + context: dict | list | str | BaseModel | None = None message: str = "" def dump(self): + if isinstance(self.context, list): + context = [c.model_dump() for c in self.context] + elif isinstance(self.context, BaseModel): + context = self.context.model_dump() + else: + context = self.context + return [ { "state": self.state.name, "display": self.display, - "context": self.context, + "context": context, "message": self.message, } ] diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py b/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py index 826afb69d..9fcdea577 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py @@ -1,10 +1,9 @@ from .tidb_graph_store import TiDBGraphStore -from .tidb_graph_editor import TiDBGraphEditor, legacy_tidb_graph_editor +from .tidb_graph_editor import TiDBGraphEditor from .tidb_graph_store import KnowledgeGraphStore __all__ = [ "TiDBGraphStore", "TiDBGraphEditor", - "legacy_tidb_graph_editor", "KnowledgeGraphStore", ] diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py index 45a343df1..784971941 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py @@ -6,7 +6,7 @@ from sqlmodel import Session, select, SQLModel from sqlalchemy.orm import joinedload from sqlalchemy.orm.attributes import flag_modified -from app.models import Relationship, Entity, EntityType +from app.models import EntityType from app.rag.indices.knowledge_graph.schema import Relationship as RelationshipAIModel from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore from app.rag.indices.knowledge_graph.graph_store.helpers import ( @@ -26,11 +26,13 @@ class TiDBGraphEditor: def __init__( self, + knowledge_base_id: int, entity_db_model: Type[SQLModel], relationship_db_model: Type[SQLModel], embed_model: Optional[EmbedType] = None, graph_type: GraphType = GraphType.general, ): + self.knowledge_base_id = knowledge_base_id self._entity_db_model = entity_db_model self._relationship_db_model = relationship_db_model self._graph_type = graph_type @@ -120,32 +122,6 @@ def get_relationship( ) -> Optional[SQLModel]: return session.get(self._relationship_db_model, relationship_id) - def get_relationship_by_ids( - self, session: Session, ids: list[int] - ) -> Tuple[List[SQLModel], List[SQLModel]]: - stmt = ( - select(self._relationship_db_model) - .where(self._relationship_db_model.id.in_(ids)) - .options( - joinedload(self._relationship_db_model.source_entity), - joinedload(self._relationship_db_model.target_entity), - ) - ) - relationships_queryset = session.exec(stmt) - - relationships = [] - entities = [] - entities_set = set() - for relationship in relationships_queryset: - entities_set.add(relationship.source_entity) - entities_set.add(relationship.target_entity) - relationships.append(relationship) - - for entity in entities_set: - entities.append(entity) - - return entities, relationships - def update_relationship( self, session: Session, relationship: SQLModel, new_relationship: dict ) -> SQLModel: @@ -213,6 +189,7 @@ def create_synopsis_entity( ) session.add(synopsis_entity) graph_store = TiDBGraphStore( + knowledge_base=self.knowledge_base_id, dspy_lm=None, session=session, embed_model=self._embed_model, @@ -247,6 +224,3 @@ def create_synopsis_entity( commit=False, ) return synopsis_entity - - -legacy_tidb_graph_editor = TiDBGraphEditor(Entity, Relationship) diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py index e9eb644a2..b931d9b04 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py @@ -35,11 +35,13 @@ from app.rag.retrievers.knowledge_graph.schema import ( RetrievedEntity, RetrievedRelationship, + RetrievedKnowledgeGraph, ) from app.models import ( Entity as DBEntity, Relationship as DBRelationship, Chunk as DBChunk, + KnowledgeBase, ) from app.models import EntityType from app.models.enums import GraphType @@ -84,6 +86,7 @@ def forward(self, entities: List[Entity]): class TiDBGraphStore(KnowledgeGraphStore): def __init__( self, + knowledge_base: KnowledgeBase, dspy_lm: dspy.LM, session: Optional[Session] = None, embed_model: Optional[EmbedType] = None, @@ -93,6 +96,7 @@ def __init__( chunk_db_model: Type[SQLModel] = DBChunk, graph_type: GraphType = GraphType.general, ): + self.knowledge_base = knowledge_base self._session = session self._owns_session = session is None if self._session is None: @@ -316,6 +320,58 @@ def create_relationship( else: self._session.flush() + def get_subgraph_by_relationship_ids( + self, ids: list[int], **kwargs + ) -> RetrievedKnowledgeGraph: + stmt = ( + select(self._relationship_model) + .where(self._relationship_model.id.in_(ids)) + .options( + joinedload(self._relationship_model.source_entity), + joinedload(self._relationship_model.target_entity), + ) + ) + relationships_set = self._session.exec(stmt) + entities_set = set() + relationships = [] + entities = [] + + for rel in relationships_set: + entities_set.add(rel.source_entity) + entities_set.add(rel.target_entity) + relationships.append( + RetrievedRelationship( + id=rel.id, + knowledge_base_id=self.knowledge_base.id, + source_entity_id=rel.source_entity_id, + target_entity_id=rel.target_entity_id, + description=rel.description, + rag_description=f"{rel.source_entity.name} -> {rel.description} -> {rel.target_entity.name}", + meta=rel.meta, + weight=rel.weight, + last_modified_at=rel.last_modified_at, + ) + ) + + for entity in entities_set: + entities.append( + RetrievedEntity( + id=entity.id, + knowledge_base_id=self.knowledge_base.id, + name=entity.name, + description=entity.description, + meta=entity.meta, + entity_type=entity.entity_type, + ) + ) + + return RetrievedKnowledgeGraph( + knowledge_base=self.knowledge_base.to_descriptor(), + entities=entities, + relationships=relationships, + **kwargs, + ) + def get_or_create_entity(self, entity: Entity, commit: bool = True) -> SQLModel: # using the cosine distance between the description vectors to determine if the entity already exists entity_type = ( @@ -516,6 +572,7 @@ def retrieve_with_weight( entities = [ RetrievedEntity( id=e.id, + knowledge_base_id=self.knowledge_base.id, name=e.name, description=e.description, meta=e.meta if include_meta else None, @@ -526,10 +583,11 @@ def retrieve_with_weight( relationships = [ RetrievedRelationship( id=r.id, + knowledge_base_id=self.knowledge_base.id, source_entity_id=r.source_entity_id, target_entity_id=r.target_entity_id, - description=r.description, rag_description=f"{r.source_entity.name} -> {r.description} -> {r.target_entity.name}", + description=r.description, meta=r.meta, weight=r.weight, last_modified_at=r.last_modified_at, diff --git a/backend/app/rag/indices/vector_search/vector_store/tidb_vector_store.py b/backend/app/rag/indices/vector_search/vector_store/tidb_vector_store.py index 9c2b4bd9a..68612d97b 100644 --- a/backend/app/rag/indices/vector_search/vector_store/tidb_vector_store.py +++ b/backend/app/rag/indices/vector_search/vector_store/tidb_vector_store.py @@ -31,15 +31,15 @@ def node_to_relation_dict(node: BaseNode) -> dict: - relations = {} + relationships = {} for r_type, node_info in node.relationships.items(): - relations[r_type.name] = { + relationships[r_type.name] = { "node_id": node_info.node_id, "node_type": node_info.node_type.name, "meta": node_info.metadata, "hash": node_info.hash, } - return relations + return relationships class TiDBVectorStore(BasePydanticVectorStore): diff --git a/backend/app/rag/knowledge_base/index_store.py b/backend/app/rag/knowledge_base/index_store.py index 6f9697908..68fa0e152 100644 --- a/backend/app/rag/knowledge_base/index_store.py +++ b/backend/app/rag/knowledge_base/index_store.py @@ -31,6 +31,7 @@ def get_kb_tidb_graph_store(session: Session, kb: KnowledgeBase, graph_type: Gra chunk_model = get_kb_chunk_model(kb) graph_store = TiDBGraphStore( + knowledge_base=kb, dspy_lm=dspy_lm, session=session, embed_model=embed_model, @@ -52,4 +53,9 @@ def get_kb_tidb_graph_editor(session: Session, kb: KnowledgeBase) -> TiDBGraphEd entity_db_model = get_kb_entity_model(kb) relationship_db_model = get_kb_relationship_model(kb) embed_model = get_kb_embed_model(session, kb) - return TiDBGraphEditor(entity_db_model, relationship_db_model, embed_model) + return TiDBGraphEditor( + knowledge_base_id=kb.id, + entity_db_model=entity_db_model, + relationship_db_model=relationship_db_model, + embed_model=embed_model, + ) diff --git a/backend/app/rag/postprocessors/__init__.py b/backend/app/rag/postprocessors/__init__.py index c23d2f333..c0eed6a84 100644 --- a/backend/app/rag/postprocessors/__init__.py +++ b/backend/app/rag/postprocessors/__init__.py @@ -1,8 +1,6 @@ from .metadata_post_filter import MetadataPostFilter, MetadataFilters -from .resolver import get_metadata_post_filter __all__ = [ "MetadataPostFilter", "MetadataFilters", - "get_metadata_post_filter", ] diff --git a/backend/app/rag/postprocessors/metadata_post_filter.py b/backend/app/rag/postprocessors/metadata_post_filter.py index f571c4262..432eba9d4 100644 --- a/backend/app/rag/postprocessors/metadata_post_filter.py +++ b/backend/app/rag/postprocessors/metadata_post_filter.py @@ -1,76 +1,49 @@ import logging -from enum import Enum -from typing import List, Optional, Any, Union +from typing import Dict, List, Optional, Any, Union from llama_index.core import QueryBundle from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.core.schema import NodeWithScore -from pydantic import BaseModel - - -class FilterOperator(str, Enum): - """Vector store filter operator.""" - - # TODO add more operators - EQ = "==" # default operator (string, int, float) - GT = ">" # greater than (int, float) - LT = "<" # less than (int, float) - NE = "!=" # not equal to (string, int, float) - GTE = ">=" # greater than or equal to (int, float) - LTE = "<=" # less than or equal to (int, float) - IN = "in" # In array (string or number) - NIN = "nin" # Not in array (string or number) - ANY = "any" # Contains any (array of strings) - ALL = "all" # Contains all (array of strings) - TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase - # within the text field) - CONTAINS = "contains" # metadata array contains value (string or number) - - -class FilterCondition(str, Enum): - AND = "and" - OR = "or" - - -class MetadataFilter(BaseModel): - key: str - value: Union[ - int, - float, - str, - List[int], - List[float], - List[str], - ] - operator: FilterOperator = FilterOperator.EQ - - -# Notice: -# -# llama index is still heavily using pydantic v1 to define data models. Using classes in llama index to define FastAPI -# parameters may cause the following errors: -# -# TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given -# -# See: https://github.com/run-llama/llama_index/issues/14807#issuecomment-2241285940 -class MetadataFilters(BaseModel): - """Metadata filters for vector stores.""" - - # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. - filters: List[Union[MetadataFilter, "MetadataFilters"]] - # and/or such conditions for combining different filters - condition: Optional[FilterCondition] = FilterCondition.AND - - -_logger = logging.getLogger(__name__) +from llama_index.core.schema import BaseNode, NodeWithScore +from llama_index.core.vector_stores.types import ( + MetadataFilter, + MetadataFilters, + FilterOperator, + FilterCondition, +) + + +SimpleMetadataFilter = Dict[str, Any] + + +def simple_filter_to_metadata_filters(filters: SimpleMetadataFilter) -> MetadataFilters: + simple_filters = [] + for key, value in filters.items(): + simple_filters.append( + MetadataFilter( + key=key, + value=value, + operator=FilterOperator.EQ, + ) + ) + return MetadataFilters(filters=simple_filters) + + +logger = logging.getLogger(__name__) class MetadataPostFilter(BaseNodePostprocessor): filters: Optional[MetadataFilters] = None - def __init__(self, filters: Optional[MetadataFilters] = None, **kwargs: Any): + def __init__( + self, + filters: Optional[Union[MetadataFilters, SimpleMetadataFilter]] = None, + **kwargs: Any, + ): super().__init__(**kwargs) - self.filters = filters + if isinstance(filters, MetadataFilters): + self.filters = filters + else: + self.filters = simple_filter_to_metadata_filters(filters) def _postprocess_nodes( self, @@ -87,29 +60,29 @@ def _postprocess_nodes( filtered_nodes.append(node) return filtered_nodes - def match_all_filters(self, node: Any) -> bool: + def match_all_filters(self, node: BaseNode) -> bool: if self.filters is None or not isinstance(self.filters, MetadataFilters): return True if self.filters.condition != FilterCondition.AND: - _logger.warning( + logger.warning( f"Advanced filtering is not supported yet. " f"Filter condition {self.filters.condition} is ignored." ) return True for f in self.filters.filters: - if f.key not in node.extra_info: + if f.key not in node.metadata: return False if f.operator is not None and f.operator != FilterOperator.EQ: - _logger.warning( + logger.warning( f"Advanced filtering is not supported yet. " f"Filter operator {f.operator} is ignored." ) return True - value = node.extra_info[f.key] + value = node.metadata[f.key] if f.value != value: return False diff --git a/backend/app/rag/postprocessors/resolver.py b/backend/app/rag/postprocessors/resolver.py deleted file mode 100644 index f8596ee89..000000000 --- a/backend/app/rag/postprocessors/resolver.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Mapping, Any -from llama_index.core.postprocessor.types import BaseNodePostprocessor -from .metadata_post_filter import ( - MetadataFilters, - MetadataPostFilter, - MetadataFilter, -) - - -def get_metadata_post_filter( - filters: Mapping[str, Any] = None, -) -> BaseNodePostprocessor: - simple_filters = [] - for key, value in filters.items(): - simple_filters.append( - MetadataFilter( - key=key, - value=value, - ) - ) - return MetadataPostFilter( - MetadataFilters( - filters=simple_filters, - ) - ) diff --git a/backend/app/rag/retrieve.py b/backend/app/rag/retrieve.py deleted file mode 100644 index 7178786c3..000000000 --- a/backend/app/rag/retrieve.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -from typing import List, Optional -from llama_index.core.schema import NodeWithScore -from sqlmodel import Session - -from app.models import ( - Document as DBDocument, -) -from app.rag.retrievers.chat_engine_based import ( - ChatEngineBasedRetriever, -) -from app.repositories.chunk import ChunkRepo -from app.repositories.knowledge_base import knowledge_base_repo -from app.models.chunk import get_kb_chunk_model -from app.rag.chat_config import ChatEngineConfig - - -logger = logging.getLogger(__name__) - - -class ChatEngineBasedRetrieveService: - def chat_engine_retrieve_documents( - self, - db_session: Session, - question: str, - top_k: int = 5, - chat_engine_name: str = "default", - similarity_top_k: Optional[int] = None, - oversampling_factor: Optional[int] = None, - enable_kg_enhance_query_refine: bool = False, - ) -> List[DBDocument]: - chat_engine_config = ChatEngineConfig.load_from_db(db_session, chat_engine_name) - if not chat_engine_config.knowledge_base: - raise Exception("Chat engine does not configured with knowledge base") - - nodes = self.chat_engine_retrieve_chunks( - db_session, - question=question, - top_k=top_k, - chat_engine_name=chat_engine_name, - similarity_top_k=similarity_top_k, - oversampling_factor=oversampling_factor, - enable_kg_enhance_query_refine=enable_kg_enhance_query_refine, - ) - - linked_knowledge_base = chat_engine_config.knowledge_base.linked_knowledge_base - kb = knowledge_base_repo.must_get(db_session, linked_knowledge_base.id) - chunk_model = get_kb_chunk_model(kb) - chunk_repo = ChunkRepo(chunk_model) - chunk_ids = [node.node.node_id for node in nodes] - - return chunk_repo.get_documents_by_chunk_ids(db_session, chunk_ids) - - def chat_engine_retrieve_chunks( - self, - db_session: Session, - question: str, - top_k: int = 5, - chat_engine_name: str = "default", - similarity_top_k: Optional[int] = None, - oversampling_factor: Optional[int] = None, - enable_kg_enhance_query_refine: bool = False, - ) -> List[NodeWithScore]: - retriever = ChatEngineBasedRetriever( - db_session=db_session, - engine_name=chat_engine_name, - top_k=top_k, - similarity_top_k=similarity_top_k, - oversampling_factor=oversampling_factor, - enable_kg_enhance_query_refine=enable_kg_enhance_query_refine, - ) - return retriever.retrieve(question) - - -retrieve_service = ChatEngineBasedRetrieveService() diff --git a/backend/app/rag/retrievers/chat_engine_based.py b/backend/app/rag/retrievers/chat_engine_based.py deleted file mode 100644 index 28d1a0d15..000000000 --- a/backend/app/rag/retrievers/chat_engine_based.py +++ /dev/null @@ -1,182 +0,0 @@ -import logging -from datetime import datetime -from typing import List, Tuple -from llama_index.core.retrievers import BaseRetriever -from llama_index.core.schema import NodeWithScore, QueryBundle -from sqlmodel import Session - -from app.rag.chat import get_prompt_by_jinja2_template -from app.rag.chat_config import ( - ChatEngineConfig, - KnowledgeGraphOption, - KnowledgeBaseOption, -) -from app.rag.retrievers.knowledge_graph.fusion_retriever import ( - KnowledgeGraphFusionRetriever, -) -from app.rag.retrievers.knowledge_graph.schema import ( - KnowledgeGraphRetrievalResult, - KnowledgeGraphRetrieverConfig, -) -from app.rag.retrievers.chunk.fusion_retriever import ChunkFusionRetriever -from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig -from app.repositories.knowledge_base import knowledge_base_repo - - -logger = logging.getLogger(__name__) - - -class ChatEngineBasedRetriever(BaseRetriever): - """ - Chat engine based retriever, which is dependent on the configuration of the chat engine. - """ - - def __init__( - self, - db_session: Session, - engine_name: str = "default", - chat_engine_config: ChatEngineConfig = None, - top_k: int = 10, - similarity_top_k: int = None, - oversampling_factor: int = 5, - enable_kg_enhance_query_refine: bool = False, - ): - self.db_session = db_session - self.engine_name = engine_name - self.top_k = top_k - self.similarity_top_k = similarity_top_k - self.oversampling_factor = oversampling_factor - self.enable_kg_enhance_query_refine = enable_kg_enhance_query_refine - - self.chat_engine_config = chat_engine_config or ChatEngineConfig.load_from_db( - db_session, engine_name - ) - self.db_chat_engine = self.chat_engine_config.get_db_chat_engine() - - # Init LLM. - self._llm = self.chat_engine_config.get_llama_llm(self.db_session) - self._fast_llm = self.chat_engine_config.get_fast_llama_llm(self.db_session) - self._fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm(self.db_session) - - # Load knowledge bases. - kb_config: KnowledgeBaseOption = self.chat_engine_config.knowledge_base - linked_knowledge_base_ids = [] - if len(kb_config.linked_knowledge_bases) == 0: - linked_knowledge_base_ids.append( - self.chat_engine_config.knowledge_base.linked_knowledge_base.id - ) - else: - linked_knowledge_base_ids.extend( - [kb.id for kb in kb_config.linked_knowledge_bases] - ) - self.knowledge_base_ids = linked_knowledge_base_ids - self.knowledge_bases = knowledge_base_repo.get_by_ids( - self.db_session, knowledge_base_ids=linked_knowledge_base_ids - ) - - def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - if self.enable_kg_enhance_query_refine: - refined_question = self._kg_enhance_query_refine(query_bundle.query_str) - else: - refined_question = query_bundle.query_str - - return self._search_relevant_chunks(refined_question=refined_question) - - def _kg_enhance_query_refine(self, query_str): - # 1. Retrieve Knowledge graph related to the user question. - kg_config = self.chat_engine_config.knowledge_graph - knowledge_graph_context = "" - if kg_config is not None and kg_config.enabled: - _, knowledge_graph_context = self._search_knowledge_graph( - user_question=query_str, kg_config=kg_config - ) - - # 2. Refine the user question using knowledge graph and chat history. - refined_question = self._refine_user_question( - user_question=query_str, - knowledge_graph_context=knowledge_graph_context, - refined_question_prompt=self.chat_engine_config.llm.condense_question_prompt, - ) - - return refined_question - - def _search_knowledge_graph( - self, user_question: str, kg_config: KnowledgeGraphOption - ) -> Tuple[KnowledgeGraphRetrievalResult, str]: - # For forward compatibility of chat engine config. - enable_metadata_filter = kg_config.enable_metadata_filter or ( - kg_config.relationship_meta_filters is not None - ) - metadata_filters = ( - kg_config.metadata_filters or kg_config.relationship_meta_filters - ) - - kg_retriever = KnowledgeGraphFusionRetriever( - db_session=self.db_session, - knowledge_base_ids=self.knowledge_base_ids, - llm=self._llm, - use_query_decompose=kg_config.using_intent_search, - use_async=True, - config=KnowledgeGraphRetrieverConfig( - depth=kg_config.depth, - include_metadata=kg_config.include_meta, - with_degree=kg_config.with_degree, - enable_metadata_filter=enable_metadata_filter, - metadata_filters=metadata_filters, - ), - callback_manager=self.callback_manager, - ) - - if kg_config.using_intent_search: - knowledge_graph = kg_retriever.retrieve_knowledge_graph(user_question) - kg_context_template = get_prompt_by_jinja2_template( - self.chat_engine_config.llm.intent_graph_knowledge, - # For forward compatibility considerations. - sub_queries=knowledge_graph.to_subqueries_dict(), - ) - knowledge_graph_context = kg_context_template.template - else: - knowledge_graph = kg_retriever.retrieve_knowledge_graph(user_question) - kg_context_template = get_prompt_by_jinja2_template( - self.chat_engine_config.llm.normal_graph_knowledge, - entities=knowledge_graph.entities, - relationships=knowledge_graph.relationships, - ) - knowledge_graph_context = kg_context_template.template - - return ( - knowledge_graph, - knowledge_graph_context, - ) - - def _refine_user_question( - self, - user_question: str, - refined_question_prompt: str, - knowledge_graph_context: str = "", - ) -> str: - return self._fast_llm.predict( - get_prompt_by_jinja2_template( - refined_question_prompt, - graph_knowledges=knowledge_graph_context, - question=user_question, - current_date=datetime.now().strftime("%Y-%m-%d"), - ), - ) - - def _search_relevant_chunks(self, refined_question: str) -> List[NodeWithScore]: - retriever = ChunkFusionRetriever( - db_session=self.db_session, - knowledge_base_ids=self.knowledge_base_ids, - llm=self._llm, - config=VectorSearchRetrieverConfig( - similarity_top_k=self.similarity_top_k, - oversampling_factor=self.oversampling_factor, - top_k=self.top_k, - ), - use_query_decompose=False, - use_async=True, - callback_manager=self.callback_manager, - ) - - return retriever.retrieve(QueryBundle(refined_question)) diff --git a/backend/app/rag/retrievers/chunk/fusion_retriever.py b/backend/app/rag/retrievers/chunk/fusion_retriever.py index 7ec7f49ef..5a6a65d61 100644 --- a/backend/app/rag/retrievers/chunk/fusion_retriever.py +++ b/backend/app/rag/retrievers/chunk/fusion_retriever.py @@ -65,11 +65,13 @@ def __init__( ) def _fusion( - self, results: Dict[Tuple[str, int], List[NodeWithScore]] + self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: - return self._simple_fusion(results) + return self._simple_fusion(query, results) - def _simple_fusion(self, results: Dict[Tuple[str, int], List[NodeWithScore]]): + def _simple_fusion( + self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]] + ): """Apply simple fusion.""" # Use a dict to de-duplicate nodes all_nodes: Dict[str, NodeWithScore] = {} diff --git a/backend/app/rag/retrievers/chunk/schema.py b/backend/app/rag/retrievers/chunk/schema.py index eefc630b2..da2257c6c 100644 --- a/backend/app/rag/retrievers/chunk/schema.py +++ b/backend/app/rag/retrievers/chunk/schema.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Optional +from typing import Any, Dict, Optional from pydantic import BaseModel @@ -7,12 +7,14 @@ class RerankerConfig(BaseModel): + enabled: bool = True model_id: int = None top_n: int = 10 class MetadataFilterConfig(BaseModel): - filters: dict = None + enabled: bool = True + filters: Dict[str, Any] = None class VectorSearchRetrieverConfig(BaseModel): @@ -50,7 +52,7 @@ class RetrievedChunk(BaseModel): class ChunksRetrievalResult(BaseModel): chunks: list[RetrievedChunk] - documents: list[Document | RetrievedChunkDocument] + documents: Optional[list[Document | RetrievedChunkDocument]] = None class ChunkRetriever(ABC): diff --git a/backend/app/rag/retrievers/chunk/simple_retriever.py b/backend/app/rag/retrievers/chunk/simple_retriever.py index 8ee20f3c5..e06ae78fc 100644 --- a/backend/app/rag/retrievers/chunk/simple_retriever.py +++ b/backend/app/rag/retrievers/chunk/simple_retriever.py @@ -3,10 +3,13 @@ from typing import List, Optional, Type from llama_index.core.callbacks import CallbackManager +from llama_index.core.indices.utils import log_vector_store_query_result +from llama_index.core.vector_stores import VectorStoreQuery, VectorStoreQueryResult from sqlmodel import Session -from llama_index.core import VectorStoreIndex from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore, QueryBundle +import llama_index.core.instrumentation as instrument + from app.models.chunk import get_kb_chunk_model from app.models.patch.sql_model import SQLModel from app.rag.knowledge_base.config import get_kb_embed_model @@ -18,12 +21,15 @@ ) from app.rag.retrievers.chunk.helpers import map_nodes_to_chunks from app.rag.indices.vector_search.vector_store.tidb_vector_store import TiDBVectorStore -from app.rag.postprocessors.resolver import get_metadata_post_filter +from app.rag.postprocessors.metadata_post_filter import MetadataPostFilter from app.repositories import knowledge_base_repo, document_repo logger = logging.getLogger(__name__) +dispatcher = instrument.get_dispatcher(__name__) + + class ChunkSimpleRetriever(BaseRetriever, ChunkRetriever): _chunk_model: Type[SQLModel] @@ -45,44 +51,69 @@ def __init__( self._embed_model = get_kb_embed_model(db_session, self._kb) self._embed_model.callback_manager = callback_manager - # Vector Index - vector_store = TiDBVectorStore( + # Init vector store. + self._vector_store = TiDBVectorStore( session=db_session, chunk_db_model=self._chunk_db_model, oversampling_factor=config.oversampling_factor, - ) - self._vector_index = VectorStoreIndex.from_vector_store( - vector_store, - embed_model=self._embed_model, callback_manager=callback_manager, ) + # Init node postprocessors. node_postprocessors = [] # Metadata filter - enable_metadata_filter = config.metadata_filter is not None - if enable_metadata_filter: - metadata_filter = get_metadata_post_filter(config.metadata_filter.filters) + filter_config = config.metadata_filter + if filter_config and filter_config.enabled: + metadata_filter = MetadataPostFilter(filter_config.filters) node_postprocessors.append(metadata_filter) # Reranker - enable_reranker = config.reranker is not None - if enable_reranker: + reranker_config = config.reranker + if reranker_config and reranker_config.enabled: reranker = resolve_reranker_by_id( - db_session, config.reranker.model_id, config.reranker.top_n + db_session, reranker_config.model_id, reranker_config.top_n ) node_postprocessors.append(reranker) - # Vector Index Retrieve Engine - self._retrieve_engine = self._vector_index.as_retriever( - node_postprocessors=node_postprocessors, - similarity_top_k=config.similarity_top_k or config.top_k, - ) + self._node_postprocessors = node_postprocessors + @dispatcher.span def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - nodes = self._retrieve_engine.retrieve(query_bundle) + if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0: + query_bundle.embedding = self._embed_model.get_agg_embedding_from_queries( + query_bundle.embedding_strs + ) + + result = self._vector_store.query( + VectorStoreQuery( + query_str=query_bundle.query_str, + query_embedding=query_bundle.embedding, + similarity_top_k=self._config.similarity_top_k or self._config.top_k, + ) + ) + nodes = self._build_node_list_from_query_result(result) + + for node_postprocessor in self._node_postprocessors: + nodes = node_postprocessor.postprocess_nodes( + nodes, query_bundle=query_bundle + ) + return nodes[: self._config.top_k] + def _build_node_list_from_query_result( + self, query_result: VectorStoreQueryResult + ) -> List[NodeWithScore]: + log_vector_store_query_result(query_result) + node_with_scores: List[NodeWithScore] = [] + for ind, node in enumerate(query_result.nodes): + score: Optional[float] = None + if query_result.similarities is not None: + score = query_result.similarities[ind] + node_with_scores.append(NodeWithScore(node=node, score=score)) + + return node_with_scores + def retrieve_chunks( self, query_str: str, full_document: bool = False ) -> ChunksRetrievalResult: diff --git a/backend/app/rag/retrievers/knowledge_graph/fusion_retriever.py b/backend/app/rag/retrievers/knowledge_graph/fusion_retriever.py index 6c500b933..aa307865d 100644 --- a/backend/app/rag/retrievers/knowledge_graph/fusion_retriever.py +++ b/backend/app/rag/retrievers/knowledge_graph/fusion_retriever.py @@ -8,6 +8,7 @@ from llama_index.core.schema import NodeWithScore from llama_index.core.tools import ToolMetadata +from app.models import KnowledgeBase from app.rag.retrievers.multiple_knowledge_base import MultiKBFusionRetriever from app.rag.knowledge_base.selector import KBSelectMode from app.rag.retrievers.knowledge_graph.simple_retriever import ( @@ -15,7 +16,6 @@ ) from app.rag.retrievers.knowledge_graph.schema import ( KnowledgeGraphRetrieverConfig, - RetrievedRelationship, KnowledgeGraphRetrievalResult, KnowledgeGraphNode, KnowledgeGraphRetriever, @@ -27,6 +27,8 @@ class KnowledgeGraphFusionRetriever(MultiKBFusionRetriever, KnowledgeGraphRetriever): + knowledge_base_map: Dict[int, KnowledgeBase] = {} + def __init__( self, db_session: Session, @@ -39,11 +41,15 @@ def __init__( callback_manager: Optional[CallbackManager] = CallbackManager([]), **kwargs, ): + self.use_query_decompose = use_query_decompose + # Prepare knowledge graph retrievers for knowledge bases. retrievers = [] retriever_choices = [] knowledge_bases = knowledge_base_repo.get_by_ids(db_session, knowledge_base_ids) + self.knowledge_bases = knowledge_bases for kb in knowledge_bases: + self.knowledge_base_map[kb.id] = kb retrievers.append( KnowledgeGraphSimpleRetriever( db_session=db_session, @@ -78,75 +84,67 @@ def retrieve_knowledge_graph( if len(nodes_with_score) == 0: return KnowledgeGraphRetrievalResult() node: KnowledgeGraphNode = nodes_with_score[0].node # type:ignore - subqueries = [ - KnowledgeGraphRetrievalResult( - query=subgraph.query, - entities=subgraph.entities, - relationships=subgraph.relationships, - ) - for subgraph in node.subqueries.values() - ] return KnowledgeGraphRetrievalResult( query=node.query, + knowledge_bases=[kb.to_descriptor() for kb in self.knowledge_bases], entities=node.entities, relationships=node.relationships, - subqueries=[ + subgraphs=[ KnowledgeGraphRetrievalResult( - query=sub.query, - entities=sub.entities, - relationships=node.relationships, + query=child_node.query, + knowledge_base=self.knowledge_base_map[ + child_node.knowledge_base_id + ].to_descriptor(), + entities=child_node.entities, + relationships=child_node.relationships, ) - for sub in subqueries + for child_node in node.children ], ) def _fusion( - self, results: Dict[Tuple[str, int], List[NodeWithScore]] + self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: - return self._knowledge_graph_fusion(results) + return self._knowledge_graph_fusion(query, results) def _knowledge_graph_fusion( - self, results: Dict[Tuple[str, int], List[NodeWithScore]] + self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: - merged_queries = {} - merged_entities = {} + merged_entities = set() merged_relationships = {} + merged_knowledge_base_ids = set() + merged_children_nodes = [] + for nodes_with_scores in results.values(): if len(nodes_with_scores) == 0: continue node: KnowledgeGraphNode = nodes_with_scores[0].node # type:ignore - merged_queries[node.query] = node + + # Merge knowledge base id. + merged_knowledge_base_ids.add(node.knowledge_base_id) # Merge entities. - for e in node.entities: - if e.id not in merged_entities: - merged_entities[e.id] = e + merged_entities.update(node.entities) # Merge relationships. for r in node.relationships: - key = (r.source_entity_id, r.target_entity_id, r.description) + key = r.rag_description if key not in merged_relationships: - merged_relationships[key] = RetrievedRelationship( - id=r.id, - source_entity_id=r.source_entity_id, - target_entity_id=r.target_entity_id, - description=r.description, - rag_description=r.rag_description, - weight=0, - meta=r.meta, - last_modified_at=r.last_modified_at, - ) + merged_relationships[key] = r else: merged_relationships[key].weight += r.weight + # Merge to children nodes. + merged_children_nodes.append(node) return [ NodeWithScore( node=KnowledgeGraphNode( - query=None, - entities=list(merged_entities.values()), + query=query, + entities=list(merged_entities), relationships=list(merged_relationships.values()), - subqueries=merged_queries, + knowledge_base_ids=merged_knowledge_base_ids, + children=merged_children_nodes, ), score=1, ) diff --git a/backend/app/rag/retrievers/knowledge_graph/schema.py b/backend/app/rag/retrievers/knowledge_graph/schema.py index f7ef39b43..36868e842 100644 --- a/backend/app/rag/retrievers/knowledge_graph/schema.py +++ b/backend/app/rag/retrievers/knowledge_graph/schema.py @@ -1,37 +1,85 @@ import datetime import json from abc import ABC +from enum import Enum from hashlib import sha256 from typing import Optional, Mapping, Any, List from llama_index.core.schema import BaseNode, MetadataMode from pydantic import BaseModel, Field +from app.models.entity import EntityType +from app.api.admin_routes.models import KnowledgeBaseDescriptor + # Retriever Config +class MetadataFilterConfig(BaseModel): + enabled: bool = True + filters: dict[str, Any] = None + + class KnowledgeGraphRetrieverConfig(BaseModel): depth: int = 2 include_meta: bool = False with_degree: bool = False - enable_metadata_filter: bool = False - metadata_filters: Optional[dict] = None + metadata_filter: Optional[MetadataFilterConfig] = None + + +# Stored Knowledge Graph + + +class StoredKnowledgeGraphVersion(int, Enum): + V1 = 1 + + +class StoredSubGraph(BaseModel): + query: Optional[str] = None + knowledge_base_id: Optional[int] = None + entities: Optional[list[int]] = None + relationships: Optional[list[int]] = None + + +class StoredKnowledgeGraph(StoredSubGraph): + """ + StoredKnowledgeGraph represents the structure of the knowledge graph stored in the database. + """ + + # If not provided, it means that the old version of the storage format is used, which only + # stores entities and relationships information. + version: Optional[int] = StoredKnowledgeGraphVersion.V1 + knowledge_base_ids: Optional[list[int]] = [] + subgraphs: Optional[list["StoredSubGraph"]] = None # Retrieved Knowledge Graph class RetrievedEntity(BaseModel): - id: int = Field(description="Unique identifier for the entity") + id: int = Field(description="ID of the entity") + knowledge_base_id: Optional[int] = Field( + description="ID of the knowledge base", default=None + ) + entity_type: Optional[EntityType] = Field( + description="Type of the entity", default=EntityType.original + ) name: str = Field(description="Name of the entity") description: str = Field(description="Description of the entity") meta: Optional[Mapping[str, Any]] = Field(description="Metadata of the entity") + @property + def global_id(self) -> str: + return f"{self.knowledge_base_id or 0}-{self.id}" + + def __hash__(self): + return hash(self.global_id) + class RetrievedRelationship(BaseModel): - id: int = Field(description="Unique identifier for the relationship") - source_entity_id: int = Field(description="Unique identifier for the source entity") - target_entity_id: int = Field(description="Unique identifier for the target entity") + id: int = Field(description="ID of the relationship") + knowledge_base_id: int = Field(description="ID of the knowledge base", default=None) + source_entity_id: int = Field(description="ID of the source entity") + target_entity_id: int = Field(description="ID of the target entity") description: str = Field(description="Description of the relationship") meta: Optional[Mapping[str, Any]] = Field( description="Metadata of the relationship" @@ -41,36 +89,89 @@ class RetrievedRelationship(BaseModel): ) weight: Optional[float] = Field(description="Weight of the relationship") last_modified_at: Optional[datetime.datetime] = Field( - description="Last modified at of the relationship" + description="Last modified at of the relationship", default=None ) + @property + def global_id(self) -> str: + return f"{self.knowledge_base_id or 0}-{self.id}" + + def __hash__(self): + return hash(self.global_id) -class KnowledgeGraphRetrievalResult(BaseModel): - query: Optional[str] = None + +class RetrievedSubGraph(BaseModel): + query: Optional[str | list[str]] = Field( + description="List of queries that are used to retrieve the knowledge graph", + default=None, + ) + knowledge_base: Optional[KnowledgeBaseDescriptor] = Field( + description="The knowledge base that the knowledge graph is retrieved from", + default=None, + ) entities: List[RetrievedEntity] = Field( description="List of entities in the knowledge graph", default_factory=list ) relationships: List[RetrievedRelationship] = Field( description="List of relationships in the knowledge graph", default_factory=list ) - subqueries: Optional[List["KnowledgeGraphRetrievalResult"]] = Field( - description="List of subqueries in the knowledge graph", default_factory=list + + +class RetrievedKnowledgeGraph(RetrievedSubGraph): + """ + RetrievedKnowledgeGraph represents the structure of the knowledge graph retrieved + from the knowledge base. + """ + + knowledge_bases: Optional[List[KnowledgeBaseDescriptor]] = Field( + description="List of knowledge bases that the knowledge graph is retrieved from", + default_factory=list, + ) + + subgraphs: Optional[List["RetrievedSubGraph"]] = Field( + description="List of subgraphs of the knowledge graph", default_factory=list ) def to_subqueries_dict(self) -> dict: - sub_queries = {} - for subquery in self.subqueries: - sub_queries[subquery.query] = { - "entities": [e.model_dump() for e in subquery.entities], - "relationships": [r.model_dump() for r in subquery.relationships], - } - return sub_queries + """ + For forward compatibility, we need to convert the subgraphs to a dictionary + of subqueries and then pass it to the prompt template. + """ + subqueries = {} + for subgraph in self.subgraphs: + if subgraph.query not in subqueries: + subqueries[subgraph.query] = { + "entities": [e.model_dump() for e in subgraph.entities], + "relationships": [r.model_dump() for r in subgraph.relationships], + } + else: + subqueries[subgraph.query]["entities"].extend( + [e.model_dump() for e in subgraph.entities] + ) + subqueries[subgraph.query]["relationships"].extend( + [r.model_dump() for r in subgraph.relationships] + ) + + return subqueries + + def to_stored_graph_dict(self) -> dict: + subgraph = self.to_stored_graph() + return subgraph.model_dump() + + def to_stored_graph(self) -> StoredKnowledgeGraph: + return StoredKnowledgeGraph( + query=self.query, + knowledge_base_id=self.knowledge_base.id if self.knowledge_base else None, + knowledge_base_ids=[kb.id for kb in self.knowledge_bases] + if self.knowledge_bases + else None, + entities=[e.id for e in self.entities], + relationships=[r.id for r in self.relationships], + subgraphs=[s.to_stored_graph() for s in self.subgraphs], + ) + - def to_graph_data_dict(self) -> dict: - return { - "entities": [e.id for e in self.entities], - "relationships": [r.id for r in self.relationships], - } +KnowledgeGraphRetrievalResult = RetrievedKnowledgeGraph class KnowledgeGraphRetriever(ABC): @@ -110,6 +211,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) query: Optional[str] = Field(description="Query of the knowledge graph") + + knowledge_base_id: Optional[int] = Field( + description="The id of the knowledge base that the knowledge graph belongs to", + default=None, + ) + knowledge_base_ids: Optional[List[int]] = Field( + description="List of ids of the knowledge base that the knowledge graph belongs to", + default_factory=list, + ) + entities: List[RetrievedEntity] = Field( description="The list of entities in the knowledge graph", default_factory=list ) @@ -117,11 +228,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: description="The list of relationships in the knowledge graph", default_factory=list, ) - subqueries: Optional[Mapping[str, "KnowledgeGraphNode"]] = Field( - description="Subqueries", - default_factory=dict, + children: Optional[List["KnowledgeGraphNode"]] = Field( + description="The children of the knowledge graph", + default_factory=list, ) + # Template + knowledge_base_template: str = Field( default=DEFAULT_KNOWLEDGE_GRAPH_TMPL, description="The template to render the knowledge graph as string", @@ -172,14 +285,11 @@ def _get_relationships_str(self) -> str: rag_description=relationship.rag_description, weight=relationship.weight, last_modified_at=relationship.last_modified_at, - meta=self._get_metadata_str(relationship.meta), + meta=json.dumps(relationship.meta, indent=2, ensure_ascii=False), ) ) return "\n\n".join(strs) - def _get_metadata_str(self, meta: Mapping[str, Any]) -> str: - return json.dumps(meta, indent=2, ensure_ascii=False) - def _get_knowledge_graph_str(self) -> str: return self.knowledge_base_template.format( query=self.query, @@ -187,11 +297,23 @@ def _get_knowledge_graph_str(self) -> str: relationships_str=self._get_relationships_str(), ) - def set_content(self, value: KnowledgeGraphRetrievalResult) -> None: - self.query = value.query - self.entities = value.entities - self.relationships = value.relationships - self.subqueries = value.subqueries + def set_content(self, kg: RetrievedKnowledgeGraph): + self.query = kg.query + self.knowledge_base_id = kg.knowledge_base.id if kg.knowledge_base else None + self.knowledge_base_ids = [] + self.entities = kg.entities + self.relationships = kg.relationships + self.children = [ + KnowledgeGraphNode( + query=subgraph.query, + knowledge_base_id=subgraph.knowledge_base.id + if subgraph.knowledge_base + else None, + entities=subgraph.entities, + relationships=subgraph.relationships, + ) + for subgraph in kg.subgraphs + ] @property def hash(self) -> str: diff --git a/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py b/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py index baa35b796..5d5c42d13 100644 --- a/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py +++ b/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py @@ -30,14 +30,17 @@ def __init__( super().__init__(callback_manager, **kwargs) self.config = config self._callback_manager = callback_manager - self.kb = knowledge_base_repo.must_get(db_session, knowledge_base_id) - self.embed_model = get_kb_embed_model(db_session, self.kb) + self.knowledge_base = knowledge_base_repo.must_get( + db_session, knowledge_base_id + ) + self.embed_model = get_kb_embed_model(db_session, self.knowledge_base) self.embed_model.callback_manager = callback_manager - self.entity_db_model = get_kb_entity_model(self.kb) - self.relationship_db_model = get_kb_relationship_model(self.kb) + self.entity_db_model = get_kb_entity_model(self.knowledge_base) + self.relationship_db_model = get_kb_relationship_model(self.knowledge_base) # TODO: remove it - dspy_lm = get_kb_dspy_llm(db_session, self.kb) + dspy_lm = get_kb_dspy_llm(db_session, self.knowledge_base) self._kg_store = TiDBGraphStore( + knowledge_base=self.knowledge_base, dspy_lm=dspy_lm, session=db_session, embed_model=self.embed_model, @@ -46,20 +49,23 @@ def __init__( ) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + metadata_filters = {} + if self.config.metadata_filter and self.config.metadata_filter.enabled: + metadata_filters = self.config.metadata_filter.filters + entities, relationships = self._kg_store.retrieve_with_weight( query_bundle.query_str, embedding=[], depth=self.config.depth, include_meta=self.config.include_meta, with_degree=self.config.with_degree, - relationship_meta_filters=self.config.metadata_filters - if self.config.metadata_filters - else None, + relationship_meta_filters=metadata_filters, ) return [ NodeWithScore( node=KnowledgeGraphNode( query=query_bundle.query_str, + knowledge_base_id=self.knowledge_base.id, entities=entities, relationships=relationships, ), @@ -73,10 +79,11 @@ def retrieve_knowledge_graph( nodes_with_score = self._retrieve(QueryBundle(query_text)) if len(nodes_with_score) == 0: return KnowledgeGraphRetrievalResult() - node = nodes_with_score[0].node + node: KnowledgeGraphNode = nodes_with_score[0].node # type:ignore return KnowledgeGraphRetrievalResult( query=node.query, + knowledge_base=self.knowledge_base.to_descriptor(), entities=node.entities, relationships=node.relationships, - subqueries=[], + subgraphs=[], ) diff --git a/backend/app/rag/retrievers/multiple_knowledge_base.py b/backend/app/rag/retrievers/multiple_knowledge_base.py index b9b7497ec..4cb714cf6 100644 --- a/backend/app/rag/retrievers/multiple_knowledge_base.py +++ b/backend/app/rag/retrievers/multiple_knowledge_base.py @@ -7,7 +7,7 @@ from llama_index.core import QueryBundle from llama_index.core.async_utils import run_async_tasks from llama_index.core.base.base_retriever import BaseRetriever -from llama_index.core.callbacks import CallbackManager, EventPayload +from llama_index.core.callbacks import CallbackManager from llama_index.core.llms import LLM from llama_index.core.schema import NodeWithScore from llama_index.core.tools import ToolMetadata @@ -80,19 +80,10 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: else: results = self._run_sync_queries(queries) - return self._fusion(results) + return self._fusion(query_bundle.query_str, results) def _gen_sub_queries(self, query_bundle: QueryBundle) -> List[QueryBundle]: - with self.callback_manager.event( - MyCBEventType.INTENT_DECOMPOSITION, - payload={EventPayload.QUERY_STR: query_bundle.query_str}, - ) as event: - queries = self._query_decomposer.decompose(query_bundle.query_str) - event.on_end( - payload={ - "subqueries": queries, - } - ) + queries = self._query_decomposer.decompose(query_bundle.query_str) return [QueryBundle(r.question) for r in queries.questions] def _run_async_queries( @@ -125,6 +116,6 @@ def _run_sync_queries( @abstractmethod def _fusion( - self, results: Dict[Tuple[str, int], List[NodeWithScore]] + self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: """fusion method""" diff --git a/backend/app/repositories/chat.py b/backend/app/repositories/chat.py index 1b3065b8a..4b7035321 100644 --- a/backend/app/repositories/chat.py +++ b/backend/app/repositories/chat.py @@ -10,7 +10,7 @@ from app.models import Chat, User, ChatMessage, ChatUpdate from app.repositories.base_repo import BaseRepo -from app.exceptions import ChatNotFound +from app.exceptions import ChatNotFound, ChatMessageNotFound class ChatRepo(BaseRepo): @@ -102,6 +102,16 @@ def get_message( ) ).first() + def must_get_message( + self, + session: Session, + chat_message_id: int, + ): + msg = self.get_message(session, chat_message_id) + if not msg: + raise ChatMessageNotFound(chat_message_id) + return msg + def create_message( self, session: Session, diff --git a/backend/app/tasks/evaluate.py b/backend/app/tasks/evaluate.py index 2a7877d5c..a77cd1f94 100644 --- a/backend/app/tasks/evaluate.py +++ b/backend/app/tasks/evaluate.py @@ -23,8 +23,8 @@ ) from dotenv import load_dotenv -from app.rag.chat import ChatFlow -from app.rag.chat_stream_protocol import ChatEvent +from app.rag.chat.chat_flow import ChatFlow +from app.rag.chat.stream_protocol import ChatEvent from app.rag.types import ChatEventType, ChatMessageSate load_dotenv() diff --git a/backend/app/utils/tracing.py b/backend/app/utils/tracing.py new file mode 100644 index 000000000..061e411f2 --- /dev/null +++ b/backend/app/utils/tracing.py @@ -0,0 +1,70 @@ +from contextlib import contextmanager +from typing import Optional, Generator +from langfuse.client import StatefulSpanClient, StatefulClient +from langfuse.llama_index import LlamaIndexInstrumentor +from langfuse.llama_index._context import langfuse_instrumentor_context + + +class LangfuseContextManager: + langfuse_client: Optional[StatefulSpanClient] = None + + def __init__(self, instrumentor: LlamaIndexInstrumentor): + self.instrumentor = instrumentor + + @contextmanager + def observe(self, **kwargs): + try: + self.instrumentor.start() + with self.instrumentor.observe(**kwargs) as trace_client: + trace_client.update(name=kwargs.get("trace_name"), **kwargs) + self.langfuse_client = trace_client + yield trace_client + except Exception: + raise + finally: + self.instrumentor.flush() + self.instrumentor.stop() + + @contextmanager + def span( + self, parent_client: Optional[StatefulClient] = None, **kwargs + ) -> Generator["StatefulSpanClient", None, None]: + if parent_client: + client = parent_client + else: + client = self.langfuse_client + span = client.span(**kwargs) + + ctx = langfuse_instrumentor_context.get().copy() + old_parent_observation_id = ctx.get("parent_observation_id") + langfuse_instrumentor_context.get().update( + { + "parent_observation_id": span.id, + } + ) + + try: + yield span + except Exception: + raise + finally: + ctx.update( + { + "parent_observation_id": old_parent_observation_id, + } + ) + langfuse_instrumentor_context.get().update(ctx) + + @property + def trace_id(self) -> Optional[str]: + if self.langfuse_client: + return self.langfuse_client.trace_id + else: + return None + + @property + def trace_url(self) -> Optional[str]: + if self.langfuse_client: + return self.langfuse_client.get_trace_url() + else: + return None diff --git a/backend/bootstrap.py b/backend/bootstrap.py index b8eb89354..659f90443 100644 --- a/backend/bootstrap.py +++ b/backend/bootstrap.py @@ -63,7 +63,7 @@ async def reset_admin_password( async def ensure_default_chat_engine(session: AsyncSession) -> None: result = await session.scalar(func.count(ChatEngine.id)) if result == 0: - from app.rag.chat_config import ChatEngineConfig + from app.rag.chat.config import ChatEngineConfig chat_engine = ChatEngine( name="default", diff --git a/backend/main.py b/backend/main.py index 3392336b8..e0f62f399 100644 --- a/backend/main.py +++ b/backend/main.py @@ -127,11 +127,19 @@ def cli(): @cli.command() @click.option("--host", default="127.0.0.1", help="Host, default=127.0.0.1") @click.option("--port", default=3000, help="Port, default=3000") -def runserver(host, port): +@click.option("--workers", default=4) +def runserver(host, port, workers): warnings.warn( "This command will start the server in development mode, do not use it in production." ) - uvicorn.run("main:app", host=host, port=port, reload=True, log_level="debug") + uvicorn.run( + "main:app", + host=host, + port=port, + reload=True, + log_level="debug", + workers=workers, + ) @cli.command() diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f0f1bf94..dad6db1d1 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "python-dotenv>=1.0.1", "sentry-sdk>=2.5.1", "dspy-ai>=2.4.9", - "langfuse>=2.48.0", + "langfuse>=2.59.1", "langchain-openai>=0.2.9", "ragas>=0.2.6", "deepeval>=0.21.73", @@ -45,7 +45,7 @@ dependencies = [ "python-docx>=1.1.2", "python-pptx>=1.0.2", "openpyxl>=3.1.5", - "llama-index>=0.12.10", + "llama-index>=0.12.16", "llama-index-llms-openai>=0.3.13", "llama-index-llms-openai-like>=0.3.3", "llama-index-llms-bedrock>=0.3.3", diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index aa22eeb01..8db21fa2b 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -316,20 +316,20 @@ langchain-openai==0.2.9 # via ragas langchain-text-splitters==0.3.2 # via langchain -langfuse==2.48.0 +langfuse==2.59.1 langsmith==0.1.143 # via langchain # via langchain-community # via langchain-core llama-cloud==0.1.6 # via llama-index-indices-managed-llama-cloud -llama-index==0.12.10 +llama-index==0.12.16 llama-index-agent-openai==0.4.1 # via llama-index # via llama-index-program-openai llama-index-cli==0.4.0 # via llama-index -llama-index-core==0.12.10.post1 +llama-index-core==0.12.16.post1 # via llama-index # via llama-index-agent-openai # via llama-index-cli @@ -652,6 +652,7 @@ requests==2.32.3 # via huggingface-hub # via langchain # via langchain-community + # via langfuse # via langsmith # via llama-index-core # via msal diff --git a/backend/requirements.lock b/backend/requirements.lock index 82d7aef7f..69ff74872 100644 --- a/backend/requirements.lock +++ b/backend/requirements.lock @@ -309,20 +309,20 @@ langchain-openai==0.2.9 # via ragas langchain-text-splitters==0.3.2 # via langchain -langfuse==2.48.0 +langfuse==2.59.1 langsmith==0.1.143 # via langchain # via langchain-community # via langchain-core llama-cloud==0.1.6 # via llama-index-indices-managed-llama-cloud -llama-index==0.12.10 +llama-index==0.12.16 llama-index-agent-openai==0.4.1 # via llama-index # via llama-index-program-openai llama-index-cli==0.4.0 # via llama-index -llama-index-core==0.12.10.post1 +llama-index-core==0.12.16.post1 # via llama-index # via llama-index-agent-openai # via llama-index-cli @@ -637,6 +637,7 @@ requests==2.32.3 # via huggingface-hub # via langchain # via langchain-community + # via langfuse # via langsmith # via llama-index-core # via msal diff --git a/backend/tests/rag/workflow/test_chat_app_workflow.py b/backend/tests/rag/workflow/test_chat_app_workflow.py deleted file mode 100644 index a972ef391..000000000 --- a/backend/tests/rag/workflow/test_chat_app_workflow.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest - -import pytest -from sqlmodel import Session - - -from app.core.db import engine -from app.rag.chat_config import ChatEngineConfig -from app.rag.workflows.chat_flow.workflow import ChatFlow - - -@pytest.mark.asyncio -async def test_something(): - with Session(engine) as db_session: - engine_config = ChatEngineConfig.load_from_db(db_session, "default") - flow = ChatFlow(db_session, engine_config) - result = await flow.run( - user_question="What is TiDB?", - chat_history=[], - db_session=db_session, - ) - print(result) - - -if __name__ == "__main__": - unittest.main()