Skip to content

Commit 1095196

Browse files
committed
opentelemetry-instrumentation-aiohttp-client: add type checking support
1 parent f730a63 commit 1095196

File tree

4 files changed

+114
-67
lines changed

4 files changed

+114
-67
lines changed

instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py

Lines changed: 111 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,20 @@ def response_hook(span: Span, params: typing.Union[
102102
---
103103
"""
104104

105+
from __future__ import annotations
106+
105107
import types
106108
import typing
107109
from timeit import default_timer
108-
from typing import Collection
110+
from typing import (
111+
Any,
112+
Collection,
113+
Union,
114+
TYPE_CHECKING,
115+
TypedDict,
116+
Callable,
117+
cast,
118+
)
109119
from urllib.parse import urlparse
110120

111121
import aiohttp
@@ -119,18 +129,18 @@ def response_hook(span: Span, params: typing.Union[
119129
HTTP_DURATION_HISTOGRAM_BUCKETS_OLD,
120130
_client_duration_attrs_new,
121131
_client_duration_attrs_old,
122-
_filter_semconv_duration_attrs,
132+
_filter_semconv_duration_attrs, # type: ignore[reportUnknownVariableType]
123133
_get_schema_url,
124134
_OpenTelemetrySemanticConventionStability,
125135
_OpenTelemetryStabilitySignalType,
126136
_report_new,
127137
_report_old,
128-
_set_http_host_client,
129-
_set_http_method,
130-
_set_http_net_peer_name_client,
131-
_set_http_peer_port_client,
132-
_set_http_url,
133-
_set_status,
138+
_set_http_host_client, # type: ignore[reportUnknownVariableType]
139+
_set_http_method, # type: ignore[reportUnknownVariableType]
140+
_set_http_net_peer_name_client, # type: ignore[reportUnknownVariableType]
141+
_set_http_peer_port_client, # type: ignore[reportUnknownVariableType]
142+
_set_http_url, # type: ignore[reportUnknownVariableType]
143+
_set_status, # type: ignore[reportUnknownVariableType]
134144
_StabilityMode,
135145
)
136146
from opentelemetry.instrumentation.aiohttp_client.package import _instruments
@@ -143,7 +153,7 @@ def response_hook(span: Span, params: typing.Union[
143153
from opentelemetry.metrics import MeterProvider, get_meter
144154
from opentelemetry.propagate import inject
145155
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
146-
from opentelemetry.semconv.metrics import MetricInstruments
156+
from opentelemetry.semconv.metrics import MetricInstruments # type: ignore[reportDeprecated]
147157
from opentelemetry.semconv.metrics.http_metrics import (
148158
HTTP_CLIENT_REQUEST_DURATION,
149159
)
@@ -155,22 +165,39 @@ def response_hook(span: Span, params: typing.Union[
155165
sanitize_method,
156166
)
157167

158-
_UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]]
159-
_RequestHookT = typing.Optional[
160-
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
161-
]
162-
_ResponseHookT = typing.Optional[
163-
typing.Callable[
164-
[
165-
Span,
166-
typing.Union[
167-
aiohttp.TraceRequestEndParams,
168-
aiohttp.TraceRequestExceptionParams,
168+
if TYPE_CHECKING:
169+
from typing_extensions import Unpack
170+
171+
UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]]
172+
RequestHookT = typing.Optional[
173+
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
174+
]
175+
ResponseHookT = typing.Optional[
176+
typing.Callable[
177+
[
178+
Span,
179+
typing.Union[
180+
aiohttp.TraceRequestEndParams,
181+
aiohttp.TraceRequestExceptionParams,
182+
],
169183
],
170-
],
171-
None,
184+
None,
185+
]
172186
]
173-
]
187+
188+
class ClientSessionInitKwargs(TypedDict, total=False):
189+
trace_configs: typing.Sequence[aiohttp.TraceConfig]
190+
191+
class InstrumentKwargs(TypedDict, total=False):
192+
tracer_provider: trace.TracerProvider
193+
meter_provider: MeterProvider
194+
url_filter: UrlFilterT
195+
request_hook: RequestHookT
196+
response_hook: ResponseHookT
197+
trace_configs: typing.Sequence[aiohttp.TraceConfig]
198+
199+
class UninstrumentKwargs(TypedDict, total=False):
200+
pass
174201

175202

176203
def _get_span_name(method: str) -> str:
@@ -181,10 +208,10 @@ def _get_span_name(method: str) -> str:
181208

182209

183210
def _set_http_status_code_attribute(
184-
span,
185-
status_code,
186-
metric_attributes=None,
187-
sem_conv_opt_in_mode=_StabilityMode.DEFAULT,
211+
span: Span,
212+
status_code: int,
213+
metric_attributes: Union[dict[str, Any], None] = None,
214+
sem_conv_opt_in_mode: _StabilityMode = _StabilityMode.DEFAULT,
188215
):
189216
status_code_str = str(status_code)
190217
try:
@@ -209,11 +236,11 @@ def _set_http_status_code_attribute(
209236
# pylint: disable=too-many-locals
210237
# pylint: disable=too-many-statements
211238
def create_trace_config(
212-
url_filter: _UrlFilterT = None,
213-
request_hook: _RequestHookT = None,
214-
response_hook: _ResponseHookT = None,
215-
tracer_provider: TracerProvider = None,
216-
meter_provider: MeterProvider = None,
239+
url_filter: UrlFilterT = None,
240+
request_hook: RequestHookT = None,
241+
response_hook: ResponseHookT = None,
242+
tracer_provider: Union[TracerProvider, None] = None,
243+
meter_provider: Union[MeterProvider, None] = None,
217244
sem_conv_opt_in_mode: _StabilityMode = _StabilityMode.DEFAULT,
218245
) -> aiohttp.TraceConfig:
219246
"""Create an aiohttp-compatible trace configuration.
@@ -268,12 +295,10 @@ def create_trace_config(
268295
schema_url,
269296
)
270297

271-
start_time = 0
272-
273298
duration_histogram_old = None
274299
if _report_old(sem_conv_opt_in_mode):
275300
duration_histogram_old = meter.create_histogram(
276-
name=MetricInstruments.HTTP_CLIENT_DURATION,
301+
name=MetricInstruments.HTTP_CLIENT_DURATION, # type: ignore[reportDeprecated]
277302
unit="ms",
278303
description="measures the duration of the outbound HTTP request",
279304
explicit_bucket_boundaries_advisory=HTTP_DURATION_HISTOGRAM_BUCKETS_OLD,
@@ -293,52 +318,62 @@ def _end_trace(trace_config_ctx: types.SimpleNamespace):
293318
elapsed_time = max(default_timer() - trace_config_ctx.start_time, 0)
294319
if trace_config_ctx.token:
295320
context_api.detach(trace_config_ctx.token)
296-
trace_config_ctx.span.end()
321+
if trace_config_ctx.span:
322+
trace_config_ctx.span.end()
297323

298324
if trace_config_ctx.duration_histogram_old is not None:
299-
duration_attrs_old = _filter_semconv_duration_attrs(
300-
trace_config_ctx.metric_attributes,
301-
_client_duration_attrs_old,
302-
_client_duration_attrs_new,
303-
_StabilityMode.DEFAULT,
325+
duration_attrs_old = cast(
326+
dict[str, Any],
327+
_filter_semconv_duration_attrs(
328+
trace_config_ctx.metric_attributes,
329+
_client_duration_attrs_old,
330+
_client_duration_attrs_new,
331+
_StabilityMode.DEFAULT,
332+
),
304333
)
305334
trace_config_ctx.duration_histogram_old.record(
306335
max(round(elapsed_time * 1000), 0),
307336
attributes=duration_attrs_old,
308337
)
309338
if trace_config_ctx.duration_histogram_new is not None:
310-
duration_attrs_new = _filter_semconv_duration_attrs(
311-
trace_config_ctx.metric_attributes,
312-
_client_duration_attrs_old,
313-
_client_duration_attrs_new,
314-
_StabilityMode.HTTP,
339+
duration_attrs_new = cast(
340+
dict[str, Any],
341+
_filter_semconv_duration_attrs(
342+
trace_config_ctx.metric_attributes,
343+
_client_duration_attrs_old,
344+
_client_duration_attrs_new,
345+
_StabilityMode.HTTP,
346+
),
315347
)
316348
trace_config_ctx.duration_histogram_new.record(
317349
elapsed_time, attributes=duration_attrs_new
318350
)
319351

320352
async def on_request_start(
321-
unused_session: aiohttp.ClientSession,
353+
_session: aiohttp.ClientSession,
322354
trace_config_ctx: types.SimpleNamespace,
323355
params: aiohttp.TraceRequestStartParams,
324356
):
325357
if (
326358
not is_http_instrumentation_enabled()
327359
or trace_config_ctx.excluded_urls.url_disabled(str(params.url))
328360
):
329-
trace_config_ctx.span = None
330361
return
331362

332363
trace_config_ctx.start_time = default_timer()
333364
method = params.method
334365
request_span_name = _get_span_name(method)
335366
request_url = (
336-
redact_url(trace_config_ctx.url_filter(params.url))
367+
redact_url(
368+
cast(Callable[[yarl.URL], str], trace_config_ctx.url_filter)(
369+
params.url
370+
)
371+
)
337372
if callable(trace_config_ctx.url_filter)
338373
else redact_url(str(params.url))
339374
)
340375

341-
span_attributes = {}
376+
span_attributes: dict[str, Any] = {}
342377
_set_http_method(
343378
span_attributes,
344379
method,
@@ -399,7 +434,7 @@ async def on_request_start(
399434
inject(params.headers)
400435

401436
async def on_request_end(
402-
unused_session: aiohttp.ClientSession,
437+
_session: aiohttp.ClientSession,
403438
trace_config_ctx: types.SimpleNamespace,
404439
params: aiohttp.TraceRequestEndParams,
405440
):
@@ -418,7 +453,7 @@ async def on_request_end(
418453
_end_trace(trace_config_ctx)
419454

420455
async def on_request_exception(
421-
unused_session: aiohttp.ClientSession,
456+
_session: aiohttp.ClientSession,
422457
trace_config_ctx: types.SimpleNamespace,
423458
params: aiohttp.TraceRequestExceptionParams,
424459
):
@@ -441,21 +476,25 @@ async def on_request_exception(
441476

442477
_end_trace(trace_config_ctx)
443478

444-
def _trace_config_ctx_factory(**kwargs):
479+
def _trace_config_ctx_factory(**kwargs: Any) -> types.SimpleNamespace:
445480
kwargs.setdefault("trace_request_ctx", {})
446481
return types.SimpleNamespace(
447482
tracer=tracer,
448-
url_filter=url_filter,
449-
start_time=start_time,
483+
span=None,
484+
token=None,
450485
duration_histogram_old=duration_histogram_old,
451486
duration_histogram_new=duration_histogram_new,
452-
excluded_urls=excluded_urls,
453487
metric_attributes={},
488+
url_filter=url_filter,
489+
excluded_urls=excluded_urls,
490+
start_time=0,
454491
**kwargs,
455492
)
456493

457494
trace_config = aiohttp.TraceConfig(
458-
trace_config_ctx_factory=_trace_config_ctx_factory
495+
trace_config_ctx_factory=cast(
496+
type[types.SimpleNamespace], _trace_config_ctx_factory
497+
)
459498
)
460499

461500
trace_config.on_request_start.append(on_request_start)
@@ -466,11 +505,11 @@ def _trace_config_ctx_factory(**kwargs):
466505

467506

468507
def _instrument(
469-
tracer_provider: TracerProvider = None,
470-
meter_provider: MeterProvider = None,
471-
url_filter: _UrlFilterT = None,
472-
request_hook: _RequestHookT = None,
473-
response_hook: _ResponseHookT = None,
508+
tracer_provider: Union[TracerProvider, None] = None,
509+
meter_provider: Union[MeterProvider, None] = None,
510+
url_filter: UrlFilterT = None,
511+
request_hook: RequestHookT = None,
512+
response_hook: ResponseHookT = None,
474513
trace_configs: typing.Optional[
475514
typing.Sequence[aiohttp.TraceConfig]
476515
] = None,
@@ -485,7 +524,12 @@ def _instrument(
485524
trace_configs = trace_configs or ()
486525

487526
# pylint:disable=unused-argument
488-
def instrumented_init(wrapped, instance, args, kwargs):
527+
def instrumented_init(
528+
wrapped: Callable[..., None],
529+
_instance: aiohttp.ClientSession,
530+
args: tuple[Any, ...],
531+
kwargs: ClientSessionInitKwargs,
532+
):
489533
client_trace_configs = list(kwargs.get("trace_configs") or [])
490534
client_trace_configs.extend(trace_configs)
491535

@@ -497,13 +541,13 @@ def instrumented_init(wrapped, instance, args, kwargs):
497541
meter_provider=meter_provider,
498542
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
499543
)
500-
trace_config._is_instrumented_by_opentelemetry = True
544+
setattr(trace_config, "_is_instrumented_by_opentelemetry", True)
501545
client_trace_configs.append(trace_config)
502546

503547
kwargs["trace_configs"] = client_trace_configs
504548
return wrapped(*args, **kwargs)
505549

506-
wrapt.wrap_function_wrapper(
550+
wrapt.wrap_function_wrapper( # type: ignore[reportUnknownVariableType]
507551
aiohttp.ClientSession, "__init__", instrumented_init
508552
)
509553

@@ -533,7 +577,7 @@ class AioHttpClientInstrumentor(BaseInstrumentor):
533577
def instrumentation_dependencies(self) -> Collection[str]:
534578
return _instruments
535579

536-
def _instrument(self, **kwargs):
580+
def _instrument(self, **kwargs: Unpack[InstrumentKwargs]):
537581
"""Instruments aiohttp ClientSession
538582
539583
Args:
@@ -562,7 +606,7 @@ def _instrument(self, **kwargs):
562606
sem_conv_opt_in_mode=_sem_conv_opt_in_mode,
563607
)
564608

565-
def _uninstrument(self, **kwargs):
609+
def _uninstrument(self, **kwargs: Unpack[UninstrumentKwargs]):
566610
_uninstrument()
567611

568612
@staticmethod

instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/py.typed

Whitespace-only changes.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ include = [
207207
"instrumentation-genai/opentelemetry-instrumentation-weaviate",
208208
"util/opentelemetry-util-genai",
209209
"exporter/opentelemetry-exporter-credential-provider-gcp",
210+
"instrumentation/opentelemetry-instrumentation-aiohttp-client",
210211
]
211212
# We should also add type hints to the test suite - It helps on finding bugs.
212213
# We are excluding for now because it's easier, and more important to add to the instrumentation packages.
@@ -223,6 +224,7 @@ exclude = [
223224
"instrumentation-genai/opentelemetry-instrumentation-weaviate/tests/**/*.py",
224225
"instrumentation-genai/opentelemetry-instrumentation-weaviate/examples/**/*.py",
225226
"util/opentelemetry-util-genai/tests/**/*.py",
227+
"instrumentation/opentelemetry-instrumentation-aiohttp-client/tests/**/*.py",
226228
]
227229

228230
[dependency-groups]

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,7 @@ deps =
11061106
{toxinidir}/instrumentation/opentelemetry-instrumentation-aiokafka[instruments]
11071107
{toxinidir}/instrumentation/opentelemetry-instrumentation-asyncclick[instruments]
11081108
{toxinidir}/exporter/opentelemetry-exporter-credential-provider-gcp
1109+
{toxinidir}/instrumentation/opentelemetry-instrumentation-aiohttp-client[instruments]
11091110

11101111
commands =
11111112
pyright

0 commit comments

Comments
 (0)