Skip to content

Commit

Permalink
[ERNIEBot Researcher] update langchain (#305)
Browse files Browse the repository at this point in the history
* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain

* update langchain
  • Loading branch information
qingzhong1 committed Jan 24, 2024
1 parent e7bfe4f commit c2284b3
Show file tree
Hide file tree
Showing 13 changed files with 548 additions and 162 deletions.
46 changes: 39 additions & 7 deletions erniebot-agent/applications/erniebot_researcher/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,42 @@ pip install -r requirements.txt
wget https://paddlenlp.bj.bcebos.com/pipelines/fonts/SimSun.ttf
```

> 第四步:运行
> 第四步:创建索引
下载实例数据

```
wget https://paddlenlp.bj.bcebos.com/pipelines/erniebot_researcher_example.tar.gz
tar xvf erniebot_researcher_example.tar.gz
```

首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:

```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
```

如果用户有url链接,你可以传入存储url链接的txt。
在txt中,每一行存储文件的路径和对应的url链接,例如:
'https://zhuanlan.zhihu.com/p/659457816 erniebot_researcher_example/Ai_Agent的起源.md'

如果用户不传入url文件,则默认文件的路径为其url链接

用户可以自己传入文件摘要的存储路径。其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,"page_content"存储文件的摘要,"url"是文件的url链接,"name"是文章的名字。例如:
[{"page_content":"文章摘要","url":"https://zhuanlan.zhihu.com/p/659457816","name":Ai_Agent的起源},...]
```
python ./tools/preprocessing.py \
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text> \
--path_full_text <the folder path of your full text> \
--url_path <the path of your url text> \
--path_abstract <the json path of your abstract text>
```

> 第五步:运行

```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
Expand All @@ -78,23 +110,23 @@ Base版本示例运行:

```
python sample_report_example.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

Base版本WebUI运行:

```
python ui.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

高阶版本多智能体自动调度示例脚本运行:

```
python sample_group_agent.py --index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
python sample_group_agent.py --index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

## Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from erniebot_agent.prompt import PromptTemplate

logger = logging.getLogger(__name__)
MAX_RETRY = 10
PLAN_VERIFICATIONS_PROMPT = """
为了验证给出的内容中数字性表述的准确性,您需要创建一系列验证问题,
用于测试原始基线响应中的事实主张。例如,如果长格式响应的一部分包含
Expand Down Expand Up @@ -154,8 +155,8 @@ async def verifications(self, facts_problems: List[dict]):
for item in facts_problems:
question = item["question"]
claim = item["fact"]
context = self.retriever_db.search(question)
context = [i["content"] for i in context]
context = await self.retriever_db(question)
context = [i["content"] for i in context["documents"]]
item["evidence"] = context
anwser = await self.generate_anwser(question, context)
item["anwser"] = anwser
Expand Down Expand Up @@ -222,11 +223,24 @@ async def report_fact(self, report: str):
messages: List[Message] = [
HumanMessage(content=self.prompt_plan_verifications.format(base_context=item))
]
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
retry_count = 0
while True:
try:
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
break
except Exception as e:
retry_count += 1
logger.error(e)
if retry_count > MAX_RETRY:
raise Exception(
f"Failed to edit research for {report} after {MAX_RETRY} times."
)
continue

else:
text.append(item)
return "\n\n".join(text)
Expand Down
14 changes: 12 additions & 2 deletions erniebot-agent/applications/erniebot_researcher/polish_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
citation_index_name: str,
dir_path: str,
report_type: str,
build_index_function: Any,
search_tool: Any,
system_message: Optional[SystemMessage] = None,
callbacks=None,
):
Expand All @@ -58,6 +60,8 @@ def __init__(
self.prompt_template_polish = PromptTemplate(
template=self.template_polish, input_variables=["content"]
)
self.build_index_function = build_index_function
self.search_tool = search_tool
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
Expand Down Expand Up @@ -143,13 +147,19 @@ async def _run(self, report, summarize=None):
final_report = await self.polish_paragraph(report, abstract, key)
await self._callback_manager.on_tool_start(self, tool=self.citation_tool, input_args=final_report)
if summarize is not None:
citation_search = add_citation(summarize, self.citation_index_name, self.embeddings)
citation_search = add_citation(
summarize,
self.citation_index_name,
self.embeddings,
self.build_index_function,
self.search_tool,
)
final_report, path = await self.citation_tool(
report=final_report,
agent_name=self.name,
report_type=self.report_type,
dir_path=self.dir_path,
citation_faiss_research=citation_search,
citation_research=citation_search,
)
path = write_md_to_pdf(self.report_type, self.dir_path, final_report)
await self._callback_manager.on_tool_end(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ langchain
scikit-learn
markdown
WeasyPrint==52.5
openai>1.0
langchain_openai
zhon
18 changes: 7 additions & 11 deletions erniebot-agent/applications/erniebot_researcher/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,21 @@ def __init__(

async def run_search_summary(self, query: str):
responses = []
url_dict = {}
results = self.retriever_fulltext_db.search(query, top_k=3)
results = await self.retriever_fulltext_db(query, top_k=3)
length_limit = 0
await self._callback_manager.on_tool_start(agent=self, tool=self.summarize_tool, input_args=query)
for doc in results:
for doc in results["documents"]:
res = await self.summarize_tool(doc["content"], query)
# Add reference to avoid hallucination
data = {"summary": res, "url": doc["url"], "name": doc["title"]}
data = {"summary": res, "url": doc["meta"]["url"], "name": doc["meta"]["name"]}
length_limit += len(res)
if length_limit < SUMMARIZE_MAX_LENGTH:
responses.append(data)
key = doc["title"]
value = doc["url"]
url_dict[key] = value
else:
logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
break
await self._callback_manager.on_tool_end(self, tool=self.summarize_tool, response=responses)
return responses, url_dict
return responses

async def run(self, query: str):
"""
Expand All @@ -117,8 +113,8 @@ async def run(self, query: str):

if self.use_context_planning:
sub_queries = []
res = self.retriever_abstract_db.search(query, top_k=3)
context = [item["content"] for item in res]
res = await self.retriever_abstract_db(query, top_k=3)
context = [item["content"] for item in res["documents"]]
context_content = ""
await self._callback_manager.on_tool_start(
agent=self, tool=self.task_planning_tool, input_args=query
Expand Down Expand Up @@ -157,7 +153,7 @@ async def run(self, query: str):
# Run Sub-Queries
paragraphs_item = []
for sub_query in sub_queries:
research_result, url_dict = await self.run_search_summary(sub_query)
research_result = await self.run_search_summary(sub_query)
paragraphs_item.extend(research_result)

paragraphs = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@

from editor_actor_agent import EditorActorAgent
from group_agent import GroupChat, GroupChatManager
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from reviser_actor_agent import ReviserActorAgent
from tools.intent_detection_tool import IntentDetectionTool
from tools.outline_generation_tool import OutlineGenerationTool
from tools.preprocessing import get_retriver_by_type
from tools.ranking_tool import TextRankingTool
from tools.report_writing_tool import ReportWritingTool
from tools.semantic_citation_tool import SemanticCitationTool
from tools.summarization_tool import TextSummarizationTool
from tools.task_planning_tool import TaskPlanningTool
from tools.utils import FaissSearch, build_index

from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
Expand Down Expand Up @@ -71,22 +71,29 @@
default="openai_embedding",
help="['openai_embedding','baizhong','ernie_embedding']",
)
parser.add_argument(
"--framework",
type=str,
default="langchain",
choices=["langchain", "llama_index"],
help="['langchain','llama_index']",
)
args = parser.parse_args()


def get_retrievers():
def get_retrievers(build_index_function, retrieval_tool):
if args.embedding_type == "openai_embedding":
embeddings = OpenAIEmbeddings(deployment="text-embedding-ada")
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada")
fulltext_db = build_index_function(index_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(index_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(fulltext_db, embeddings=embeddings)
elif args.embedding_type == "ernie_embedding":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
fulltext_db = build_index_function(index_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(index_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(fulltext_db, embeddings=embeddings)
elif args.embedding_type == "baizhong":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
retriever_search = BaizhongSearch(
Expand All @@ -102,7 +109,9 @@ def get_retrievers():
return {"full_text": retriever_search, "abstract": abstract_search, "embeddings": embeddings}


def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
def get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
):
research_actor = ResearchAgent(
name="generate_report",
system_message=SystemMessage("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿"),
Expand Down Expand Up @@ -134,6 +143,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
dir_path=target_path,
report_type=args.report_type,
citation_tool=tool_sets["semantic_citation"],
build_index_function=build_index_function,
search_tool=retrieval_tool,
)
return {
"research_agents": research_actor,
Expand Down Expand Up @@ -171,10 +182,12 @@ def main(query):
os.makedirs(target_path, exist_ok=True)
llm_long = ERNIEBot(model="ernie-longtext")
llm = ERNIEBot(model="ernie-4.0")

retriever_sets = get_retrievers()
build_index_function, retrieval_tool = get_retriver_by_type(args.framework)
retriever_sets = get_retrievers(build_index_function, retrieval_tool)
tool_sets = get_tools(llm, llm_long)
agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path)
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
research_actor = agent_sets["research_agents"]
report = asyncio.run(research_actor.run(query))
report = {"report": report[0], "paragraphs": report[1]}
Expand Down
Loading

0 comments on commit c2284b3

Please sign in to comment.