Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ERNIEBot Researcher] update langchain #305

Merged
merged 13 commits into from
Jan 24, 2024
Merged
41 changes: 33 additions & 8 deletions erniebot-agent/applications/erniebot_researcher/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,35 @@ pip install -r requirements.txt
```
wget https://paddlenlp.bj.bcebos.com/pipelines/fonts/SimSun.ttf
```
> 第四步:创建索引
首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:
```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
export AZURE_OPENAI_API_KEY=<openai-api-token>
qingzhong1 marked this conversation as resolved.
Show resolved Hide resolved
export OPENAI_API_KEY=<openai-api-token>
export AZURE_OPENAI_ENDPOINT=<openai-endpoint>
export OPENAI_API_VERSION=<openai-api-version>
```

> 第四步:运行
如果用户有url链接,你可以传入存储url链接的txt或者json文件。
在txt中,每一行存储文件的路径和对应的url链接,例如'https://zhuanlan.zhihu.com/p/659457816 data/Ai_Agent的起源.md'
在json文件中,字典的每一个键是文件的路径,值是url链接
如果用户不传入url文件,则默认文件的路径为其url链接

qingzhong1 marked this conversation as resolved.
Show resolved Hide resolved
用户可以自己传入文件摘要的存储路径。其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,"page_content"存储文件的摘要,"url"是文件的url链接,"name"是文章的名字。

```
python ./tools/preprocessing.py \
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text> \
--path_full_text <the folder path of your full text> \
--url_path <the path of your url text> \
--path_abstract <the json path of your abstract text>
```

> 第五步:运行

首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:

```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
Expand All @@ -78,23 +103,23 @@ Base版本示例运行:

```
python sample_report_example.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

Base版本WebUI运行:

```
python ui.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

高阶版本多智能体自动调度示例脚本运行:

```
python sample_group_agent.py --index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
python sample_group_agent.py --index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

## Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from erniebot_agent.prompt import PromptTemplate

logger = logging.getLogger(__name__)
MAX_RETRY = 10
PLAN_VERIFICATIONS_PROMPT = """
为了验证给出的内容中数字性表述的准确性,您需要创建一系列验证问题,
用于测试原始基线响应中的事实主张。例如,如果长格式响应的一部分包含
Expand Down Expand Up @@ -154,8 +155,8 @@ async def verifications(self, facts_problems: List[dict]):
for item in facts_problems:
question = item["question"]
claim = item["fact"]
context = self.retriever_db.search(question)
context = [i["content"] for i in context]
context = await self.retriever_db(question)
context = [i["content"] for i in context["documents"]]
item["evidence"] = context
anwser = await self.generate_anwser(question, context)
item["anwser"] = anwser
Expand Down Expand Up @@ -222,11 +223,24 @@ async def report_fact(self, report: str):
messages: List[Message] = [
HumanMessage(content=self.prompt_plan_verifications.format(base_context=item))
]
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
retry_count = 0
while True:
try:
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
break
except Exception as e:
retry_count += 1
logger.error(e)
if retry_count > MAX_RETRY:
raise Exception(
f"Failed to edit research for {report} after {MAX_RETRY} times."
)
continue

else:
text.append(item)
return "\n\n".join(text)
Expand Down
12 changes: 11 additions & 1 deletion erniebot-agent/applications/erniebot_researcher/polish_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
citation_index_name: str,
dir_path: str,
report_type: str,
build_index_function: Any,
search_tool: Any,
system_message: Optional[SystemMessage] = None,
callbacks=None,
):
Expand All @@ -58,6 +60,8 @@ def __init__(
self.prompt_template_polish = PromptTemplate(
template=self.template_polish, input_variables=["content"]
)
self.build_index_function = build_index_function
self.search_tool = search_tool
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
Expand Down Expand Up @@ -143,7 +147,13 @@ async def _run(self, report, summarize=None):
final_report = await self.polish_paragraph(report, abstract, key)
await self._callback_manager.on_tool_start(self, tool=self.citation_tool, input_args=final_report)
if summarize is not None:
citation_search = add_citation(summarize, self.citation_index_name, self.embeddings)
citation_search = add_citation(
summarize,
self.citation_index_name,
self.embeddings,
self.build_index_function,
self.search_tool,
)
final_report, path = await self.citation_tool(
report=final_report,
agent_name=self.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ langchain
scikit-learn
markdown
WeasyPrint==52.5
openai
langchain_openai
18 changes: 7 additions & 11 deletions erniebot-agent/applications/erniebot_researcher/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,21 @@ def __init__(

async def run_search_summary(self, query: str):
responses = []
url_dict = {}
results = self.retriever_fulltext_db.search(query, top_k=3)
results = await self.retriever_fulltext_db(query, top_k=3)
length_limit = 0
await self._callback_manager.on_tool_start(agent=self, tool=self.summarize_tool, input_args=query)
for doc in results:
for doc in results["documents"]:
res = await self.summarize_tool(doc["content"], query)
# Add reference to avoid hallucination
data = {"summary": res, "url": doc["url"], "name": doc["title"]}
data = {"summary": res, "url": doc["meta"]["url"], "name": doc["meta"]["name"]}
length_limit += len(res)
if length_limit < SUMMARIZE_MAX_LENGTH:
responses.append(data)
key = doc["title"]
value = doc["url"]
url_dict[key] = value
else:
logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
break
await self._callback_manager.on_tool_end(self, tool=self.summarize_tool, response=responses)
return responses, url_dict
return responses

async def run(self, query: str):
"""
Expand All @@ -117,8 +113,8 @@ async def run(self, query: str):

if self.use_context_planning:
sub_queries = []
res = self.retriever_abstract_db.search(query, top_k=3)
context = [item["content"] for item in res]
res = await self.retriever_abstract_db(query, top_k=3)
context = [item["content"] for item in res["documents"]]
context_content = ""
await self._callback_manager.on_tool_start(
agent=self, tool=self.task_planning_tool, input_args=query
Expand Down Expand Up @@ -157,7 +153,7 @@ async def run(self, query: str):
# Run Sub-Queries
paragraphs_item = []
for sub_query in sub_queries:
research_result, url_dict = await self.run_search_summary(sub_query)
research_result = await self.run_search_summary(sub_query)
paragraphs_item.extend(research_result)

paragraphs = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@

from editor_actor_agent import EditorActorAgent
from group_agent import GroupChat, GroupChatManager
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from reviser_actor_agent import ReviserActorAgent
from tools.intent_detection_tool import IntentDetectionTool
from tools.outline_generation_tool import OutlineGenerationTool
from tools.preprocessing import get_retriver_by_type
from tools.ranking_tool import TextRankingTool
from tools.report_writing_tool import ReportWritingTool
from tools.semantic_citation_tool import SemanticCitationTool
from tools.summarization_tool import TextSummarizationTool
from tools.task_planning_tool import TaskPlanningTool
from tools.utils import FaissSearch, build_index

from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
Expand Down Expand Up @@ -71,22 +71,29 @@
default="openai_embedding",
help="['openai_embedding','baizhong','ernie_embedding']",
)
parser.add_argument(
"--use_frame",
type=str,
default="langchain",
choices=["langchain", "llama_index"],
help="['langchain','llama_index']",
)
args = parser.parse_args()


def get_retrievers():
def get_retrievers(build_index_function, retrieval_tool):
if args.embedding_type == "openai_embedding":
embeddings = OpenAIEmbeddings(deployment="text-embedding-ada")
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada")
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "ernie_embedding":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "baizhong":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
retriever_search = BaizhongSearch(
Expand All @@ -102,7 +109,9 @@ def get_retrievers():
return {"full_text": retriever_search, "abstract": abstract_search, "embeddings": embeddings}


def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
def get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
):
research_actor = ResearchAgent(
name="generate_report",
system_message=SystemMessage("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿"),
Expand Down Expand Up @@ -134,6 +143,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
dir_path=target_path,
report_type=args.report_type,
citation_tool=tool_sets["semantic_citation"],
build_index_function=build_index_function,
search_tool=retrieval_tool,
)
return {
"research_agents": research_actor,
Expand Down Expand Up @@ -171,10 +182,12 @@ def main(query):
os.makedirs(target_path, exist_ok=True)
llm_long = ERNIEBot(model="ernie-longtext")
llm = ERNIEBot(model="ernie-4.0")

retriever_sets = get_retrievers()
build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame)
retriever_sets = get_retrievers(build_index_function, retrieval_tool)
tool_sets = get_tools(llm, llm_long)
agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path)
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
research_actor = agent_sets["research_agents"]
report = asyncio.run(research_actor.run(query))
report = {"report": report[0], "paragraphs": report[1]}
Expand Down
Loading