Skip to content

Commit

Permalink
feat: Add support for custom (or offline) Mermaid.ink server and supp…
Browse files Browse the repository at this point in the history
…ort all parameters (#8799)

* compress graph data to support pako endpoint

* support mermaid.ink parameters and custom servers

* dont try to resolve conflicts with the github web ui...

* avoid double graph copy

* fixing typing, improving docstrings and release notes

* reverting type

* nit - force type checker no cache

* nit - force type checker no cache

---------

Co-authored-by: Ulises M <[email protected]>
Co-authored-by: Ulises M <[email protected]>
  • Loading branch information
3 people authored Feb 3, 2025
1 parent 503d275 commit f165212
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 27 deletions.
63 changes: 54 additions & 9 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

DEFAULT_MARSHALLER = YamlMarshaller()

# We use a generic type to annotate the return value of classmethods,
# We use a generic type to annotate the return value of class methods,
# so that static analyzers won't be confused when derived classes
# use those methods.
T = TypeVar("T", bound="PipelineBase")
Expand Down Expand Up @@ -619,31 +619,76 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di
}
return outputs

def show(self) -> None:
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
"""
If running in a Jupyter notebook, display an image representing this `Pipeline`.
Display an image representing this `Pipeline` in a Jupyter notebook.
This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
the notebook.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
:raises PipelineDrawingError:
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
"""
if is_in_jupyter():
from IPython.display import Image, display # type: ignore

image_data = _to_mermaid_image(self.graph)

image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)

def draw(self, path: Path) -> None:
def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
"""
Save an image representing this `Pipeline` to `path`.
Save an image representing this `Pipeline` to the specified file path.
This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
:param path:
The path to save the image to.
The file path where the generated image will be saved.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
:raises PipelineDrawingError:
If there is an issue with rendering or saving the image.
"""
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
image_data = _to_mermaid_image(self.graph)
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
Path(path).write_bytes(image_data)

def walk(self) -> Iterator[Tuple[str, Component]]:
Expand Down
133 changes: 119 additions & 14 deletions haystack/core/pipeline/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import json
import zlib
from typing import Any, Dict, Optional

import networkx # type:ignore
import requests
Expand Down Expand Up @@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
ARROWHEAD_MANDATORY = "-->"
ARROWHEAD_OPTIONAL = ".->"
MERMAID_STYLED_TEMPLATE = """
%%{{ init: {{'theme': 'neutral' }} }}%%
%%{{ init: {params} }}%%
graph TD;
Expand All @@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""


def _to_mermaid_image(graph: networkx.MultiDiGraph):
def _validate_mermaid_params(params: Dict[str, Any]) -> None:
"""
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
Validates and sets default values for Mermaid parameters.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details.
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
:raises ValueError:
If any parameter is invalid or does not match the expected format.
"""
valid_img_types = {"jpeg", "png", "webp"}
valid_themes = {"default", "neutral", "dark", "forest"}
valid_formats = {"img", "svg", "pdf"}

params.setdefault("format", "img")
params.setdefault("type", "png")
params.setdefault("theme", "neutral")

if params["format"] not in valid_formats:
raise ValueError(f"Invalid image format: {params['format']}. Valid options are: {valid_formats}.")

if params["format"] == "img" and params["type"] not in valid_img_types:
raise ValueError(f"Invalid image type: {params['type']}. Valid options are: {valid_img_types}.")

if params["theme"] not in valid_themes:
raise ValueError(f"Invalid theme: {params['theme']}. Valid options are: {valid_themes}.")

if "width" in params and not isinstance(params["width"], int):
raise ValueError("Width must be an integer.")
if "height" in params and not isinstance(params["height"], int):
raise ValueError("Height must be an integer.")

if "scale" in params and not 1 <= params["scale"] <= 3:
raise ValueError("Scale must be a number between 1 and 3.")
if "scale" in params and not ("width" in params or "height" in params):
raise ValueError("Scale is only allowed when width or height is set.")

if "bgColor" in params and not isinstance(params["bgColor"], str):
raise ValueError("Background color must be a string.")

# PDF specific parameters
if params["format"] == "pdf":
if "fit" in params and not isinstance(params["fit"], bool):
raise ValueError("Fit must be a boolean.")
if "paper" in params and not isinstance(params["paper"], str):
raise ValueError("Paper size must be a string (e.g., 'a4', 'a3').")
if "landscape" in params and not isinstance(params["landscape"], bool):
raise ValueError("Landscape must be a boolean.")
if "fit" in params and ("paper" in params or "landscape" in params):
logger.warning("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`.")


def _to_mermaid_image(
graph: networkx.MultiDiGraph, server_url: str = "https://mermaid.ink", params: Optional[dict] = None
) -> bytes:
"""
Renders a pipeline using a Mermaid server.
:param graph:
The graph to render as a Mermaid pipeline.
:param server_url:
Base URL of the Mermaid server (default: 'https://mermaid.ink').
:param params:
Dictionary of customization parameters. See `validate_mermaid_params` for valid keys.
:returns:
The image, SVG, or PDF data returned by the Mermaid server as bytes.
:raises ValueError:
If any parameter is invalid or does not match the expected format.
:raises PipelineDrawingError:
If there is an issue connecting to the Mermaid server or the server returns an error.
"""

if params is None:
params = {}

_validate_mermaid_params(params)

theme = params.get("theme")
init_params = json.dumps({"theme": theme})

# Copy the graph to avoid modifying the original
graph_styled = _to_mermaid_text(graph.copy())
graph_styled = _to_mermaid_text(graph.copy(), init_params)
json_string = json.dumps({"code": graph_styled})

# Uses the DEFLATE algorithm at the highest level for smallest size
compressor = zlib.compressobj(level=9)
# Compress the JSON string with zlib (RFC 1950)
compressor = zlib.compressobj(level=9, wbits=15)
compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush()
compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip()

url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png"
# Determine the correct endpoint
endpoint_format = params.get("format", "img") # Default to /img endpoint
if endpoint_format not in {"img", "svg", "pdf"}:
raise ValueError(f"Invalid format: {endpoint_format}. Valid options are 'img', 'svg', or 'pdf'.")

# Construct the URL without query parameters
url = f"{server_url}/{endpoint_format}/pako:{compressed_url_safe_base64}"

# Add query parameters adhering to mermaid.ink documentation
query_params = []
for key, value in params.items():
if key not in {"theme", "format"}: # Exclude theme (handled in init_params) and format (endpoint-specific)
if value is True:
query_params.append(f"{key}")
else:
query_params.append(f"{key}={value}")

if query_params:
url += "?" + "&".join(query_params)

logger.debug("Rendering graph at {url}", url=url)
try:
resp = requests.get(url, timeout=10)
if resp.status_code >= 400:
logger.warning(
"Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}",
"Failed to draw the pipeline: {server_url} returned status {status_code}",
server_url=server_url,
status_code=resp.status_code,
)
logger.info("Exact URL requested: {url}", url=url)
Expand All @@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):

except Exception as exc: # pylint: disable=broad-except
logger.warning(
"Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})", error=exc
"Failed to draw the pipeline: could not connect to {server_url} ({error})", server_url=server_url, error=exc
)
logger.info("Exact URL requested: {url}", url=url)
logger.warning("No pipeline diagram will be saved.")
raise PipelineDrawingError(
"There was an issue with https://mermaid.ink/, see the stacktrace for details."
) from exc
raise PipelineDrawingError(f"There was an issue with {server_url}, see the stacktrace for details.") from exc

return resp.content


def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
"""
Converts a Networkx graph into Mermaid syntax.
Expand Down Expand Up @@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
]
connections = "\n".join(connections_list + input_connections + output_connections)

graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)

return graph_styled
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---

features:
- |
Drawing pipelines, i.e.: calls to draw() or show(), can now be done using a custom Mermaid server and additional parameters. This allows for more flexibility in how pipelines are rendered. See Mermaid.ink's [documentation](https://github.com/jihchi/mermaid.ink) for more information on how to set up a custom server.
Loading

0 comments on commit f165212

Please sign in to comment.