Skip to content

Commit

Permalink
[search agent] Add search info (#195)
Browse files Browse the repository at this point in the history
* Add search info

* Up format

* restore questions

* Update retrieval agents

* Update functional_agent_with_retrieval.py

---------

Co-authored-by: Sijun He <[email protected]>
  • Loading branch information
w5688414 and sijunhe committed Dec 21, 2023
1 parent bbecff8 commit 2520b8d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def offline_ann(data_path, baizhong_db):
res = offline_ann(args.data_path, baizhong_db)
print(res)

llm = ERNIEBot(model="ernie-3.5", api_type="custom")
llm = ERNIEBot(
model="ernie-bot", api_type="aistudio", enable_multi_step_tool_call=True, enable_citation=True
)

retrieval_tool = BaizhongSearchTool(
description="Use Baizhong Search to retrieve documents.", db=baizhong_db, threshold=0.1
Expand All @@ -90,6 +92,7 @@ def offline_ann(data_path, baizhong_db):
# "abcabc",
# ]
queries = [
"今天百度美股的股价是多少?",
"心血管科,高血压可以蒸桑拿吗?",
"量化交易",
"城市景观照明中有过度照明的规定是什么?",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
ToolResponse,
)
from erniebot_agent.file_io.base import File
from erniebot_agent.messages import AIMessage, FunctionMessage, HumanMessage, Message
from erniebot_agent.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
Message,
SearchInfo,
)
from erniebot_agent.prompt import PromptTemplate
from erniebot_agent.retrieval import BaizhongSearch
from erniebot_agent.tools.base import Tool
Expand Down Expand Up @@ -100,6 +106,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A
actions_taken: List[AgentAction] = []
files_involved: List[AgentFile] = []
actions_taken.append(AgentAction(tool_name=self.search_tool.tool_name, tool_args=tool_args))

tool_ret_json = json.dumps(results, ensure_ascii=False)
tool_resp = ToolResponse(json=tool_ret_json, files=[])
llm_resp = await self._async_run_llm_without_hooks(
Expand All @@ -108,6 +115,23 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A
system=self.system_message.content if self.system_message is not None else None,
)
output_message = llm_resp.message
if output_message.search_info is None:
search_info = SearchInfo(results=[])
for index, item in enumerate(docs):
search_info["results"].append(
{
"index": index + 1,
"url": "",
"title": item["title"],
}
)
output_message.search_info = search_info
else:
cur_index = len(output_message.search_info["results"])
for index, item in enumerate(docs):
output_message.search_info["results"].append(
{"index": cur_index + index + 1, "url": "", "title": item["title"]}
)
chat_history.append(output_message)
# Using on_tool_error here since retrieval is formatted as a tool
except (Exception, KeyboardInterrupt) as e:
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/src/erniebot_agent/tools/remote_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import base64
import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type

import json
import requests

from erniebot_agent.file_io.file_manager import FileManager
Expand Down

0 comments on commit 2520b8d

Please sign in to comment.