Skip to content

Commit

Permalink
Fix callback manager concurrent unsafe issue (#177)
Browse files Browse the repository at this point in the history
Co-authored-by: wd0517 <[email protected]>
  • Loading branch information
IANTHEREAL and wd0517 authored Jul 22, 2024
1 parent f8b4104 commit e9aebb0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
52 changes: 30 additions & 22 deletions backend/app/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import jinja2
from sqlmodel import Session, select
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core import VectorStoreIndex, ServiceContext
from llama_index.core.base.llms.base import ChatMessage
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.base.response.schema import StreamingResponse
Expand Down Expand Up @@ -56,10 +56,6 @@ def __init__(

self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.chat_engine_config.get_db_chat_engine()
self._llm = self.chat_engine_config.get_llama_llm()
self._dspy_lm = self.chat_engine_config.get_dspy_lm()
self._fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm()
self._embed_model = self.chat_engine_config.get_embedding_model()
self._reranker = self.chat_engine_config.get_reranker()

def chat(
Expand Down Expand Up @@ -153,7 +149,11 @@ def _chat(
),
)

def _set_langfuse_callback_manager():
_llm = self.chat_engine_config.get_llama_llm()
_embed_model = self.chat_engine_config.get_embedding_model()
_dspy_lm = self.chat_engine_config.get_dspy_lm()

def _get_langfuse_callback_manager():
# Why we don't use high-level decorator `observe()` as \
# `https://langfuse.com/docs/integrations/llama-index/get-started` suggested?
# track:
Expand All @@ -162,8 +162,10 @@ def _set_langfuse_callback_manager():
observation = langfuse.trace(id=trace_id)
langfuse_handler = LlamaIndexCallbackHandler()
langfuse_handler.set_root(observation)
Settings.callback_manager = CallbackManager([langfuse_handler])
self._llm.callback_manager = Settings.callback_manager
callback_manager = CallbackManager([langfuse_handler])
_llm.callback_manager = callback_manager
_embed_model.callback_manager = callback_manager
return callback_manager

# Frontend requires the empty event to start the chat
yield ChatEvent(
Expand All @@ -188,23 +190,23 @@ def _set_langfuse_callback_manager():
)

# 1. Retrieve entities, relations, and chunks from the knowledge graph
_set_langfuse_callback_manager()
callback_manager = _get_langfuse_callback_manager()
kg_config = self.chat_engine_config.knowledge_graph
if kg_config.enabled:
graph_store = TiDBGraphStore(
dspy_lm=self._fast_dspy_lm,
dspy_lm=_dspy_lm,
session=self.db_session,
embed_model=self._embed_model,
embed_model=_embed_model,
)
graph_index: KnowledgeGraphIndex = KnowledgeGraphIndex.from_existing(
dspy_lm=self._fast_dspy_lm,
dspy_lm=_dspy_lm,
kg_store=graph_store,
callback_manager=Settings.callback_manager,
callback_manager=callback_manager,
)

if kg_config.using_intent_search:
with Settings.callback_manager.as_trace("retrieve_with_weight"):
with Settings.callback_manager.event(
with callback_manager.as_trace("retrieve_with_weight"):
with callback_manager.event(
MyCBEventType.RETRIEVE_FROM_GRAPH,
payload={
EventPayload.QUERY_STR: {
Expand Down Expand Up @@ -253,13 +255,13 @@ def _set_langfuse_callback_manager():
display="Refine the user question ...",
),
)
_set_langfuse_callback_manager()
with Settings.callback_manager.as_trace("condense_question"):
with Settings.callback_manager.event(
callback_manager = _get_langfuse_callback_manager()
with callback_manager.as_trace("condense_question"):
with callback_manager.event(
MyCBEventType.CONDENSE_QUESTION,
payload={EventPayload.QUERY_STR: user_question},
) as event:
refined_question = self._llm.predict(
refined_question = _llm.predict(
get_prompt_by_jinja2_template(
self.chat_engine_config.llm.condense_question_prompt,
graph_knowledges=graph_knowledges_context,
Expand All @@ -279,7 +281,7 @@ def _set_langfuse_callback_manager():
display="Search related documents ...",
),
)
_set_langfuse_callback_manager()
callback_manager = _get_langfuse_callback_manager()
text_qa_template = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.text_qa_prompt,
graph_knowledges=graph_knowledges_context,
Expand All @@ -288,18 +290,24 @@ def _set_langfuse_callback_manager():
self.chat_engine_config.llm.refine_prompt,
graph_knowledges=graph_knowledges_context,
)
service_context = ServiceContext.from_defaults(
llm=_llm,
embed_model=_embed_model,
callback_manager=callback_manager,
)
vector_store = TiDBVectorStore(session=self.db_session)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
service_context=service_context,
)
query_engine = vector_index.as_query_engine(
llm=self._llm,
llm=_llm,
node_postprocessors=[self._reranker],
streaming=True,
text_qa_template=text_qa_template,
refine_template=refine_template,
similarity_top_k=100,
service_context=service_context,
)
response: StreamingResponse = query_engine.query(refined_question)
source_documents = self._get_source_documents(response)
Expand Down
7 changes: 0 additions & 7 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ def get_dspy_lm(self) -> dspy.LM:
llama_llm = self.get_llama_llm()
return get_dspy_lm_by_llama_llm(llama_llm)

def get_fast_llama_llm(self) -> Optional[LLM]:
return Gemini(model=self.llm.gemini_chat_model.value)

def get_fast_dspy_lm(self) -> Optional[dspy.LM]:
llama_llm = self.get_fast_llama_llm()
return get_dspy_lm_by_llama_llm(llama_llm)

def get_embedding_model(self) -> BaseEmbedding:
# The embedding model should remain the same for both building and chatting,
# currently we do not support dynamic configuration of embedding model
Expand Down

0 comments on commit e9aebb0

Please sign in to comment.