Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tree recombination #504

Merged
merged 15 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .cspell/custom_misc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ANNchor
approximator
approximators
archiveprefix
binom
cand
coreax
coreset
Expand All @@ -25,18 +26,23 @@ kernelised
kernelized
KSD
linewidth
mapsto
ml.p3.8xlarge
ndmin
parsable
PCIMQ
phdthesis
PMLR
primaryclass
PRNG
rcond
recomb
refs
regulariser
RKHS
RPCHOLESKY
sigmas
subseteq
supp
TLDR
typecheck
Expand Down
5 changes: 5 additions & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ brightgreen
bysource
cmap
color
Concretization
dispname
dollarmath
dtype
Expand All @@ -30,8 +31,10 @@ dunder
eigh
emibcn
eprint
errstate
figsize
finfo
flatnonzero
fori_loop
fracv
furo
Expand Down Expand Up @@ -112,6 +115,7 @@ sklearn
softplus
sphinxcontrib
sphobjinv
tensordot
texttt
toctree
tomli
Expand All @@ -125,6 +129,7 @@ unsrt
vars
viewcode
vmap
vmaps
writebytes
xlabel
ylabel
5 changes: 5 additions & 0 deletions .cspell/people.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Benard
Caratheodory
Chatalic
Duvenaud
Epperly
Expand All @@ -13,6 +14,8 @@ Jaehoon
Jiaxin
Jitkrittum
Kanagawa
Litterer
Lyons
Martinsson
Motonobu
Nystr
Expand All @@ -26,6 +29,8 @@ Schreuder
Smirnov
Smola
Staber
Tchakaloff
Tchernychova
Tropp
Veiga
Wittawat
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added Kernelised Stein Discrepancy divergence in `coreax.metrics.KSD`.
- Added the `coreax.solvers.recombination` module, which provides the following new solvers:
- `RecombinationSolver`: an abstract base class for recombination solvers.
- `CaratheodoryRecombination`: a simple deterministic approach to solving recombination problems.
- `TreeRecombination`: an advanced deterministic approach that utilises `CaratheodoryRecombination`,
but provides superior performance for solving all but the smallest recombination problems.
db091756 marked this conversation as resolved.
Show resolved Hide resolved
- Added supervised coreset construction algorithm in `coreax.solvers.GreedyKernelPoints`
- Added `coreax.kernels.PowerKernel` to replace repeated calls of `coreax.kernels.ProductKernel`
within the `**` magic method of `coreax.kernel.ScalarValuedKernel`
Expand Down
15 changes: 3 additions & 12 deletions coreax/coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Shaped
from typing_extensions import Self

from coreax.data import Data, SupervisedData, as_data
from coreax.data import Data, as_data
from coreax.metrics import Metric
from coreax.weights import WeightsOptimiser

Expand Down Expand Up @@ -122,17 +122,8 @@ class Coresubset(Coreset[_Data], Generic[_Data]):
@property
def coreset(self) -> Data:
"""Materialise the coresubset from the indices and original data."""
coreset_data = self.pre_coreset_data.data[self.unweighted_indices]
if isinstance(self.pre_coreset_data, SupervisedData):
coreset_supervision = self.pre_coreset_data.supervision[
self.unweighted_indices
]
return SupervisedData(
data=coreset_data,
supervision=coreset_supervision,
weights=self.nodes.weights,
)
return Data(data=coreset_data, weights=self.nodes.weights)
coreset_data = self.pre_coreset_data[self.unweighted_indices]
return eqx.tree_at(lambda x: x.weights, coreset_data, self.nodes.weights)

@property
def unweighted_indices(self) -> Shaped[Array, " n"]:
Expand Down
8 changes: 8 additions & 0 deletions coreax/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
RPCholeskyState,
SteinThinning,
)
from coreax.solvers.recombination import (
CaratheodoryRecombination,
RecombinationSolver,
TreeRecombination,
)

__all__ = [
"CompositeSolver",
Expand All @@ -49,4 +54,7 @@
"PaddingInvariantSolver",
"GreedyKernelPoints",
"GreedyKernelPointsState",
"RecombinationSolver",
"CaratheodoryRecombination",
"TreeRecombination",
]
2 changes: 1 addition & 1 deletion coreax/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class PaddingInvariantSolver(Solver):
A :class:`Solver` whose results are invariant to zero weighted data.

In some cases, such as in :class:`coreax.solvers.MapReduce`, there is a need to pad
data to ensure shape stability. In some cases, we may assign zero weight to the
data to ensure shape stability. In these cases, we may assign zero weight to the
padded data points, which allows certain 'padding invariant' solvers to return the
same values on a call to :meth:`~coreax.solvers.Solver.reduce` as would have been
returned if no padding were present.
Expand Down
Loading
Loading