diff --git a/docs/libraries.rst b/docs/libraries.rst index 7a34af9a..b4b03671 100644 --- a/docs/libraries.rst +++ b/docs/libraries.rst @@ -221,6 +221,16 @@ In the background, this action uses `wait_for_data() None: + rclpy.init() + self.parser = OpenScenario2Parser(Logger('test', False)) + self.scenario_execution = ScenarioExecution(debug=False, + log_model=False, + live_tree=False, + scenario_file='test', + output_dir='') + + def tearDown(self): + rclpy.try_shutdown() + + def test_success(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float, max_val: float) -> float is external scenario_execution.external_methods.random.get_float() + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1, max_val: 3)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + scenarios = create_py_tree(model, self.parser.logger, False) + self.scenario_execution.scenarios = scenarios + self.scenario_execution.run() + self.assertTrue(self.scenario_execution.process_results()) + + def test_not_external(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float, max_val: float) -> float is undefined + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1, max_val: 3)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + self.assertRaises(ValueError, self.parser.create_internal_model, parsed_tree, "test.osc") + + def test_defined_parameters_differ(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float) -> float is external scenario_execution.external_methods.random.get_float() + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + self.parser.create_internal_model(parsed_tree, "test.osc", False) + self.assertFalse(self.scenario_execution.process_results()) + + def test_call_parameters_differ(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float, max_val: float) -> float is external scenario_execution.external_methods.random.get_float() + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + self.parser.create_internal_model(parsed_tree, "test.osc", False) + self.assertFalse(self.scenario_execution.process_results()) + + def test_call_parameters_differ_2(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float, max_val: float) -> float is external scenario_execution.external_methods.random.get_float() + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1, UNKNOWN: 3)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + self.parser.create_internal_model(parsed_tree, "test.osc", False) + self.assertFalse(self.scenario_execution.process_results()) + + def test_call_parameters_differ_3(self): + scenario_content = """ +import osc.standard.base + +struct test: + def get_float(min_val: float) -> float is external scenario_execution.external_methods.random.get_float() + +scenario test_param: + do serial: + wait elapsed(test.get_float(min_val: 1, UNKNOWN: 3)) +""" + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + self.parser.create_internal_model(parsed_tree, "test.osc", False) + self.assertFalse(self.scenario_execution.process_results()) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_message_count.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_message_count.py new file mode 100644 index 00000000..01d2dbc4 --- /dev/null +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_message_count.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import py_trees +import rclpy +from py_trees.common import Status +from scenario_execution_ros.actions.conversions import get_qos_preset_profile +import importlib + + +class RosWaitForMessageCount(py_trees.behaviour.Behaviour): + """ + Class for waiting until a certain amount of messages was received + """ + + def __init__(self, name, topic_name: str, topic_type: str, qos_profile, count: int): + super().__init__(name) + + self.subscriber = None + self.node = None + self.topic_name = topic_name + self.expected_count = count + self.current_count = 0 + self.last_count_reported = 0 + + datatype_in_list = topic_type.split(".") + self.topic_type = getattr( + importlib.import_module(".".join(datatype_in_list[0:-1])), + datatype_in_list[-1] + ) + + self.qos_profile = get_qos_preset_profile(qos_profile) + self.qos_profile.depth = self.expected_count + + def setup(self, **kwargs): + """ + Setup the subscriber + """ + try: + self.node: rclpy.Node = kwargs['node'] + except KeyError as e: + error_message = "didn't find 'node' in setup's kwargs [{}][{}]".format( + self.name, self.__class__.__name__) + raise KeyError(error_message) from e + + self.subscriber = self.node.create_subscription( # pylint: disable= attribute-defined-outside-init + msg_type=self.topic_type, + topic=self.topic_name, + callback=self._callback, + qos_profile=self.qos_profile, + callback_group=rclpy.callback_groups.MutuallyExclusiveCallbackGroup() + ) + self.feedback_message = f"Waiting for messages on {self.topic_name}" # pylint: disable= attribute-defined-outside-init + + def update(self) -> py_trees.common.Status: + if self.current_count >= self.expected_count: + self.feedback_message = f"Received expected ({self.expected_count}) messages." # pylint: disable= attribute-defined-outside-init + return Status.SUCCESS + else: + if self.last_count_reported != self.current_count: + self.feedback_message = f"Received {self.current_count} of expected {self.expected_count}..." # pylint: disable= attribute-defined-outside-init + self.last_count_reported = self.current_count + return Status.RUNNING + + def _callback(self, _): + self.current_count += 1 diff --git a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc index f45da696..7a9217db 100644 --- a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc +++ b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc @@ -144,6 +144,13 @@ action wait_for_data: qos_profile: qos_preset_profiles = qos_preset_profiles!system_default # qos profile for the subscriber clearing_policy: clearing_policy = clearing_policy!on_initialise # when to clear the data +action wait_for_message_count: + # Wait until a certain amount of messages was received. + topic_name: string # name of the topic to connect to + topic_type: string # class of the message type (e.g. std_msgs.msg.String) + qos_profile: qos_preset_profiles = qos_preset_profiles!system_default # qos profile for the subscriber + count: uint # amount of messages to wait for + action wait_for_topics: # wait for topics to get available topics: list of topics diff --git a/scenario_execution_ros/setup.py b/scenario_execution_ros/setup.py index 13b2dfb6..06fd743b 100644 --- a/scenario_execution_ros/setup.py +++ b/scenario_execution_ros/setup.py @@ -64,6 +64,7 @@ 'assert_topic_latency = scenario_execution_ros.actions.assert_topic_latency:AssertTopicLatency', 'assert_tf_moving = scenario_execution_ros.actions.assert_tf_moving:AssertTfMoving', 'assert_lifecycle_state = scenario_execution_ros.actions.assert_lifecycle_state:AssertLifecycleState', + 'wait_for_message_count = scenario_execution_ros.actions.ros_topic_wait_for_message_count:RosWaitForMessageCount', ], 'scenario_execution.osc_libraries': [ 'ros = ' diff --git a/scenario_execution_ros/test/test_wait_for_message_count.py b/scenario_execution_ros/test/test_wait_for_message_count.py new file mode 100644 index 00000000..66b93fbe --- /dev/null +++ b/scenario_execution_ros/test/test_wait_for_message_count.py @@ -0,0 +1,101 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from ament_index_python.packages import get_package_share_directory + +import rclpy + +from scenario_execution_ros import ROSScenarioExecution +from scenario_execution.model.osc2_parser import OpenScenario2Parser +from scenario_execution.model.model_to_py_tree import create_py_tree +from scenario_execution.utils.logging import Logger + +from antlr4.InputStream import InputStream + + +class TestWaitForMessageCount(unittest.TestCase): + # pylint: disable=missing-function-docstring,missing-class-docstring + + def setUp(self) -> None: + rclpy.init() + self.parser = OpenScenario2Parser(Logger('test', False)) + self.scenario_execution_ros = ROSScenarioExecution() + + self.scenario_dir = get_package_share_directory('scenario_execution_ros') + + def tearDown(self): + rclpy.try_shutdown() + + def test_sucess(self): + scenario_content = """ +import osc.ros + +scenario test: + do parallel: + test: serial: + wait elapsed(1s) + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + receive: serial: + wait_for_message_count( + topic_name: '/bla', + topic_type: 'std_msgs.msg.Bool', + count: 5) + emit end + time_out: serial: + wait elapsed(10s) + emit fail +""" + + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + scenarios = create_py_tree(model, self.parser.logger, False) + self.scenario_execution_ros.scenarios = scenarios + self.scenario_execution_ros.run() + self.assertTrue(self.scenario_execution_ros.process_results()) + + def test_failure_1(self): + scenario_content = """ +import osc.ros + +scenario test: + do parallel: + test: serial: + wait elapsed(1s) + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + topic_publish('/bla', 'std_msgs.msg.Bool', value: '{\\\"data\\\": True}') + receive: serial: + wait_for_message_count( + topic_name: '/bla', + topic_type: 'std_msgs.msg.Bool', + count: 5) + emit end + time_out: serial: + wait elapsed(5s) + emit fail +""" + + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + scenarios = create_py_tree(model, self.parser.logger, False) + self.scenario_execution_ros.scenarios = scenarios + self.scenario_execution_ros.run() + self.assertFalse(self.scenario_execution_ros.process_results())