Skip to content

Commit

Permalink
Merge pull request #600 from c-bata/plotly-user-defined-graph-objects
Browse files Browse the repository at this point in the history
Support user-defined plotly figures
  • Loading branch information
c-bata committed Sep 7, 2023
2 parents 436afe7 + 167a4de commit 94ef540
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ General APIs
optuna_dashboard.wsgi
optuna_dashboard.set_objective_names
optuna_dashboard.save_note
optuna_dashboard.save_plotly_graph_object

Human-in-the-loop
-----------------
Expand Down
1 change: 1 addition & 0 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._app import run_server # noqa
from ._app import wsgi # noqa
from ._custom_plot_data import save_plotly_graph_object # noqa
from ._form_widget import ChoiceWidget # noqa
from ._form_widget import dict_to_form_widget # noqa
from ._form_widget import ObjectiveChoiceWidget # noqa
Expand Down
4 changes: 4 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._bottle_util import BottleViewReturn
from ._bottle_util import json_api_view
from ._cached_extra_study_property import get_cached_extra_study_property
from ._custom_plot_data import get_plotly_graph_objects
from ._importance import get_param_importance_from_trials_cache
from ._pareto_front import get_pareto_front_trials
from ._preferential_history import NewHistory
Expand Down Expand Up @@ -214,6 +215,8 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)

plotly_graph_objects = get_plotly_graph_objects(system_attrs)
return serialize_study_detail(
summary,
best_trials,
Expand All @@ -222,6 +225,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
plotly_graph_objects,
)

@app.get("/api/studies/<study_id:int>/param_importances")
Expand Down
134 changes: 134 additions & 0 deletions optuna_dashboard/_custom_plot_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING
import uuid

from optuna import Study


if TYPE_CHECKING:
from typing import Any

from optuna.storages import BaseStorage
import plotly.graph_objs as go


SYSTEM_ATTR_PLOT_DATA = "dashboard:plot_data:"
SYSTEM_ATTR_MAX_LENGTH = 2045


def save_plotly_graph_object(
study: Study, figure: go.Figure, *, graph_object_id: str | None = None
) -> str:
"""Save the user-defined plotly's graph object to the study.
Example:
.. code-block:: python
import optuna
from optuna_dashboard import save_plotly_graph_object
def objective(trial):
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y
study = optuna.create_study()
study.optimize(objective, n_trials=100)
figure = optuna.visualization.plot_optimization_history(study)
save_plotly_graph_object(study, figure)
Args:
study:
Target study object.
plot_data:
The plotly's graph object to save.
graph_object_id:
Unique identifier of the graph object. If specified, the graph object is overwritten.
This must be a valid HTML id attribute value.
Returns:
The graph object ID.
"""
if graph_object_id is not None and not is_valid_graph_object_id(graph_object_id):
raise ValueError("graph_object_id must be a valid HTML id attribute value.")

storage = study._storage
study_id = study._study_id

graph_object_id = graph_object_id or str(uuid.uuid4())
key = SYSTEM_ATTR_PLOT_DATA + graph_object_id + ":"
plot_data_json_str = figure.to_json()
save_graph_object_json(storage, study_id, key, plot_data_json_str)
return graph_object_id


def save_graph_object_json(
storage: BaseStorage, study_id: int, key_prefix: str, plot_data_json_str: str
) -> None:
plot_data_system_attrs = split_plot_data(plot_data_json_str, key_prefix)
for k, v in plot_data_system_attrs.items():
storage.set_study_system_attr(study_id, k, v)

# Clear previous graph object attributes
study_system_attrs = storage.get_study_system_attrs(study_id)
all_plot_data_system_attrs = [k for k in study_system_attrs if k.startswith(key_prefix)]
if len(all_plot_data_system_attrs) > len(plot_data_system_attrs):
for i in range(len(plot_data_system_attrs), len(all_plot_data_system_attrs)):
storage.set_study_system_attr(study_id, f"{key_prefix}{i}", "")


def list_graph_object_ids(system_attrs: dict[str, Any]) -> list[str]:
titles = set()
for key in system_attrs:
if not key.startswith(SYSTEM_ATTR_PLOT_DATA):
continue

s = key.split(":", maxsplit=2) # e.g. ["dashboard", "plot_data", "Optimization History:1"]
if len(s) != 3:
continue
# Please note that title may contain ":".
title = s[2].rsplit(":", maxsplit=1)[0]
titles.add(title)
return list(titles)


def get_plotly_graph_objects(system_attrs: dict[str, Any]) -> dict[str, str]:
graph_objects = {}
for title in list_graph_object_ids(system_attrs):
key_prefix = SYSTEM_ATTR_PLOT_DATA + title + ":"
plot_data_attrs = {k: v for k, v in system_attrs.items() if k.startswith(key_prefix)}
graph_objects[title] = concat_plot_data(plot_data_attrs, key_prefix)
return graph_objects


def split_plot_data(plot_data_str: str, key_prefix: str) -> dict[str, str]:
plot_data_len = len(plot_data_str)
attrs = {}
for i in range(math.ceil(plot_data_len / SYSTEM_ATTR_MAX_LENGTH)):
start = i * SYSTEM_ATTR_MAX_LENGTH
end = min((i + 1) * SYSTEM_ATTR_MAX_LENGTH, plot_data_len)
attrs[f"{key_prefix}{i}"] = plot_data_str[start:end]
return attrs


def concat_plot_data(plot_data_attrs: dict[str, str], key_prefix: str) -> str:
return "".join(plot_data_attrs[f"{key_prefix}{i}"] for i in range(len(plot_data_attrs)))


def is_valid_graph_object_id(graph_object_id: str) -> bool:
if len(graph_object_id) == 0:
return False

# Can only contain letters [A-Za-z], numbers [0-9], hyphens ("-"), underscores ("_"),
# colons, and periods.
if not all(
"a" <= c <= "z" or "A" <= c <= "Z" or "0" <= c <= "9" or c in ("-", "_", ":", ".")
for c in graph_object_id[1:]
):
return False
# Unlike HTML id attribute, graph object id can begin with a letter [A-Za-z]
return True
5 changes: 5 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def serialize_study_detail(
union: list[tuple[str, BaseDistribution]],
union_user_attrs: list[tuple[str, bool]],
has_intermediate_values: bool,
plotly_graph_objects: dict[str, str],
) -> dict[str, Any]:
serialized: dict[str, Any] = {
"name": summary.study_name,
Expand Down Expand Up @@ -162,6 +163,10 @@ def serialize_study_detail(
serialized["form_widgets"] = form_widgets
if serialized["is_preferential"]:
serialized["preference_history"] = serialize_preference_history(system_attrs)
serialized["plotly_graph_objects"] = [
{"id": id_, "graph_object": graph_object}
for id_, graph_object in plotly_graph_objects.items()
]
return serialized


Expand Down
2 changes: 2 additions & 0 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ interface StudyDetailResponse {
objective_names?: string[]
form_widgets?: FormWidgets
preference_history?: PreferenceHistoryResponce[]
plotly_graph_objects: PlotlyGraphObject[]
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -131,6 +132,7 @@ export const getStudyDetailAPI = (
preference_history: res.data.preference_history?.map(
convertPreferenceHistory
),
plotly_graph_objects: res.data.plotly_graph_objects,
}
})
}
Expand Down
11 changes: 11 additions & 0 deletions optuna_dashboard/ts/components/StudyHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { GraphIntermediateValues } from "./GraphIntermediateValues"
import Grid2 from "@mui/material/Unstable_Grid2"
import { DataGrid, DataGridColumn } from "./DataGrid"
import { GraphHyperparameterImportance } from "./GraphHyperparameterImportances"
import { UserDefinedPlot } from "./UserDefinedPlot"
import { BestTrialsCard } from "./BestTrialsCard"
import {
useStudyDetailValue,
Expand Down Expand Up @@ -124,6 +125,16 @@ export const StudyHistory: FC<{ studyId: number }> = ({ studyId }) => {
<Grid2 xs={6}>
<GraphTimeline study={studyDetail} />
</Grid2>
{studyDetail !== null &&
studyDetail.plotly_graph_objects.map((go) => (
<Grid2 xs={6} key={go.id}>
<Card>
<CardContent>
<UserDefinedPlot graphObject={go} />
</CardContent>
</Card>
</Grid2>
))}
<Grid2 xs={6} spacing={2}>
<BestTrialsCard studyDetail={studyDetail} />
</Grid2>
Expand Down
21 changes: 21 additions & 0 deletions optuna_dashboard/ts/components/UserDefinedPlot.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect } from "react"
import { Box } from "@mui/material"

export const UserDefinedPlot: FC<{
graphObject: PlotlyGraphObject
}> = ({ graphObject }) => {
const plotDomId = `user-defined-plot:${graphObject.id}`

useEffect(() => {
try {
const parsed = JSON.parse(graphObject.graph_object)
plotly.react(plotDomId, parsed.data, parsed.layout)
} catch (e) {
// Avoid to crash the whole page when given invalid grpah objects.
console.error(e)
}
}, [graphObject])

return <Box id={plotDomId} sx={{ height: "450px" }} />
}
6 changes: 6 additions & 0 deletions optuna_dashboard/ts/types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ type FormWidgets =
widgets: UserAttrFormWidget[]
}

type PlotlyGraphObject = {
id: string
graph_object: string
}

type StudyDetail = {
id: number
name: string
Expand All @@ -199,6 +204,7 @@ type StudyDetail = {
objective_names?: string[]
form_widgets?: FormWidgets
preference_history?: PreferenceHistory[]
plotly_graph_objects: PlotlyGraphObject[]
}

type StudyDetails = {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ docs = [

test = [
"coverage",
"plotly",
"pytest",
"moto[s3]",
]
Expand Down
86 changes: 86 additions & 0 deletions python_tests/test_custom_plot_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import optuna
from optuna_dashboard import _custom_plot_data as custom_plot_data
from optuna_dashboard import save_plotly_graph_object
import pytest


def get_dummy_study() -> optuna.Study:
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y

study = optuna.create_study()
optuna.logging.set_verbosity(optuna.logging.ERROR)
study.optimize(objective, n_trials=100)
return study


def test_save_plotly_graph_object() -> None:
# Save history plot
dummy_study = get_dummy_study()
plot_data = optuna.visualization.plot_optimization_history(dummy_study)
graph_object_id = save_plotly_graph_object(dummy_study, plot_data)

study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id)
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs)
assert len(plot_data_dict) == 1
assert plot_data_dict[graph_object_id] == plot_data.to_json()

# Save parallel coordinate plot
plot_data = optuna.visualization.plot_parallel_coordinate(dummy_study)
graph_object_id = save_plotly_graph_object(dummy_study, plot_data)

study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id)
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs)
assert len(plot_data_dict) == 2
assert plot_data_dict[graph_object_id] == plot_data.to_json()


def test_update_plotly_graph_object() -> None:
# Save history plot
dummy_study = get_dummy_study()
plot_data = optuna.visualization.plot_optimization_history(dummy_study)
graph_object_id = save_plotly_graph_object(dummy_study, plot_data)

study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id)
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs)
assert len(plot_data_dict) == 1
assert plot_data_dict[graph_object_id] == plot_data.to_json()

# Save parallel coordinate plot
plot_data = optuna.visualization.plot_parallel_coordinate(dummy_study)
graph_object_id = save_plotly_graph_object(
dummy_study, plot_data, graph_object_id=graph_object_id
)

study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id)
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs)
assert len(plot_data_dict) == 1
assert plot_data_dict[graph_object_id] == plot_data.to_json()


@pytest.mark.parametrize(
"name",
[
"0",
"a",
"a1-:_.",
],
)
def test_is_valid_graph_object_id(name: str) -> None:
assert custom_plot_data.is_valid_graph_object_id(name)


@pytest.mark.parametrize(
"name",
[
"a,",
"a b",
"aあいうえお",
],
)
def test_is_invalid_graph_object_id(name: str) -> None:
assert not custom_plot_data.is_valid_graph_object_id(name)
4 changes: 2 additions & 2 deletions python_tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_get_study_detail_is_preferential() -> None:
assert len(study_summaries) == 1

study_summary = study_summaries[0]
study_detail = serialize_study_detail(study_summary, [], study.trials, [], [], [], False)
study_detail = serialize_study_detail(study_summary, [], study.trials, [], [], [], False, {})
assert study_detail["is_preferential"]


Expand All @@ -40,7 +40,7 @@ def test_get_study_detail_is_not_preferential() -> None:
assert len(study_summaries) == 1

study_summary = study_summaries[0]
study_detail = serialize_study_detail(study_summary, [], study.trials, [], [], [], False)
study_detail = serialize_study_detail(study_summary, [], study.trials, [], [], [], False, {})
assert not study_detail["is_preferential"]


Expand Down

0 comments on commit 94ef540

Please sign in to comment.