1
+ import atexit
1
2
import inspect
2
3
import threading
3
4
import asyncio
7
8
from .StageFunction import StageFunctionMixin
8
9
9
10
class BaseStage :
10
- def __init__ (self , max_workers = 5 , max_concurrent_tasks = None , on_error = None , close_when_exception = False ):
11
- self ._max_workers = max_workers
11
+ _global_executor = None
12
+ _global_max_workers = 5
13
+ _executor_lock = threading .RLock ()
14
+
15
+ @staticmethod
16
+ def _get_global_executor ():
17
+ if BaseStage ._global_executor is None :
18
+ with BaseStage ._executor_lock :
19
+ BaseStage ._global_executor = ThreadPoolExecutor (max_workers = BaseStage ._global_max_workers )
20
+ return BaseStage ._global_executor
21
+
22
+ @staticmethod
23
+ def set_global_max_workers (global_max_workers ):
24
+ BaseStage ._global_max_workers = global_max_workers
25
+ with BaseStage ._executor_lock :
26
+ if BaseStage ._global_executor :
27
+ BaseStage ._global_executor .shutdown (wait = True )
28
+ BaseStage ._global_executor = ThreadPoolExecutor (max_workers = BaseStage ._global_max_workers )
29
+ atexit .register (BaseStage ._global_executor .shutdown )
30
+
31
+ def __init__ (self , private_max_workers = 5 , max_concurrent_tasks = None , on_error = None , is_daemon = False ):
32
+ self ._private_max_workers = private_max_workers
12
33
self ._max_concurrent_tasks = max_concurrent_tasks
13
34
self ._on_error = on_error
35
+ self ._is_daemon = is_daemon
14
36
self ._semaphore = None
15
37
self ._loop_thread = None
16
38
self ._loop = None
17
- self ._executor = None
18
- self ._responses = []
19
- #self._initialize()
20
-
21
- def _initialize (self ):
39
+ self ._current_executor = None
22
40
self ._loop_ready = threading .Event ()
23
- self ._loop_thread = threading .Thread (target = self ._start_loop )
24
- self ._loop_thread .start ()
25
- self ._executor = ThreadPoolExecutor (max_workers = self ._max_workers )
26
- self ._loop_ready .wait ()
27
- del self ._loop_ready
41
+ self ._responses = set ()
42
+ self ._closed = False
43
+ if self ._is_daemon :
44
+ atexit .register (self .close )
45
+ self ._initialize ()
46
+
47
+ def __enter__ (self ):
48
+ self ._initialize ()
49
+ return self
50
+
51
+ def __exit__ (self , type , value , traceback ):
52
+ self .close ()
53
+ if type is not None and self ._on_error is not None :
54
+ self ._on_error (value )
55
+ return False
28
56
29
- def _loop_exception_handler (self , loop , context ):
30
- if self ._on_error is not None :
31
- loop .call_soon_threadsafe (self ._on_error , context ["exception" ])
57
+ @property
58
+ def _executor (self ):
59
+ if self ._current_executor is not None :
60
+ return self ._current_executor
61
+ if self ._private_max_workers :
62
+ self ._current_executor = ThreadPoolExecutor (max_workers = self ._private_max_workers )
63
+ return self ._current_executor
32
64
else :
33
- raise context ["exception" ]
65
+ self ._current_executor = BaseStage ._get_global_executor ()
66
+ return self ._current_executor
34
67
68
+ def _initialize (self ):
69
+ self ._closed = False
70
+ if (
71
+ not self ._loop_thread
72
+ or not self ._loop_thread .is_alive ()
73
+ or not self ._loop
74
+ or not self ._loop .is_running ()
75
+ ):
76
+ self ._loop_thread = threading .Thread (target = self ._start_loop , daemon = self ._is_daemon )
77
+ self ._loop_thread .start ()
78
+ self ._loop_ready .wait ()
79
+
35
80
def _start_loop (self ):
36
81
self ._loop = asyncio .new_event_loop ()
37
82
self ._loop .set_exception_handler (self ._loop_exception_handler )
38
- asyncio .set_event_loop (self ._loop )
39
83
if self ._max_concurrent_tasks :
40
84
self ._semaphore = asyncio .Semaphore (self ._max_concurrent_tasks )
41
- self ._loop_ready .set ()
85
+ asyncio .set_event_loop (self ._loop )
86
+ self ._loop .call_soon (lambda : self ._loop_ready .set ())
42
87
self ._loop .run_forever ()
43
88
89
+ def _loop_exception_handler (self , loop , context ):
90
+ if self ._on_error is not None :
91
+ loop .call_soon_threadsafe (self ._on_error , context ["exception" ])
92
+ else :
93
+ raise context ["exception" ]
94
+
44
95
def go (self , task , * args , on_success = None , on_error = None , lazy = False , async_gen_interval = 0.1 , ** kwargs ):
45
- if not self ._executor or not self . _loop or not self ._loop .is_running ():
96
+ if not self ._loop or self ._loop .is_running ():
46
97
self ._initialize ()
47
98
response_kwargs = {
48
99
"on_success" : on_success ,
@@ -93,7 +144,7 @@ async def async_generator():
93
144
elif inspect .isfunction (task ) or inspect .ismethod (task ):
94
145
return StageResponse (self , self ._loop .run_in_executor (self ._executor , lambda : task (* args , ** kwargs )), ** response_kwargs )
95
146
else :
96
- return task
147
+ raise TypeError ( f"Task seems like a value or an executed function not an executable task: { task } " )
97
148
98
149
def go_all (self , * task_list ):
99
150
response_list = []
@@ -134,7 +185,11 @@ def on_error(self, handler):
134
185
self ._on_error = handler
135
186
136
187
def close (self ):
137
- for response in self ._responses :
188
+ if self ._closed :
189
+ return
190
+ self ._closed = True
191
+
192
+ for response in self ._responses .copy ():
138
193
response ._result_ready .wait ()
139
194
140
195
if self ._loop and self ._loop .is_running ():
@@ -143,15 +198,15 @@ def close(self):
143
198
pending = asyncio .all_tasks (self ._loop )
144
199
if pending :
145
200
self ._loop .run_until_complete (asyncio .gather (* pending , return_exceptions = True ))
201
+ if self ._private_max_workers and self ._current_executor is not None :
202
+ self ._current_executor .shutdown (wait = True )
203
+ self ._current_executor = None
146
204
if self ._loop_thread and self ._loop_thread .is_alive ():
147
205
self ._loop_thread .join ()
148
206
self ._loop_thread = None
149
- if self ._loop and not self ._loop .is_closed :
207
+ if self ._loop and not self ._loop .is_closed () :
150
208
self ._loop .close ()
151
209
self ._loop = None
152
- if self ._executor :
153
- self ._executor .shutdown (wait = True )
154
- self ._executor = None
155
210
156
211
class Stage (BaseStage , StageFunctionMixin ):
157
212
pass
0 commit comments