4
4
from abc import ABC
5
5
from asyncio import Semaphore , Task , create_task
6
6
from concurrent .futures import Executor , ProcessPoolExecutor , ThreadPoolExecutor
7
- from typing import List , Optional
7
+ from typing import Optional
8
8
9
9
10
10
class BoundedExecutor (ABC ):
@@ -31,27 +31,24 @@ def _get_default_max_workers():
31
31
return min (32 , (os .cpu_count () or 1 ) + 4 )
32
32
33
33
async def __aenter__ (self ):
34
- self ._tasks : List [Task ] = []
35
34
self ._semaphore = Semaphore (self ._semaphore_size )
36
35
return self
37
36
38
37
async def __aexit__ (self , exc_type , exc_value , traceback ):
39
- while self ._tasks :
40
- await self ._tasks .pop ()
38
+ # make sure no tasks being executed
39
+ for _ in range (self ._semaphore_size ):
40
+ await self ._semaphore .acquire ()
41
41
42
42
async def _acquire (self ):
43
43
await self ._semaphore .acquire ()
44
44
45
45
def _release (self , fut ):
46
- if fut in self ._tasks :
47
- self ._tasks .remove (fut )
48
46
self ._semaphore .release ()
49
47
50
48
async def submit (self , coro , * args , ** kwargs ) -> Task :
51
49
await self ._acquire ()
52
50
task = create_task (coro (* args , ** kwargs ))
53
51
task .add_done_callback (self ._release )
54
- self ._tasks .append (task )
55
52
return task
56
53
57
54
0 commit comments