diff --git a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py index 151d401c3..d4d1eb127 100644 --- a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py +++ b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py @@ -67,8 +67,8 @@ parser.add_argument( "--embedding_type", type=str, - default="open_embedding", - help="['open_embedding','baizhong','ernie_embedding']", + default="openai_embedding", + help="['openai_embedding','baizhong','ernie_embedding']", ) args = parser.parse_args() @@ -77,7 +77,7 @@ def get_retrievers(): - if args.embedding_type == "open_embedding": + if args.embedding_type == "openai_embedding": embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) diff --git a/erniebot-agent/applications/erniebot_researcher/tools/utils.py b/erniebot-agent/applications/erniebot_researcher/tools/utils.py index 182ff6eb2..c14e28096 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/utils.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/utils.py @@ -17,7 +17,9 @@ from erniebot_agent.agents.callback import LoggingHandler from erniebot_agent.agents.schema import ToolResponse from erniebot_agent.tools.base import BaseTool +from erniebot_agent.utils import config_from_environ as C from erniebot_agent.utils.json import to_pretty_json +from erniebot_agent.utils.logging import ColorFormatter, set_role_color from erniebot_agent.utils.output_style import ColoredContent default_logger = logging.getLogger(__name__) @@ -216,3 +218,22 @@ def parse_json(self, json_str, start_indicator: str = "{", end_indicator: str = corrected_data = json_str[start_idx : end_idx + 1] response = json.loads(corrected_data) return response + + +def setup_logging(log_file_path: str): + logger = logging.getLogger("generate_report") + verbosity = C.get_logging_level() + if verbosity: + numeric_level = getattr(logging, verbosity.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid logging level: {verbosity}") + logger.setLevel(numeric_level) + logger.propagate = False + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColorFormatter("%(levelname)s - %(message)s")) + logger.addHandler(console_handler) + set_role_color() + file_handler = logging.FileHandler(log_file_path) + file_handler.setFormatter(ColorFormatter("%(levelname)s - %(message)s")) + logger.addHandler(file_handler) + return logger diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index 8944f9514..79942440c 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -18,13 +18,12 @@ from tools.semantic_citation_tool import SemanticCitationTool from tools.summarization_tool import TextSummarizationTool from tools.task_planning_tool import TaskPlanningTool -from tools.utils import FaissSearch, ReportCallbackHandler, build_index +from tools.utils import FaissSearch, ReportCallbackHandler, build_index, setup_logging from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings from erniebot_agent.memory import SystemMessage from erniebot_agent.retrieval import BaizhongSearch -from erniebot_agent.utils.logging import setup_logging parser = argparse.ArgumentParser() parser.add_argument("--api_type", type=str, default="aistudio") @@ -79,10 +78,10 @@ args = parser.parse_args() os.environ["api_type"] = args.api_type access_token = os.environ.get("EB_AGENT_ACCESS_TOKEN", None) -os.environ["EB_AGENT_LOGGING_FILE"] = args.log_path +# os.environ["EB_AGENT_LOGGING_FILE"] = args.log_path # sh = logging.StreamHandler() # logging.basicConfig(filename=args.log_path, level=logging.INFO) -logger = setup_logging(use_fileformatter=False) +logger = setup_logging(args.log_path) def get_logs(path=args.log_path): @@ -148,8 +147,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): system_message=SystemMessage("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿"), dir_path=dir_path, report_type=args.report_type, - retriever_abstract_tool=retriever_sets["abstract"], - retriever_tool=retriever_sets["full_text"], + retriever_abstract_db=retriever_sets["abstract"], + retriever_fulltext_db=retriever_sets["full_text"], intent_detection_tool=tool_sets["intent_detection"], task_planning_tool=tool_sets["task_planning"], report_writing_tool=tool_sets["report_writing"], @@ -176,22 +175,20 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): name="polish", llm=llm, llm_long=llm_long, - faiss_name_citation=args.index_name_citation, + citation_index_name=args.index_name_citation, embeddings=retriever_sets["embeddings"], dir_path=target_path, report_type=args.report_type, citation_tool=tool_sets["semantic_citation"], callbacks=ReportCallbackHandler(logger=logger), ) - team_actor = ResearchTeam( - ranker_actor=ranker_actor, - research_actor=research_actor, - editor_actor=editor_actor, - reviser_actor=reviser_actor, - polish_actor=polish_actor, - use_reflection=True, - ) - return team_actor + return { + "research_actor": research_actor, + "editor_actor": editor_actor, + "reviser_actor": reviser_actor, + "ranker_actor": ranker_actor, + "polish_actor": polish_actor, + } def generate_report(query, history=[]): @@ -203,7 +200,8 @@ def generate_report(query, history=[]): llm_long = ERNIEBot(model="ernie-longtext") retriever_sets = get_retrievers() tool_sets = get_tools(llm, llm_long) - team_actor = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path) + agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path) + team_actor = ResearchTeam(**agent_sets, use_reflection=True) report, path = asyncio.run(team_actor.run(query, args.iterations)) return report, path diff --git a/erniebot-agent/src/erniebot_agent/utils/logging.py b/erniebot-agent/src/erniebot_agent/utils/logging.py index 781804c75..7f4bf8485 100644 --- a/erniebot-agent/src/erniebot_agent/utils/logging.py +++ b/erniebot-agent/src/erniebot_agent/utils/logging.py @@ -158,7 +158,6 @@ def setup_logging( use_standard_format: bool = True, use_file_handler: bool = False, max_log_length: int = 100, - use_fileformatter: bool = True, ): """Configures logging for the ERNIE Bot Agent library. @@ -195,11 +194,7 @@ def setup_logging( log_file_path = "erniebot-agent.log" if use_file_handler or log_file_path: file_handler = logging.FileHandler(log_file_path) - if use_fileformatter: - file_handler.setFormatter(FileFormatter("%(message)s")) - else: - file_handler.setFormatter(ColorFormatter("%(levelname)s - %(message)s")) + file_handler.setFormatter(FileFormatter("%(message)s")) logger.addHandler(file_handler) ColoredContent.set_global_max_length(max_log_length) - return logger