diff --git a/contrib/opencensus-ext-httpx/CHANGELOG.md b/contrib/opencensus-ext-httpx/CHANGELOG.md index 755e63048..1ae57f6a3 100644 --- a/contrib/opencensus-ext-httpx/CHANGELOG.md +++ b/contrib/opencensus-ext-httpx/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Allow tracing for `Advanced Usage` of httpx ([#1186](https://github.com/census-instrumentation/opencensus-python/pull/1186)) + ## 0.1.0 Released 2023-01-18 diff --git a/contrib/opencensus-ext-httpx/opencensus/ext/httpx/trace.py b/contrib/opencensus-ext-httpx/opencensus/ext/httpx/trace.py index b77f29630..2c1008383 100644 --- a/contrib/opencensus-ext-httpx/opencensus/ext/httpx/trace.py +++ b/contrib/opencensus-ext-httpx/opencensus/ext/httpx/trace.py @@ -41,33 +41,33 @@ def trace_integration(tracer=None): execution_context.set_opencensus_tracer(tracer) wrapt.wrap_function_wrapper( - MODULE_NAME, "Client.request", wrap_client_request + MODULE_NAME, "Client.send", wrap_client_send ) # pylint: disable=protected-access integrations.add_integration(integrations._Integrations.HTTPX) -def wrap_client_request(wrapped, instance, args, kwargs): +def wrap_client_send(wrapped, instance, args, kwargs): """Wrap the session function to trace it.""" # Check if request was sent from an exporter. If so, do not wrap. if execution_context.is_exporter(): return wrapped(*args, **kwargs) - method = kwargs.get("method") or args[0] - url = kwargs.get("url") or args[1] - + request: httpx.Request = kwargs.get("request") or args[0] + method = request.method + url: httpx.URL = request.url + excludelist_hostnames = execution_context.get_opencensus_attr( "excludelist_hostnames" ) - parsed_url = urlparse(url) - if parsed_url.port is None: - dest_url = parsed_url.hostname + if url.port is None: + dest_url = url.host else: - dest_url = "{}:{}".format(parsed_url.hostname, parsed_url.port) + dest_url = "{}:{}".format(url.host, url.port) if utils.disable_tracing_hostname(dest_url, excludelist_hostnames): return wrapped(*args, **kwargs) - path = parsed_url.path if parsed_url.path else "/" + path = url.path if url.path else "/" _tracer = execution_context.get_opencensus_tracer() _span = _tracer.start_span() @@ -77,7 +77,7 @@ def wrap_client_request(wrapped, instance, args, kwargs): try: tracer_headers = _tracer.propagator.to_headers(_tracer.span_context) - kwargs.setdefault("headers", {}).update(tracer_headers) + request.headers.update(tracer_headers) except Exception: # pragma: NO COVER pass diff --git a/contrib/opencensus-ext-httpx/tests/test_httpx_trace.py b/contrib/opencensus-ext-httpx/tests/test_httpx_trace.py index ddb9cb87f..d13812ec9 100644 --- a/contrib/opencensus-ext-httpx/tests/test_httpx_trace.py +++ b/contrib/opencensus-ext-httpx/tests/test_httpx_trace.py @@ -37,7 +37,7 @@ def test_trace_integration(self): noop_tracer.NoopTracer, ) mock_wrap.assert_called_once_with( - trace.MODULE_NAME, "Client.request", trace.wrap_client_request + trace.MODULE_NAME, "Client.send", trace.wrap_client_send ) def test_trace_integration_set_tracer(self): @@ -54,7 +54,7 @@ class TmpTracer(noop_tracer.NoopTracer): execution_context.get_opencensus_tracer(), TmpTracer ) mock_wrap.assert_called_once_with( - trace.MODULE_NAME, "Client.request", trace.wrap_client_request + trace.MODULE_NAME, "Client.send", trace.wrap_client_send ) def test_wrap_client_request(self): @@ -77,11 +77,11 @@ def test_wrap_client_request(self): url = "http://localhost:8080/test" request_method = "POST" - kwargs = {} + request = httpx.Request(request_method, url) with patch, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) expected_attributes = { @@ -101,7 +101,7 @@ def test_wrap_client_request(self): self.assertEqual( expected_attributes, mock_tracer.current_span.attributes ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") self.assertEqual(expected_name, mock_tracer.current_span.name) self.assertEqual( expected_status.__dict__, mock_tracer.current_span.status.__dict__ @@ -134,10 +134,11 @@ def wrapped(*args, **kwargs): url = "http://localhost/" request_method = "POST" + request = httpx.Request(request_method, url) with patch_tracer, patch_attr, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), {} + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) expected_name = "/" @@ -167,10 +168,11 @@ def wrapped(*args, **kwargs): url = "http://localhost:8080" request_method = "POST" + request = httpx.Request(request_method, url) with patch_tracer, patch_attr, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), {} + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) self.assertEqual(None, mock_tracer.current_span) @@ -199,10 +201,11 @@ def wrapped(*args, **kwargs): url = "http://localhost:8080" request_method = "POST" + request = httpx.Request(request_method, url) with patch_tracer, patch_attr, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), {} + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) self.assertEqual(None, mock_tracer.current_span) @@ -226,14 +229,14 @@ def test_header_is_passed_in(self): url = "http://localhost:8080" request_method = "POST" - kwargs = {} + request = httpx.Request(request_method, url) with patch, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") def test_headers_are_preserved(self): wrapped = mock.Mock(return_value=mock.Mock(status_code=200)) @@ -255,14 +258,15 @@ def test_headers_are_preserved(self): url = "http://localhost:8080" request_method = "POST" kwargs = {"headers": {"key": "value"}} + request = httpx.Request(request_method, url, **kwargs) with patch, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) self.assertEqual(kwargs["headers"]["key"], "value") - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") def test_tracer_headers_are_overwritten(self): wrapped = mock.Mock(return_value=mock.Mock(status_code=200)) @@ -285,13 +289,14 @@ def test_tracer_headers_are_overwritten(self): url = "http://localhost:8080" request_method = "POST" kwargs = {"headers": {"x-trace": "original-value"}} + request = httpx.Request(request_method, url, **kwargs) with patch, patch_thread: - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") def test_wrap_client_request_timeout(self): wrapped = mock.Mock(return_value=mock.Mock(status_code=200)) @@ -314,13 +319,13 @@ def test_wrap_client_request_timeout(self): url = "http://localhost:8080/test" request_method = "POST" - kwargs = {} + request = httpx.Request(request_method, url) with patch, patch_thread: with self.assertRaises(httpx.TimeoutException): # breakpoint() - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) expected_attributes = { @@ -339,7 +344,7 @@ def test_wrap_client_request_timeout(self): self.assertEqual( expected_attributes, mock_tracer.current_span.attributes ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") self.assertEqual(expected_name, mock_tracer.current_span.name) self.assertEqual( expected_status.__dict__, mock_tracer.current_span.status.__dict__ @@ -366,12 +371,12 @@ def test_wrap_client_request_invalid_url(self): url = "http://localhost:8080/test" request_method = "POST" - kwargs = {} + request = httpx.Request(request_method, url) with patch, patch_thread: with self.assertRaises(httpx.InvalidURL): - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) expected_attributes = { @@ -390,7 +395,7 @@ def test_wrap_client_request_invalid_url(self): self.assertEqual( expected_attributes, mock_tracer.current_span.attributes ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") self.assertEqual(expected_name, mock_tracer.current_span.name) self.assertEqual( expected_status.__dict__, mock_tracer.current_span.status.__dict__ @@ -417,12 +422,12 @@ def test_wrap_client_request_exception(self): url = "http://localhost:8080/test" request_method = "POST" - kwargs = {} + request = httpx.Request(request_method, url) with patch, patch_thread: with self.assertRaises(httpx.TooManyRedirects): - trace.wrap_client_request( - wrapped, "Client.request", (request_method, url), kwargs + trace.wrap_client_send( + wrapped, "Client.send", (request,), {} ) expected_attributes = { @@ -441,7 +446,7 @@ def test_wrap_client_request_exception(self): self.assertEqual( expected_attributes, mock_tracer.current_span.attributes ) - self.assertEqual(kwargs["headers"]["x-trace"], "some-value") + self.assertEqual(request.headers["x-trace"], "some-value") self.assertEqual(expected_name, mock_tracer.current_span.name) self.assertEqual( expected_status.__dict__, mock_tracer.current_span.status.__dict__