Skip to content

Commit

Permalink
Merge branch 'master' into tension
Browse files Browse the repository at this point in the history
  • Loading branch information
williamjameshandley committed Sep 19, 2024
2 parents 7a10785 + d367a4b commit 970f431
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 28 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install hdf5 for macOS
if: ${{ matrix.os == 'macos-latest' }}
run: brew install hdf5 c-blosc

- name: Install dependencies
run: |
Expand Down Expand Up @@ -107,9 +110,12 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install hdf5 for macOS
if: ${{ matrix.os == 'macos-latest' }}
run: brew install hdf5 c-blosc

- name: Install dependencies
shell: bash -l {0}
Expand Down
4 changes: 4 additions & 0 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,12 @@ def _set_logticks(self):
if ax is not None:
if x in self._logx:
ax.xaxis.set_major_locator(LogLocator(numticks=3))
if x != y:
ax.set_xlim(ax.dataLim.intervalx)
if y in self._logy:
ax.yaxis.set_major_locator(LogLocator(numticks=3))
if y != x:
ax.set_ylim(ax.dataLim.intervaly)

@staticmethod
def _set_labels(axes, labels, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions anesthetic/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def _make_plot(self, fig):
ax = self._get_ax(0) # another one of these hard-coded 0s

kwds = self.kwds.copy()
if self.color is not None:
kwds["color"] = self.color
label = pprint_thing(self.label)
kwds["label"] = label

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def plot_2d(self, axes=None, *args, **kwargs):
if np.isinf(self[x]).any():
warnings.warn(f"column {y} has inf values.")
selfxy = self[[x, y]]
selfxy = self.replace([-np.inf, np.inf], np.nan)
selfxy = selfxy.replace([-np.inf, np.inf], np.nan)
selfxy = selfxy.dropna(axis=0)
selfxy.plot(x, y, ax=ax, xlabel=xlabel,
logx=x in logx, logy=y in logy,
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ readme = "README.rst"
license = {file = "LICENSE"}
requires-python = ">=3.8"
dependencies = [
"scipy",
"numpy",
"scipy<2.0.0",
"numpy<2.0.0",
"pandas~=2.2.0",
"matplotlib>=3.6.1,<3.9.0",
"matplotlib>=3.6.1,<3.10.0",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -68,7 +68,7 @@ astropy = ["astropy"]
fastkde = ["fastkde"]
getdist = ["getdist"]
hdf5 = ["tables==3.8.0"]
all = ["h5py", "astropy", "fastkde", "getdist", "tables==3.8.0"]
all = ["h5py", "astropy", "fastkde", "getdist", "tables"]

[project.scripts]
anesthetic = "anesthetic.scripts:gui"
Expand Down
53 changes: 31 additions & 22 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,31 +308,29 @@ def test_plot_2d_colours(kind):
kinds = {'diagonal': kind + '_1d',
'lower': kind + '_2d',
'upper': 'scatter_2d'}
gd.plot_2d(axes, kind=kinds, label="gd")
pc.plot_2d(axes, kind=kinds, label="pc")
mn.plot_2d(axes, kind=kinds, label="mn")
gd_colors = []
pc_colors = []
mn_colors = []
gd.plot_2d(axes, kind=kinds, label="A")
pc.plot_2d(axes, kind=kinds, label="B")
mn.plot_2d(axes, kind=kinds, label="C")
gd.plot_2d(axes, kind=kinds, label="D", color='C7')
pc.plot_2d(axes, kind=kinds, label="E", color='C6')
mn.plot_2d(axes, kind=kinds, label="F", color='C5')

from collections import defaultdict
d = defaultdict(set)

for y, rows in axes.iterrows():
for x, ax in rows.items():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels):
if isinstance(handle, Rectangle):
color = to_hex(handle.get_facecolor())
color = handle.get_facecolor()
else:
color = handle.get_color()
color = to_hex(color)
d[label].add(color)

if label == 'gd':
gd_colors.append(color)
elif label == 'pc':
pc_colors.append(color)
elif label == 'mn':
mn_colors.append(color)

assert len(set(gd_colors)) == 1
assert len(set(mn_colors)) == 1
assert len(set(pc_colors)) == 1
for v in d.values():
assert len(v) == 1


@pytest.mark.parametrize('kwargs', [dict(color='r', alpha=0.5, ls=':', lw=1),
Expand Down Expand Up @@ -526,20 +524,31 @@ def test_plot_logscale_2d(kind):
def test_logscale_ticks():
np.random.seed(42)
ndim = 5
data = np.exp(10 * np.random.randn(200, ndim))
data1 = np.exp(10 * np.random.randn(200, ndim))
data2 = np.exp(10 * np.random.randn(200, ndim) - 50)
params = [f'a{i}' for i in range(ndim)]
fig, axes = make_2d_axes(params, logx=params, logy=params, upper=False)
samples = Samples(data, columns=params)
samples.plot_2d(axes)
for _, col in axes.iterrows():
for _, ax in col.items():
samples1 = Samples(data1, columns=params)
samples2 = Samples(data2, columns=params)
samples1.plot_2d(axes)
samples2.plot_2d(axes)
for y, col in axes.iterrows():
for x, ax in col.items():
if ax is not None:
xlims = ax.get_xlim()
xticks = ax.get_xticks()
assert np.sum((xticks > xlims[0]) & (xticks < xlims[1])) > 1
ylims = ax.get_ylim()
yticks = ax.get_yticks()
assert np.sum((yticks > ylims[0]) & (yticks < ylims[1])) > 1
if x == y:
data_min = ax.twin.dataLim.intervalx[0]
data_max = ax.twin.dataLim.intervalx[1]
assert xlims[0] == pytest.approx(data_min, rel=1e-14)
assert xlims[1] == pytest.approx(data_max, rel=1e-14)
else:
assert_array_equal(xlims, ax.dataLim.intervalx)
assert_array_equal(ylims, ax.dataLim.intervaly)


@pytest.mark.parametrize('k', ['hist_1d', 'hist'])
Expand Down

0 comments on commit 970f431

Please sign in to comment.