Skip to content

Commit

Permalink
[NeuralChat] Add ut (intel#1083)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liangyx2 authored Jan 10, 2024
1 parent ccd8759 commit 938a3f0
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
# limitations under the License.

"""The wrapper for Child-Parent retriever based on langchain"""
from langchain.retrievers import MultiVectorRetriever
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
from langchain_core.pydantic_v1 import Field
from enum import Enum
from typing import List
from langchain_core.documents import Document
from langchain.callbacks.manager import CallbackManagerForRetrieverRun


class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
Expand All @@ -30,15 +34,17 @@ class SearchType(str, Enum):
"""Maximal Marginal Relevance reranking of similarity search."""


class ChildParentRetriever(MultiVectorRetriever):
class ChildParentRetriever(BaseRetriever):
"""Retrieve from a set of multiple embeddings for the same document."""

vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
parentstore: VectorStore

def get_context(self, query:str, *, run_manager: CallbackManagerForRetrieverRun):
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""

def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
Expand All @@ -52,15 +58,20 @@ def get_context(self, query:str, *, run_manager: CallbackManagerForRetrieverRun)
)
else:
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)

ids = []
for d in sub_docs:
if d.metadata['doc_id'] not in ids:
ids.append(d.metadata['doc_id'])
if d.metadata["identify_id"] not in ids:
ids.append(d.metadata['identify_id'])
retrieved_documents = self.parentstore.get(ids)
return retrieved_documents

def get_context(self, query):
context = ''
links = []
for doc in retrieved_documents:
context = context + doc.page_content + " "
links.append(doc.metadata['source'])
retrieved_documents = self.get_relevant_documents(query)
for doc in retrieved_documents['documents']:
context = context + doc + " "
for meta in retrieved_documents['metadatas']:
links.append(meta['source'])
return context.strip(), links
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def document_transfer(data_collection):
documents = []
for data, meta in data_collection:
doc_id = str(uuid.uuid4())
metadata = {"source": meta, "doc_id":doc_id}
metadata = {"source": meta, "identify_id":doc_id}
doc = Document(page_content=data, metadata=metadata)
documents.append(doc)
return documents

def document_append_id(documents):
for _doc in documents:
_doc.metadata["doc_id"] = _doc.metadata["identify_id"]
return documents


class Agent_QA():
"""
Expand Down Expand Up @@ -152,6 +157,7 @@ def __init__(self,
knowledge_base.client.close()
elif self.retrieval_type == "child_parent": # Using child-parent store retriever
child_documents = self.splitter.split_documents(langchain_documents)
langchain_documents = document_append_id(langchain_documents)
if append:
knowledge_base = self.database.from_documents(documents=langchain_documents, embedding=self.embeddings,
**kwargs)
Expand Down Expand Up @@ -196,13 +202,17 @@ def create(self, input_path, **kwargs):
"""
data_collection = self.document_parser.load(input=input_path, **kwargs)
langchain_documents = document_transfer(data_collection)
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)

if self.retrieval_type == 'default':
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, document_store=knowledge_base, \
**kwargs).retriever
elif self.retrieval_type == "child_parent":
child_documents = self.splitter.split_documents(langchain_documents)
langchain_documents = document_append_id(langchain_documents)
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)
child_knowledge_base = self.database.from_documents(documents=child_documents, sign='child', \
embedding=self.embeddings, **kwargs)
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, document_store=knowledge_base, \
Expand All @@ -215,13 +225,18 @@ def append_localdb(self, append_path, **kwargs):

data_collection = self.document_parser.load(input=append_path, **kwargs)
langchain_documents = document_transfer(data_collection)
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)

if self.retrieval_type == 'default':
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, \
document_store=knowledge_base, **kwargs).retriever
elif self.retrieval_type == "child_parent":
child_knowledge_base = self.database.from_documents(documents=langchain_documents, sign = 'child', \
child_documents = self.splitter.split_documents(langchain_documents)
langchain_documents = document_append_id(langchain_documents)
knowledge_base = self.database.from_documents(documents=langchain_documents, \
embedding=self.embeddings, **kwargs)
child_knowledge_base = self.database.from_documents(documents=child_documents, sign = 'child', \
embedding=self.embeddings, **kwargs)
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, document_store=knowledge_base, \
child_document_store=child_knowledge_base, **kwargs).retriever
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""The wrapper for Retriever based on langchain"""
from intel_extension_for_transformers.langchain.retrievers import VectorStoreRetriever, ChildParentRetriever
import logging

logging.basicConfig(
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
Expand All @@ -33,7 +34,7 @@ def __init__(self, retrieval_type='default', document_store=None, child_document
if self.retrieval_type == "default":
self.retriever = VectorStoreRetriever(vectorstore = document_store, **kwargs)
elif self.retrieval_type == "child_parent":
self.retriever = ChildParentRetriever(vectorstore=child_document_store, parentstore=document_store, \
**kwargs)
self.retriever = ChildParentRetriever(parentstore=document_store, \
vectorstore=child_document_store, **kwargs) # pylint: disable=abstract-class-instantiated
else:
logging.error('The chosen retrieval type remains outside the supported scope.')
188 changes: 185 additions & 3 deletions intel_extension_for_transformers/neural_chat/tests/ci/api/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
# All UT cases use 'facebook/opt-125m' to reduce test time.
class TestChatbotBuilder(unittest.TestCase):
def setUp(self):
if os.path.exists("test_for_accuracy"):
shutil.rmtree("test_for_accuracy", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("output"):
shutil.rmtree("output")
if os.path.exists("test_for_accuracy"):
shutil.rmtree("test_for_accuracy", ignore_errors=True)
return super().tearDown()

def test_retrieval_accuracy(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
plugins.retrieval.args["input_path"] = "../assets/docs/"
plugins.retrieval.args["persist_directory"] = "./test_for_accuracy"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
Expand All @@ -45,6 +48,185 @@ def test_retrieval_accuracy(self):
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_txt(unittest.TestCase):
def setUp(self):
if os.path.exists("test_txt"):
shutil.rmtree("test_txt", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_txt"):
shutil.rmtree("test_txt", ignore_errors=True)
return super().tearDown()

def test_retrieval_txt(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
plugins.retrieval.args["persist_directory"] = "./test_txt"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_docx(unittest.TestCase):
def setUp(self):
if os.path.exists("test_docx"):
shutil.rmtree("test_docx", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_docx"):
shutil.rmtree("test_docx", ignore_errors=True)
return super().tearDown()

def test_retrieval_docx(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.docx"
plugins.retrieval.args["persist_directory"] = "./test_docx"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_xlsx(unittest.TestCase):
def setUp(self):
if os.path.exists("test_xlsx"):
shutil.rmtree("test_xlsx", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_xlsx"):
shutil.rmtree("test_xlsx", ignore_errors=True)
return super().tearDown()

def test_retrieval_xlsx(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.xlsx"
plugins.retrieval.args["persist_directory"] = "./test_xlsx"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("Who is the CEO of Intel?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_xlsx_1(unittest.TestCase):
def setUp(self):
if os.path.exists("test_xlsx_1"):
shutil.rmtree("test_xlsx_1", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_xlsx_1"):
shutil.rmtree("test_xlsx_1", ignore_errors=True)
return super().tearDown()

def test_retrieval_xlsx_1(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample_1.xlsx"
plugins.retrieval.args["persist_directory"] = "./test_xlsx_1"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("Who is the CEO of Intel?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_xlsx_2(unittest.TestCase):
def setUp(self):
if os.path.exists("test_xlsx_2"):
shutil.rmtree("test_xlsx_2", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_xlsx_2"):
shutil.rmtree("test_xlsx_2", ignore_errors=True)
return super().tearDown()

def test_retrieval_xlsx_2(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample_2.xlsx"
plugins.retrieval.args["persist_directory"] = "./test_xlsx_2"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("Who is the CEO of Intel?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_jsonl(unittest.TestCase):
def setUp(self):
if os.path.exists("test_jsonl"):
shutil.rmtree("test_jsonl", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_jsonl"):
shutil.rmtree("test_jsonl", ignore_errors=True)
return super().tearDown()

def test_retrieval_jsonl(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.jsonl"
plugins.retrieval.args["persist_directory"] = "./test_jsonl"
plugins.retrieval.args["retrieval_type"] = 'default'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("What does this blog talk about?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False

class TestChatbotBuilder_child_parent(unittest.TestCase):
def setUp(self):
if os.path.exists("test_rag"):
shutil.rmtree("test_rag", ignore_errors=True)
if os.path.exists("test_rag_child"):
shutil.rmtree("test_rag_child", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_rag"):
shutil.rmtree("test_rag", ignore_errors=True)
if os.path.exists("test_rag_child"):
shutil.rmtree("test_rag_child", ignore_errors=True)
return super().tearDown()

def test_retrieval_child_parent(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
plugins.retrieval.args["persist_directory"] = "./test_rag"
plugins.retrieval.args["retrieval_type"] = "child_parent"
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
plugins.retrieval.args["retrieval_type"] = 'default'
self.assertIsNotNone(response)
plugins.retrieval.enable = False

if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def test_html_loader(self):
self.assertIsNotNone(vectordb)

if __name__ == "__main__":
unittest.main()
unittest.main()
Loading

0 comments on commit 938a3f0

Please sign in to comment.