Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpsampler #2995

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions plugins/hydra_optuna_sweeper/NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
1.4.0 (2024-12-01)
======================

### Features

- Updated to be compatible with Optuna 3.6.0+
- Added GPSampler support for Gaussian Process based optimization ([Optuna GPSampler](https://optuna.readthedocs.io/en/stable/reference/samplers/generated/optuna.samplers.GPSampler.html))

1.2.0 (2022-05-17)
======================

Expand Down
5 changes: 5 additions & 0 deletions plugins/hydra_optuna_sweeper/example/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import hydra
from omegaconf import DictConfig

import warnings
from optuna.exceptions import ExperimentalWarning

warnings.filterwarnings("ignore", category=ExperimentalWarning)


@hydra.main(version_base=None, config_path="conf", config_name="config")
def sphere(cfg: DictConfig) -> float:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

__version__ = "1.4.0.dev0"
__version__ = "1.4.0.dev1"
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@
BaseDistribution,
CategoricalChoiceType,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
FloatDistribution,
IntDistribution,
)
from optuna.trial import Trial

Expand All @@ -62,17 +59,17 @@ def create_optuna_distribution_from_config(
assert param.low is not None
assert param.high is not None
if param.log:
return IntLogUniformDistribution(int(param.low), int(param.high))
return IntDistribution(int(param.low), int(param.high), log=True)
step = int(param.step) if param.step is not None else 1
return IntUniformDistribution(int(param.low), int(param.high), step=step)
return IntDistribution(int(param.low), int(param.high), step=step)
if param.type == DistributionType.float:
assert param.low is not None
assert param.high is not None
if param.log:
return LogUniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, log=True)
if param.step is not None:
return DiscreteUniformDistribution(param.low, param.high, param.step)
return UniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, step=param.step)
return FloatDistribution(param.low, param.high)
raise NotImplementedError(f"{param.type} is not supported by Optuna sweeper.")


Expand Down Expand Up @@ -107,23 +104,21 @@ def create_optuna_distribution_from_override(override: Override) -> Any:
or isinstance(value.stop, float)
or isinstance(value.step, float)
):
return DiscreteUniformDistribution(value.start, value.stop, value.step)
return IntUniformDistribution(
int(value.start), int(value.stop), step=int(value.step)
)
return FloatDistribution(value.start, value.stop, step=value.step)
return IntDistribution(int(value.start), int(value.stop), step=int(value.step))

if override.is_interval_sweep():
assert isinstance(value, IntervalSweep)
assert value.start is not None
assert value.end is not None
if "log" in value.tags:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntLogUniformDistribution(int(value.start), int(value.end))
return LogUniformDistribution(value.start, value.end)
return IntDistribution(int(value.start), int(value.end), log=True)
return FloatDistribution(value.start, value.end, log=True)
else:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntUniformDistribution(value.start, value.end)
return UniformDistribution(value.start, value.end)
return IntDistribution(value.start, value.end)
return FloatDistribution(value.start, value.end)

raise NotImplementedError(f"{override} is not supported by Optuna sweeper.")

Expand Down Expand Up @@ -266,15 +261,18 @@ def _parse_sweeper_params_config(self) -> List[str]:
def _to_grid_sampler_choices(self, distribution: BaseDistribution) -> Any:
if isinstance(distribution, CategoricalDistribution):
return distribution.choices
elif isinstance(distribution, IntUniformDistribution):
elif isinstance(distribution, IntDistribution):
assert (
distribution.step is not None
), "`step` of IntUniformDistribution must be a positive integer."
), "`step` of IntDistribution must be a positive integer."
n_items = (distribution.high - distribution.low) // distribution.step
return [distribution.low + i * distribution.step for i in range(n_items)]
elif isinstance(distribution, DiscreteUniformDistribution):
n_items = int((distribution.high - distribution.low) // distribution.q)
return [distribution.low + i * distribution.q for i in range(n_items)]
elif (
isinstance(distribution, FloatDistribution)
and distribution.step is not None
):
n_items = int((distribution.high - distribution.low) // distribution.step)
return [distribution.low + i * distribution.step for i in range(n_items)]
else:
raise ValueError("GridSampler only supports discrete distributions.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ class MOTPESamplerConfig(SamplerConfig):
n_ehvi_candidates: int = 24


@dataclass
class GPSamplerConfig(SamplerConfig):
"""
https://optuna.readthedocs.io/en/stable/reference/samplers/generated/optuna.samplers.GPSampler.html
"""

_target_: str = "optuna.samplers.GPSampler"
seed: Optional[int] = None

n_startup_trials: int = 10
deterministic_objective = False


@dataclass
class DistributionConfig:
# Type of distribution. "int", "float" or "categorical"
Expand Down Expand Up @@ -234,3 +247,10 @@ class OptunaSweeperConf:
node=GridSamplerConfig,
provider="optuna_sweeper",
)

ConfigStore.instance().store(
group="hydra/sweeper/sampler",
name="gp",
node=GPSamplerConfig,
provider="optuna_sweeper",
)
5 changes: 3 additions & 2 deletions plugins/hydra_optuna_sweeper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
],
install_requires=[
"hydra-core>=1.1.0.dev7",
"optuna>=2.10.0,<3.0.0",
"sqlalchemy~=1.3.0", # TODO: Unpin when upgrading to optuna v3.0
"optuna>=3.6.0",
"torch",
"scipy",
],
include_package_data=True,
)
40 changes: 19 additions & 21 deletions plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
from optuna.distributions import (
BaseDistribution,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
IntDistribution,
FloatDistribution,
)
from optuna.samplers import RandomSampler
from pytest import mark, warns
Expand Down Expand Up @@ -59,24 +56,24 @@ def check_distribution(expected: BaseDistribution, actual: BaseDistribution) ->
{"type": "categorical", "choices": [1, 2, 3]},
CategoricalDistribution([1, 2, 3]),
),
({"type": "int", "low": 0, "high": 10}, IntUniformDistribution(0, 10)),
({"type": "int", "low": 0, "high": 10}, IntDistribution(0, 10)),
(
{"type": "int", "low": 0, "high": 10, "step": 2},
IntUniformDistribution(0, 10, step=2),
IntDistribution(0, 10, step=2),
),
({"type": "int", "low": 0, "high": 5}, IntUniformDistribution(0, 5)),
({"type": "int", "low": 0, "high": 5}, IntDistribution(0, 5)),
(
{"type": "int", "low": 1, "high": 100, "log": True},
IntLogUniformDistribution(1, 100),
IntDistribution(1, 100, log=True),
),
({"type": "float", "low": 0, "high": 1}, UniformDistribution(0, 1)),
({"type": "float", "low": 0, "high": 1}, FloatDistribution(0, 1)),
(
{"type": "float", "low": 0, "high": 10, "step": 2},
DiscreteUniformDistribution(0, 10, 2),
FloatDistribution(0, 10, step=2),
),
(
{"type": "float", "low": 1, "high": 100, "log": True},
LogUniformDistribution(1, 100),
FloatDistribution(1, 100, log=True),
),
],
)
Expand All @@ -92,12 +89,12 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No
("key=choice(true, false)", CategoricalDistribution([True, False])),
("key=choice('hello', 'world')", CategoricalDistribution(["hello", "world"])),
("key=shuffle(range(1,3))", CategoricalDistribution((1, 2))),
("key=range(1,3)", IntUniformDistribution(1, 3)),
("key=interval(1, 5)", UniformDistribution(1, 5)),
("key=int(interval(1, 5))", IntUniformDistribution(1, 5)),
("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)),
("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)),
("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)),
("key=range(1,3)", IntDistribution(1, 3)),
("key=interval(1, 5)", FloatDistribution(1, 5)),
("key=int(interval(1, 5))", IntDistribution(1, 5)),
("key=tag(log, interval(1, 5))", FloatDistribution(1, 5, log=True)),
("key=tag(log, int(interval(1, 5)))", IntDistribution(1, 5, log=True)),
("key=range(0.5, 5.5, step=1)", FloatDistribution(0.5, 5.5, step=1)),
],
)
def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None:
Expand All @@ -121,7 +118,7 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) ->
(
{
"key1": CategoricalDistribution([1, 2]),
"key3": IntUniformDistribution(1, 3),
"key3": IntDistribution(1, 3),
},
{"key2": "5"},
),
Expand Down Expand Up @@ -152,7 +149,8 @@ def test_launch_jobs(hydra_sweep_runner: TSweepRunner) -> None:


@mark.parametrize("with_commandline", (True, False))
def test_optuna_example(with_commandline: bool, tmpdir: Path) -> None:
@mark.parametrize("sampler", ("tpe", "gp"))
def test_optuna_example(with_commandline: bool, sampler: str, tmpdir: Path) -> None:
storage = "sqlite:///" + os.path.join(str(tmpdir), "test.db")
study_name = "test-optuna-example"
cmd = [
Expand All @@ -164,7 +162,7 @@ def test_optuna_example(with_commandline: bool, tmpdir: Path) -> None:
"hydra.sweeper.n_jobs=1",
f"hydra.sweeper.storage={storage}",
f"hydra.sweeper.study_name={study_name}",
"hydra/sweeper/sampler=tpe",
f"hydra/sweeper/sampler={sampler}",
"hydra.sweeper.sampler.seed=123",
"~z",
]
Expand Down
17 changes: 14 additions & 3 deletions website/docs/plugins/optuna_sweeper.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ best_value: 0.0
This plugin supports Optuna's [samplers](https://optuna.readthedocs.io/en/stable/reference/samplers.html).
You can change the sampler by overriding `hydra/sweeper/sampler` or change sampler settings within `hydra.sweeper.sampler`.

### Experimental GPSampler

If you want to use the GPSampler, you need to suppress the warnings from Optuna. You can do this by adding the following code to your script:

```python
import warnings
from optuna.exceptions import ExperimentalWarning

warnings.filterwarnings("ignore", category=ExperimentalWarning)
```

## Search space configuration

This plugin supports Optuna's [distributions](https://optuna.readthedocs.io/en/stable/reference/distributions.html) to configure search spaces. They can be defined either through commandline override or config file.
Expand All @@ -119,7 +130,7 @@ Hydra provides a override parser that support rich syntax. Please refer to [Over

#### Interval override

By default, `interval` is converted to [`UniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html). You can use [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html), [`LogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html) or [`IntLogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html) by casting the interval to `int` and tagging it with `log`.
By default, `interval` is converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html). You can use [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html) by casting the interval to `int`.

<details><summary>Example for interval override</summary>

Expand Down Expand Up @@ -147,8 +158,8 @@ The output is as follows:

#### Range override

`range` is converted to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html)
`range` is converted to [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html)

<details><summary>Example for range override</summary>

Expand Down