Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ def __call__(
)()


def create_hv_scalarization(
num_scalarizations: int, labels: types.PaddedArray, rng: jax.Array
):
"""Creates a HyperVolumeScalarization with random weights.

Args:
num_scalarizations: The number of scalarizations to create.
labels: The labels used to create the reference point.
rng: The random key to use for sampling the weights.

Returns:
A HyperVolumeScalarization with random weights.
"""
weights = jax.random.normal(
rng,
shape=(num_scalarizations, labels.shape[1]),
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
ref_point = (
get_reference_point(labels, scale=0.01) if labels.shape[0] > 0 else None
)
return scalarization.HyperVolumeScalarization(weights, ref_point)


# TODO: What do we end up jitting? If we end up directly jitting this call
# then we should make it `eqx.Module` and set
# `reduction_fn=eqx.field(static=True)` instead.
Expand Down
18 changes: 4 additions & 14 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.designers import scalarization
from vizier._src.algorithms.designers.gp import acquisitions as acq_lib
from vizier._src.algorithms.designers.gp import gp_models
from vizier._src.algorithms.designers.gp import output_warpers
Expand Down Expand Up @@ -202,27 +201,18 @@ def __attrs_post_init__(self):
# Multi-objective overrides.
m_info = self._problem.metric_information
if not m_info.is_single_objective:
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))

# Create scalarization weights.
self._rng, weights_rng = jax.random.split(self._rng)
weights = jax.random.normal(
weights_rng, shape=(self._num_scalarizations, num_obj)
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
ref_point = (
acq_lib.get_reference_point(data.labels, self._ref_scaling)
if has_labels
else None
scalarizer = acq_lib.create_hv_scalarization(
self._num_scalarizations, data.labels, weights_rng
)
scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)

labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
max_scalarized = None
if has_labels:
max_scalarized = jnp.max(scalarizer(labels_array), axis=-1)
Expand Down
Loading
Loading