From a170c4dc78e72ec35e31e18a845e0bfabc715171 Mon Sep 17 00:00:00 2001 From: "sarju.ladwa@salesforce.com" Date: Mon, 1 May 2023 19:35:02 +0100 Subject: [PATCH] Add support for ConversationalRetrievalChain. ChatVectorDBChain is depricated. --- query_data.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/query_data.py b/query_data.py index c0028317f..1a137ae7f 100644 --- a/query_data.py +++ b/query_data.py @@ -1,7 +1,7 @@ """Create a ChatVectorDBChain for question/answering.""" from langchain.callbacks.base import AsyncCallbackManager from langchain.callbacks.tracers import LangChainTracer -from langchain.chains import ChatVectorDBChain +from langchain.chains import ConversationalRetrievalChain from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, QA_PROMPT) from langchain.chains.llm import LLMChain @@ -12,9 +12,9 @@ def get_chain( vectorstore: VectorStore, question_handler, stream_handler, tracing: bool = False -) -> ChatVectorDBChain: - """Create a ChatVectorDBChain for question/answering.""" - # Construct a ChatVectorDBChain with a streaming llm for combine docs +) -> ConversationalRetrievalChain: + """Create a ConversationalRetrievalChain for question/answering.""" + # Construct a ConversationalRetrievalChain with a streaming llm for combine docs # and a separate, non-streaming llm for question generation manager = AsyncCallbackManager([]) question_manager = AsyncCallbackManager([question_handler]) @@ -45,10 +45,11 @@ def get_chain( streaming_llm, chain_type="stuff", prompt=QA_PROMPT, callback_manager=manager ) - qa = ChatVectorDBChain( - vectorstore=vectorstore, + qa = ConversationalRetrievalChain( + retriever=vectorstore.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator, callback_manager=manager, + verbose=True ) return qa