diff --git a/CHANGELOG.md b/CHANGELOG.md index 47876ed51..050e4d5ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ disable tqdm progress bar terminal output. Defaults to disabled (`False`). ### Fixed - +- `MapReduce` in `coreax.solvers.composite.py` now keeps track of the indices. ### Changed diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index d8c8b2974..931e0ee07 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -23,18 +23,20 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from jax import Array from sklearn.neighbors import BallTree, KDTree from typing_extensions import TypeAlias, override from coreax.coreset import Coreset, Coresubset from coreax.data import Data from coreax.solvers.base import ExplicitSizeSolver, PaddingInvariantSolver, Solver -from coreax.util import tree_zero_pad_leading_axis +from coreax.util import ArrayLike, tree_zero_pad_leading_axis BinaryTree: TypeAlias = Union[KDTree, BallTree] _Data = TypeVar("_Data", bound=Data) _Coreset = TypeVar("_Coreset", Coreset, Coresubset) _State = TypeVar("_State") +_Indices = TypeVar("_Indices", ArrayLike, None) class CompositeSolver( @@ -125,22 +127,56 @@ def reduce( # There is no obvious way to use state information here. del solver_state - def _reduce_coreset(data: _Data) -> tuple[_Coreset, _State]: + def _reduce_coreset( + data: _Data, _indices: Optional[_Indices] = None + ) -> tuple[_Coreset, _State, _Indices]: if len(data) <= self.leaf_size: - return self.base_solver.reduce(data) - partitioned_dataset = _jit_tree(data, self.leaf_size, self.tree_type) - coreset_ensemble, _ = jax.vmap(self.base_solver.reduce)(partitioned_dataset) + coreset, state = self.base_solver.reduce(data) + if _indices is not None: + _indices = _indices[coreset.nodes.data] + return coreset, state, _indices + + def wrapper(partition: _Data) -> tuple[_Data, Array]: + """ + Apply the `reduce` method of the base solver on a partition. + + This is a wrapper for `reduce()` for processing a single partition. + The data is partitioned with `_jit_tree()`. + The reduction is performed on each partition via ``vmap()``. + """ + x, _ = self.base_solver.reduce(partition) + return x.coreset, x.nodes.data + + partitioned_dataset, partitioned_indices = _jit_tree( + data, self.leaf_size, self.tree_type + ) + # Reduce each partition and get indices from each + coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset) + # Calculate the indices with respect to the original data + concatenated_indices = jax.vmap(lambda x, index: x[index])( + partitioned_indices, ensemble_indices + ) + concatenated_indices = jnp.ravel(concatenated_indices) _coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble) - return _reduce_coreset(_coreset.coreset) - coreset_wrong_pre_coreset_data, output_solver_state = _reduce_coreset(dataset) - coreset = eqx.tree_at( - lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset - ) + if _indices is not None: + final_indices = _indices[concatenated_indices] + else: + final_indices = concatenated_indices + return _reduce_coreset(_coreset, final_indices) + + (pre_coreset, output_solver_state, _indices) = _reduce_coreset(dataset) + # Correct the pre-coreset data and the indices + coreset = eqx.tree_at(lambda x: x.pre_coreset_data, pre_coreset, dataset) + if _indices is not None: + if isinstance(coreset, Coresubset): + coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) return coreset, output_solver_state -def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data: +def _jit_tree( + dataset: _Data, leaf_size: int, tree_type: type[BinaryTree] +) -> tuple[_Data, _Indices]: """ Return JIT compatible BinaryTree partitioning of 'dataset'. @@ -183,4 +219,4 @@ def _binary_tree(_input_data: Data) -> np.ndarray: return node_indices.reshape(n_leaves, -1).astype(np.int32) indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) - return dataset[indices] + return padded_dataset[indices], jnp.arange(len(dataset))[indices] diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 9d9b9ef25..94c1e1f07 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -1326,6 +1326,164 @@ def test_base_solver( solver_factory.keywords["base_solver"] = base_solver solver_factory() + def test_map_reduce_diverse_selection(self): + """Check if MapReduce returns indices from multiple partitions.""" + dataset_size = 40 + data_dim = 5 + coreset_size = 6 + leaf_size = 12 + + key = jr.PRNGKey(0) + dataset = jr.normal(key, shape=(dataset_size, data_dim)) + + kernel = SquaredExponentialKernel() + base_solver = KernelHerding(coreset_size=coreset_size, kernel=kernel) + + solver = MapReduce(base_solver=base_solver, leaf_size=leaf_size) + coreset, _ = solver.reduce(Data(dataset)) + selected_indices = coreset.nodes.data + + assert jnp.any( + selected_indices >= coreset_size + ), "MapReduce should select points beyond the first few" + + # Check if there are indices from different partitions + partitions_represented = jnp.unique(selected_indices // leaf_size) + assert ( + len(partitions_represented) > 1 + ), "MapReduce should select points from multiple partitions" + + def test_map_reduce_analytic(self): + r""" + Test ``MapReduce`` on an analytical example, enforcing a unique coreset. + + In this example, we start with the original dataset + :math:`[10, 20, 30, 210, 40, 60, 180, 90, 150, 70, 120, + 200, 50, 140, 80, 170, 100, 190, 110, 160, 130]`. + + Suppose we want a subset size of 3, and we want maximum leaf size of 6. + + We can see that we have a dataset of size 21. The partitioning scheme + only allows for :math:`n` partitions where :math:`n` is a power of 2. + Therefore, we can partition into: + + 1. 1 partition of size 21 + 2. 2 partitions of size :math:`\lceil 10.5 \rceil = 11` each (with one padded 0) + 3. 4 partitions of size :math:`\lceil 5.25 \rceil = 6` each (with 3 padded 0's) + 4. 8 partitions of size :math:`\lceil 2.625 \rceil = 3` each (with 3 padded 0's) + + Since we set the maximum leaf size :math:`m = 6`, we choose the largest + partition size that is less than or equal to 6. Thus, we have 4 partitions + each of size 6. + + This results in the following 4 partitions (see how + data is in ascending order): + + 1. :math:`[0, 0, 0, 10, 20, 30]` + 2. :math:`[40, 50, 60, 70, 80, 90]` + 3. :math:`[100, 110, 120, 130, 140, 150]` + 4. :math:`[160, 170, 180, 190, 200, 210]` + + Now we want to reduce each partition with our ``interleaved_base_solver`` + which is designed to choose first, last, second, second-last, third, + third-last elements etc. until the coreset of correct size is formed. + Hence, we obtain: + + 1. :math:`[0, 30, 0]` + 2. :math:`[40, 90, 50]` + 3. :math:`[100, 150, 110]` + 4. :math:`[160, 210, 170]` + + Concatenating we obtain + :math:`[0, 30, 0, 40, 90, 50, 100, 150, 110, 160, 210, 170]`. + We repeat the process, checking how many partitions we want to divide this + intermediate dataset (of size 12) into. Recall, this number of partitions must + be a power of 2. Our options are: + + 1. 1 partition of size 12 + 2. 2 partitions of size 6 + 3. 4 partitions of size 3 + 4. 8 partitions of size 1.5 (rounded up to 2) + + Given our maximum leaf size :math:`m = 6`, we choose the largest partition size + that is less than or equal to 6. Therefore, we select 2 partitions of size 6. + This time no padding is necessary. The two partitions resulting from this step + are (note that it is again in ascending order): + + 1. :math:`[0, 0, 30, 40, 50, 90]` + 2. :math:`[100, 110, 150, 160, 170, 210]` + + Applying our ``interleaved_base_solver`` with `coreset_size` 3 on + each partition, we obtain: + + 1. :math:`[0, 90, 0]` + 2. :math:`[100, 210, 110]` + + Now, we concatenate the two subsets and repeat the process to + obtain only one partition: + + 1. Concatenated subset: :math:`[0, 90, 0, 100, 210, 110]` + + Note that the size of the dataset is 6, + therefore, no more partitioning is necessary. + + Applying ``interleaved_base_solver`` one last time we obtain the final coreset: + :math:`[0, 110, 90]`. + """ + interleaved_base_solver = MagicMock(_ExplicitPaddingInvariantSolver) + interleaved_base_solver.coreset_size = 3 + + def interleaved_mock_reduce( + dataset: Data, solver_state: None = None + ) -> tuple[Coreset[Data], None]: + half_size = interleaved_base_solver.coreset_size // 2 + indices = jnp.arange(interleaved_base_solver.coreset_size) + forward_indices = indices[:half_size] + backward_indices = -(indices[:half_size] + 1) + interleaved_indices = jnp.stack( + [forward_indices, backward_indices], axis=1 + ).ravel() + + if interleaved_base_solver.coreset_size % 2 != 0: + interleaved_indices = jnp.append(interleaved_indices, half_size) + return Coreset(dataset[interleaved_indices], dataset), solver_state + + interleaved_base_solver.reduce = interleaved_mock_reduce + + original_data = Data( + jnp.array( + [ + 10, + 20, + 30, + 210, + 40, + 60, + 180, + 90, + 150, + 70, + 120, + 200, + 50, + 140, + 80, + 170, + 100, + 190, + 110, + 160, + 130, + ] + ) + ) + expected_coreset_data = Data(jnp.array([0, 110, 90])) + + coreset, _ = MapReduce(base_solver=interleaved_base_solver, leaf_size=6).reduce( + original_data + ) + assert eqx.tree_equal(coreset.coreset.data == expected_coreset_data.data) + class TestCaratheodoryRecombination(RecombinationSolverTest): """Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`."""