-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ERNIEBot Researcher] update llama_index (#311)
* update llama_index * update llama_index * update llama_index * update llama_index * update llama_index * update llama_index * update llama_index
- Loading branch information
1 parent
c2284b3
commit 885addb
Showing
7 changed files
with
161 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ WeasyPrint==52.5 | |
openai>1.0 | ||
langchain_openai | ||
zhon | ||
llama_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ types-beautifulsoup4 | |
lxml | ||
langchain-community | ||
arxiv | ||
responses | ||
responses | ||
llama_index |
29 changes: 29 additions & 0 deletions
29
erniebot-agent/tests/unit_tests/tools/test_langchain_retrieval_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
from langchain.docstore.document import Document | ||
|
||
from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool | ||
|
||
|
||
class FakeSearch: | ||
def similarity_search_with_relevance_scores(self, query: str, top_k: int = 10, **kwargs): | ||
doc = (Document(page_content="电动汽车的品牌有哪些?各有什么特点?"), 0.5) | ||
retrieval_results = [doc] | ||
|
||
return retrieval_results | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tool(): | ||
db = FakeSearch() | ||
return LangChainRetrievalTool(db) | ||
|
||
|
||
def test_schema(tool): | ||
function_call_schema = tool.function_call_schema() | ||
assert function_call_schema["description"] == LangChainRetrievalTool.description | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_tool(tool): | ||
results = await tool(query="This is a test query") | ||
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]} |
33 changes: 33 additions & 0 deletions
33
erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pytest | ||
from llama_index.schema import NodeWithScore, TextNode | ||
|
||
from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool | ||
|
||
|
||
class FakeRetrieval: | ||
def retrieve(self, query: str): | ||
doc = NodeWithScore(node=TextNode(text="电动汽车的品牌有哪些?各有什么特点?"), score=0.5) | ||
retrieval_results = [doc] | ||
return retrieval_results | ||
|
||
|
||
class FakeSearch: | ||
def as_retriever(self, similarity_top_k: int = 10, **kwargs): | ||
return FakeRetrieval() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tool(): | ||
db = FakeSearch() | ||
return LlamaIndexRetrievalTool(db) | ||
|
||
|
||
def test_schema(tool): | ||
function_call_schema = tool.function_call_schema() | ||
assert function_call_schema["description"] == LlamaIndexRetrievalTool.description | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_tool(tool): | ||
results = await tool(query="This is a test query") | ||
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]} |