diff --git a/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako b/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako index bea48c97b..a8dad7e94 100644 --- a/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako +++ b/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako @@ -62,11 +62,14 @@ class ${class_name}: grpc_channel_pool: An optional gRPC channel pool. """ - self._initialization_lock = threading.Lock() + self._initialization_lock = threading.RLock() self._service_class = ${service_class | repr} self._grpc_channel_pool = grpc_channel_pool self._discovery_client = discovery_client self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None + self._measure_response: Optional[ + Generator[v2_measurement_service_pb2.MeasureResponse, None, None] + ] = None self._configuration_metadata = ${configuration_metadata} self._output_metadata = ${output_metadata} if grpc_channel is not None: @@ -142,6 +145,7 @@ class ${class_name}: configuration_parameters=serialized_configuration ) + % if output_metadata: def _deserialize_response( self, response: v2_measurement_service_pb2.MeasureResponse ) -> Outputs: @@ -157,6 +161,7 @@ class ${class_name}: result[k - 1] = v return Outputs._make(result) + % endif def measure( self, ${configuration_parameters_with_type_and_default_values} @@ -166,12 +171,16 @@ class ${class_name}: Returns: Measurement outputs. """ - parameter_values = [${measure_api_parameters}] - request = self._create_measure_request(parameter_values) - - for response in self._get_stub().Measure(request): - result = self._deserialize_response(response) + stream_measure_response = self.stream_measure( + ${measure_api_parameters} + ) + for response in stream_measure_response: + % if output_metadata: + result = response return result + % else: + pass + % endif def stream_measure( self, @@ -183,7 +192,33 @@ class ${class_name}: Stream of measurement outputs. """ parameter_values = [${measure_api_parameters}] - request = self._create_measure_request(parameter_values) - - for response in self._get_stub().Measure(request): - yield self._deserialize_response(response) + with self._initialization_lock: + if self._measure_response is not None: + raise RuntimeError( + "A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance." + ) + request = self._create_measure_request(parameter_values) + self._measure_response = self._get_stub().Measure(request) + + try: + for response in self._measure_response: + % if output_metadata: + yield self._deserialize_response(response) + % else: + yield + % endif + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + _logger.debug("The measurement is canceled.") + raise + finally: + with self._initialization_lock: + self._measure_response = None + + def cancel(self) -> bool: + """Cancels the active measurement call.""" + with self._initialization_lock: + if self._measure_response: + return self._measure_response.cancel() + else: + return False diff --git a/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py b/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py index 5d085f366..12ea1f509 100644 --- a/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py +++ b/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py @@ -15,7 +15,7 @@ def test___measurement_plugin_client___measure___returns_output( measurement_plugin_client_module: ModuleType, ) -> None: test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") - output_type = getattr(measurement_plugin_client_module, "Output") + output_type = getattr(measurement_plugin_client_module, "Outputs") expected_output = output_type( float_out=0.05999999865889549, double_array_out=[0.1, 0.2, 0.3], @@ -40,7 +40,7 @@ def test___measurement_plugin_client___stream_measure___returns_output( measurement_plugin_client_module: ModuleType, ) -> None: test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") - output_type = getattr(measurement_plugin_client_module, "Output") + output_type = getattr(measurement_plugin_client_module, "Outputs") expected_output = output_type( float_out=0.05999999865889549, double_array_out=[0.1, 0.2, 0.3], diff --git a/packages/generator/tests/acceptance/test_streaming_measurement_client.py b/packages/generator/tests/acceptance/test_streaming_measurement_client.py index ad08b084b..0dfa194f1 100644 --- a/packages/generator/tests/acceptance/test_streaming_measurement_client.py +++ b/packages/generator/tests/acceptance/test_streaming_measurement_client.py @@ -1,8 +1,10 @@ +import concurrent.futures import importlib.util import pathlib from types import ModuleType from typing import Generator +import grpc import pytest from ni_measurement_plugin_sdk_service.measurement.service import MeasurementService @@ -15,7 +17,7 @@ def test___measurement_plugin_client___measure___returns_output( measurement_plugin_client_module: ModuleType, ) -> None: test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") - output_type = getattr(measurement_plugin_client_module, "Output") + output_type = getattr(measurement_plugin_client_module, "Outputs") expected_output = output_type( name="", index=9, @@ -32,7 +34,7 @@ def test___measurement_plugin_client___stream_measure___returns_output( measurement_plugin_client_module: ModuleType, ) -> None: test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") - output_type = getattr(measurement_plugin_client_module, "Output") + output_type = getattr(measurement_plugin_client_module, "Outputs") measurement_plugin_client = test_measurement_client_type() response_iterator = measurement_plugin_client.stream_measure() @@ -49,6 +51,64 @@ def test___measurement_plugin_client___stream_measure___returns_output( assert responses == expected_output +def test___measurement_plugin_client___invoke_measure_from_two_threads___initiates_first_measure_and_rejects_second_measure( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + with pytest.raises(RuntimeError) as exc_info: + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + future_measure_1 = executor.submit(measurement_plugin_client.measure) + future_measure_2 = executor.submit(measurement_plugin_client.measure) + future_measure_1.result() + future_measure_2.result() + + expected_error_message = "A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance." + assert expected_error_message in exc_info.value.args[0] + + +def test___non_streaming_measurement_execution___cancel___cancels_measurement( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + with pytest.raises(grpc.RpcError) as exc_info: + with concurrent.futures.ThreadPoolExecutor() as executor: + measure = executor.submit(measurement_plugin_client.measure) + measurement_plugin_client.cancel() + measure.result() + + assert exc_info.value.code() == grpc.StatusCode.CANCELLED + + +def test___streaming_measurement_execution___cancel___cancels_measurement( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + with pytest.raises(grpc.RpcError) as exc_info: + with concurrent.futures.ThreadPoolExecutor() as executor: + measure = executor.submit(lambda: list(measurement_plugin_client.stream_measure())) + measurement_plugin_client.cancel() + measure.result() + + assert exc_info.value.code() == grpc.StatusCode.CANCELLED + + +def test___measurement_client___cancel_without_measure___returns_false( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + is_canceled = measurement_plugin_client.cancel() + + assert not is_canceled + + @pytest.fixture(scope="module") def measurement_client_directory( tmp_path_factory: pytest.TempPathFactory, diff --git a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py index 3a568508d..0bd336e8f 100644 --- a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py +++ b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py @@ -1,4 +1,4 @@ -"""Python Measurement Plug-In Client.""" +"""Generated client API for the 'Non-Streaming Data Measurement (Py)' measurement plug-in.""" import logging import threading @@ -29,8 +29,8 @@ _V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService" -class Output(NamedTuple): - """Measurement result container.""" +class Outputs(NamedTuple): + """Outputs for the 'Non-Streaming Data Measurement (Py)' measurement plug-in.""" float_out: float double_array_out: List[float] @@ -46,7 +46,7 @@ class Output(NamedTuple): class NonStreamingDataMeasurementClient: - """Client to interact with the measurement plug-in.""" + """Client for the 'Non-Streaming Data Measurement (Py)' measurement plug-in.""" def __init__( self, @@ -64,11 +64,14 @@ def __init__( grpc_channel_pool: An optional gRPC channel pool. """ - self._initialization_lock = threading.Lock() + self._initialization_lock = threading.RLock() self._service_class = "ni.tests.NonStreamingDataMeasurement_Python" self._grpc_channel_pool = grpc_channel_pool self._discovery_client = discovery_client self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None + self._measure_response: Optional[ + Generator[v2_measurement_service_pb2.MeasureResponse, None, None] + ] = None self._configuration_metadata = { 1: ParameterMetadata( display_name="Float In", @@ -366,7 +369,9 @@ def _create_measure_request( configuration_parameters=serialized_configuration ) - def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResponse) -> Output: + def _deserialize_response( + self, response: v2_measurement_service_pb2.MeasureResponse + ) -> Outputs: if self._output_metadata: result = [None] * max(self._output_metadata.keys()) else: @@ -377,7 +382,7 @@ def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResp for k, v in output_values.items(): result[k - 1] = v - return Output._make(result) + return Outputs._make(result) def measure( self, @@ -391,13 +396,13 @@ def measure( io_in: str = "resource", io_array_in: List[str] = ["resource1", "resource2"], integer_in: int = 10, - ) -> Output: - """Executes the Non-Streaming Data Measurement (Py). + ) -> Outputs: + """Perform a single measurement. Returns: - Measurement output. + Measurement outputs. """ - parameter_values = [ + stream_measure_response = self.stream_measure( float_in, double_array_in, bool_in, @@ -408,11 +413,9 @@ def measure( io_in, io_array_in, integer_in, - ] - request = self._create_measure_request(parameter_values) - - for response in self._get_stub().Measure(request): - result = self._deserialize_response(response) + ) + for response in stream_measure_response: + result = response return result def stream_measure( @@ -427,11 +430,11 @@ def stream_measure( io_in: str = "resource", io_array_in: List[str] = ["resource1", "resource2"], integer_in: int = 10, - ) -> Generator[Output, None, None]: - """Executes the Non-Streaming Data Measurement (Py). + ) -> Generator[Outputs, None, None]: + """Perform a streaming measurement. Returns: - Stream of measurement output. + Stream of measurement outputs. """ parameter_values = [ float_in, @@ -445,7 +448,28 @@ def stream_measure( io_array_in, integer_in, ] - request = self._create_measure_request(parameter_values) + with self._initialization_lock: + if self._measure_response is not None: + raise RuntimeError( + "A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance." + ) + request = self._create_measure_request(parameter_values) + self._measure_response = self._get_stub().Measure(request) + try: + for response in self._measure_response: + yield self._deserialize_response(response) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + _logger.debug("The measurement is canceled.") + raise + finally: + with self._initialization_lock: + self._measure_response = None - for response in self._get_stub().Measure(request): - yield self._deserialize_response(response) + def cancel(self) -> bool: + """Cancels the active measurement call.""" + with self._initialization_lock: + if self._measure_response: + return self._measure_response.cancel() + else: + return False