Skip to content

Commit

Permalink
refactor: Workflow execution logic
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 committed Dec 20, 2024
1 parent c000ee4 commit 8337c8a
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from setting.models import Model
from setting.models_provider import get_model_credential

executor = ThreadPoolExecutor(max_workers=50)
executor = ThreadPoolExecutor(max_workers=200)


class Edge:
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
self.current_result = None
self.answer = ""
self.answer_list = ['']
self.status = 0
self.status = 200
self.base_to_response = base_to_response
self.chat_record = chat_record
self.await_future_map = {}
Expand Down Expand Up @@ -384,8 +384,23 @@ def await_result(self, result):
'', True, message_tokens, answer_tokens, {})

def run_chain_async(self, current_node, node_result_future):
future = executor.submit(self.run_chain, current_node, node_result_future)
return future
return executor.submit(self.run_chain_manage, current_node, node_result_future)

def run_chain_manage(self, current_node, node_result_future):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
result = self.run_chain(current_node, node_result_future)
node_list = self.get_next_node_list(current_node, result)
if len(node_list) == 1:
self.run_chain_manage(node_list[0], None)
elif len(node_list) > 1:

# 获取到可执行的子节点
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]

def set_await_map(self, node_run_list):
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
Expand All @@ -395,9 +410,6 @@ def set_await_map(self, node_run_list):
for i in range(index)]

def run_chain(self, current_node, node_result_future=None):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
if node_result_future is None:
node_result_future = self.run_node_future(current_node)
try:
Expand All @@ -409,18 +421,10 @@ def run_chain(self, current_node, node_result_future=None):
result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future)
with self.lock:
if current_node.status == 500:
return
node_list = self.get_next_node_list(current_node, result)
# 获取到可执行的子节点
result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]
if self.status == 0:
self.status = 200
return result
except Exception as e:
traceback.print_exc()
return []

def hand_node_result(self, current_node, node_result_future):
try:
Expand Down

0 comments on commit 8337c8a

Please sign in to comment.