From 8337c8a7ff9b1ce5c6157cd4b1eb64470f1a1c6c Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 20 Dec 2024 20:27:27 +0800 Subject: [PATCH] refactor: Workflow execution logic --- apps/application/flow/workflow_manage.py | 38 +++++++++++++----------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 02397ec4216..f8e1931ebec 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -29,7 +29,7 @@ from setting.models import Model from setting.models_provider import get_model_credential -executor = ThreadPoolExecutor(max_workers=50) +executor = ThreadPoolExecutor(max_workers=200) class Edge: @@ -271,7 +271,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl self.current_result = None self.answer = "" self.answer_list = [''] - self.status = 0 + self.status = 200 self.base_to_response = base_to_response self.chat_record = chat_record self.await_future_map = {} @@ -384,8 +384,23 @@ def await_result(self, result): '', True, message_tokens, answer_tokens, {}) def run_chain_async(self, current_node, node_result_future): - future = executor.submit(self.run_chain, current_node, node_result_future) - return future + return executor.submit(self.run_chain_manage, current_node, node_result_future) + + def run_chain_manage(self, current_node, node_result_future): + if current_node is None: + start_node = self.get_start_node() + current_node = get_node(start_node.type)(start_node, self.params, self) + result = self.run_chain(current_node, node_result_future) + node_list = self.get_next_node_list(current_node, result) + if len(node_list) == 1: + self.run_chain_manage(node_list[0], None) + elif len(node_list) > 1: + + # 获取到可执行的子节点 + result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in + node_list] + self.set_await_map(result_list) + [r.get('future').result() for r in result_list] def set_await_map(self, node_run_list): sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y) @@ -395,9 +410,6 @@ def set_await_map(self, node_run_list): for i in range(index)] def run_chain(self, current_node, node_result_future=None): - if current_node is None: - start_node = self.get_start_node() - current_node = get_node(start_node.type)(start_node, self.params, self) if node_result_future is None: node_result_future = self.run_node_future(current_node) try: @@ -409,18 +421,10 @@ def run_chain(self, current_node, node_result_future=None): result = self.hand_event_node_result(current_node, node_result_future) if is_stream else self.hand_node_result( current_node, node_result_future) - with self.lock: - if current_node.status == 500: - return - node_list = self.get_next_node_list(current_node, result) - # 获取到可执行的子节点 - result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list] - self.set_await_map(result_list) - [r.get('future').result() for r in result_list] - if self.status == 0: - self.status = 200 + return result except Exception as e: traceback.print_exc() + return [] def hand_node_result(self, current_node, node_result_future): try: