Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo exporter] Support string in dynamic_shapes #1631

Closed
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion olive/common/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from pathlib import Path
from types import FunctionType, MethodType
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import yaml

Expand Down Expand Up @@ -358,3 +358,31 @@ def convert_configs_to_dicts(config: Any) -> Any:
if isinstance(config, list):
return [convert_configs_to_dicts(v) for v in config]
return config


def get_the_flattened_and_tree_spec(
dynamic_shapes: Union[Dict[str, Any], List[Any]], leave_is_str: bool = False
) -> Tuple[List[Any], Any]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree."""
# More info: https://github.com/pytorch/pytorch/blob/48203bec636692e1a9140fe7f23ba1323b19550d/torch/utils/_pytree.py#L985
from torch.utils import _pytree

def is_axes_with_str_key(x) -> bool:
# axes can be either a dict or a list/tuple
# dict: {str: str}
# list/tuple: [str]
return (
isinstance(x, dict)
and all(isinstance(k, str) and (v is None or isinstance(v, (str, int))) for k, v in x.items())
) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x))

def is_axes_with_int_key(x) -> bool:
# axes can be either a dict or a list/tuple
# dict: {int: str}
# list/tuple: [str]
return (
isinstance(x, dict)
and all(isinstance(k, int) and (v is None or isinstance(v, (str, int))) for k, v in x.items())
) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x))

return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes_with_str_key if leave_is_str else is_axes_with_int_key)
60 changes: 55 additions & 5 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
from itertools import chain
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from olive.common.hf.mlflow import get_pretrained_name_or_path
from olive.common.hf.peft import is_peft_model
Expand Down Expand Up @@ -124,10 +125,59 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", **
for axis, axis_name in value.items():
if axis_name == "past_sequence_length + 1":
value[axis] = "past_sequence_length + sequence_length"
# NOTE: Due to the complexity of dynamic_shapes, we don't provide it here.
# torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and
# dynamic_axes, so we don't need to provide dynamic shapes here.
return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}
# dynamic_shapes should follow input order and format
dynamic_shapes = _unflatten_past_key_values_with_check(inputs)
return {
"input_names": input_names,
"output_names": output_names,
"dynamic_axes": dynamic_axes,
"dynamic_shapes": dynamic_shapes,
}


def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]:
max_idx = -1
past_key_value_count = 0 # Track number of key-value pairs

# Find the max index for generating unflatten past_key_values later
# and record the total number of past_key_values entries for validation
for input_name in flattened_inputs:
if input_name.startswith("past_key_values"):
# From Optimum: past_key_values.0.key, past_key_values.0.value,
# past_key_values.1.key, past_key_values.1.value, ...
idx = int(input_name.split(".")[1])
max_idx = max(max_idx, idx)
past_key_value_count += 1

# Check if we have exactly 2 * (max_idx + 1) key-value pairs
expected_count = 2 * (max_idx + 1)
if past_key_value_count != expected_count or past_key_value_count % 2 != 0:
warnings.warn(
f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs."
"Giving up generating dynamic_shapes from Optimum inputs."
"Olive will use dynamic_axes instead.",
UserWarning,
stacklevel=3,
)
return {}
# No past_key_values found
if max_idx == -1:
return flattened_inputs
# Keep all inputs except past_key_values
unflattened = {
input_name: dynamic_shapes
for input_name, dynamic_shapes in flattened_inputs.items()
if not input_name.startswith("past_key_values")
}
# Based on Optimum's implementation:
# https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
# past_key_values is a list of lists, and it locates at the end of the input list/dict
# Generate the past_key_values list using the max index
unflattened["past_key_values"] = [
[flattened_inputs[f"past_key_values.{idx}.key"], flattened_inputs[f"past_key_values.{idx}.value"]]
for idx in range(max_idx + 1)
]
return unflattened


def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]:
Expand Down
35 changes: 25 additions & 10 deletions olive/model/config/io_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Union

from olive.common.config_utils import ConfigBase
from olive.common.config_utils import ConfigBase, get_the_flattened_and_tree_spec
from olive.common.hf.wrapper import ModelWrapper
from olive.common.pydantic_v1 import validator
from olive.model.config.kv_cache_config import KVCacheConfig
Expand All @@ -23,10 +23,10 @@ class IoConfig(ConfigBase):
"images": { "0": "batch", "1": "height", "2": "width", "3": "channels" }
},
"dynamic_shapes": {
"clip_input": { "0": ["batch", 1, 512], "1": ["channels", 0, 3],
"2": ["height", 0, 512], "3": ["width", 0, 512] },
"images": { "0": ["batch", 1, 512], "1": ["height", 0, 512],
"2": ["width", 0, 512], "3": ["channels", 0, 3] }
"clip_input": { "0": "batch", "1": "channels",
"2": "height", "3": "width" },
"images": { "0": "batch", "1": "height",
"2": "width", "3": "channels" }
},
"kv_cache": None
}
Expand All @@ -40,11 +40,9 @@ class IoConfig(ConfigBase):
output_shapes: List[List[int]] = None
output_types: List[str] = None
dynamic_axes: Dict[str, Dict[int, str]] = None
# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
# NOTE: JSON does not support torch.export.Dim, so we use List[str, int, int] here.
# for example, {"input_ids": {0: torch.export.Dim("batch", min=2, max=1024)}}
# -> {"input_ids": {0: ["batch", 2, 1024]}}
# dynamic_shapes is different from dynamic_axes, it is nested.
# We need to post-process its keys to int under onnx/conversion.py
# for example, {"input_ids": {"0": "batch"}}
dynamic_shapes: Union[List[Any], Dict[str, Any]] = None
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference
# even though we want the dimension to be a constant int.
Expand Down Expand Up @@ -88,6 +86,20 @@ def convert_dynamic_axes(cls, v):
dynamic_axes[k] = {int(kk): vv for kk, vv in value.items()}
return dynamic_axes

@validator("dynamic_shapes")
def convert_dynamic_shapes(cls, v):
if not v:
return v

flattened, tree_spec = get_the_flattened_and_tree_spec(v, leave_is_str=True)
new_flattened = []
for axes in flattened:
if isinstance(axes, dict):
new_flattened.append({int(kk): vv for kk, vv in axes.items()})
else:
new_flattened.append(axes)
return tree_spec.unflatten(new_flattened)

@validator("string_to_int_dim_params")
def check_string_to_int_dim_params(cls, v):
if not v:
Expand Down Expand Up @@ -172,6 +184,8 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig):
output_names = kv_cache_config.get_output_names()
dynamic_axes = deepcopy(io_config.dynamic_axes or {})
dynamic_axes.update(kv_cache_config.get_dynamic_axes())
dynamic_shapes = deepcopy(io_config.dynamic_shapes or {})
dynamic_shapes.update(kv_cache_config.get_dynamic_shapes())
return IoConfig(
input_names=(io_config.input_names or []) + kv_names,
input_shapes=(io_config.input_shapes or []) + kv_shapes,
Expand All @@ -180,6 +194,7 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig):
output_shapes=io_config.output_shapes, # ignore kv_cache output shapes
output_types=io_config.output_types, # ignore kv_cache output types
dynamic_axes=dynamic_axes,
dynamic_shapes=dynamic_shapes,
kv_cache=kv_cache_config,
)

Expand Down
8 changes: 8 additions & 0 deletions olive/model/config/kv_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,11 @@ def get_dynamic_axes(self):
for present_name in self.get_ort_present_kv_names():
dynamic_axis[present_name] = self.present_kv_dynamic_axis
return dynamic_axis

def get_dynamic_shapes(self):
dynamic_shapes = {}
past_kv_names = self.get_ort_past_kv_names()
dynamic_shapes["past_key_values"] = [
[self.past_kv_dynamic_axis, self.past_kv_dynamic_axis] for _ in range(0, len(past_kv_names), 2)
]
return dynamic_shapes
135 changes: 44 additions & 91 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import collections
import inspect
import logging
import multiprocessing
import tempfile
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Dict, Optional, Tuple, Type, Union

import onnx
import torch
from packaging import version

from olive.common.config_utils import validate_config
from olive.common.config_utils import get_the_flattened_and_tree_spec, validate_config
from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device, tensor_data_to_dtype
from olive.hardware import AcceleratorSpec
from olive.model import (
Expand Down Expand Up @@ -214,15 +217,22 @@ def _export_pytorch_model(
if config.use_dynamo_exporter:
# Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0
torch_version = version.parse(torch.__version__).release
if torch_version < version.parse("2.7.0").release:
warnings.warn(
"Dynamic shape support in torch.onnx.export(..., dynamo=True) requires "
"PyTorch version 2.7.0 or later. "
"Please upgrade to PyTorch 2.7.0 or newer if you need dynamic shapes.",
UserWarning,
stacklevel=3,
)
# The "legacy dynamo" is the torch.onnx_dynamo_export API
legacy_dynamo_supported_version = version.parse("2.2.0").release
# The new "dynamo" api is torch.onnx.export with dynamo=True
# TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive
dynamo_supported_version = version.parse("2.6.0").release
dynamo_supported_version = version.parse("2.7.0").release
if torch_version < legacy_dynamo_supported_version:
raise ImportError(
f"torch.onnx.dynamo_export is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.5.0 or above."
f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.7.0 or above."
)
from torch._dynamo import config as dynamo_config

Expand All @@ -246,8 +256,9 @@ def _export_pytorch_model(
# NOTE: Usually validation is done in io_config.py, but because
# dynamic_shapes has nested complexity, and it can't be validated multiple
# times like others, we validate it here.
io_config.dynamic_shapes = _validate_dynamic_shapes(io_config.dynamic_shapes, the_input)
io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes)
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
io_config.dynamic_shapes, the_input, pytorch_model
)

# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
Expand Down Expand Up @@ -582,7 +593,7 @@ def _run_for_config(
return model_proto_to_olive_model(converted_model_proto, output_model_path, config)


def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):
def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model):
"""Validate dynamic_shapes.

This function validates two things:
Expand All @@ -601,88 +612,30 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):

from torch.utils import _pytree

def is_dict_axes(x) -> bool:
return isinstance(x, dict) and all(
isinstance(key, str)
and len(key) == 1
and isinstance(value, list)
and len(value) == 3
and isinstance(value[0], str)
and isinstance(value[1], int)
and isinstance(value[2], int)
for key, value in x.items()
)

flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes)
new_dynamic_shapes = []
for axes in flat_dynamic_shapes:
if axes is None:
new_dynamic_shapes.append(axes)
continue
new_axes = {}
for axis, dynamic_shape in axes.items():
new_axes[int(axis)] = dynamic_shape
new_dynamic_shapes.append(new_axes)

_, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes)
return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)


def _convert_dynamic_shapes_to_torch_export_dims(
dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]]
) -> Dict[str, Dict[int, torch.export.Dim]]:
"""Convert dynamic_shapes to torch export dims.
flat_dynamic_shapes, _ = get_the_flattened_and_tree_spec(dynamic_shapes)

torch.onnx.export takes the exported program (fx graph) from
torch.export.export, which requires the dynamic_shapes to be in the format
of using torch.export.Dim(name, min=min, max=max). This function converts
the dynamic_shapes to the format that torch.export.export requires.
# dict: {axis: axis_name} -> {int(axis): axis_name}
# list/tuple: [axis_name] -> [axis_name]
new_dynamic_shapes = [
{int(k): v for k, v in axes.items()} if isinstance(axes, dict) else axes for axes in flat_dynamic_shapes
]

For a single axis:
# reconstruct the dynamic_shapes to the same tree structure as dummy_inputs
_, tree_structure = get_the_flattened_and_tree_spec(dummy_inputs, leave_is_str=False)
unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)

before: ["axis_name", min_value, max_value]
after: torch.export.Dim("axis_name", min=min_value, max=max_value)

# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export

:param dynamic_shapes: the dynamic_shapes to convert
:return: the converted dynamic_shapes
"""
if dynamic_shapes is None:
return None

# If the axes has the same name, they should be the same torch.export.Dim
torch_export_dim_farm: Dict[str, torch.export.Dim] = {}

# dynamic_shapes follows input format, which could be nested
def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]:
if isinstance(data, dict):
for key, value in data.items():
data[key] = _from_tuple_to_dim(value)
# TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format?
# JSON foramt does not accept tuple.
elif isinstance(data, (tuple, list)):
# We assume the tuple/list is in the format of (name, min, max)
# TODO(titaiwang): This format could potentially be used as model
# inputs (would string be used as model input?)
if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int):
if data[0] in torch_export_dim_farm:
if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]:
return torch_export_dim_farm[data[0]]
raise ValueError(
f"Found different boundary for the same axis name {data[0]}. "
f"Previous min: {torch_export_dim_farm[data[0]].min} and "
f"max: {torch_export_dim_farm[data[0]].max}. "
f"Current min: {data[1]} and max: {data[2]}."
)
dim = torch.export.Dim(data[0], min=data[1], max=data[2])
torch_export_dim_farm[data[0]] = dim
return dim
if isinstance(data, tuple):
return tuple(_from_tuple_to_dim(item) for item in data)
if isinstance(data, list):
return [_from_tuple_to_dim(item) for item in data]
return data

return _from_tuple_to_dim(dynamic_shapes)
# NOTE: dynamic_shapes need to follow the same model.forward signature when it's referring to kwargs.
if isinstance(unflatten_dynamic_shapes, dict):
param_order = list(inspect.signature(model.forward).parameters)
# Sort io_config.dynamic_shapes based on this order
unflatten_dynamic_shapes = collections.OrderedDict(
sorted(unflatten_dynamic_shapes.items(), key=lambda item: param_order.index(item[0]))
)
dummy_inputs = collections.OrderedDict(
sorted(dummy_inputs.items(), key=lambda item: param_order.index(item[0]))
)
# dummy_inputs is kwargs
return unflatten_dynamic_shapes, (), dummy_inputs
# If dynami_shapes and dummy_inputs are both list/tuple, we don't need to do anything.
# dummy_inputs is args
return unflatten_dynamic_shapes, dummy_inputs, {}
Loading
Loading