Skip to content

Commit 80c0310

Browse files
chore(profiling): improve typing
1 parent 5fa0681 commit 80c0310

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

ddtrace/internal/datadog/profiling/code_provenance.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import typing as t
88

99
from ddtrace.internal import gitmetadata
10+
from ddtrace.internal.packages import Distribution
1011
from ddtrace.internal.packages import _package_for_root_module_mapping
1112

1213

@@ -17,19 +18,19 @@ def __init__(
1718
name: str,
1819
version: str,
1920
paths: t.Set[str],
20-
):
21+
) -> None:
2122
self.kind = kind
2223
self.name = name
2324
self.version = version
2425
self.paths = paths
2526

26-
def to_dict(self):
27+
def to_dict(self) -> t.Dict[str, t.Any]:
2728
return {"kind": self.kind, "name": self.name, "version": self.version, "paths": list(self.paths)}
2829

2930

3031
class CodeProvenance:
31-
def __init__(self):
32-
self.libraries = []
32+
def __init__(self) -> None:
33+
self.libraries: t.List[Library] = []
3334

3435
python_stdlib = Library(
3536
kind="standard library",
@@ -65,7 +66,7 @@ def __init__(self):
6566

6667
self.libraries.append(python_stdlib)
6768

68-
module_to_distribution = _package_for_root_module_mapping()
69+
module_to_distribution: t.Dict[str, Distribution] = _package_for_root_module_mapping() or {}
6970

7071
libraries: t.Dict[str, Library] = {}
7172

@@ -98,10 +99,10 @@ def __init__(self):
9899

99100
self.libraries.extend(libraries.values())
100101

101-
def to_dict(self):
102+
def to_dict(self) -> t.Dict[str, t.Any]:
102103
return {"v1": [lib.to_dict() for lib in self.libraries]}
103104

104105

105-
def json_str_to_export():
106+
def json_str_to_export() -> str:
106107
cp = CodeProvenance()
107108
return json.dumps(cp.to_dict())
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import asyncio
2+
13
def register_thread(id: int, native_id: int, name: str) -> None: ... # noqa: A002
24
def unregister_thread(name: str) -> None: ...
5+
def track_asyncio_loop(thread_id: int, loop: asyncio.AbstractEventLoop) -> None: ...
6+
def link_tasks(parent: asyncio.AbstractEventLoop, child: asyncio.Task) -> None: ...
37

48
is_available: bool
59
failure_msg: str

ddtrace/profiling/_asyncio.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# -*- encoding: utf-8 -*-
22
from functools import partial
33
import sys
4-
from types import ModuleType # noqa:F401
5-
import typing # noqa:F401
4+
from types import ModuleType # noqa: F401
5+
import typing
6+
7+
8+
if typing.TYPE_CHECKING:
9+
import asyncio
610

711
from ddtrace.internal._unpatched import _threading as ddtrace_threading
812
from ddtrace.internal.datadog.profiling import stack_v2
@@ -17,15 +21,17 @@
1721
THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink]
1822

1923

20-
def current_task(loop=None):
24+
def current_task(loop: typing.Union["asyncio.AbstractEventLoop", None] = None) -> typing.Union["asyncio.Task", None]:
2125
return None
2226

2327

24-
def all_tasks(loop=None):
28+
def all_tasks(
29+
loop: typing.Union["asyncio.AbstractEventLoop", None] = None,
30+
) -> typing.Union[typing.List["asyncio.Task"], None]:
2531
return []
2632

2733

28-
def _task_get_name(task):
34+
def _task_get_name(task: "asyncio.Task") -> str:
2935
return "Task-%d" % id(task)
3036

3137

@@ -55,12 +61,13 @@ def _(asyncio):
5561

5662
@partial(wrap, sys.modules["asyncio.events"].BaseDefaultEventLoopPolicy.set_event_loop)
5763
def _(f, args, kwargs):
58-
loop = get_argument_value(args, kwargs, 1, "loop")
64+
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 1, "loop"))
5965
try:
6066
if init_stack_v2:
6167
stack_v2.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
6268
return f(*args, **kwargs)
6369
finally:
70+
assert THREAD_LINK is not None # nosec: assert is used for typing
6471
THREAD_LINK.clear_threads(set(sys._current_frames().keys()))
6572
if loop is not None:
6673
THREAD_LINK.link_object(loop)
@@ -73,10 +80,14 @@ def _(f, args, kwargs):
7380
return f(*args, **kwargs)
7481
finally:
7582
children = get_argument_value(args, kwargs, 1, "children")
83+
assert children is not None # nosec: assert is used for typing
84+
7685
# Pass an invalid positional index for 'loop'
7786
loop = get_argument_value(args, kwargs, -1, "loop")
87+
7888
# Link the parent gathering task to the gathered children
7989
parent = globals()["current_task"](loop)
90+
8091
for child in children:
8192
stack_v2.link_tasks(parent, child)
8293

@@ -90,7 +101,7 @@ def _(f, args, kwargs):
90101
stack_v2.init_asyncio(asyncio.tasks._current_tasks, scheduled_tasks, eager_tasks) # type: ignore[attr-defined]
91102

92103

93-
def get_event_loop_for_thread(thread_id):
104+
def get_event_loop_for_thread(thread_id: int) -> typing.Union["asyncio.AbstractEventLoop", None]:
94105
global THREAD_LINK
95106

96107
return THREAD_LINK.get_object(thread_id) if THREAD_LINK is not None else None

ddtrace/profiling/collector/pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def __init__(
3131
class MLProfilerCollector(collector.CaptureSamplerCollector):
3232
"""Record ML framework (i.e. pytorch) profiler usage."""
3333

34-
def __init__(self):
34+
def __init__(self) -> None:
3535
super().__init__()
36-
self.tracer = None
36+
self.tracer: typing.Union[Tracer, None] = None
3737
# Holds the pytorch profiler object which is wrapped by this class
3838
self._original: typing.Any = None
3939

0 commit comments

Comments
 (0)