Skip to content

Commit 5ad2aa2

Browse files
authored
multivariate bugfixes (0.3.0 release) (#7)
* MVSamples: fix to pass samples.T to stats.kde_gaussian * MVSamples.to_mvhistogram: use samples directly if N>=nsamples * temporary fix: accept unit/as_quantity kwargs to MV*.sample
1 parent 2d885e8 commit 5ad2aa2

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

distl/distl.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
else:
4646
_has_dill = True
4747

48-
__version__ = '0.2.0'
48+
__version__ = '0.3.0'
4949
version = __version__
5050

5151
_math_symbols = {'__mul__': '*', '__add__': '+', '__sub__': '-',
@@ -3255,7 +3255,8 @@ def uncertainties(self, sigma=1, tex=False, dimension=None, samples=None):
32553255
else:
32563256
return qs_per_dim
32573257

3258-
def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
3258+
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
3259+
unit=None, as_quantity=False):
32593260
"""
32603261
Sample from the distribution.
32613262
@@ -3270,13 +3271,18 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
32703271
prior to sampling.
32713272
* `cache_sample` (bool, optional, default=True): whether to override the
32723273
existing <<class>.cached_sample>.
3274+
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
3275+
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
32733276
32743277
Returns
32753278
---------
32763279
* float or array: float if `size=None`, otherwise a numpy array with
32773280
shape defined by `size`.
32783281
"""
32793282

3283+
if unit is not None or as_quantity:
3284+
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")
3285+
32803286
# TODO: add support for per-dimension unit, wrap_at, as_quantity (and pass in to_mvhistogram)
32813287
# TODO: add support for seed
32823288
if isinstance(seed, dict):
@@ -3714,14 +3720,16 @@ def uncertainties(self, sigma=1, tex=False):
37143720

37153721
### SAMPLING & PLOTTING
37163722

3717-
def sample(self, size=None, wrap_at=None, seed=None, cache_sample=True):
3723+
def sample(self, size=None, wrap_at=None, seed=None, cache_sample=True,
3724+
unit=None, as_quantity=False):
37183725
"""
37193726
Sample the underlying <<class>.multivariate> distribution in the dimension
37203727
defined in <<class>.dimension>.
37213728
"""
37223729

37233730
# TODO: support unit, wrap_at, as_quantity
3724-
return self.multivariate.sample(size=size, seed=seed, dimension=self.dimension, cache_sample=cache_sample)
3731+
return self.multivariate.sample(size=size, seed=seed, dimension=self.dimension, cache_sample=cache_sample,
3732+
unit=unit, as_quantity=as_quantity)
37253733

37263734
def plot_sample(self, *args, **kwargs):
37273735
if hasattr(self, 'bins'):
@@ -3958,7 +3966,7 @@ def get_distributions_with_values(self, values=None, as_univariates=False):
39583966
if not as_univariates and isinstance(dist_orig, BaseMultivariateSliceDistribution):
39593967
d = dist_orig.multivariate
39603968
else:
3961-
d = dist_orig
3969+
d = dist_orig #.to_univariate()?
39623970

39633971
# if as_univariates then we want MVSlices with the same parent MV to be treated separately
39643972
take_dimensions = not as_univariates and isinstance(dist_orig, BaseMultivariateSliceDistribution)
@@ -4044,7 +4052,7 @@ def logpdf(self, values=None, as_univariates=False):
40444052
samples are available, a ValueError will be raised.
40454053
* `as_univariates` (bool, optional, default=False): whether `values` corresponds
40464054
to the passed distributions (<DistributionCollection.distributions>)
4047-
or the underlying unpacked distributions (<DistributionCollection.distributions_unpacked>).
4055+
or the underlying unpacked distributions (<DistributionCollection.dists_unpacked>).
40484056
If the former (`as_univariates=False`), covariances will be respected
40494057
from any underlying multivariate distributions. If the latter
40504058
(`as_univariates=True`) covariances will be ignored.
@@ -7245,7 +7253,8 @@ def take_dimensions(self, dimensions):
72457253
labels=[self.labels[d] for d in dimensions] if self.labels is not None else None,
72467254
wrap_ats=[self.wrap_ats[d] for d in dimensions] if self.wrap_ats is not None else None)
72477255

7248-
def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
7256+
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
7257+
unit=None, as_quantity=False):
72497258
"""
72507259
72517260
Arguments
@@ -7256,6 +7265,8 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
72567265
prior to sampling.
72577266
* `cache_sample` (bool, optional, default=True): whether to override the
72587267
existing <<class>.cached_sample>.
7268+
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
7269+
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
72597270
72607271
"""
72617272
# if dimension is not None:
@@ -7266,6 +7277,10 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
72667277
# bins = self.bins
72677278
# density = self.density
72687279

7280+
if unit is not None or as_quantity:
7281+
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")
7282+
7283+
72697284
if isinstance(seed, dict):
72707285
seed = seed.get(self.uniqueid, None)
72717286

@@ -7535,7 +7550,7 @@ def __init__(self, samples, weights=None, bw_method=None, units=None,
75357550
75367551
Arguments
75377552
--------------
7538-
* `samples` (np.array object with shape (nsamples, <MVSamples.ndimensions>)):
7553+
* `samples` (np.array object with shape (<MVSamples.nsamples>, <MVSamples.ndimensions>)):
75397554
the samples.
75407555
* `weights` (np.array object with shape (nsamples) or None, optional, default=None):
75417556
weights for each entry in `samples`. NOTE: only supported with scipy
@@ -7560,6 +7575,8 @@ def __init__(self, samples, weights=None, bw_method=None, units=None,
75607575
--------
75617576
* an <MVSamples> object
75627577
"""
7578+
# NOTE: the passed samples need to be transposed, so see the override
7579+
# in dist_constructor_args
75637580
super(MVSamples, self).__init__(units, labels, labels_latex, wrap_ats,
75647581
_stats.gaussian_kde, ('samples', 'bw_method') if StrictVersion(_scipy_version) < StrictVersion("1.2.0") else ('samples', 'bw_method', 'weights'),
75657582
samples=samples, weights=weights, bw_method=bw_method,
@@ -7580,7 +7597,7 @@ def samples(self, value):
75807597
@property
75817598
def weights(self):
75827599
"""
7583-
weights for each entry in <Samples.samples>
7600+
weights for each sample in <Samples.samples> (nsamples)
75847601
"""
75857602
return self._weights
75867603

@@ -7608,6 +7625,25 @@ def bw_method(self, value):
76087625
self._bw_method = is_float(value)
76097626
self._dist_constructor_object_clear_cache()
76107627

7628+
@property
7629+
def dist_constructor_args(self):
7630+
"""
7631+
Return the arguments to pass to the the underlying distribution
7632+
constructor (often the scipy.stats random variable generator function)
7633+
7634+
<MVSamples.samples> is transposed before passing on to gaussian_kde
7635+
7636+
See also:
7637+
7638+
* <<class>.dist_constructor_func>
7639+
* <<class>.dist_constructor_object>
7640+
7641+
Returns
7642+
-------
7643+
* tuple
7644+
"""
7645+
return [getattr(self, a).T if a=='samples' else getattr(self,a) for a in self.dist_constructor_argnames]
7646+
76117647
@property
76127648
def ndimensions(self):
76137649
"""
@@ -7733,7 +7769,8 @@ def interval(self, *args, **kwargs):
77337769
# TODO: manual implementation
77347770
raise NotImplementedError()
77357771

7736-
def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
7772+
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
7773+
unit=None, as_quantity=False):
77377774
"""
77387775
Sample from the samples (<MVSamples.samples> if <MVSamples.weights>
77397776
is not provided, otherwise <MVSamples.samples_weighted>)
@@ -7746,9 +7783,16 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
77467783
prior to sampling.
77477784
* `cache_sample` (bool, optional, default=True): whether to override the
77487785
existing <<class>.cached_sample>.
7786+
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
7787+
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
7788+
77497789
77507790
"""
77517791

7792+
if unit is not None or as_quantity:
7793+
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")
7794+
7795+
77527796
if isinstance(seed, dict):
77537797
seed = seed.get(self.uniqueid, None)
77547798

@@ -7928,7 +7972,8 @@ def to_mvhistogram(self, N=1e6, bins=15, range=None):
79287972
Arguments
79297973
-----------
79307974
* `N` (int, optional, default=1e6): number of samples to use for
7931-
the histogram.
7975+
the histogram. If N>=<MVSamples.nsamples>, <MVSamples.samples>
7976+
will be passed directly.
79327977
* `bins` (int, optional, default=15): number of bins to use for the
79337978
histogram.
79347979
* `range` (tuple or None): range to use for the histogram.
@@ -7938,7 +7983,7 @@ def to_mvhistogram(self, N=1e6, bins=15, range=None):
79387983
* an <MVHistogram> object
79397984
"""
79407985
# TODO: if sample is updated to take wrap_at/wrap_ats... pass wrap_at=False here
7941-
return MVHistogram.from_data(self.sample(size=int(N), cache_sample=False),
7986+
return MVHistogram.from_data(self.samples if N >= self.nsamples else self.sample(size=int(N), cache_sample=False),
79427987
bins=bins, range=range,
79437988
units=self.units,
79447989
labels=self.labels, labels_latex=self._labels_latex,
@@ -8015,6 +8060,18 @@ def ppf(self, q, unit=None, as_quantity=False, wrap_at=None):
80158060
"""
80168061
return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).ppf(q, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)
80178062

8063+
# def pdf(self, x, unit=None, as_quantity=False, wrap_at=None):
8064+
# """
8065+
# See <Samples.pdf>
8066+
# """
8067+
# return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).pdf(x, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)
8068+
#
8069+
# def logpdf(self, x, unit=None, as_quantity=False, wrap_at=None):
8070+
# """
8071+
# See <Samples.logpdf>
8072+
# """
8073+
# return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).logpdf(x, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)
8074+
80188075
def interval(self, alpha, unit=None, as_quantity=False, wrap_at=None):
80198076
"""
80208077
See <Samples.interval>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
long_description = fh.read()
77

88
setup(name='distl',
9-
version='0.2.0',
9+
version='0.3.0',
1010
description='Simple Distributions: math operations, serializing, covariances',
1111
long_description=long_description,
1212
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)