77
77
from trieste .observer import OBJECTIVE
78
78
from trieste .space import (
79
79
Box ,
80
+ CategoricalSearchSpace ,
80
81
DiscreteSearchSpace ,
81
82
SearchSpace ,
82
83
TaggedMultiSearchSpace ,
@@ -2057,29 +2058,41 @@ def discrete_search_space() -> DiscreteSearchSpace:
2057
2058
return DiscreteSearchSpace (points )
2058
2059
2059
2060
2061
+ @pytest .fixture
2062
+ def categorical_search_space () -> CategoricalSearchSpace :
2063
+ return CategoricalSearchSpace ([10 , 3 ])
2064
+
2065
+
2060
2066
@pytest .fixture
2061
2067
def continuous_search_space () -> Box :
2062
2068
return Box ([0.0 ], [1.0 ])
2063
2069
2064
2070
2071
+ @pytest .mark .parametrize ("space_fixture" , ["discrete_search_space" , "categorical_search_space" ])
2065
2072
@pytest .mark .parametrize ("with_initialize" , [True , False ])
2066
2073
def test_fixed_trust_region_discrete_initialize (
2067
- discrete_search_space : DiscreteSearchSpace , with_initialize : bool
2074
+ space_fixture : str ,
2075
+ with_initialize : bool ,
2076
+ request : Any ,
2068
2077
) -> None :
2069
2078
"""Check that FixedTrustRegionDiscrete inits correctly by picking a single point from the global
2070
2079
search space."""
2071
- tr = FixedPointTrustRegionDiscrete (discrete_search_space )
2080
+ search_space = request .getfixturevalue (space_fixture )
2081
+ tr = FixedPointTrustRegionDiscrete (search_space )
2072
2082
if with_initialize :
2073
2083
tr .initialize ()
2074
2084
assert tr .location .shape == (2 ,)
2075
- assert tr .location in discrete_search_space
2085
+ assert tr .location in search_space
2076
2086
2077
2087
2088
+ @pytest .mark .parametrize ("space_fixture" , ["discrete_search_space" , "categorical_search_space" ])
2078
2089
def test_fixed_trust_region_discrete_update (
2079
- discrete_search_space : DiscreteSearchSpace ,
2090
+ space_fixture : str ,
2091
+ request : Any ,
2080
2092
) -> None :
2081
2093
"""Update call should not change the location of the region."""
2082
- tr = FixedPointTrustRegionDiscrete (discrete_search_space )
2094
+ search_space = request .getfixturevalue (space_fixture )
2095
+ tr = FixedPointTrustRegionDiscrete (search_space )
2083
2096
tr .initialize ()
2084
2097
orig_location = tr .location .numpy ()
2085
2098
assert not tr .requires_initialization
@@ -2103,13 +2116,16 @@ def test_trust_region_discrete_get_dataset_min_raises_if_dataset_is_faulty(
2103
2116
tr .get_dataset_min (datasets )
2104
2117
2105
2118
2119
+ @pytest .mark .parametrize ("space_fixture" , ["discrete_search_space" , "categorical_search_space" ])
2106
2120
def test_trust_region_discrete_raises_on_location_not_found (
2107
- discrete_search_space : DiscreteSearchSpace ,
2121
+ space_fixture : str ,
2122
+ request : Any ,
2108
2123
) -> None :
2109
2124
"""Check that an error is raised if the location is not found in the global search space."""
2110
- tr = SingleObjectiveTrustRegionDiscrete (discrete_search_space )
2125
+ search_space = request .getfixturevalue (space_fixture )
2126
+ tr = SingleObjectiveTrustRegionDiscrete (search_space )
2111
2127
with pytest .raises (ValueError , match = "location .* not found in the global search space" ):
2112
- tr .location = tf .constant ([0.0 , 0.0 ], dtype = tf .float64 )
2128
+ tr .location = tf .constant ([0.1 , 0.0 ], dtype = tf .float64 )
2113
2129
2114
2130
2115
2131
def test_trust_region_discrete_get_dataset_min (discrete_search_space : DiscreteSearchSpace ) -> None :
@@ -2172,6 +2188,24 @@ def test_trust_region_discrete_initialize(
2172
2188
npt .assert_array_equal (tr ._y_min , tf .constant ([np .inf ], dtype = tf .float64 ))
2173
2189
2174
2190
2191
+ def test_trust_region_categorical_initialize (
2192
+ categorical_search_space : CategoricalSearchSpace ,
2193
+ ) -> None :
2194
+ """Check initialize sets the region to a random location, and sets the eps and y_min values."""
2195
+ datasets = {
2196
+ OBJECTIVE : Dataset ( # Points outside the search space should be ignored.
2197
+ tf .constant ([[0 , 1 , 2 , 0 ], [4 , - 4 , - 5 , 3 ]], dtype = tf .float64 ),
2198
+ tf .constant ([[0.7 ], [0.9 ]], dtype = tf .float64 ),
2199
+ )
2200
+ }
2201
+ tr = SingleObjectiveTrustRegionDiscrete (categorical_search_space , input_active_dims = [1 , 2 ])
2202
+ tr .initialize (datasets = datasets )
2203
+
2204
+ npt .assert_array_equal (tr .eps , 1 )
2205
+ assert tr .location in categorical_search_space
2206
+ npt .assert_array_equal (tr ._y_min , tf .constant ([np .inf ], dtype = tf .float64 ))
2207
+
2208
+
2175
2209
def test_trust_region_discrete_requires_initialization (
2176
2210
discrete_search_space : DiscreteSearchSpace ,
2177
2211
) -> None :
@@ -2223,20 +2257,28 @@ def test_trust_region_discrete_update_no_initialize(
2223
2257
2224
2258
@pytest .mark .parametrize ("dtype" , [tf .float32 , tf .float64 ])
2225
2259
@pytest .mark .parametrize ("success" , [True , False ])
2260
+ @pytest .mark .parametrize ("space_fixture" , ["discrete_search_space" , "categorical_search_space" ])
2226
2261
def test_trust_region_discrete_update_size (
2227
- dtype : tf .DType , success : bool , discrete_search_space : DiscreteSearchSpace
2262
+ dtype : tf .DType , success : bool , space_fixture : str , request : Any
2228
2263
) -> None :
2229
- discrete_search_space = DiscreteSearchSpace ( # Convert to the correct dtype.
2230
- tf .cast (discrete_search_space .points , dtype = dtype )
2231
- )
2264
+ search_space = request .getfixturevalue (space_fixture )
2265
+ categorical = isinstance (search_space , CategoricalSearchSpace )
2266
+
2267
+ # Convert to the correct dtype.
2268
+ if isinstance (search_space , DiscreteSearchSpace ):
2269
+ search_space = DiscreteSearchSpace (tf .cast (search_space .points , dtype = dtype ))
2270
+ else :
2271
+ assert isinstance (search_space , CategoricalSearchSpace )
2272
+ search_space = CategoricalSearchSpace (search_space .tags , dtype = dtype )
2273
+
2232
2274
"""Check that update shrinks/expands region on successful/unsuccessful step."""
2233
2275
datasets = {
2234
2276
OBJECTIVE : Dataset (
2235
2277
tf .constant ([[5 , 4 ], [0 , 1 ], [1 , 1 ]], dtype = dtype ),
2236
2278
tf .constant ([[0.5 ], [0.3 ], [1.0 ]], dtype = dtype ),
2237
2279
)
2238
2280
}
2239
- tr = SingleObjectiveTrustRegionDiscrete (discrete_search_space , min_eps = 0.1 )
2281
+ tr = SingleObjectiveTrustRegionDiscrete (search_space , min_eps = 0.1 )
2240
2282
tr .initialize (datasets = datasets )
2241
2283
2242
2284
# Ensure there is at least one point captured in the region.
@@ -2252,11 +2294,17 @@ def test_trust_region_discrete_update_size(
2252
2294
eps = tr .eps
2253
2295
2254
2296
if success :
2255
- # Sample a point from the region.
2256
- new_point = tr .sample (1 )
2297
+ # Sample a point from the region. For categorical spaces ensure that
2298
+ # it's a different point to tr.location (this must exist)
2299
+ for _ in range (10 ):
2300
+ new_point = tr .sample (1 )
2301
+ if not (categorical and tf .reduce_all (new_point [0 ] == tr .location )):
2302
+ break
2303
+ else :
2304
+ assert False , "TR contains just one point"
2257
2305
else :
2258
2306
# Pick point outside the region.
2259
- new_point = tf .constant ([[1 , 2 ]], dtype = dtype )
2307
+ new_point = tf .constant ([[10 , 1 ]], dtype = dtype )
2260
2308
2261
2309
# Add a new min point to the dataset.
2262
2310
assert not tr .requires_initialization
@@ -2269,28 +2317,33 @@ def test_trust_region_discrete_update_size(
2269
2317
tr .update (datasets = datasets )
2270
2318
2271
2319
assert tr .location .dtype == dtype
2272
- assert tr .eps .dtype == dtype
2320
+ assert tr .eps == 1 if categorical else tr . eps .dtype == dtype
2273
2321
assert tr .points .dtype == dtype
2274
2322
2275
2323
if success :
2276
2324
# Check that the location is the new min point.
2277
2325
new_point = np .squeeze (new_point )
2278
2326
npt .assert_array_equal (new_point , tr .location )
2279
2327
npt .assert_allclose (new_min , tr ._y_min )
2280
- # Check that the region is larger by beta.
2281
- npt .assert_allclose (eps / tr ._beta , tr .eps )
2328
+ # Check that the region is larger by beta (except for categorical)
2329
+ npt .assert_allclose (1 if categorical else eps / tr ._beta , tr .eps )
2282
2330
else :
2283
2331
# Check that the location is the old min point.
2284
2332
orig_point = np .squeeze (orig_point )
2285
2333
npt .assert_array_equal (orig_point , tr .location )
2286
2334
npt .assert_allclose (orig_min , tr ._y_min )
2287
- # Check that the region is smaller by beta.
2288
- npt .assert_allclose (eps * tr ._beta , tr .eps )
2335
+ # Check that the region is smaller by beta (except for categorical)
2336
+ npt .assert_allclose (1 if categorical else eps * tr ._beta , tr .eps )
2289
2337
2290
2338
# Check the new set of neighbors.
2291
- neighbors_mask = tf .abs (discrete_search_space .points - tr .location ) <= tr .eps
2292
- neighbors_mask = tf .reduce_all (neighbors_mask , axis = - 1 )
2293
- neighbors = tf .boolean_mask (discrete_search_space .points , neighbors_mask )
2339
+ if categorical :
2340
+ # Hamming distance
2341
+ neighbors_mask = tf .where (search_space .points != tr .location , 1 , 0 )
2342
+ neighbors_mask = tf .reduce_sum (neighbors_mask , axis = - 1 ) <= tr .eps
2343
+ else :
2344
+ neighbors_mask = tf .abs (search_space .points - tr .location ) <= tr .eps
2345
+ neighbors_mask = tf .reduce_all (neighbors_mask , axis = - 1 )
2346
+ neighbors = tf .boolean_mask (search_space .points , neighbors_mask )
2294
2347
npt .assert_array_equal (tr .points , neighbors )
2295
2348
2296
2349
0 commit comments