diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index c2f8133a..e0d05d31 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -19,6 +19,7 @@ from datatree import DataTree from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from scanpy._settings import settings as sc_settings from spatialdata import get_extent from spatialdata.models import PointsModel, get_table_keys @@ -624,8 +625,10 @@ def _render_images( _ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder) if legend_params.colorbar: + ax_divider = make_axes_locatable(ax) sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm) - fig_params.fig.colorbar(sm, ax=ax) + cax = ax_divider.append_axes("right", size="7%", pad="2%") + fig_params.fig.colorbar(sm, ax=ax, cax=cax) # 2) Image has any number of channels but 1 else: diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 33a103da..5e22aa21 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -39,6 +39,7 @@ from matplotlib.gridspec import GridSpec from matplotlib.transforms import CompositeGenericTransform from matplotlib_scalebar.scalebar import ScaleBar +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from numpy.ma.core import MaskedArray from numpy.random import default_rng from pandas.api.types import CategoricalDtype @@ -1016,7 +1017,9 @@ def _decorate_axs( ) elif colorbar: # TODO: na_in_legend should have some effect here - cb = plt.colorbar(cax, ax=ax, pad=0.01, fraction=0.08, aspect=30) + ax_divider = make_axes_locatable(ax) + cax2 = ax_divider.append_axes("right", size="5%", pad="2%") + cb = fig_params.fig.colorbar(cax, ax=ax, cax=cax2) cb.solids.set_alpha(alpha) if isinstance(scalebar_dx, list) and isinstance(scalebar_units, list):