Skip to content

Commit

Permalink
make callback manager local
Browse files Browse the repository at this point in the history
  • Loading branch information
IANTHEREAL committed Jul 19, 2024
1 parent f8b4104 commit 7536bc3
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions backend/app/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import jinja2
from sqlmodel import Session, select
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core import VectorStoreIndex, ServiceContext
from llama_index.core.base.llms.base import ChatMessage
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.base.response.schema import StreamingResponse
Expand Down Expand Up @@ -56,10 +56,10 @@ def __init__(

self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.chat_engine_config.get_db_chat_engine()
self._llm = self.chat_engine_config.get_llama_llm()
# self._llm = self.chat_engine_config.get_llama_llm()
self._dspy_lm = self.chat_engine_config.get_dspy_lm()
self._fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm()
self._embed_model = self.chat_engine_config.get_embedding_model()
# self._embed_model = self.chat_engine_config.get_embedding_model()
self._reranker = self.chat_engine_config.get_reranker()

def chat(
Expand Down Expand Up @@ -153,7 +153,10 @@ def _chat(
),
)

def _set_langfuse_callback_manager():
_llm = self.chat_engine_config.get_llama_llm()
_embed_model = self.chat_engine_config.get_embedding_model()

def _get_langfuse_callback_manager():
# Why we don't use high-level decorator `observe()` as \
# `https://langfuse.com/docs/integrations/llama-index/get-started` suggested?
# track:
Expand All @@ -162,8 +165,10 @@ def _set_langfuse_callback_manager():
observation = langfuse.trace(id=trace_id)
langfuse_handler = LlamaIndexCallbackHandler()
langfuse_handler.set_root(observation)
Settings.callback_manager = CallbackManager([langfuse_handler])
self._llm.callback_manager = Settings.callback_manager
callback_manager = CallbackManager([langfuse_handler])
_llm.callback_manager = callback_manager
_embed_model.callback_manager = callback_manager
return callback_manager

# Frontend requires the empty event to start the chat
yield ChatEvent(
Expand All @@ -188,23 +193,23 @@ def _set_langfuse_callback_manager():
)

# 1. Retrieve entities, relations, and chunks from the knowledge graph
_set_langfuse_callback_manager()
callback_manager = _get_langfuse_callback_manager()
kg_config = self.chat_engine_config.knowledge_graph
if kg_config.enabled:
graph_store = TiDBGraphStore(
dspy_lm=self._fast_dspy_lm,
session=self.db_session,
embed_model=self._embed_model,
embed_model=_embed_model,
)
graph_index: KnowledgeGraphIndex = KnowledgeGraphIndex.from_existing(
dspy_lm=self._fast_dspy_lm,
kg_store=graph_store,
callback_manager=Settings.callback_manager,
callback_manager=callback_manager,
)

if kg_config.using_intent_search:
with Settings.callback_manager.as_trace("retrieve_with_weight"):
with Settings.callback_manager.event(
with callback_manager.as_trace("retrieve_with_weight"):
with callback_manager.event(
MyCBEventType.RETRIEVE_FROM_GRAPH,
payload={
EventPayload.QUERY_STR: {
Expand Down Expand Up @@ -253,13 +258,13 @@ def _set_langfuse_callback_manager():
display="Refine the user question ...",
),
)
_set_langfuse_callback_manager()
with Settings.callback_manager.as_trace("condense_question"):
with Settings.callback_manager.event(
callback_manager = _get_langfuse_callback_manager()
with callback_manager.as_trace("condense_question"):
with callback_manager.event(
MyCBEventType.CONDENSE_QUESTION,
payload={EventPayload.QUERY_STR: user_question},
) as event:
refined_question = self._llm.predict(
refined_question = _llm.predict(
get_prompt_by_jinja2_template(
self.chat_engine_config.llm.condense_question_prompt,
graph_knowledges=graph_knowledges_context,
Expand All @@ -279,7 +284,7 @@ def _set_langfuse_callback_manager():
display="Search related documents ...",
),
)
_set_langfuse_callback_manager()
callback_manager = _get_langfuse_callback_manager()
text_qa_template = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.text_qa_prompt,
graph_knowledges=graph_knowledges_context,
Expand All @@ -288,18 +293,24 @@ def _set_langfuse_callback_manager():
self.chat_engine_config.llm.refine_prompt,
graph_knowledges=graph_knowledges_context,
)
service_context = ServiceContext.from_defaults(
llm=_llm,
embed_model=_embed_model,
callback_manager=callback_manager,
)
vector_store = TiDBVectorStore(session=self.db_session)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
service_context=service_context,
)
query_engine = vector_index.as_query_engine(
llm=self._llm,
llm=_llm,
node_postprocessors=[self._reranker],
streaming=True,
text_qa_template=text_qa_template,
refine_template=refine_template,
similarity_top_k=100,
service_context=service_context,
)
response: StreamingResponse = query_engine.query(refined_question)
source_documents = self._get_source_documents(response)
Expand Down

0 comments on commit 7536bc3

Please sign in to comment.