diff --git a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py index 978f683c1..87b39c40e 100644 --- a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py +++ b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py @@ -400,6 +400,15 @@ def init_extra_arg_cases(self) -> list: def define_invalid_arg_cases(self) -> None: super().define_invalid_arg_cases() self.define_rejected_missing_eventlogs() + self.rejected['Missing Platform argument'] = { + 'valid': False, + 'callable': partial(self.raise_validation_exception, + 'Cannot run tool cmd without platform argument. Re-run the command ' + 'providing the platform argument.'), + 'cases': [ + [ArgValueCase.UNDEFINED, ArgValueCase.IGNORE, ArgValueCase.IGNORE] + ] + } self.rejected['Cluster By Name Without Platform Hints'] = { 'valid': False, 'callable': partial(self.raise_validation_exception, diff --git a/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py b/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py index 300751242..ff4bfff35 100644 --- a/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py +++ b/user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py @@ -131,13 +131,9 @@ def test_with_platform_with_eventlogs(self, get_ut_data_dir, tool_name, csp): cost_savings_enabled=False, expected_platform=csp) - # should pass: platform not provided; event logs are provided - tool_args = self.create_tool_args_should_pass(tool_name, - eventlogs=f'{get_ut_data_dir}/eventlogs') - # for qualification, cost savings should be disabled because cluster is not provided - self.validate_tool_args(tool_name=tool_name, tool_args=tool_args, - cost_savings_enabled=False, - expected_platform=CspEnv.ONPREM) + # should fail: platform must be provided + self.create_tool_args_should_fail(tool_name, + eventlogs=f'{get_ut_data_dir}/eventlogs') @pytest.mark.parametrize('tool_name', ['qualification', 'profiling']) @pytest.mark.parametrize('csp', all_csps) @@ -150,17 +146,19 @@ def test_with_platform_with_eventlogs_with_jar_files(self, get_ut_data_dir, tool tools_jar=f'{get_ut_data_dir}/tools_mock.jar') assert tool_args['toolsJar'] == f'{get_ut_data_dir}/tools_mock.jar' - # should pass: tools_jar is correct - tool_args = self.create_tool_args_should_pass(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs', - tools_jar=f'{get_ut_data_dir}/tools_mock.jar') - assert tool_args['toolsJar'] == f'{get_ut_data_dir}/tools_mock.jar' + # should fail: platform must be provided + self.create_tool_args_should_fail(tool_name, + eventlogs=f'{get_ut_data_dir}/eventlogs', + tools_jar=f'{get_ut_data_dir}/tools_mock.jar') # should fail: tools_jar does not exist - self.create_tool_args_should_fail(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs', + self.create_tool_args_should_fail(tool_name, platform=csp, + eventlogs=f'{get_ut_data_dir}/eventlogs', tools_jar=f'{get_ut_data_dir}/tools_mock.txt') # should fail: tools_jar is not .jar extension - self.create_tool_args_should_fail(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs', + self.create_tool_args_should_fail(tool_name, platform=csp, + eventlogs=f'{get_ut_data_dir}/eventlogs', tools_jar=f'{get_ut_data_dir}/worker_info.yaml') @pytest.mark.parametrize('tool_name', ['qualification', 'profiling']) @@ -230,25 +228,15 @@ def test_with_platform_with_cluster_props(self, get_ut_data_dir, tool_name, csp, self.validate_tool_args(tool_name=tool_name, tool_args=tool_args, cost_savings_enabled=True, expected_platform=csp) - - # should pass: platform not provided; missing eventlogs should be accepted for all CSPs (except onPrem) - # because the eventlogs can be retrieved from the cluster properties - tool_args = self.create_tool_args_should_pass(tool_name, - cluster=cluster_prop_file) - # for qualification, cost savings should be enabled because cluster is provided - self.validate_tool_args(tool_name=tool_name, tool_args=tool_args, - cost_savings_enabled=True, - expected_platform=csp) else: # should fail: onprem platform cannot retrieve eventlogs from cluster properties self.create_tool_args_should_fail(tool_name, platform=csp, cluster=cluster_prop_file) - # should fail: platform not provided; defaults platform to onprem, cannot retrieve eventlogs from - # cluster properties - self.create_tool_args_should_fail(tool_name, - cluster=cluster_prop_file) + # should fail: platform must be provided for all CSPs as well as onprem + self.create_tool_args_should_fail(tool_name, + cluster=cluster_prop_file) @pytest.mark.parametrize('tool_name', ['qualification', 'profiling']) @pytest.mark.parametrize('csp,prop_path', all_cpu_cluster_props) @@ -266,14 +254,10 @@ def test_with_platform_with_cluster_props_with_eventlogs(self, get_ut_data_dir, cost_savings_enabled=CspEnv(csp) != CspEnv.ONPREM, expected_platform=csp) - # should pass: platform not provided; cluster properties and eventlogs are provided - tool_args = self.create_tool_args_should_pass(tool_name, - cluster=cluster_prop_file, - eventlogs=f'{get_ut_data_dir}/eventlogs') - # for qualification, cost savings should be enabled because cluster is provided (except for onprem) - self.validate_tool_args(tool_name=tool_name, tool_args=tool_args, - cost_savings_enabled=CspEnv(csp) != CspEnv.ONPREM, - expected_platform=csp) + # should fail: platform must be provided + self.create_tool_args_should_fail(tool_name, + cluster=cluster_prop_file, + eventlogs=f'{get_ut_data_dir}/eventlogs') @pytest.mark.parametrize('tool_name', ['profiling']) @pytest.mark.parametrize('csp', all_csps) @@ -308,18 +292,15 @@ def test_with_platform_with_autotuner_with_eventlogs(self, get_ut_data_dir, tool cost_savings_enabled=False, expected_platform=csp) - # should pass: platform not provided; autotuner properties and eventlogs are provided - tool_args = self.create_tool_args_should_pass(tool_name, - cluster=autotuner_prop_file, - eventlogs=f'{get_ut_data_dir}/eventlogs') - # cost savings should be disabled for profiling - self.validate_tool_args(tool_name=tool_name, tool_args=tool_args, - cost_savings_enabled=False, - expected_platform=CspEnv.ONPREM) + # should fail: platform must be provided + self.create_tool_args_should_fail(tool_name, + cluster=autotuner_prop_file, + eventlogs=f'{get_ut_data_dir}/eventlogs') @pytest.mark.parametrize('prop_path', [autotuner_prop_path]) def test_profiler_with_driverlog(self, get_ut_data_dir, prop_path): prof_args = AbsToolUserArgModel.create_tool_args('profiling', + platform=CspEnv.get_default(), driverlog=f'{get_ut_data_dir}/{prop_path}') assert not prof_args['requiresEventlogs'] assert prof_args['rapidOptions']['driverlog'] == f'{get_ut_data_dir}/{prop_path}'