From 2520b8d482a36e39fcb2287614197a981ecb6b79 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Dec 2023 22:16:23 +0800 Subject: [PATCH] [search agent] Add search info (#195) * Add search info * Up format * restore questions * Update retrieval agents * Update functional_agent_with_retrieval.py --------- Co-authored-by: Sijun He --- ...functional_agent_with_retrieval_example.py | 5 +++- .../agents/functional_agent_with_retrieval.py | 26 ++++++++++++++++++- .../src/erniebot_agent/tools/remote_tool.py | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/erniebot-agent/examples/functional_agent_with_retrieval_example.py b/erniebot-agent/examples/functional_agent_with_retrieval_example.py index ce877d3cc..330da95c5 100644 --- a/erniebot-agent/examples/functional_agent_with_retrieval_example.py +++ b/erniebot-agent/examples/functional_agent_with_retrieval_example.py @@ -71,7 +71,9 @@ def offline_ann(data_path, baizhong_db): res = offline_ann(args.data_path, baizhong_db) print(res) - llm = ERNIEBot(model="ernie-3.5", api_type="custom") + llm = ERNIEBot( + model="ernie-bot", api_type="aistudio", enable_multi_step_tool_call=True, enable_citation=True + ) retrieval_tool = BaizhongSearchTool( description="Use Baizhong Search to retrieve documents.", db=baizhong_db, threshold=0.1 @@ -90,6 +92,7 @@ def offline_ann(data_path, baizhong_db): # "abcabc", # ] queries = [ + "今天百度美股的股价是多少?", "心血管科,高血压可以蒸桑拿吗?", "量化交易", "城市景观照明中有过度照明的规定是什么?", diff --git a/erniebot-agent/src/erniebot_agent/agents/functional_agent_with_retrieval.py b/erniebot-agent/src/erniebot_agent/agents/functional_agent_with_retrieval.py index a155e2551..10fe82c68 100644 --- a/erniebot-agent/src/erniebot_agent/agents/functional_agent_with_retrieval.py +++ b/erniebot-agent/src/erniebot_agent/agents/functional_agent_with_retrieval.py @@ -13,7 +13,13 @@ ToolResponse, ) from erniebot_agent.file_io.base import File -from erniebot_agent.messages import AIMessage, FunctionMessage, HumanMessage, Message +from erniebot_agent.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + Message, + SearchInfo, +) from erniebot_agent.prompt import PromptTemplate from erniebot_agent.retrieval import BaizhongSearch from erniebot_agent.tools.base import Tool @@ -100,6 +106,7 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A actions_taken: List[AgentAction] = [] files_involved: List[AgentFile] = [] actions_taken.append(AgentAction(tool_name=self.search_tool.tool_name, tool_args=tool_args)) + tool_ret_json = json.dumps(results, ensure_ascii=False) tool_resp = ToolResponse(json=tool_ret_json, files=[]) llm_resp = await self._async_run_llm_without_hooks( @@ -108,6 +115,23 @@ async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> A system=self.system_message.content if self.system_message is not None else None, ) output_message = llm_resp.message + if output_message.search_info is None: + search_info = SearchInfo(results=[]) + for index, item in enumerate(docs): + search_info["results"].append( + { + "index": index + 1, + "url": "", + "title": item["title"], + } + ) + output_message.search_info = search_info + else: + cur_index = len(output_message.search_info["results"]) + for index, item in enumerate(docs): + output_message.search_info["results"].append( + {"index": cur_index + index + 1, "url": "", "title": item["title"]} + ) chat_history.append(output_message) # Using on_tool_error here since retrieval is formatted as a tool except (Exception, KeyboardInterrupt) as e: diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py index 355f35178..f36b1d265 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py @@ -1,10 +1,10 @@ from __future__ import annotations import base64 +import json from copy import deepcopy from typing import Any, Dict, List, Optional, Type -import json import requests from erniebot_agent.file_io.file_manager import FileManager