diff --git a/ddtrace/internal/datadog/profiling/code_provenance.py b/ddtrace/internal/datadog/profiling/code_provenance.py index 9d4405e2b04..031310db28a 100644 --- a/ddtrace/internal/datadog/profiling/code_provenance.py +++ b/ddtrace/internal/datadog/profiling/code_provenance.py @@ -7,6 +7,7 @@ import typing as t from ddtrace.internal import gitmetadata +from ddtrace.internal.packages import Distribution from ddtrace.internal.packages import _package_for_root_module_mapping @@ -17,19 +18,19 @@ def __init__( name: str, version: str, paths: t.Set[str], - ): + ) -> None: self.kind = kind self.name = name self.version = version self.paths = paths - def to_dict(self): + def to_dict(self) -> t.Dict[str, t.Any]: return {"kind": self.kind, "name": self.name, "version": self.version, "paths": list(self.paths)} class CodeProvenance: - def __init__(self): - self.libraries = [] + def __init__(self) -> None: + self.libraries: t.List[Library] = [] python_stdlib = Library( kind="standard library", @@ -65,7 +66,7 @@ def __init__(self): self.libraries.append(python_stdlib) - module_to_distribution = _package_for_root_module_mapping() + module_to_distribution: t.Dict[str, Distribution] = _package_for_root_module_mapping() or {} libraries: t.Dict[str, Library] = {} @@ -98,10 +99,10 @@ def __init__(self): self.libraries.extend(libraries.values()) - def to_dict(self): + def to_dict(self) -> t.Dict[str, t.Any]: return {"v1": [lib.to_dict() for lib in self.libraries]} -def json_str_to_export(): +def json_str_to_export() -> str: cp = CodeProvenance() return json.dumps(cp.to_dict()) diff --git a/ddtrace/internal/datadog/profiling/stack_v2/__init__.pyi b/ddtrace/internal/datadog/profiling/stack_v2/__init__.pyi index 87bd4598fb4..0a9d09493f1 100644 --- a/ddtrace/internal/datadog/profiling/stack_v2/__init__.pyi +++ b/ddtrace/internal/datadog/profiling/stack_v2/__init__.pyi @@ -1,5 +1,9 @@ +import asyncio + def register_thread(id: int, native_id: int, name: str) -> None: ... # noqa: A002 def unregister_thread(name: str) -> None: ... +def track_asyncio_loop(thread_id: int, loop: asyncio.AbstractEventLoop) -> None: ... +def link_tasks(parent: asyncio.AbstractEventLoop, child: asyncio.Task) -> None: ... is_available: bool failure_msg: str diff --git a/ddtrace/profiling/_asyncio.py b/ddtrace/profiling/_asyncio.py index 3d6d18c11bb..66d44f7fc45 100644 --- a/ddtrace/profiling/_asyncio.py +++ b/ddtrace/profiling/_asyncio.py @@ -1,8 +1,12 @@ # -*- encoding: utf-8 -*- from functools import partial import sys -from types import ModuleType # noqa:F401 -import typing # noqa:F401 +from types import ModuleType # noqa: F401 +import typing + + +if typing.TYPE_CHECKING: + import asyncio from ddtrace.internal._unpatched import _threading as ddtrace_threading from ddtrace.internal.datadog.profiling import stack_v2 @@ -17,15 +21,17 @@ THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink] -def current_task(loop=None): +def current_task(loop: typing.Union["asyncio.AbstractEventLoop", None] = None) -> typing.Union["asyncio.Task", None]: return None -def all_tasks(loop=None): +def all_tasks( + loop: typing.Union["asyncio.AbstractEventLoop", None] = None, +) -> typing.Union[typing.List["asyncio.Task"], None]: return [] -def _task_get_name(task): +def _task_get_name(task: "asyncio.Task") -> str: return "Task-%d" % id(task) @@ -55,12 +61,13 @@ def _(asyncio): @partial(wrap, sys.modules["asyncio.events"].BaseDefaultEventLoopPolicy.set_event_loop) def _(f, args, kwargs): - loop = get_argument_value(args, kwargs, 1, "loop") + loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 1, "loop")) try: if init_stack_v2: stack_v2.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop) return f(*args, **kwargs) finally: + assert THREAD_LINK is not None # nosec: assert is used for typing THREAD_LINK.clear_threads(set(sys._current_frames().keys())) if loop is not None: THREAD_LINK.link_object(loop) @@ -73,10 +80,14 @@ def _(f, args, kwargs): return f(*args, **kwargs) finally: children = get_argument_value(args, kwargs, 1, "children") + assert children is not None # nosec: assert is used for typing + # Pass an invalid positional index for 'loop' loop = get_argument_value(args, kwargs, -1, "loop") + # Link the parent gathering task to the gathered children parent = globals()["current_task"](loop) + for child in children: stack_v2.link_tasks(parent, child) @@ -90,7 +101,7 @@ def _(f, args, kwargs): stack_v2.init_asyncio(asyncio.tasks._current_tasks, scheduled_tasks, eager_tasks) # type: ignore[attr-defined] -def get_event_loop_for_thread(thread_id): +def get_event_loop_for_thread(thread_id: int) -> typing.Union["asyncio.AbstractEventLoop", None]: global THREAD_LINK return THREAD_LINK.get_object(thread_id) if THREAD_LINK is not None else None diff --git a/ddtrace/profiling/collector/pytorch.py b/ddtrace/profiling/collector/pytorch.py index 187b252629a..731c92ebb24 100644 --- a/ddtrace/profiling/collector/pytorch.py +++ b/ddtrace/profiling/collector/pytorch.py @@ -31,9 +31,9 @@ def __init__( class MLProfilerCollector(collector.CaptureSamplerCollector): """Record ML framework (i.e. pytorch) profiler usage.""" - def __init__(self): + def __init__(self) -> None: super().__init__() - self.tracer = None + self.tracer: typing.Union[Tracer, None] = None # Holds the pytorch profiler object which is wrapped by this class self._original: typing.Any = None