Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ERNIEBot Researcher] update llama_index #311

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ WeasyPrint==52.5
openai>1.0
langchain_openai
zhon
llama_index
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import json
import os
from typing import Any, List
from typing import Any, Callable, List

import faiss
import jsonlines
import spacy
from langchain.docstore.document import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
from llama_index import SimpleDirectoryReader
from llama_index import (
qingzhong1 marked this conversation as resolved.
Show resolved Hide resolved
ServiceContext,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from llama_index.vector_stores.faiss import FaissVectorStore

from erniebot_agent.memory import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate
Expand All @@ -22,6 +32,19 @@
"""


def split_by_sentence_spacy(
pipeline="zh_core_web_sm", max_length: int = 1_000_000
) -> Callable[[str], List[str]]:
sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"])
sentencizer.max_length = max_length

def split(text: str) -> List[str]:
sentences = (s.text for s in sentencizer(text).sents)
return [item for item in sentences]

return split


class GenerateAbstract:
def __init__(self, llm, chunk_size: int = 1500, chunk_overlap=0, path="./abstract.json"):
self.chunk_size = chunk_size
Expand Down Expand Up @@ -175,8 +198,66 @@ def build_index_langchain(


def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None):
# TODO: Adapt to llamaindex
pass
if embeddings.model == "text-embedding-ada-002":
d = 1536
elif embeddings.model == "ernie-text-embedding":
d = 384
else:
raise ValueError(f"model {embeddings.model} not support")

faiss_index = faiss.IndexFlatIP(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
if os.path.exists(index_name):
vector_store = FaissVectorStore.from_persist_dir(persist_dir=index_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=index_name)
service_context = ServiceContext.from_defaults(embed_model=embeddings)
index = load_index_from_storage(storage_context=storage_context, service_context=service_context)
return index
if not abstract and not origin_data:
documents = preprocess(path, url_path=url_path, use_langchain=False)
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif abstract:
nodes = get_abstract_data(path, use_langchain=False)
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif origin_data:
nodes = [TextNode(text=item.page_content, metadata=item.metadata) for item in origin_data]
text_splitter = SentenceSplitter(
chunking_tokenizer_fn=split_by_sentence_spacy(), chunk_size=1024, chunk_overlap=0
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index


def get_retriver_by_type(frame_type):
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/applications/erniebot_researcher/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def generate_report(query, history=[]):
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
team_actor = ResearchTeam(**agent_sets, use_reflection=False)
team_actor = ResearchTeam(**agent_sets, use_reflection=True)
report, path = asyncio.run(team_actor.run(query, args.iterations))
return report, path

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,13 @@ def __init__(
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
# TODO: Adapt to llamaindex
pass
retriever = self.db.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

给erniebot-agent/src/erniebot_agent/tools/llama_index_retrieval_tool.py和lanchain retrieval tool加上单测,参考我的PR。

https://github.com/PaddlePaddle/ERNIE-SDK/pull/258/files#diff-c2373d021a8bfd57cdd710d5cc38b078122c4a93c295af39fbfe659e665cd476

docs = []
for doc in nodes:
if doc.score > self.threshold:
new_doc = {"content": doc.node.text, "score": doc.score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)
return {"documents": docs}
3 changes: 2 additions & 1 deletion erniebot-agent/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ types-beautifulsoup4
lxml
langchain-community
arxiv
responses
responses
llama_index
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from langchain.docstore.document import Document

from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool


class FakeSearch:
def similarity_search_with_relevance_scores(self, query: str, top_k: int = 10, **kwargs):
doc = (Document(page_content="电动汽车的品牌有哪些?各有什么特点?"), 0.5)
retrieval_results = [doc]

return retrieval_results


@pytest.fixture(scope="module")
def tool():
db = FakeSearch()
return LangChainRetrievalTool(db)


def test_schema(tool):
function_call_schema = tool.function_call_schema()
assert function_call_schema["description"] == LangChainRetrievalTool.description


@pytest.mark.asyncio
async def test_tool(tool):
results = await tool(query="This is a test query")
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from llama_index.schema import NodeWithScore, TextNode

from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool


class FakeRetrieval:
def retrieve(self, query: str):
doc = NodeWithScore(node=TextNode(text="电动汽车的品牌有哪些?各有什么特点?"), score=0.5)
retrieval_results = [doc]
return retrieval_results


class FakeSearch:
def as_retriever(self, similarity_top_k: int = 10, **kwargs):
return FakeRetrieval()


@pytest.fixture(scope="module")
def tool():
db = FakeSearch()
return LlamaIndexRetrievalTool(db)


def test_schema(tool):
function_call_schema = tool.function_call_schema()
assert function_call_schema["description"] == LlamaIndexRetrievalTool.description


@pytest.mark.asyncio
async def test_tool(tool):
results = await tool(query="This is a test query")
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]}