diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 9a3c90ea84d..ceacb8df80f 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -283,6 +283,30 @@ def tune( if max_failed_trial_count is not None: experiment.spec.max_failed_trial_count = max_failed_trial_count + # Iterate over input parameters. + input_params = {} + experiment_params = [] + trial_params = [] + base_image = constants.BASE_IMAGE_TENSORFLOW, + + for p_name, p_value in parameters.items(): + # If input parameter value is Katib Experiment parameter sample. + if isinstance(p_value, models.V1beta1ParameterSpec): + # Wrap value for the function input. + input_params[p_name] = f"${{trialParameters.{p_name}}}" + + # Add value to the Katib Experiment parameters. + p_value.name = p_name + experiment_params.append(p_value) + + # Add value to the Katib Experiment's Trial parameters. + trial_params.append( + models.V1beta1TrialParameterSpec(name=p_name, reference=p_name) + ) + else: + # Otherwise, add value to the function input. + input_params[p_name] = p_value + # Handle different types of objective input if callable(objective): # Validate objective function. @@ -295,29 +319,6 @@ def tune( # (e.g. in another function). We need to dedent the function code. objective_code = textwrap.dedent(objective_code) - # Iterate over input parameters. - input_params = {} - experiment_params = [] - trial_params = [] - base_image = constants.BASE_IMAGE_TENSORFLOW, - for p_name, p_value in parameters.items(): - # If input parameter value is Katib Experiment parameter sample. - if isinstance(p_value, models.V1beta1ParameterSpec): - # Wrap value for the function input. - input_params[p_name] = f"${{trialParameters.{p_name}}}" - - # Add value to the Katib Experiment parameters. - p_value.name = p_name - experiment_params.append(p_value) - - # Add value to the Katib Experiment's Trial parameters. - trial_params.append( - models.V1beta1TrialParameterSpec(name=p_name, reference=p_name) - ) - else: - # Otherwise, add value to the function input. - input_params[p_name] = p_value - # Wrap objective function to execute it from the file. For example # def objective(parameters): # print(f'Parameters are {parameters}') @@ -407,12 +408,12 @@ def tune( trial_template = models.V1beta1TrialTemplate( primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, retain=retain_trials, - trial_parameters=trial_params if callable(objective) else [], + trial_parameters=trial_params, trial_spec=trial_spec, ) # Add parameters to the Katib Experiment. - experiment.spec.parameters = experiment_params if callable(objective) else [] + experiment.spec.parameters = experiment_params # Add Trial template to the Katib Experiment. experiment.spec.trial_template = trial_template