diff --git a/tests/unit/test_ask_tell_optimization.py b/tests/unit/test_ask_tell_optimization.py index a68d2e4ff6..e75c777c38 100644 --- a/tests/unit/test_ask_tell_optimization.py +++ b/tests/unit/test_ask_tell_optimization.py @@ -945,6 +945,7 @@ def test_ask_tell_optimizer_dataset_len_variables( dataset = init_dataset assert AskTellOptimizer.dataset_len({"tag": dataset}) == 2 + assert AskTellOptimizer.dataset_len({"tag1": dataset, "tag2": dataset}) == 2 def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets( diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index f3b9977144..e5e4579c0a 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -442,8 +442,8 @@ def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int: for tag, dataset in datasets.items() if not LocalizedTag.from_tag(tag).is_local ] - unique_lens, unique_idxs = tf.unique(dataset_lens) - if len(unique_idxs) == 1: + unique_lens, _ = tf.unique(dataset_lens) + if len(unique_lens) == 1: return int(unique_lens[0]) else: raise ValueError(f"Expected unique global dataset size, got {unique_lens}")