11# -*- encoding: utf-8 -*-
22from functools import partial
33import sys
4- from types import ModuleType # noqa:F401
5- import typing # noqa:F401
4+ from types import ModuleType # noqa: F401
5+ import typing
6+
7+
8+ if typing .TYPE_CHECKING :
9+ import asyncio
610
711from ddtrace .internal ._unpatched import _threading as ddtrace_threading
812from ddtrace .internal .datadog .profiling import stack_v2
1721THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink]
1822
1923
20- def current_task (loop = None ):
24+ def current_task (loop : typing . Union [ "asyncio.AbstractEventLoop" , None ] = None ) -> typing . Union [ "asyncio.Task" , None ] :
2125 return None
2226
2327
24- def all_tasks (loop = None ):
28+ def all_tasks (
29+ loop : typing .Union ["asyncio.AbstractEventLoop" , None ] = None ,
30+ ) -> typing .Union [typing .List ["asyncio.Task" ], None ]:
2531 return []
2632
2733
28- def _task_get_name (task ) :
34+ def _task_get_name (task : "asyncio.Task" ) -> str :
2935 return "Task-%d" % id (task )
3036
3137
@@ -55,12 +61,13 @@ def _(asyncio):
5561
5662 @partial (wrap , sys .modules ["asyncio.events" ].BaseDefaultEventLoopPolicy .set_event_loop )
5763 def _ (f , args , kwargs ):
58- loop = get_argument_value (args , kwargs , 1 , "loop" )
64+ loop = typing . cast ( "asyncio.AbstractEventLoop" , get_argument_value (args , kwargs , 1 , "loop" ) )
5965 try :
6066 if init_stack_v2 :
6167 stack_v2 .track_asyncio_loop (typing .cast (int , ddtrace_threading .current_thread ().ident ), loop )
6268 return f (* args , ** kwargs )
6369 finally :
70+ assert THREAD_LINK is not None # nosec: assert is used for typing
6471 THREAD_LINK .clear_threads (set (sys ._current_frames ().keys ()))
6572 if loop is not None :
6673 THREAD_LINK .link_object (loop )
@@ -73,10 +80,14 @@ def _(f, args, kwargs):
7380 return f (* args , ** kwargs )
7481 finally :
7582 children = get_argument_value (args , kwargs , 1 , "children" )
83+ assert children is not None # nosec: assert is used for typing
84+
7685 # Pass an invalid positional index for 'loop'
7786 loop = get_argument_value (args , kwargs , - 1 , "loop" )
87+
7888 # Link the parent gathering task to the gathered children
7989 parent = globals ()["current_task" ](loop )
90+
8091 for child in children :
8192 stack_v2 .link_tasks (parent , child )
8293
@@ -90,7 +101,7 @@ def _(f, args, kwargs):
90101 stack_v2 .init_asyncio (asyncio .tasks ._current_tasks , scheduled_tasks , eager_tasks ) # type: ignore[attr-defined]
91102
92103
93- def get_event_loop_for_thread (thread_id ) :
104+ def get_event_loop_for_thread (thread_id : int ) -> typing . Union [ "asyncio.AbstractEventLoop" , None ] :
94105 global THREAD_LINK
95106
96107 return THREAD_LINK .get_object (thread_id ) if THREAD_LINK is not None else None
0 commit comments