diff --git a/ros2param/ros2param/api/__init__.py b/ros2param/ros2param/api/__init__.py index ef03a2456..60f46e7e0 100644 --- a/ros2param/ros2param/api/__init__.py +++ b/ros2param/ros2param/api/__init__.py @@ -50,6 +50,10 @@ def load_parameter_file(*, node, node_name, parameter_file, use_wildcard): parameters = list(parameter_dict_from_yaml_file(parameter_file, use_wildcard).values()) rclpy.spin_until_future_complete(node, future) response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + assert len(response.results) == len(parameters), 'Not all parameters set' for i in range(0, len(response.results)): result = response.results[i] @@ -66,6 +70,26 @@ def load_parameter_file(*, node, node_name, parameter_file, use_wildcard): print(msg, file=sys.stderr) +def load_parameter_file_atomically(*, node, node_name, parameter_file, use_wildcard): + client = AsyncParameterClient(node, node_name) + ready = client.wait_for_services(timeout_sec=5.0) + if not ready: + raise RuntimeError('Wait for service timed out') + future = client.load_parameter_file_atomically(parameter_file, use_wildcard) + parameters = list(parameter_dict_from_yaml_file(parameter_file, use_wildcard).values()) + rclpy.spin_until_future_complete(node, future) + response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + + if response.result.successful: + msg = 'Set parameters {} successful'.format(' '.join([i.name for i in parameters])) + if response.result.reason: + msg += ': ' + response.result.reason + print(msg) + + def call_describe_parameters(*, node, node_name, parameter_names=None): client = AsyncParameterClient(node, node_name) ready = client.wait_for_services(timeout_sec=5.0) @@ -96,6 +120,18 @@ def call_get_parameters(*, node, node_name, parameter_names): return response +def call_set_parameters_atomically(*, node, node_name, parameters): + client = AsyncParameterClient(node, node_name) + client.wait_for_services(timeout_sec=5.0) + future = client.set_parameters_atomically(parameters) + rclpy.spin_until_future_complete(node, future) + response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + return response + + def call_set_parameters(*, node, node_name, parameters): client = AsyncParameterClient(node, node_name) ready = client.wait_for_services(timeout_sec=5.0) diff --git a/ros2param/ros2param/verb/load.py b/ros2param/ros2param/verb/load.py index e260f81f6..56041aa28 100644 --- a/ros2param/ros2param/verb/load.py +++ b/ros2param/ros2param/verb/load.py @@ -19,6 +19,7 @@ from ros2node.api import get_node_names from ros2node.api import NodeNameCompleter from ros2param.api import load_parameter_file +from ros2param.api import load_parameter_file_atomically from ros2param.verb import VerbExtension @@ -39,6 +40,9 @@ def add_arguments(self, parser, cli_name): # noqa: D102 parser.add_argument( '--no-use-wildcard', action='store_true', help="Do not load parameters in the '/**' namespace into the node") + parser.add_argument( + '--atomic', action='store_true', + help='Load parameters atomically') def main(self, *, args): # noqa: D102 with NodeStrategy(args) as node: @@ -50,5 +54,11 @@ def main(self, *, args): # noqa: D102 return 'Node not found' with DirectNode(args) as node: - load_parameter_file(node=node, node_name=node_name, parameter_file=args.parameter_file, - use_wildcard=not args.no_use_wildcard) + if args.atomic: + load_parameter_file_atomically(node=node, node_name=node_name, + parameter_file=args.parameter_file, + use_wildcard=not args.no_use_wildcard) + else: + load_parameter_file(node=node, node_name=node_name, + parameter_file=args.parameter_file, + use_wildcard=not args.no_use_wildcard) diff --git a/ros2param/ros2param/verb/set.py b/ros2param/ros2param/verb/set.py index 86c29fa06..dbce9dd57 100644 --- a/ros2param/ros2param/verb/set.py +++ b/ros2param/ros2param/verb/set.py @@ -24,6 +24,7 @@ from ros2node.api import NodeNameCompleter from ros2param.api import call_set_parameters +from ros2param.api import call_set_parameters_atomically from ros2param.api import ParameterNameCompleter from ros2param.verb import VerbExtension @@ -37,14 +38,33 @@ def add_arguments(self, parser, cli_name): # noqa: D102 'node_name', help='Name of the ROS node') arg.completer = NodeNameCompleter( include_hidden_nodes_key='include_hidden_nodes') + + arg = parser.add_argument( + 'parameters', nargs='*', + help='List of parameter name and value pairs i.e. "int_param 1 str_param hello_world"') + arg.completer = ParameterNameCompleter() + parser.add_argument( '--include-hidden-nodes', action='store_true', help='Consider hidden nodes as well') - arg = parser.add_argument( - 'parameter_name', help='Name of the parameter') - arg.completer = ParameterNameCompleter() + parser.add_argument( - 'value', help='Value of the parameter') + '--atomic', action='store_true', + help='Set parameters atomically') + + def build_parameters(self, params): + parameters = [] + if len(params) % 2: + raise RuntimeError('Must pass list of parameter name and value pairs') + + params = [(params[i], params[i+1]) for i in range(0, len(params), 2)] + for param_str in params: + parameter = Parameter() + parameter.name = param_str[0] + parameter.value = get_parameter_value(string_value=param_str[1]) + parameters.append(parameter) + + return parameters def main(self, *, args): # noqa: D102 with NodeStrategy(args) as node: @@ -56,23 +76,24 @@ def main(self, *, args): # noqa: D102 return 'Node not found' with DirectNode(args) as node: - parameter = Parameter() - Parameter.name = args.parameter_name - parameter.value = get_parameter_value(string_value=args.value) - - response = call_set_parameters( - node=node, node_name=args.node_name, parameters=[parameter]) - - # output response - assert len(response.results) == 1 - result = response.results[0] - if result.successful: - msg = 'Set parameter successful' - if result.reason: - msg += ': ' + result.reason - print(msg) + parameters = self.build_parameters(args.parameters) + if args.atomic: + response = call_set_parameters_atomically(node=node, node_name=args.node_name, + parameters=parameters) + results = [response.result] else: - msg = 'Setting parameter failed' - if result.reason: - msg += ': ' + result.reason - print(msg, file=sys.stderr) + response = call_set_parameters(node=node, node_name=args.node_name, + parameters=parameters) + results = response.results + + for result in results: + if result.successful: + msg = 'Set parameter successful' + if result.reason: + msg += ': ' + result.reason + print(msg) + else: + msg = 'Setting parameter failed' + if result.reason: + msg += ': ' + result.reason + print(msg, file=sys.stderr) diff --git a/ros2param/test/test_verb_load.py b/ros2param/test/test_verb_load.py index e5619399a..a8ec307cc 100644 --- a/ros2param/test/test_verb_load.py +++ b/ros2param/test/test_verb_load.py @@ -294,6 +294,31 @@ def test_verb_load(self): strict=True ) + def test_verb_load_atomic(self): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = self._write_param_file(tmpdir, 'params.yaml') + with self.launch_param_load_command( + arguments=[f'{TEST_NAMESPACE}/{TEST_NODE}', filepath, '--atomic'] + ) as param_load_command: + assert param_load_command.wait_for_shutdown(timeout=TEST_TIMEOUT) + assert param_load_command.exit_code == launch_testing.asserts.EXIT_OK + assert launch_testing.tools.expect_output( + expected_lines=[''], + text=param_load_command.output, + strict=True + ) + # Dump with ros2 param dump and compare that output matches input file + with self.launch_param_dump_command( + arguments=[f'{TEST_NAMESPACE}/{TEST_NODE}'] + ) as param_dump_command: + assert param_dump_command.wait_for_shutdown(timeout=TEST_TIMEOUT) + assert param_dump_command.exit_code == launch_testing.asserts.EXIT_OK + assert launch_testing.tools.expect_output( + expected_text=INPUT_PARAMETER_FILE + '\n', + text=param_dump_command.output, + strict=True + ) + def test_verb_load_wildcard(self): with tempfile.TemporaryDirectory() as tmpdir: # Try param file with only wildcard