Skip to content

Commit

Permalink
initial logic update
Browse files Browse the repository at this point in the history
Signed-off-by: Shashank Mittal <[email protected]>
  • Loading branch information
shashank-iitbhu committed Jan 30, 2025
1 parent c9ff7d8 commit f263b26
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions pkg/suggestion/v1beta1/optuna/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import optuna

from pkg.apis.manager.v1beta1.python import api_pb2
from pkg.suggestion.v1beta1.internal.constant import (
CATEGORICAL,
DISCRETE,
Expand Down Expand Up @@ -110,13 +111,37 @@ def _get_optuna_search_space(self):
search_space = {}
for param in self.search_space.params:
if param.type == INTEGER:
search_space[param.name] = optuna.distributions.IntDistribution(
int(param.min), int(param.max)
)
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
if param.step:
search_space[param.name] = optuna.distributions.IntDistribution(
int(param.min), int(param.max), False, param.step
)
else:
search_space[param.name] = optuna.distributions.IntDistribution(
int(param.min), int(param.max)
)
if param.distribution == api_pb2.LOG_UNIFORM:
search_space[param.name] = optuna.distributions.IntDistribution(
int(param.min), int(param.max), True, param.step
)
elif param.type == DOUBLE:
search_space[param.name] = optuna.distributions.FloatDistribution(
float(param.min), float(param.max)
)
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
if param.step:
search_space[param.name] = (
optuna.distributions.FloatDistribution(
int(param.min), int(param.max), False, param.step
)
)
else:
search_space[param.name] = (
optuna.distributions.FloatDistribution(
int(param.min), int(param.max)
)
)
if param.distribution == api_pb2.LOG_UNIFORM:
search_space[param.name] = optuna.distributions.FloatDistribution(
int(param.min), int(param.max), True, param.step
)
elif param.type == CATEGORICAL or param.type == DISCRETE:
search_space[param.name] = optuna.distributions.CategoricalDistribution(
param.list
Expand Down

0 comments on commit f263b26

Please sign in to comment.