Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ dependencies = [
"tqdm>=4.66.5",
]

[project.urls]
Documentation = "https://rascalsoftware.github.io/RAT/"
Repository = "https://github.com/RascalSoftware/python-RAT"

[project.optional-dependencies]
dev = [
"pytest>=7.4.0",
Expand Down
47 changes: 17 additions & 30 deletions ratapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
for j in range(len(sld)):
results["sld"][-1].append([sld[j][:, 0], sld[j][:, 1]])

results["sld_resample"].append([])
if event_data.resample[i] == 1 or event_data.modelType == "custom xy":
layers = event_data.resampledLayers[i][0]
results["sld_resample"].append([])
for j in range(len(event_data.resampledLayers[i])):
layer = event_data.resampledLayers[i][j]
if layers.shape[1] == 4:
Expand Down Expand Up @@ -198,7 +198,7 @@ def plot_ref_sld_helper(
sld_min, sld_max = confidence_intervals["sld"][i][j]
sld_plot.fill_between(plot_data["sld"][i][j][0], sld_min, sld_max, alpha=0.6, color="grey")

if plot_data["sld_resample"]:
if plot_data["sld_resample"] and plot_data["sld_resample"][i]:
for j in range(len(plot_data["sld_resample"][i])):
sld_plot.plot(
plot_data["sld_resample"][i][j][0],
Expand Down Expand Up @@ -544,7 +544,7 @@ def update_foreground(self, data):
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
i += 1

if plot_data["sld_resample"]:
if plot_data["sld_resample"] and plot_data["sld_resample"][j]:
for resampled in plot_data["sld_resample"][j]:
self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1])
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
Expand Down Expand Up @@ -982,10 +982,7 @@ def plot_contour(


def panel_plot_helper(
plot_func: Callable,
indices: list[int],
fig: matplotlib.figure.Figure | None = None,
progress_callback: Callable[[int, int], None] | None = None,
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
) -> matplotlib.figure.Figure:
"""Generate a panel-based plot from a single plot function.

Expand All @@ -997,9 +994,6 @@ def panel_plot_helper(
The list of indices to pass into ``plot_func``.
fig : matplotlib.figure.Figure, optional
The figure object to use for plot.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1011,19 +1005,21 @@ def panel_plot_helper(
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))

if fig is None:
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
else:
fig.clf()
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
fig.subplots(nrows, ncols)
axs = fig.get_axes()
for index, plot_num in enumerate(indices):
axs[index].tick_params(which="both", labelsize="medium")
axs[index].xaxis.offsetText.set_fontsize("small")
axs[index].yaxis.offsetText.set_fontsize("small")
axs[index].set_visible(True)
plot_func(axs[index], plot_num)
if progress_callback is not None:
progress_callback(index, nplots)

for plot_num, index in enumerate(indices):
axs[plot_num].tick_params(which="both", labelsize="medium")
axs[plot_num].xaxis.offsetText.set_fontsize("small")
axs[plot_num].yaxis.offsetText.set_fontsize("small")
plot_func(axs[plot_num], index)

# blank unused plots
for i in range(nplots, len(axs)):
axs[i].set_visible(False)

fig.tight_layout()
return fig
Expand All @@ -1040,7 +1036,6 @@ def plot_hists(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
**hist_settings,
):
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
Expand Down Expand Up @@ -1077,9 +1072,6 @@ def plot_hists(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots
hist_settings :
Settings passed to `np.histogram`. By default, the settings
passed are `bins = 25` and `density = True`.
Expand Down Expand Up @@ -1138,7 +1130,6 @@ def validate_dens_type(dens_type: str | None, param: str):
),
params,
fig,
progress_callback,
)
if return_fig:
return fig
Expand All @@ -1153,7 +1144,6 @@ def plot_chain(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
):
"""Plot the MCMC chain for each parameter of a Bayesian analysis.

Expand All @@ -1172,9 +1162,6 @@ def plot_chain(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1200,7 +1187,7 @@ def plot_one_chain(axes: Axes, i: int):
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
axes.set_title(results.fitNames[i], fontsize="small")

fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
fig = panel_plot_helper(plot_one_chain, params, fig=fig)
if return_fig:
return fig
plt.show(block=block)
Expand Down
6 changes: 0 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ def build_libraries(self, libraries):

setup(
name=PACKAGE_NAME,
author="",
author_email="",
url="https://github.com/RascalSoftware/python-RAT",
description="Python extension for the Reflectivity Analysis Toolbox (RAT)",
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
packages=find_packages(exclude=("tests",)),
include_package_data=True,
package_data={"": [get_shared_object_name(libevent[0])], "ratapi.examples": ["data/*.dat"]},
Expand Down