diff --git a/.cspell/custom_misc.txt b/.cspell/custom_misc.txt index f85e84c8..0a8eef55 100644 --- a/.cspell/custom_misc.txt +++ b/.cspell/custom_misc.txt @@ -3,6 +3,7 @@ ANNchor approximator approximators archiveprefix +binom cand coreax coreset @@ -25,19 +26,24 @@ kernelised kernelized KSD linewidth +mapsto Matern ml.p3.8xlarge ndmin parsable PCIMQ +phdthesis PMLR primaryclass PRNG +rcond +recomb refs regulariser RKHS RPCHOLESKY sigmas +subseteq supp TLDR typecheck diff --git a/.cspell/library_terms.txt b/.cspell/library_terms.txt index 0576c395..d5815575 100644 --- a/.cspell/library_terms.txt +++ b/.cspell/library_terms.txt @@ -22,6 +22,7 @@ brightgreen bysource cmap color +Concretization dispname dollarmath dtype @@ -30,8 +31,10 @@ dunder eigh emibcn eprint +errstate figsize finfo +flatnonzero fori_loop fracv furo @@ -116,6 +119,7 @@ sklearn softplus sphinxcontrib sphobjinv +tensordot texttt toctree tomli @@ -130,6 +134,7 @@ unsrt vars viewcode vmap +vmaps writebytes xlabel ylabel diff --git a/.cspell/people.txt b/.cspell/people.txt index 604627c9..01e7bbfb 100644 --- a/.cspell/people.txt +++ b/.cspell/people.txt @@ -1,4 +1,5 @@ Benard +Caratheodory Chatalic Duvenaud Epperly @@ -7,15 +8,22 @@ Ferenc Frobenius Garg Garreau +Goaoc Halko +Helly Huszar Jaehoon Jiaxin Jitkrittum Kanagawa +Litterer +Loera +Lyons Martinsson Matérn +Meunier Motonobu +Nabil Nystr Nystrom Qiang @@ -26,8 +34,13 @@ Sahaj Schreuder Smirnov Smola +Sperner Staber +Tchakaloff +Tchernychova +Teichmann Tropp +Tverberg Veiga Wittawat Yifan diff --git a/CHANGELOG.md b/CHANGELOG.md index c59bf83d..d03a183b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added Kernelised Stein Discrepancy divergence in `coreax.metrics.KSD`. +- Added the `coreax.solvers.recombination` module, which provides the following new solvers: + - `RecombinationSolver`: an abstract base class for recombination solvers. + - `CaratheodoryRecombination`: a simple deterministic approach to solving recombination problems. + - `TreeRecombination`: an advanced deterministic approach that utilises `CaratheodoryRecombination`, + but is faster for solving all but the smallest recombination problems. - Added supervised coreset construction algorithm in `coreax.solvers.GreedyKernelPoints` - Added `coreax.kernels.PowerKernel` to replace repeated calls of `coreax.kernels.ProductKernel` within the `**` magic method of `coreax.kernel.ScalarValuedKernel` diff --git a/coreax/solvers/__init__.py b/coreax/solvers/__init__.py index f2c14af7..0559b3f6 100644 --- a/coreax/solvers/__init__.py +++ b/coreax/solvers/__init__.py @@ -32,6 +32,11 @@ RPCholeskyState, SteinThinning, ) +from coreax.solvers.recombination import ( + CaratheodoryRecombination, + RecombinationSolver, + TreeRecombination, +) __all__ = [ "CompositeSolver", @@ -49,4 +54,7 @@ "PaddingInvariantSolver", "GreedyKernelPoints", "GreedyKernelPointsState", + "RecombinationSolver", + "CaratheodoryRecombination", + "TreeRecombination", ] diff --git a/coreax/solvers/base.py b/coreax/solvers/base.py index 360837bb..fd2f3c4c 100644 --- a/coreax/solvers/base.py +++ b/coreax/solvers/base.py @@ -114,7 +114,7 @@ class PaddingInvariantSolver(Solver): A :class:`Solver` whose results are invariant to zero weighted data. In some cases, such as in :class:`coreax.solvers.MapReduce`, there is a need to pad - data to ensure shape stability. In some cases, we may assign zero weight to the + data to ensure shape stability. In these cases, we may assign zero weight to the padded data points, which allows certain 'padding invariant' solvers to return the same values on a call to :meth:`~coreax.solvers.Solver.reduce` as would have been returned if no padding were present. diff --git a/coreax/solvers/recombination.py b/coreax/solvers/recombination.py new file mode 100644 index 00000000..e5a86997 --- /dev/null +++ b/coreax/solvers/recombination.py @@ -0,0 +1,609 @@ +# © Crown Copyright GCHQ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Recombination solvers. + +Take a dataset :math:`\{(x_i, w_i)\}_{i=1}^n`, where each node :math:`x_i \in \Omega` +is paired with a weight :math:`w_i \in \mathbb{R} \ge 0`, and the sum of all weights +is one, :math:`\sum_{i=1}^n w_i = 1` (a strict requirement for a probability measure). + +.. note:: + Given any weighted dataset, we can use normalisation to satisfy the sum to + one condition, providing :math:`\sum_{i=1}^n w_i \neq 0`. + +Combined with :math:`m-1` test-functions :math:`\Phi^\prime = \{\phi_i\}_{i=1}^{m-1}`, +where :math:`\phi_i \colon \Omega \to \mathbb{R}`, that parametrise a set of :math:`m` +test-functions :math:`\Phi = \{x \mapsto 1\} \cup \Phi^\prime`, there exists a dataset +push-forward measure :math:`\mu_n := \Phi_* \nu_n`. + +A recombination solver attempts to find a reduced measure (a coresubset) +:math:`\hat{\mu}_{m^\prime}`, which is given as a basic-feasible solution (BFS) to the +following linear-programming problem (with trivial objective) + +.. math:: + \begin{align} + \mathbf{Y} \mathbf{\hat{w}} &= \text{CoM}(\mu_n),\\ + \mathbf{\hat{w}} &\ge 0, + \end{align} + +where the system variables and "centre-of-mass" are defined as + +.. math:: + \begin{gather} + \mathbf{Y} := \left[\Phi(x_1), \dots, \Phi(x_n)\right] \in \mathbb{R}^{m \times n},\ + \mathbf{\hat{w}} \in \mathbb{R}^n \ge 0,\\ + \text{CoM}(\mu_n) := \sum_{i=1}^n w_i \Phi(x_i) + = \left[ \sum_{i=1}^n w_i \phi_j(x_i) \right]_{j=1}^m \in \mathbb{R}^m.\\ + \end{gather} + +.. note:: + The source dataset is, by definition, a solution to the linear-program that is not + necessarily a BFS. Hence, one may consider the fundamental problem of recombination + as that of finding a BFS given a solution that is not a BFS. + +Basic feasible solutions to the linear-program above are of the form +:math:`\mathbf{\hat{w}} = \{\hat{w}_1, \dots, \hat{w}_{m^\prime}, \mathbf{0}\}`; I.E. +BFSs are feasible solutions with :math:`n-m^\prime` weights equal to zero. Given a BFS, +the reduced measure (the coresubset) can be constructed by explicitly removing the nodes +associated with each zero valued (implicitly removed) weight + +.. math:: + \begin{gather} + \hat{\nu}_{m^\prime} = \sum_{i \in I} \hat{w_i} \delta_{x_i},\\ + I = \{i \mid \hat{w_i} \neq 0\, \forall i \in \{1, \dots, n\}\}. + \end{gather} + +Due to Tchakaloff's theorem :cite:`tchakaloff1957,bayer2006tchakaloff`, which follows +from Caratheodory's convex hull theorem :cite:`caratheodory1907,loera2018caratheodory`, +we know there always exists a basic-feasible solution to the linear-program, with at +most :math:`m^\prime = \text{dim}(\text{span}(\Phi))` non-zero weights. Hence, we have +an upper bound on the size of a coresubset, controlled by the choice of test-functions. + +.. note:: + A basic feasible solution (coresubset produced by recombination) is non-unique. In + fact, there exists :math:`\binom{n}{m^\prime}` basic feasible solutions + (coresubsets) for the described linear-program. In the context of Coreax, this means + that a :class:`RecombinationSolver` is unlikely to ever be truly invariant to the + presence of padding (see :class:`~coreax.solvers.PaddingInvariantSolver`). I.E. the + padded problem may have an equivalent, but different BFS than the unpadded problem. + +Canonically, recombination is used for reducing the support of a quadrature/cubature +measure, against which integration of any function :math:`f \in \text{span}(\Phi)` +is identical to integration against a "target" (potentially continuous) measure +:math:`\mu`. +""" + +import math +from collections.abc import Callable +from typing import Generic, Literal, NamedTuple, Optional, TypeVar, Union + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +import jax.tree_util as jtu +from jaxtyping import Array, Bool, DTypeLike, Float, Integer, Real, Shaped +from typing_extensions import override + +from coreax import Coresubset, Data +from coreax.solvers.base import CoresubsetSolver + +_Data = TypeVar("_Data", bound=Data) +_State = TypeVar("_State") + + +class RecombinationSolver(CoresubsetSolver[_Data, _State], Generic[_Data, _State]): + r""" + Solver which returns a :class:`~coreax.coreset.Coresubset` via recombination. + + Given :math:`m-1` explicitly provided test-functions :math:`\Phi^\prime`, a + recombination solver finds a coresubset with :math:`m^\prime \le m` points, whose + push-forward :math:`\hat{\mu}_{m^\prime}` has the same "centre-of-mass" (CoM) as + the dataset push-forward :math:`\mu_n := \Phi_* \nu_n`. + + :param test_functions: A callable that applies a set of specified test-functions + :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map + :math:`\phi_i \colon \Omega\to\mathbb{R}`; a value of :data:`None` implies the + identity map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes + that :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` + :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding + a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly + removed) points; 'implicit' explicitly removes no points, yielding a coreset of + size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) + points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a + coreset of size :math:`m^\prime`, but unlike the other methods is not JIT + compatible as the coreset size :math:`m^\prime` is unknown at compile time. + """ + + test_functions: Optional[Callable[[Array], Real[Array, " m-1"]]] = None + mode: Literal["implicit-explicit", "implicit", "explicit"] = "implicit-explicit" + + def __check_init__(self): + """Ensure a valid `self.mode` is specified.""" + if self.mode not in {"implicit-explicit", "implicit", "explicit"}: + raise ValueError( + "Invalid mode, expected 'implicit-explicit', 'implicit' or 'explicit'." + ) + + +class _EliminationState(NamedTuple): + weights: Shaped[Array, " n"] + nodes: Shaped[Array, "n m"] + iteration: int + + +class CaratheodoryRecombination(RecombinationSolver[Data, None]): + r""" + Recombination via Caratheodory measure reduction (Gaussian-Elimination). + + Proposed in :cite:`tchernychova2016recombination` (see Chapter 1.3.3.3) as an + alternative to the Simplex algorithm for solving the recombination problem. + + Unlike the Simplex method, with time complexity :math:`\mathcal{O}(m^3 n + m n^2)`, + Caratheodory recombination has time complexity of only :math:`\mathcal{O}(m n^2)`. + + .. note:: + Given :math:`n = cm`, for a rational constant :math:`c`, the above complexities + can be alternatively represented as :math:`\mathcal{O}(m^4)` for the Simplex + method and :math:`\mathcal{O}(m^3)` for Caratheodory recombination. + + :param test_functions: A callable that applies a set of specified test-functions + :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map + :math:`\phi_i \colon \Omega \to \mathbb{R}`; a value of non implies the identity + map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that + :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` + :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding + a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly + removed) points; 'implicit' explicitly removes no points, yielding a coreset of + size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) + points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a + coreset of size :math:`m^\prime`, but unlike the other methods is not JIT + compatible as the coreset size :math:`m^\prime` is unknown at compile time. + :param rcond: A relative condition number; any singular value :math:`s` below the + threshold :math:`\text{rcond} * \text{max}(s)` is treated as equal to zero; if + rcond is :data:`None`, it defaults to `floating point eps * max(n, d)` + """ + + rcond: Optional[float] = None + + @override + def reduce( + self, dataset: Data, solver_state: None = None + ) -> tuple[Coresubset, None]: + nodes, weights = jtu.tree_leaves(dataset.normalize(preserve_zeros=True)) + push_forward_nodes = _push_forward(nodes, self.test_functions) + # Handle pre-existing zero-weighted nodes (not handled by the base algorithm + # described in :cite:`tchernychova2016recombination`) + safe_push_forward_nodes, safe_weights, indices = _co_linearize( + push_forward_nodes, weights + ) + largest_null_space_basis, null_space_rank = _resolve_null_basis( + safe_push_forward_nodes, self.rcond + ) + + def _eliminate_cond(state: _EliminationState) -> Bool[Array, ""]: + """ + If to continue the iterative Gaussian-Elimination procedure. + + On each iteration, we eliminate a basis vector from the left null space. We + repeat until all basis vectors have been eliminated (the dimension of the + null space is zero); once the number of iterations is the same as the rank + of the original null space. + + .. note:: + The reason for using a while loop, rather than scanning over the basis + vectors, is due to the dimension of the null space being unknown at JIT + compile time, preventing us from slicing the left singular vectors down + to only those which form a basis for the left null space. + + :param state: Elimination state information + :return: Boolean indicating if to continue/exit the elimination loop. + """ + *_, basis_index = state + return basis_index < null_space_rank + + def _eliminate(state: _EliminationState) -> _EliminationState: + """ + Eliminate a basis from the left null space. + + At least one weight is zeroed (implicitly removed from the dataset), and one + left null space basis vector eliminated on each iteration. The mass that is + "lost" in weight zeroing/elimination is redistributed among the remaining + non-zero weights to preserve the total mass/weight sum. + + If the procedure is repeated until all the left null space basis vectors + are eliminated, the resulting weights (when combined with the original + nodes) are a BFS to the recombination problem/linear-program. + + :param state: Elimination state information + :return: Updated `state` information resulting from the elimination step. + """ + # Algorithm 6 - Chapter 3.3 of :cite:`tchernychova2016recombination` + # Our Notation -> Their Notation + # - `basis_index` (loop iteration) -> i + # - `elimination_index` -> k^{(i)} + # - `elimination_rescaling_factor` -> \alpha_{(i)} + # - `updated_weights` -> \underline\Beta^{(i)} + # - `null_space_basis_update` -> d_{l+1}^{(i)}\phi_1^{(i-1)} + # - `updated_null_space_basis` -> \Psi^{(i)) + _weights, null_space_basis, basis_index = state + basis_vector = null_space_basis[basis_index] + # Equation 3: Select the weight to eliminate. + elimination_condition = jnp.where( + basis_vector > 0, _weights / basis_vector, jnp.inf + ) + elimination_index = jnp.argmin(elimination_condition) + elimination_rescaling_factor = elimination_condition[elimination_index] + # Equation 4: Eliminate the selected weight and redistribute its mass. + # NOTE: Equation 5 is implicit from Equation 4 and is performed outside + # of `_eliminate` via `_coresubset_nodes`. + updated_weights = _weights - elimination_rescaling_factor * basis_vector + updated_weights = updated_weights.at[elimination_index].set(0) + # Equations 6, 7 and 8: Update the Null space basis. + null_space_basis_update = jnp.tensordot( + null_space_basis[:, elimination_index], + basis_vector / basis_vector[elimination_index], + axes=0, + ) + updated_null_space_basis = null_space_basis - null_space_basis_update + updated_null_space_basis = updated_null_space_basis.at[basis_index].set(0) + return _EliminationState( + updated_weights, + updated_null_space_basis, + basis_index + 1, + ) + + in_state = _EliminationState(safe_weights, largest_null_space_basis, 0) + out_weights, *_ = jax.lax.while_loop(_eliminate_cond, _eliminate, in_state) + coresubset_nodes = _coresubset_nodes( + safe_push_forward_nodes, + out_weights, + indices, + self.mode, + is_affine_augmented=True, + ) + return Coresubset(coresubset_nodes, dataset), solver_state + + +def _push_forward( + nodes: Shaped[Array, " n"], + test_functions: Optional[Callable[[Array], Real[Array, " m-1"]]], + augment: bool = True, +) -> Shaped[Array, "n m"]: + r""" + Push the 'nodes' forward through the 'test_functions'. + + :param nodes: The nodes to push-forward through the test-functions + :param test_functions: A callable that applies a set of specified test-functions + :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map + :math:`\phi_i \colon \Omega \to \mathbb{R}`; a value of non implies the identity + map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that + :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` + :param augment: If to prepend the affine-augmentation test function + :math:`\{x \mapsto 1\}` to the explicitly pushed forward nodes \Phi^\prime(x), + to yield \Phi(x); default behaviour prepends the affine-augmentation function + :return: The pushed-forward nodes. + """ + if test_functions is None: + push_forward_nodes = nodes + else: + push_forward_nodes = jax.vmap(test_functions, in_axes=0)(nodes) + if augment: + shape, dtype = push_forward_nodes.shape[0], push_forward_nodes.dtype + affine_augmentation = jnp.ones((shape,), dtype) + push_forward_nodes = jnp.c_[affine_augmentation, push_forward_nodes] + return push_forward_nodes + + +def _co_linearize( + nodes: Shaped[Array, "n m"], weights: Shaped[Array, " n"] +) -> tuple[Shaped[Array, "n m"], Shaped[Array, " n"], Shaped[Array, " n"]]: + """ + Make zero-weighted nodes co-linear with the maximum weighted node. + + Due to the static shape requirements imposed by JAX, we implicitly remove nodes by + setting their corresponding weight to zero. This is sufficient in the recombination + algorithm for all but one scenario, the computation of the null space basis. Because + the zero-weighted nodes still exist in the node matrix, they influence the SVD and + yield an erroneous null space basis. + + We ameliorate this problem by setting the zero-weighted nodes equal (co-linear) to + the largest weighted node (an arbitrary but consistent choice). Because the nodes + are now co-linear to each other and the largest weighted node, we know that at least + all but one of them can be safely eliminated by the recombination procedure. Thus, + the nodes become effectively "invisible" to the elimination procedure. + + The only caveat is that we don't know which of the equal nodes will be retained post + elimination. To handle this, we keep an index (reference) from the zero-weighted + nodes to the largest weighted node, and we redistribute the largest weight equally + over all the "co-linearized" nodes (preserving the CoM and allowing any node to be + eliminated). + + :param nodes: The nodes to co-linearize + :param weights: The weights to apply the co-linearization correction to + :return: The co-linearized nodes, corrected weights, and co-linearized-to-original + reference indices. + """ + max_index = jnp.argmax(weights) + non_zero_weights_mask = weights > 0 + zero_weights_mask = 1 - non_zero_weights_mask + n_zeros = zero_weights_mask.sum() + # Create a new set of indices that replace the zero-weighted node indices with the + # maximum weighted node's index. + indices = jnp.arange(weights.shape[0]) + indices *= non_zero_weights_mask + indices += zero_weights_mask * max_index + # Renormalize the maximum weight; ensures the weight sum is preserved under the new + # (co-linearized) indices; prevents co-linearization from changing the weight sum. + weights = weights.at[max_index].divide(n_zeros + 1) + return nodes[indices], weights[indices], indices + + +# pylint: disable=line-too-long +# Credit: https://github.com/patrick-kidger/lineax/blob/9b923c8df6556551fedc7adeea7979b5c7b3ffb0/lineax/_solver/svd.py#L67 # noqa: E501 +# for the rank determination code. +# pylint: enable=line-too-long +def _resolve_null_basis( + nodes: Shaped[Array, "n m"], + rcond: Union[float, None] = None, +) -> tuple[Shaped[Array, "n n"], Integer[Array, ""]]: + r""" + Resolve the largest left null space basis, and its rank, for passed the node matrix. + + By largest left null space basis, we mean the null space basis under the assumption + that the rank of the null space is maximal (assumed to be ``n``). If the rank is not + maximal, then only the first :math:`n - m^\prime` basis vectors will be actual basis + vectors for the null space (where :math:`m^\prime` is the rank of the node matrix). + The remaining "basis" vectors can, and should, be ignored in upstream computations + by using the left null space rank value as a cut-off index. + + :param nodes: Matrix of nodes (m-vectors) whose null space is to be determined + :param rcond: The relative condition number of the Matrix of nodes + :return: The largest left null space basis and its rank, for the passed node matrix. + """ + q, s, _ = jsp.linalg.svd(nodes, full_matrices=True) + _rcond = _resolve_rcond(nodes.shape, s.dtype, rcond) + if s.size > 0: + _rcond *= jnp.max(s[0]) + mask = s > _rcond + matrix_rank = sum(mask) + null_space_rank = jnp.maximum(0, nodes.shape[0] - matrix_rank) + largest_null_space_basis = q.T[::-1] + return largest_null_space_basis, null_space_rank + + +# pylint: disable=line-too-long +# Credit: https://github.com/patrick-kidger/lineax/blob/9b923c8df6556551fedc7adeea7979b5c7b3ffb0/lineax/_misc.py#L34 # noqa: E501 +# pylint: enable=line-too-long +def _resolve_rcond( + shape: tuple[int, ...], dtype: DTypeLike, rcond: Optional[float] = None +) -> Float[Array, ""]: + """ + Resolve the relative condition number (rcond). + + :param shape: The shape of the matrix whose relative condition number to resolved + :param dtype: The element dtype of the matrix whose rcond is to be resolved + :param rcond: The relative condition number of a given matrix; if ``None``, + ``rcond = dtype_floating_point_eps * max(shape)``; else if negative, + ``rcond = dtype_floating_point_eps`` + :return: The resolved relative condition number (rcond) + """ + epsilon = jnp.asarray(jnp.finfo(dtype).eps, dtype) + if rcond is None: + return epsilon * max(shape) + return jnp.where(rcond < jnp.asarray(0), epsilon, rcond) + + +def _coresubset_nodes( + push_forward_nodes: Shaped[Array, "n m"], + weights: Shaped[Array, " n"], + indices: Shaped[Array, " n"], + mode: Literal["implicit-explicit", "implicit", "explicit"], + is_affine_augmented: bool = False, +) -> Data: + r""" + Determine the coresubset nodes based on the 'mode'. + + :param push_forward_nodes: The dataset push forward nodes + :param weights: The coresubset weights + :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding + a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly + removed) points; 'implicit' explicitly removes no points, yielding a coreset of + size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) + points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a + coreset of size :math:`m^\prime`, but unlike the other methods is not JIT + compatible as the coreset size :math:`m^\prime` is unknown at compile time. + :param is_affine_augmented: If the 'push_forward_nodes' include the :math:`\phi_1` + affine-augmentation map. + :return: The coresubset nodes as defined by the 'mode. + """ + n, m = push_forward_nodes.shape + m = m if is_affine_augmented else m + 1 + if mode == "implicit-explicit": + # Inside the JIT context we cannot explicitly remove all the non-zero + # weights, because we don't know how many there will be a priori (`m^\prime` + # is unknown until after the singular value decomposition is performed). + # However, we do have an upper bound on the number of non-zero points + # `min(n, m) \ge m^\prime`. Thus, we need only return the `min(n, m)` non-zero + # weights where `min(n, m) - m^\prime` of these may be zero-weighted (implicitly + # removed). The fill value is set to `argmin(weights)` to ensure we always index + # a zero-weighted data point whenever the weight is zero. + idx = jnp.flatnonzero(weights, size=min(n, m), fill_value=jnp.argmin(weights)) + elif mode == "implicit": + idx = jnp.flatnonzero(weights, size=n, fill_value=jnp.argmin(weights)) + elif mode == "explicit": + # Explicit mode is JIT incompatible + try: + idx = jnp.flatnonzero(weights) + except jax.errors.ConcretizationTypeError as err: + raise ValueError( + "'explicit' mode is incompatible with transformations such as 'jax.jit'" + ) from err + else: + # Should only get here if the `__check_init__`` has been skipped/avoided, or if + # this function is called from an unexpected place. + raise ValueError( + "Invalid mode, expected 'implicit-explicit', 'implicit' or 'explicit'." + ) + return Data(indices[idx], weights[idx]) + + +class TreeRecombination(RecombinationSolver[Data, None]): + r""" + Tree recombination based coresubset solver. + + Based on Algorithm 7 Chapter 3.3 of :cite:`tchernychova2016recombination`, which + is an order of magnitude more efficient than Algorithm 5 in Chapter 3.2, originally + introduced in :cite:`litterer2012recombination`. + + The time complexity is of order :math:`\mathcal{O}(\log_2(\frac{n}{c_r m}) m^3)`, + where :math`c_r` is the `tree_reduction_factor`. The time complexity can be + equivalently expressed as :math:`\mathcal{O}(m^3)`, using the same arguments as used + in :class:`CaratheodoryRecombination`. + + ..note:: + As the ratio of :math:`n / m` grows, the constant factor for the time complexity + of :class:`TreeRecombination` increases at a logarithmic rate, rather than at a + quadratic rate for plain :class:`CaratheodoryRecombination`. Hence, in general, + we would expect :class:`TreeRecombination` to be the more efficient choice for + all but the smallest values of :math:`n / m`. + + :param test_functions: the map :math:`\Phi^\prime = \{ \phi_1, \dots, \phi_{M-1} \}` + where each :math:`\phi_i \colon \Omega \to \mathbb{R}` represents a linearly + independent test-function; a value of :data:`None` implies the identity function + (necessarily assuming :math:`\Omega \subseteq \mathbb{R}^{M-1}`) + :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding + a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly + removed) points; 'implicit' explicitly removes no points, yielding a coreset of + size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) + points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a + coreset of size :math:`m^\prime`, but unlike the other methods is not JIT + compatible as the coreset size :math:`m^\prime` is unknown at compile time. + :param rcond: a relative condition number; any singular value :math:`s` below the + threshold :math:`\text{rcond} * \text{max}(s)` is treated as equal to zero; if + :code:`rcond is None`, it defaults to `floating point eps * max(n, d)` + :param tree_reduction_factor: The factor by which each tree reduction step reduces + the number of non-zero points; the remaining number of non-zero nodes, after + performing recombination, is equal to `n_nodes / tree_reduction_factor`; + """ + + rcond: Union[float, None] = None + tree_reduction_factor: int = 2 + + @override + def reduce( + self, dataset: Data, solver_state: None = None + ) -> tuple[Coresubset, None]: + nodes, weights = jtu.tree_leaves(dataset.normalize(preserve_zeros=True)) + # Push the nodes forward through the test-functions \Phi^\prime. + push_forward_nodes = _push_forward(nodes, self.test_functions, augment=False) + n, m = push_forward_nodes.shape + # We don't apply the affine-augmentation test-function \phi_1 here, instead + # deferring it to `CaratheodoryRecombination.reduce`. Thus, we have to manually + # correct the value for `m`. + padding, count, depth = _prepare_tree(n, m + 1, self.tree_reduction_factor) + car_recomb_solver = CaratheodoryRecombination(rcond=self.rcond, mode="implicit") + + def _tree_reduce(_, state: tuple[Array, Array]) -> tuple[Array, Array]: + """ + Apply Tree-Based Caratheodory Recombination (Gaussian-Elimination). + + Partitions the dataset into 'count' clusters of size 'n / count' and then + computes the cluster centroids. Caratheodory recombination is then performed + on these centroids (rather than on the full dataset), with every node in the + eliminated centroids' cluster being implicitly removed (given zero-weight). + + There are 'tree_reduction_factor * m' clusters, with each step reducing the + number of remaining clusters down to 'm'. We can repeat the process until + each cluster contains, at most, a single non-zero weighted point (at this + point the recombination problem has been solved). + + :param _: Not used + :param state: Tuple of node weights and indices; indices are passed to keep + a correspondence between the original data indices and + :return: Updated tuple of node weights and indices; weights are zeroed + (implicitly removed) where appropriate; indices are shuffled to ensure + balanced centroids in subsequent iterations (centroids are balanced when + they are all constructed from subsets with as near to an equal number + of non-zero weighted nodes as possible). + """ + _weights, _indices = state + # Index weights to a centroid; argsort ensures that centroids are balanced. + centroid_indices = jnp.argsort(_weights).reshape(count, -1, order="F") + centroid_nodes, centroid_weights = _centroid( + push_forward_nodes[_indices[centroid_indices]], + _weights[centroid_indices], + ) + centroid_dataset = Data(centroid_nodes, centroid_weights) + # Solve the measure reduction problem on the centroid dataset. + centroid_coresubset, _ = car_recomb_solver.reduce(centroid_dataset) + coresubset_indices = centroid_coresubset.unweighted_indices + coresubset_weights = centroid_coresubset.coreset.weights + # Propagate centroid coresubset weights to the underlying weights for each + # centroid, as defined by `centroid_indices`. + weight_update_indices = centroid_indices[coresubset_indices] + weight_update = coresubset_weights / centroid_weights[coresubset_indices] + updated_weights = _weights[weight_update_indices] * weight_update[..., None] + # Maintain a correspondence between the original data indices and the sorted + # indices, used to construct the balanced centroids. + updated_indices = _indices[weight_update_indices.reshape(-1, order="F")] + return updated_weights.reshape(-1, order="F"), updated_indices + + in_state = (jnp.pad(weights, (0, padding)), jnp.arange(n + padding)) + out_weights, indices = jax.lax.fori_loop(0, depth, _tree_reduce, in_state) + coresubset_nodes = _coresubset_nodes( + push_forward_nodes, out_weights, indices, self.mode + ) + return Coresubset(coresubset_nodes, dataset), solver_state + + +def _prepare_tree( + n: int, m: int, tree_reduction_factor: int = 2 +) -> tuple[int, int, int]: + r""" + Compute and apply dataset padding and compute tree count and depth. + + :param n: Number of nodes + :param m: Number of test-functions + :param tree_reduction_factor: The factor by which each tree reduction step reduces + the number of non-zero points; the remaining number of non-zero nodes, after + performing recombination, is equal to `n_nodes / tree_reduction_factor` + :return: The required amount of padding, to allow reshaping of the nodes into equal + sized clusters), the tree_count (number of clusters), and the maximum tree depth + (number of tree_reduction iterations required to complete tree recombination) + """ + tree_count = tree_reduction_factor * m + max_tree_depth = math.ceil(math.log(n / m, tree_reduction_factor)) + padding = m * tree_reduction_factor**max_tree_depth - n + return padding, tree_count, max_tree_depth + + +@jax.vmap +def _centroid( + nodes: Shaped[Array, "tree_count n/tree_count m"], + weights: Shaped[Array, "tree_count n/tree_count"], +) -> tuple[Shaped[Array, "n/tree_count m"], Shaped[Array, " n/tree_count"]]: + """ + Compute the centroid mass and node centre (centre-of-mass). + + :param nodes: A set of clustered nodes where the leading axis indexes each cluster, + which this function vmaps over, and the middle axis indexes each node within a + given cluster. + :param weights: A set of clustered weights associated with each node; has the same + index layout as the nodes. + :return: Cluster centroid (centre-of-mass) and total cluster mass for all clusters + """ + centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)) + centroid_weights = jnp.sum(weights) + return centroid_nodes, centroid_weights diff --git a/documentation/source/conf.py b/documentation/source/conf.py index 15e1985b..c37eea67 100644 --- a/documentation/source/conf.py +++ b/documentation/source/conf.py @@ -168,6 +168,15 @@ ("py:class", "Array"), ("py:class", "typing.Self"), ("py:class", "jaxtyping.Shaped"), + ("py:class", "jaxtyping.Shaped[Array, 'n *d']"), + ("py:class", "jaxtyping.Shaped[ndarray, 'n *d']"), + ("py:class", "jaxtyping.Shaped[Array, 'n d']"), + ("py:class", "jaxtyping.Real[Array, 'm-1']"), + ("py:class", "jaxtyping.Shaped[ndarray, 'n d']"), + ("py:class", "jaxtyping.Shaped[Array, 'n *p']"), + ("py:class", "jaxtyping.Shaped[Array, 'n p']"), + ("py:class", "jaxtyping.Shaped[Array, 'n']"), + ("py:class", "jaxtyping.Shaped[ndarray, 'n']"), ("py:class", "jax._src.typing.SupportsDType"), ("py:class", "'n d'"), ("py:class", "'n p'"), @@ -188,6 +197,8 @@ ("py:obj", "coreax.solvers.coresubset._Data"), ("py:obj", "coreax.solvers.coresubset._State"), ("py:obj", "coreax.solvers.coresubset._Coreset"), + ("py:obj", "coreax.solvers.recombination._Data"), + ("py:obj", "coreax.solvers.recombination._State"), ("py:obj", "coreax.weights._Data"), ("py:obj", "coreax.metrics._Data"), ("py:obj", "coreax.solvers.coresubset._SupervisedData"), diff --git a/documentation/source/coreax/solvers.rst b/documentation/source/coreax/solvers.rst index 15ef1d12..651aecad 100644 --- a/documentation/source/coreax/solvers.rst +++ b/documentation/source/coreax/solvers.rst @@ -2,3 +2,5 @@ Solvers ======== .. automodule:: coreax.solvers + :no-private-members: + :no-undoc-members: diff --git a/documentation/source/references.bib b/documentation/source/references.bib index da7078c4..e08e1ed6 100644 --- a/documentation/source/references.bib +++ b/documentation/source/references.bib @@ -52,19 +52,6 @@ @misc{liu2016kernelized primaryclass = {stat.ML} } -@article{litterer2012recombination, - title={High order recombination and an application to cubature on Wiener space}, - volume={22}, - ISSN={1050-5164}, - url={http://dx.doi.org/10.1214/11-AAP786}, - DOI={10.1214/11-aap786}, - number={4}, - journal={The Annals of Applied Probability}, - publisher={Institute of Mathematical Statistics}, - author={Litterer, C. and Lyons, T.}, - year={2012}, - month=aug -} @misc{chen2023randomly, title = {{Randomly pivoted Cholesky: Practical approximation of a kernel matrix with few entry evaluations}}, @@ -118,3 +105,71 @@ @misc{nguyen2021meta archivePrefix={arXiv}, primaryClass={cs.LG} } + +@article{litterer2012recombination, + title={High order recombination and an application to cubature on Wiener space}, + volume={22}, + ISSN={1050-5164}, + url={http://dx.doi.org/10.1214/11-AAP786}, + DOI={10.1214/11-aap786}, + number={4}, + journal={The Annals of Applied Probability}, + publisher={Institute of Mathematical Statistics}, + author={Litterer, C. and Lyons, T.}, + year={2012}, + month=aug +} + +@phdthesis{tchernychova2016recombination, + publisher = {University of Oxford}, + school = {University of Oxford}, + title = {Caratheodory cubature measures}, + author = {Tchernychova, M}, + year = {2016}, + url={https://ora.ox.ac.uk/objects/uuid:a3a10980-d35d-467b-b3c0-d10d2e491f2d} +} + +// cSpell:disable - French, +@article{tchakaloff1957, + title = {Formules de cubatures mécaniques à coefficients non négatifs}, + author = {Tchakaloff, V}, + year = {1957}, + journal = {Bulletin des Sciences Mathématiques}, + number = {2}, + volume = {81}, + pages = {123--134} +} +// cSpell:enable + +@article{bayer2006tchakaloff, + title = {The proof of Tchakaloff's Theorem}, + author = {Bayer, C. and Teichmann, J.}, + year = {2006}, + journal = {Proceedings of the American Mathematical Society}, + volume = {134}, + pages = {3035--3040}, + url = {https://doi.org/10.1090/S0002-9939-06-08249-9} +} + +// cSpell:disable - German, +@article{caratheodory1907, + title = {Über den Variabilitätsbereich der Koeffizienten von Potenzreihen, die gegebene Werte nicht annehmen}, + author = {Carathéodory, C.}, + year = {1907}, + journal = {Mathematische Annalen}, + volume = {64}, + issue = {1}, + pages = {95--115}, + url = {https://doi.org/10.1007/BF01449883} +} +// cSpell:enable + +@misc{loera2018caratheodory, + title = {The discrete yet ubiquitous theorems of Carathéodory, Helly, Sperner, Tucker, and Tverberg}, + author = {Jesus A. De Loera and Xavier Goaoc and Frédéric Meunier and Nabil Mustafa}, + year = {2018}, + eprint = {1706.05975}, + archivePrefix = {arXiv}, + primaryClass = {math.CO}, + url = {https://arxiv.org/abs/1706.05975}, +} diff --git a/pyproject.toml b/pyproject.toml index a469343f..d9fe5017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "flax", "jax", "jaxopt", + "jaxtyping", "optax", "scikit-learn", "tqdm", @@ -110,6 +111,8 @@ ignore = [ # Incompatible with other pydocstyle rules "D203", # one-blank-line-before-class "D212", # multi-line-summary-first-line + # Incompatible with jaxtyping + "F722", # forward-annotation-syntax-error # Opinionated ignores "PLR6301", # no-self-use (opinionated) # Incompatible with jaxtyping diff --git a/tests/unit/test_coreset.py b/tests/unit/test_coreset.py index 9eef41ea..f62fa90f 100644 --- a/tests/unit/test_coreset.py +++ b/tests/unit/test_coreset.py @@ -14,58 +14,63 @@ """Tests for coreset data-structures.""" -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock import equinox as eqx import jax.numpy as jnp import pytest from coreax.coreset import Coreset, Coresubset -from coreax.data import Data +from coreax.data import Data, SupervisedData from coreax.metrics import Metric from coreax.weights import WeightsOptimiser -NODES = Data(jnp.arange(5, dtype=jnp.int32)[..., None]) +DATA = Data(jnp.arange(5, dtype=jnp.int32)[..., None]) +SUPERVISED_DATA = SupervisedData( + jnp.arange(5, dtype=jnp.int32)[..., None], jnp.arange(5, dtype=jnp.int32)[..., None] +) PRE_CORESET_DATA = Data(jnp.arange(10)[..., None]) @pytest.mark.parametrize("coreset_type", [Coreset, Coresubset]) +@pytest.mark.parametrize("data", [DATA, SUPERVISED_DATA]) class TestCoresetCommon: """Common tests for `coreax.coreset.Coreset` and `coreax.coreset.Coresubset`.""" - def test_init_array_conversion(self, coreset_type): + def test_init_array_conversion(self, coreset_type, data): """ Test the initialisation behaviour. The nodes can be passed as an 'Array' or as a 'Data' instance. In the former case, we expect this array to be automatically converted to a 'Data' instance. """ - array_nodes = NODES.data + array_nodes = data.data + data_obj = Data(data.data, data.weights) coreset_array_nodes = coreset_type(array_nodes, PRE_CORESET_DATA) - coreset_data_nodes = coreset_type(NODES, PRE_CORESET_DATA) + coreset_data_nodes = coreset_type(data_obj, PRE_CORESET_DATA) assert coreset_array_nodes == coreset_data_nodes - def test_materialization(self, coreset_type): + def test_materialization(self, coreset_type, data): """Test the coreset materialisation behaviour.""" - coreset = coreset_type(NODES, PRE_CORESET_DATA) + coreset = coreset_type(data, PRE_CORESET_DATA) expected_materialization = coreset.nodes if isinstance(coreset, Coresubset): - materialized_nodes = PRE_CORESET_DATA.data[NODES.data.squeeze()] + materialized_nodes = PRE_CORESET_DATA.data[data.data.squeeze()] expected_materialization = Data(materialized_nodes) assert expected_materialization == coreset.coreset - def test_len(self, coreset_type): + def test_len(self, coreset_type, data): """Test the coreset length.""" - coreset = coreset_type(NODES, PRE_CORESET_DATA) - assert len(coreset) == len(NODES.data) + coreset = coreset_type(data, PRE_CORESET_DATA) + assert len(coreset) == len(data.data) - def test_solve_weights(self, coreset_type): + def test_solve_weights(self, coreset_type, data): """Test the weights solving convenience interface.""" solver = MagicMock(WeightsOptimiser) - solved_weights = jnp.full_like(jnp.asarray(NODES), 123) + solved_weights = jnp.full_like(jnp.asarray(data), 123) solver.solve.return_value = solved_weights - re_weighted_nodes = eqx.tree_at(lambda x: x.weights, NODES, solved_weights) - coreset = coreset_type(NODES, PRE_CORESET_DATA) + re_weighted_nodes = eqx.tree_at(lambda x: x.weights, data, solved_weights) + coreset = coreset_type(data, PRE_CORESET_DATA) coreset_expected = coreset_type(re_weighted_nodes, PRE_CORESET_DATA) kwargs = {"test": None} coreset_solved_weights = coreset.solve_weights(solver, **kwargs) @@ -74,12 +79,12 @@ def test_solve_weights(self, coreset_type): coreset.pre_coreset_data, coreset.coreset, **kwargs ) - def test_compute_metric(self, coreset_type): + def test_compute_metric(self, coreset_type, data): """Test the metric computation convenience interface.""" metric = MagicMock(spec=Metric) expected_metric = jnp.asarray(123) - metric.compute = Mock(return_value=expected_metric) - coreset = coreset_type(NODES, PRE_CORESET_DATA) + metric.compute.return_value = expected_metric + coreset = coreset_type(data, PRE_CORESET_DATA) kwargs = {"test": None} coreset_metric = coreset.compute_metric(metric, **kwargs) assert eqx.tree_equal(coreset_metric, expected_metric) @@ -93,6 +98,6 @@ class TestCoresubset: def test_unweighted_indices(self): """Test the coresubset 'unweighted_indices' property.""" - coresubset = Coresubset(NODES, PRE_CORESET_DATA) - expected_indices = NODES.data.squeeze() + coresubset = Coresubset(DATA, PRE_CORESET_DATA) + expected_indices = DATA.data.squeeze() assert eqx.tree_equal(expected_indices, coresubset.unweighted_indices) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index b5a0d8ca..733ada99 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -21,7 +21,7 @@ AbstractContextManager, nullcontext as does_not_raise, ) -from typing import NamedTuple, Optional, Union, cast +from typing import Literal, NamedTuple, Optional, Union, cast from unittest.mock import MagicMock import equinox as eqx @@ -38,6 +38,7 @@ from coreax.kernels import PCIMQKernel, ScalarValuedKernel, SquaredExponentialKernel from coreax.least_squares import RandomisedEigendecompositionSolver from coreax.solvers import ( + CaratheodoryRecombination, GreedyKernelPoints, GreedyKernelPointsState, HerdingState, @@ -48,6 +49,7 @@ RPCholeskyState, Solver, SteinThinning, + TreeRecombination, ) from coreax.solvers.base import ( ExplicitSizeSolver, @@ -109,13 +111,14 @@ def check_solution_invariants( 1. Check 'coreset.pre_coreset_data' is equal to 'dataset' 2. Check 'coreset' is equal to 'expected_coreset' (if expected is not 'None') - 3. If 'isinstance(coreset, Coresubset)', check coreset is a subset of 'dataset' - 4. If 'not hasattr(solver, random_key))', check that the + 3. If 'isinstance(coreset, Coresubset)', check coreset is a subset of 'dataset'; + note: 'coreset.weights' doesn't need to be a subset of dataset.weights + 4. If 'isinstance(problem.solver, PaddingInvariantSolver)', check that the addition of zero weighted data-points to the leading axis of the input 'dataset' does not modify the resulting coreset when the solver is a 'PaddingInvariantSolver'. """ - dataset, _, expected_coreset = problem + dataset, solver, expected_coreset = problem if isinstance(problem, _RefineProblem): dataset = problem.initial_coresubset.pre_coreset_data assert eqx.tree_equal(coreset.pre_coreset_data, dataset) @@ -123,10 +126,10 @@ def check_solution_invariants( assert isinstance(coreset, type(expected_coreset)) assert eqx.tree_equal(coreset, expected_coreset) if isinstance(coreset, Coresubset): - membership = jtu.tree_map(jnp.isin, coreset.coreset, dataset) + membership = jtu.tree_map(jnp.isin, coreset.coreset.data, dataset.data) all_membership = jtu.tree_map(jnp.all, membership) assert jtu.tree_all(all_membership) - if isinstance(problem.solver, PaddingInvariantSolver): + if isinstance(solver, PaddingInvariantSolver): padded_dataset = tree_zero_pad_leading_axis(dataset, len(dataset)) if isinstance(problem, _RefineProblem): padded_initial_coreset = eqx.tree_at( @@ -134,9 +137,9 @@ def check_solution_invariants( problem.initial_coresubset, padded_dataset, ) - coreset_from_padded, _ = problem.solver.refine(padded_initial_coreset) + coreset_from_padded, _ = solver.refine(padded_initial_coreset) else: - coreset_from_padded, _ = problem.solver.reduce(padded_dataset) + coreset_from_padded, _ = solver.reduce(padded_dataset) assert eqx.tree_equal(coreset_from_padded.coreset, coreset.coreset) @pytest.mark.parametrize("use_cached_state", (False, True)) @@ -164,6 +167,163 @@ def test_reduce( self.check_solution_invariants(coreset, reduce_problem) +class RecombinationSolverTest(SolverTest): + """Test cases for coresubset solvers that perform recombination.""" + + @override + @pytest.fixture( + params=["random", "partial-null", "null", "full_rank", "rank_deficient"], + scope="class", + ) + def reduce_problem( + self, request: pytest.FixtureRequest, solver_factory: Union[Solver, jtu.Partial] + ) -> _ReduceProblem: + node_key, weight_key, rng_key = jr.split(self.random_key, num=3) + nodes = jr.uniform(node_key, self.shape) + weights = jr.uniform(weight_key, (self.shape[0],)) + expected_coreset = None + if request.param == "random": + test_functions = None + elif request.param == "partial-null": + zero_weights = jr.choice(rng_key, self.shape[0], (self.shape[0] // 2,)) + weights = weights.at[zero_weights].set(0) + test_functions = None + elif request.param == "null": + + def test_functions(x): + return jnp.zeros(x.shape) + elif request.param == "full_rank": + + def test_functions(x): + norm_x = jnp.linalg.norm(x) + return jnp.array([norm_x, norm_x**2, norm_x**3]) + elif request.param == "rank_deficient": + + def test_functions(x): + norm_x = jnp.linalg.norm(x) + return jnp.array([norm_x, 2 * norm_x, 2 + norm_x]) + else: + raise ValueError("Invalid fixture parametrization") + solver_factory.keywords["test_functions"] = test_functions + solver = solver_factory() + return _ReduceProblem(Data(nodes, weights), solver, expected_coreset) + + @override + def check_solution_invariants( + self, coreset: Coreset, problem: Union[_RefineProblem, _ReduceProblem] + ) -> None: + r""" + Check that a coreset obeys certain expected invariant properties. + + In addition to the standard checks in the parent class we also check: + 1. Check 'sum(coreset.weights)' is one. + 1. Check 'len(coreset)' is less than or equal to the upper bound `m`. + 2. Check 'len(coreset[idx]) where idx = jnp.nonzero(coreset.weights)' is less + than or equal to the rank, :math:`m^\prime`, of the pushed forward nodes. + 3. Check the push-forward of the coreset preserves the "centre-of-mass" (CoM) of + the pushed-forward dataset (with implicit and explicit zero weight removal). + 4. Check the default value of 'test_functions' is the identity map. + """ + super().check_solution_invariants(coreset, problem) + dataset, solver, _ = problem + coreset_nodes, coreset_weights = coreset.coreset.data, coreset.coreset.weights + assert eqx.tree_equal(jnp.sum(coreset_weights), jnp.asarray(1.0), rtol=5e-5) + if solver.test_functions is None: + solver = eqx.tree_at( + lambda x: x.test_functions, + solver, + lambda x: x, + is_leaf=lambda x: x is None, + ) + expected_default_coreset, _ = solver.reduce(dataset) + assert eqx.tree_equal(coreset, expected_default_coreset) + + vmap_test_functions = jax.vmap(solver.test_functions) + pushed_forward_nodes = vmap_test_functions(dataset.data) + augmented_pushed_forward_nodes = jnp.c_[ + jnp.ones_like(dataset.weights), pushed_forward_nodes + ] + rank = jnp.linalg.matrix_rank(augmented_pushed_forward_nodes) + max_rank = augmented_pushed_forward_nodes.shape[-1] + assert rank <= max_rank + non_zero = jnp.flatnonzero(coreset_weights) + if solver.mode == "implicit-explicit": + assert len(coreset) <= max_rank + assert len(non_zero) <= len(coreset) - (max_rank - rank) + if solver.mode == "implicit": + assert len(coreset) == len(augmented_pushed_forward_nodes) + assert len(non_zero) <= len(coreset) - (max_rank - rank) + if solver.mode == "explicit": + assert len(non_zero) == len(coreset) + assert len(coreset) <= rank + pushed_forward_com = jnp.average( + pushed_forward_nodes, 0, weights=dataset.weights + ) + pushed_forward_coreset_nodes = vmap_test_functions( + jnp.atleast_2d(coreset_nodes) + ) + coreset_pushed_forward_com = jnp.average( + pushed_forward_coreset_nodes, 0, weights=coreset_weights + ) + assert eqx.tree_equal(pushed_forward_com, coreset_pushed_forward_com, rtol=1e-5) + explicit_coreset_pushed_forward_com = jnp.average( + pushed_forward_coreset_nodes[non_zero], 0, weights=coreset_weights[non_zero] + ) + assert eqx.tree_equal( + coreset_pushed_forward_com, explicit_coreset_pushed_forward_com, rtol=1e-5 + ) + + @override + @pytest.mark.parametrize("use_cached_state", (False, True)) + @pytest.mark.parametrize( + "recombination_mode, context", + ( + ("implicit-explicit", does_not_raise()), + ("implicit", does_not_raise()), + ( + "explicit", + pytest.raises(ValueError, match="'explicit' mode is incompatible"), + ), + (None, pytest.raises(ValueError, match="Invalid mode")), + ), + ) + # We don't care too much that arguments differ as this is required to override the + # parametrization. Nevertheless, this should probably be revisited in the future. + def test_reduce( # pylint: disable=arguments-differ + self, + jit_variant: Callable[[Callable], Callable], + reduce_problem: _ReduceProblem, + use_cached_state: bool, + recombination_mode: Literal["implicit-explicit", "implicit", "explicit"], + context: AbstractContextManager, + ) -> None: + """ + Check 'reduce' raises no errors and is resultant 'solver_state' invariant. + + Overrides the default implementation to provide handling of different modes of + recombination. + + By resultant 'solver_state' invariant we mean the following procedure succeeds: + 1. Call 'reduce' with the default 'solver_state' to get the resultant state + 2. Call 'reduce' again, this time passing the 'solver_state' from the previous + run, and keeping all other arguments the same. + 3. Check the two calls to 'refine' yield that same result. + """ + dataset, base_solver, expected_coreset = reduce_problem + solver = eqx.tree_at(lambda x: x.mode, base_solver, recombination_mode) + updated_problem = _ReduceProblem(dataset, solver, expected_coreset) + # Explicit should only raise if jit_variant is eqx.filter_jit (or jax.jit). + if jit_variant is not eqx.filter_jit and recombination_mode == "explicit": + context = does_not_raise() + with context: + coreset, state = jit_variant(solver.reduce)(dataset) + if use_cached_state: + coreset_with_state, recycled_state = solver.reduce(dataset, state) + assert eqx.tree_equal(recycled_state, state) + assert eqx.tree_equal(coreset_with_state, coreset) + self.check_solution_invariants(coreset, updated_problem) + + class RefinementSolverTest(SolverTest): """Test cases for coresubset solvers that provide a 'refine' method.""" @@ -689,3 +849,23 @@ def test_base_solver( solver_factory.keywords["leaf_size"] = self.leaf_size solver_factory.keywords["base_solver"] = base_solver solver_factory() + + +class TestCaratheodoryRecombination(RecombinationSolverTest): + """Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`.""" + + @override + @pytest.fixture(scope="class") + def solver_factory(self) -> Union[Solver, jtu.Partial]: + return jtu.Partial(CaratheodoryRecombination, test_functions=None, rcond=None) + + +class TestTreeRecombination(RecombinationSolverTest): + """Tests for :class:`coreax.solvers.recombination.TreeRecombination`.""" + + @override + @pytest.fixture(scope="class") + def solver_factory(self) -> Union[Solver, jtu.Partial]: + return jtu.Partial( + TreeRecombination, test_functions=None, rcond=None, tree_reduction_factor=3 + )