-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a2d529
commit 80dc5ad
Showing
2 changed files
with
60 additions
and
0 deletions.
There are no files selected for viewing
29 changes: 29 additions & 0 deletions
29
erniebot-agent/tests/unit_tests/tools/test_langchain_retrieval_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
from langchain.docstore.document import Document | ||
|
||
from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool | ||
|
||
|
||
class FakeSearch: | ||
def similarity_search_with_relevance_scores(self, query: str, top_k: int = 10, **kwargs): | ||
doc = (Document(page_content="电动汽车的品牌有哪些?各有什么特点?"), 0.5) | ||
retrieval_results = [doc] | ||
|
||
return retrieval_results | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tool(): | ||
db = FakeSearch() | ||
return LangChainRetrievalTool(db) | ||
|
||
|
||
def test_schema(tool): | ||
function_call_schema = tool.function_call_schema() | ||
assert function_call_schema["description"] == LangChainRetrievalTool.description | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_tool(tool): | ||
results = await tool(query="This is a test query") | ||
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]} |
31 changes: 31 additions & 0 deletions
31
erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
from llama_index import Document | ||
|
||
from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool | ||
|
||
|
||
class FakeSearch: | ||
def as_retriever(self, similarity_top_k: int = 10, **kwargs): | ||
def retrieve(query: str): | ||
doc = (Document(text="电动汽车的品牌有哪些?各有什么特点?"), 0.5) | ||
retrieval_results = [doc] | ||
return retrieval_results | ||
|
||
return retrieve | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tool(): | ||
db = FakeSearch() | ||
return LlamaIndexRetrievalTool(db) | ||
|
||
|
||
def test_schema(tool): | ||
function_call_schema = tool.function_call_schema() | ||
assert function_call_schema["description"] == LlamaIndexRetrievalTool.description | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_tool(tool): | ||
results = await tool(query="This is a test query") | ||
assert results == {"documents": [{"content": "电动汽车的品牌有哪些?各有什么特点?", "score": 0.5, "meta": {}}]} |