Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ddtrace/internal/datadog/profiling/code_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing as t

from ddtrace.internal import gitmetadata
from ddtrace.internal.packages import Distribution
from ddtrace.internal.packages import _package_for_root_module_mapping


Expand All @@ -17,19 +18,19 @@ def __init__(
name: str,
version: str,
paths: t.Set[str],
):
) -> None:
self.kind = kind
self.name = name
self.version = version
self.paths = paths

def to_dict(self):
def to_dict(self) -> t.Dict[str, t.Any]:
return {"kind": self.kind, "name": self.name, "version": self.version, "paths": list(self.paths)}


class CodeProvenance:
def __init__(self):
self.libraries = []
def __init__(self) -> None:
self.libraries: t.List[Library] = []

python_stdlib = Library(
kind="standard library",
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self):

self.libraries.append(python_stdlib)

module_to_distribution = _package_for_root_module_mapping()
module_to_distribution: t.Dict[str, Distribution] = _package_for_root_module_mapping() or {}

libraries: t.Dict[str, Library] = {}

Expand Down Expand Up @@ -98,10 +99,10 @@ def __init__(self):

self.libraries.extend(libraries.values())

def to_dict(self):
def to_dict(self) -> t.Dict[str, t.Any]:
return {"v1": [lib.to_dict() for lib in self.libraries]}


def json_str_to_export():
def json_str_to_export() -> str:
cp = CodeProvenance()
return json.dumps(cp.to_dict())
4 changes: 4 additions & 0 deletions ddtrace/internal/datadog/profiling/stack_v2/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import asyncio

def register_thread(id: int, native_id: int, name: str) -> None: ... # noqa: A002
def unregister_thread(name: str) -> None: ...
def track_asyncio_loop(thread_id: int, loop: asyncio.AbstractEventLoop) -> None: ...
def link_tasks(parent: asyncio.AbstractEventLoop, child: asyncio.Task) -> None: ...

is_available: bool
failure_msg: str
25 changes: 18 additions & 7 deletions ddtrace/profiling/_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# -*- encoding: utf-8 -*-
from functools import partial
import sys
from types import ModuleType # noqa:F401
import typing # noqa:F401
from types import ModuleType # noqa: F401
import typing


if typing.TYPE_CHECKING:
import asyncio

from ddtrace.internal._unpatched import _threading as ddtrace_threading
from ddtrace.internal.datadog.profiling import stack_v2
Expand All @@ -17,15 +21,17 @@
THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink]


def current_task(loop=None):
def current_task(loop: typing.Union["asyncio.AbstractEventLoop", None] = None) -> typing.Union["asyncio.Task", None]:
return None


def all_tasks(loop=None):
def all_tasks(
loop: typing.Union["asyncio.AbstractEventLoop", None] = None,
) -> typing.Union[typing.List["asyncio.Task"], None]:
return []


def _task_get_name(task):
def _task_get_name(task: "asyncio.Task") -> str:
return "Task-%d" % id(task)


Expand Down Expand Up @@ -55,12 +61,13 @@ def _(asyncio):

@partial(wrap, sys.modules["asyncio.events"].BaseDefaultEventLoopPolicy.set_event_loop)
def _(f, args, kwargs):
loop = get_argument_value(args, kwargs, 1, "loop")
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 1, "loop"))
try:
if init_stack_v2:
stack_v2.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
return f(*args, **kwargs)
finally:
assert THREAD_LINK is not None # nosec: assert is used for typing
THREAD_LINK.clear_threads(set(sys._current_frames().keys()))
if loop is not None:
THREAD_LINK.link_object(loop)
Expand All @@ -73,10 +80,14 @@ def _(f, args, kwargs):
return f(*args, **kwargs)
finally:
children = get_argument_value(args, kwargs, 1, "children")
assert children is not None # nosec: assert is used for typing

# Pass an invalid positional index for 'loop'
loop = get_argument_value(args, kwargs, -1, "loop")

# Link the parent gathering task to the gathered children
parent = globals()["current_task"](loop)

for child in children:
stack_v2.link_tasks(parent, child)

Expand All @@ -90,7 +101,7 @@ def _(f, args, kwargs):
stack_v2.init_asyncio(asyncio.tasks._current_tasks, scheduled_tasks, eager_tasks) # type: ignore[attr-defined]


def get_event_loop_for_thread(thread_id):
def get_event_loop_for_thread(thread_id: int) -> typing.Union["asyncio.AbstractEventLoop", None]:
global THREAD_LINK

return THREAD_LINK.get_object(thread_id) if THREAD_LINK is not None else None
4 changes: 2 additions & 2 deletions ddtrace/profiling/collector/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
class MLProfilerCollector(collector.CaptureSamplerCollector):
"""Record ML framework (i.e. pytorch) profiler usage."""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.tracer = None
self.tracer: typing.Union[Tracer, None] = None
# Holds the pytorch profiler object which is wrapped by this class
self._original: typing.Any = None

Expand Down
Loading