Skip to content

Commit

Permalink
[ERNIEBot Researcher] update llama_index (#311)
Browse files Browse the repository at this point in the history
* update llama_index

* update llama_index

* update llama_index

* update llama_index

* update llama_index

* update llama_index

* update llama_index
  • Loading branch information
qingzhong1 committed Jan 24, 2024
1 parent c2284b3 commit 885addb
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ WeasyPrint==52.5
openai>1.0
langchain_openai
zhon
llama_index
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import json
import os
from typing import Any, List
from typing import Any, Callable, List

import faiss
import jsonlines
import spacy
from langchain.docstore.document import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
from llama_index import SimpleDirectoryReader
from llama_index import (
ServiceContext,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from llama_index.vector_stores.faiss import FaissVectorStore

from erniebot_agent.memory import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate
Expand All @@ -22,6 +32,19 @@
"""


def split_by_sentence_spacy(
pipeline="zh_core_web_sm", max_length: int = 1_000_000
) -> Callable[[str], List[str]]:
sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"])
sentencizer.max_length = max_length

def split(text: str) -> List[str]:
sentences = (s.text for s in sentencizer(text).sents)
return [item for item in sentences]

return split


class GenerateAbstract:
def __init__(self, llm, chunk_size: int = 1500, chunk_overlap=0, path="./abstract.json"):
self.chunk_size = chunk_size
Expand Down Expand Up @@ -175,8 +198,66 @@ def build_index_langchain(


def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None):
# TODO: Adapt to llamaindex
pass
if embeddings.model == "text-embedding-ada-002":
d = 1536
elif embeddings.model == "ernie-text-embedding":
d = 384
else:
raise ValueError(f"model {embeddings.model} not support")

faiss_index = faiss.IndexFlatIP(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
if os.path.exists(index_name):
vector_store = FaissVectorStore.from_persist_dir(persist_dir=index_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=index_name)
service_context = ServiceContext.from_defaults(embed_model=embeddings)
index = load_index_from_storage(storage_context=storage_context, service_context=service_context)
return index
if not abstract and not origin_data:
documents = preprocess(path, url_path=url_path, use_langchain=False)
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif abstract:
nodes = get_abstract_data(path, use_langchain=False)
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif origin_data:
nodes = [TextNode(text=item.page_content, metadata=item.metadata) for item in origin_data]
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index


def get_retriver_by_type(frame_type):
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/applications/erniebot_researcher/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def generate_report(query, history=[]):
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
team_actor = ResearchTeam(**agent_sets, use_reflection=False)
team_actor = ResearchTeam(**agent_sets, use_reflection=True)
report, path = asyncio.run(team_actor.run(query, args.iterations))
return report, path

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,13 @@ def __init__(
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
# TODO: Adapt to llamaindex
pass
retriever = self.db.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
docs = []
for doc in nodes:
if doc.score > self.threshold:
new_doc = {"content": doc.node.text, "score": doc.score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)
return {"documents": docs}
3 changes: 2 additions & 1 deletion erniebot-agent/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ types-beautifulsoup4
lxml
langchain-community
arxiv
responses
responses
llama_index
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from langchain.docstore.document import Document

from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool


class FakeSearch:
def similarity_search_with_relevance_scores(self, query: str, top_k: int = 10, **kwargs):
doc = (Document(page_content="电动汽车的品牌有哪些?各有什么特点?"), 0.5)
retrieval_results = [doc]

return retrieval_results


@pytest.fixture(scope="module")
def tool():
db = FakeSearch()
return LangChainRetrievalTool(db)


def test_schema(tool):
function_call_schema = tool.function_call_schema()
assert function_call_schema["description"] == LangChainRetrievalTool.description


@pytest.mark.asyncio
async def test_tool(tool):
results = await tool(query="This is a test query")
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from llama_index.schema import NodeWithScore, TextNode

from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool


class FakeRetrieval:
def retrieve(self, query: str):
doc = NodeWithScore(node=TextNode(text="电动汽车的品牌有哪些?各有什么特点?"), score=0.5)
retrieval_results = [doc]
return retrieval_results


class FakeSearch:
def as_retriever(self, similarity_top_k: int = 10, **kwargs):
return FakeRetrieval()


@pytest.fixture(scope="module")
def tool():
db = FakeSearch()
return LlamaIndexRetrievalTool(db)


def test_schema(tool):
function_call_schema = tool.function_call_schema()
assert function_call_schema["description"] == LlamaIndexRetrievalTool.description


@pytest.mark.asyncio
async def test_tool(tool):
results = await tool(query="This is a test query")
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]}

0 comments on commit 885addb

Please sign in to comment.