diff --git a/contrib/opencensus-ext-threading/opencensus/ext/threading/trace.py b/contrib/opencensus-ext-threading/opencensus/ext/threading/trace.py index c2824aa4f..b7b74851f 100644 --- a/contrib/opencensus-ext-threading/opencensus/ext/threading/trace.py +++ b/contrib/opencensus-ext-threading/opencensus/ext/threading/trace.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging import threading from concurrent import futures @@ -88,8 +89,14 @@ def wrap_apply_async(apply_async_func): that will be called and wrap it then add the opencensus context.""" def call(self, func, args=(), kwds={}, **kwargs): - wrapped_func = wrap_task_func(func) _tracer = execution_context.get_opencensus_tracer() + + from opencensus.trace.tracers.noop_tracer import NoopTracer + + if isinstance(_tracer, NoopTracer): + return apply_async_func(self, func, args=args, kwds={}, **kwargs) + + wrapped_func = wrap_task_func(func) propagator = binary_format.BinaryFormatPropagator() wrapped_kwargs = {} @@ -113,14 +120,21 @@ def wrap_submit(submit_func): that will be called and wrap it then add the opencensus context.""" def call(self, func, *args, **kwargs): - wrapped_func = wrap_task_func(func) _tracer = execution_context.get_opencensus_tracer() + + from opencensus.trace.tracers.noop_tracer import NoopTracer + + if isinstance(_tracer, NoopTracer): + return submit_func(self, func, *args, **kwargs) + + wrapped_func = wrap_task_func(func) propagator = binary_format.BinaryFormatPropagator() wrapped_kwargs = {} wrapped_kwargs["span_context_binary"] = propagator.to_header( _tracer.span_context ) + wrapped_kwargs["kwds"] = kwargs wrapped_kwargs["sampler"] = _tracer.sampler wrapped_kwargs["exporter"] = _tracer.exporter diff --git a/contrib/opencensus-ext-threading/tests/test_noop_tracer.py b/contrib/opencensus-ext-threading/tests/test_noop_tracer.py new file mode 100644 index 000000000..2e587169f --- /dev/null +++ b/contrib/opencensus-ext-threading/tests/test_noop_tracer.py @@ -0,0 +1,50 @@ +import unittest +from unittest.mock import patch, MagicMock + +from opencensus.trace.tracers.noop_tracer import NoopTracer +from opencensus.ext.threading.trace import wrap_submit, wrap_apply_async + + +class TestNoopTracer(unittest.TestCase): + """ + In case no OpenCensus context is present (i.e. we have a NoopTracer), do _not_ pass down tracer in apply_async + and submit; instead invoke function directly. + """ + + @patch("opencensus.ext.threading.trace.wrap_task_func") + @patch("opencensus.trace.execution_context.get_opencensus_tracer") + def test_noop_tracer_apply_async( + self, get_opencensus_tracer_mock: MagicMock, wrap_task_func_mock: MagicMock + ): + mock_tracer = NoopTracer() + get_opencensus_tracer_mock.return_value = mock_tracer + submission_function_mock = MagicMock() + original_function_mock = MagicMock() + + wrap_apply_async(submission_function_mock)(None, original_function_mock) + + # check whether invocation of original function _has_ happened + submission_function_mock.assert_called_once_with( + None, original_function_mock, args=(), kwds={} + ) + + # ensure that the function has _not_ been wrapped + wrap_task_func_mock.assert_not_called() + + @patch("opencensus.ext.threading.trace.wrap_task_func") + @patch("opencensus.trace.execution_context.get_opencensus_tracer") + def test_noop_tracer_wrap_submit( + self, get_opencensus_tracer_mock: MagicMock, wrap_task_func_mock: MagicMock + ): + mock_tracer = NoopTracer() + get_opencensus_tracer_mock.return_value = mock_tracer + submission_function_mock = MagicMock() + original_function_mock = MagicMock() + + wrap_submit(submission_function_mock)(None, original_function_mock) + + # check whether invocation of original function _has_ happened + submission_function_mock.assert_called_once_with(None, original_function_mock) + + # ensure that the function has _not_ been wrapped + wrap_task_func_mock.assert_not_called() diff --git a/contrib/opencensus-ext-threading/tests/test_tracer.py b/contrib/opencensus-ext-threading/tests/test_tracer.py new file mode 100644 index 000000000..7396ba6a1 --- /dev/null +++ b/contrib/opencensus-ext-threading/tests/test_tracer.py @@ -0,0 +1,72 @@ +import unittest +from unittest.mock import patch, MagicMock +from opencensus.ext.threading.trace import wrap_submit, wrap_apply_async + + +class TestTracer(unittest.TestCase): + """ + Ensures that sampler, exporter, propagator are passed through + in case global tracer is present. + """ + + @patch("opencensus.trace.propagation.binary_format.BinaryFormatPropagator") + @patch("opencensus.ext.threading.trace.wrap_task_func") + @patch("opencensus.trace.execution_context.get_opencensus_tracer") + def test_apply_async_context_passed( + self, + get_opencensus_tracer_mock: MagicMock, + wrap_task_func_mock: MagicMock, + binary_format_propagator_mock: MagicMock, + ): + mock_tracer = NoNoopTracerMock() + # ensure that unique object is generated + mock_tracer.sampler = MagicMock() + mock_tracer.exporter = MagicMock() + mock_tracer.propagator = MagicMock() + + get_opencensus_tracer_mock.return_value = mock_tracer + + submission_function_mock = MagicMock() + original_function_mock = MagicMock() + + wrap_apply_async(submission_function_mock)(None, original_function_mock) + + # check whether invocation of original function _has_ happened + call = submission_function_mock.call_args_list[0].kwargs + + self.assertEqual(id(call["kwds"]["sampler"]), id(mock_tracer.sampler)) + self.assertEqual(id(call["kwds"]["exporter"]), id(mock_tracer.exporter)) + self.assertEqual(id(call["kwds"]["propagator"]), id(mock_tracer.propagator)) + + @patch("opencensus.trace.propagation.binary_format.BinaryFormatPropagator") + @patch("opencensus.ext.threading.trace.wrap_task_func") + @patch("opencensus.trace.execution_context.get_opencensus_tracer") + def test_wrap_submit_context_passed( + self, + get_opencensus_tracer_mock: MagicMock, + wrap_task_func_mock: MagicMock, + binary_format_propagator_mock: MagicMock, + ): + mock_tracer = NoNoopTracerMock() + # ensure that unique object is generated + mock_tracer.sampler = MagicMock() + mock_tracer.exporter = MagicMock() + mock_tracer.propagator = MagicMock() + + get_opencensus_tracer_mock.return_value = mock_tracer + + submission_function_mock = MagicMock() + original_function_mock = MagicMock() + + wrap_submit(submission_function_mock)(None, original_function_mock) + + # check whether invocation of original function _has_ happened + call = submission_function_mock.call_args_list[0].kwargs + + self.assertEqual(id(call["sampler"]), id(mock_tracer.sampler)) + self.assertEqual(id(call["exporter"]), id(mock_tracer.exporter)) + self.assertEqual(id(call["propagator"]), id(mock_tracer.propagator)) + + +class NoNoopTracerMock(MagicMock): + pass