diff --git a/doc/sinter_api.md b/doc/sinter_api.md
index 5cb64f821..ac7e10006 100644
--- a/doc/sinter_api.md
+++ b/doc/sinter_api.md
@@ -42,6 +42,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi
- [`sinter.iter_collect`](#sinter.iter_collect)
- [`sinter.log_binomial`](#sinter.log_binomial)
- [`sinter.log_factorial`](#sinter.log_factorial)
+- [`sinter.plot_custom`](#sinter.plot_custom)
- [`sinter.plot_discard_rate`](#sinter.plot_discard_rate)
- [`sinter.plot_error_rate`](#sinter.plot_error_rate)
- [`sinter.post_selection_mask_from_4th_coord`](#sinter.post_selection_mask_from_4th_coord)
@@ -1452,6 +1453,70 @@ def log_factorial(
"""
```
+
+```python
+# sinter.plot_custom
+
+# (at top-level in the sinter module)
+def plot_custom(
+ *,
+ ax: 'plt.Axes',
+ stats: 'Iterable[sinter.TaskStats]',
+ x_func: Callable[[sinter.TaskStats], Any],
+ y_func: Callable[[sinter.TaskStats], Union[sinter.Fit, float, int]],
+ group_func: Callable[[sinter.TaskStats], ~TCurveId] = lambda _: None,
+ point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
+ filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True,
+ plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
+ line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
+) -> None:
+ """Plots error rates in curves with uncertainty highlights.
+
+ Args:
+ ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
+ stats: The collected statistics to plot.
+ x_func: The X coordinate to use for each stat's data point. For example, this could be
+ `x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
+ y_func: The Y value to use for each stat's data point. This can be a float or it can be a
+ sinter.Fit value, in which case the curve will follow the fit.best value and a
+ highlighted area will be shown from fit.low to fit.high.
+ group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
+ The statistics are grouped into curves based on whether or not they get the same result
+ out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
+ If the result of the function is a dictionary, then optional keys in the dictionary will
+ also control the plotting of each curve. Available keys are:
+ 'label': the label added to the legend for the curve
+ 'color': the color used for plotting the curve
+ 'marker': the marker used for the curve
+ 'linestyle': the linestyle used for the curve
+ 'sort': the order in which the curves will be plotted and added to the legend
+ e.g. if two curves (with different resulting dictionaries from group_func) share the same
+ value for key 'marker', they will be plotted with the same marker.
+ Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
+ point_label_func: Optional. Specifies text to draw next to data points.
+ filter_func: Optional. When specified, some curves will not be plotted.
+ The statistics are filtered and only plotted if filter_func(stat) returns True.
+ For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
+ where the saved metadata indicates the basis was 'x'.
+ plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
+ `plot` and `fill_between` used to do the actual plotting. For example, this can be used
+ to specify markers and colors. Takes the index of the curve in sorted order and also a
+ curve_id (these will be 0 and None respectively if group_func is not specified). For example,
+ this could be:
+
+ plot_args_func=lambda index, group_key, group_stats: {
+ 'color': (
+ 'red'
+ if group_key == 'decoder=pymatching p=0.001'
+ else 'blue'
+ ),
+ }
+ line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
+ fit to every curve. The scales determine how to transform the coordinates before
+ performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
+ """
+```
+
```python
# sinter.plot_discard_rate
diff --git a/glue/sample/src/sinter/__init__.py b/glue/sample/src/sinter/__init__.py
index 47f702793..9046657f2 100644
--- a/glue/sample/src/sinter/__init__.py
+++ b/glue/sample/src/sinter/__init__.py
@@ -37,6 +37,7 @@
better_sorted_str_terms,
plot_discard_rate,
plot_error_rate,
+ plot_custom,
group_by,
)
from sinter._predict import (