Skip to content

Commit

Permalink
[dynamo exporter] Support string in dynamic_shapes (#1639)
Browse files Browse the repository at this point in the history
## Describe your changes

With pytorch/pytorch#146321,
`torch.onnx.export(.., dynamo=True)` now can support string in
dynamic_shapes, which fits better with Olive driven with configuration.

Major changes:
- Add support for string in dynamic_shapes
- Move dynamic_shapes pre-process to io_config.py (like dynamic_axes)
- Get rid of the post-process of making [str, int, int] ->
torch.export.Dim(str, max=int, min=int). `torch.onn.export(...,
dynamo=True)` can now take string.
- Leverage Optimum to auto-generate dynamic_shapes when Optimum models
is requested.
    - KV cache support

Pitfall:
- When dynamic_shapes targets kwargs, both of them need to follow the
order of model.forward signature. onnx/conversion.py provides naive
approach to sort them, but users should be aware of this.

Minor changes:
- Move onnxscript (released) to the official requirement.txt
- dynamic_shapes with string is supported since torch 2.7

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [x] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
titaiwangms authored Feb 21, 2025
1 parent 51d2c8a commit ba7e187
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 140 deletions.
30 changes: 29 additions & 1 deletion olive/common/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from pathlib import Path
from types import FunctionType, MethodType
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import yaml

Expand Down Expand Up @@ -358,3 +358,31 @@ def convert_configs_to_dicts(config: Any) -> Any:
if isinstance(config, list):
return [convert_configs_to_dicts(v) for v in config]
return config


def get_the_flattened_and_tree_spec(
dynamic_shapes: Union[Dict[str, Any], List[Any]], leave_is_str: bool = False
) -> Tuple[List[Any], Any]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree."""
# More info: https://github.com/pytorch/pytorch/blob/48203bec636692e1a9140fe7f23ba1323b19550d/torch/utils/_pytree.py#L985
from torch.utils import _pytree

def is_axes_with_str_key(x) -> bool:
# axes can be either a dict or a list/tuple
# dict: {str: str}
# list/tuple: [str]
return (
isinstance(x, dict)
and all(isinstance(k, str) and (v is None or isinstance(v, (str, int))) for k, v in x.items())
) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x))

def is_axes_with_int_key(x) -> bool:
# axes can be either a dict or a list/tuple
# dict: {int: str}
# list/tuple: [str]
return (
isinstance(x, dict)
and all(isinstance(k, int) and (v is None or isinstance(v, (str, int))) for k, v in x.items())
) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x))

return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes_with_str_key if leave_is_str else is_axes_with_int_key)
59 changes: 54 additions & 5 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import logging
from itertools import chain
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from olive.common.hf.mlflow import get_pretrained_name_or_path
from olive.common.hf.peft import is_peft_model
Expand Down Expand Up @@ -124,10 +124,59 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", **
for axis, axis_name in value.items():
if axis_name == "past_sequence_length + 1":
value[axis] = "past_sequence_length + sequence_length"
# NOTE: Due to the complexity of dynamic_shapes, we don't provide it here.
# torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and
# dynamic_axes, so we don't need to provide dynamic shapes here.
return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}
# dynamic_shapes should follow input order and format
dynamic_shapes = _unflatten_past_key_values_with_check(inputs)
return {
"input_names": input_names,
"output_names": output_names,
"dynamic_axes": dynamic_axes,
"dynamic_shapes": dynamic_shapes,
}


def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]:
max_idx = -1
past_key_value_count = 0 # Track number of key-value pairs

# Find the max index for generating unflatten past_key_values later
# and record the total number of past_key_values entries for validation
for input_name in flattened_inputs:
if input_name.startswith("past_key_values"):
# From Optimum: past_key_values.0.key, past_key_values.0.value,
# past_key_values.1.key, past_key_values.1.value, ...
idx = int(input_name.split(".")[1])
max_idx = max(max_idx, idx)
past_key_value_count += 1

# Check if we have exactly 2 * (max_idx + 1) key-value pairs
expected_count = 2 * (max_idx + 1)
if past_key_value_count != expected_count or past_key_value_count % 2 != 0:
logger.warning(
"Expected %d past_key_values entries, but found %d from Optimum inputs."
"Giving up generating dynamic_shapes from Optimum inputs."
"Olive will use dynamic_axes instead.",
expected_count,
past_key_value_count,
)
return {}
# No past_key_values found
if max_idx == -1:
return flattened_inputs
# Keep all inputs except past_key_values
unflattened = {
input_name: dynamic_shapes
for input_name, dynamic_shapes in flattened_inputs.items()
if not input_name.startswith("past_key_values")
}
# Based on Optimum's implementation:
# https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
# past_key_values is a list of lists, and it locates at the end of the input list/dict
# Generate the past_key_values list using the max index
unflattened["past_key_values"] = [
[flattened_inputs[f"past_key_values.{idx}.key"], flattened_inputs[f"past_key_values.{idx}.value"]]
for idx in range(max_idx + 1)
]
return unflattened


def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]:
Expand Down
35 changes: 25 additions & 10 deletions olive/model/config/io_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Union

from olive.common.config_utils import ConfigBase
from olive.common.config_utils import ConfigBase, get_the_flattened_and_tree_spec
from olive.common.hf.wrapper import ModelWrapper
from olive.common.pydantic_v1 import validator
from olive.model.config.kv_cache_config import KVCacheConfig
Expand All @@ -23,10 +23,10 @@ class IoConfig(ConfigBase):
"images": { "0": "batch", "1": "height", "2": "width", "3": "channels" }
},
"dynamic_shapes": {
"clip_input": { "0": ["batch", 1, 512], "1": ["channels", 0, 3],
"2": ["height", 0, 512], "3": ["width", 0, 512] },
"images": { "0": ["batch", 1, 512], "1": ["height", 0, 512],
"2": ["width", 0, 512], "3": ["channels", 0, 3] }
"clip_input": { "0": "batch", "1": "channels",
"2": "height", "3": "width" },
"images": { "0": "batch", "1": "height",
"2": "width", "3": "channels" }
},
"kv_cache": None
}
Expand All @@ -40,11 +40,9 @@ class IoConfig(ConfigBase):
output_shapes: List[List[int]] = None
output_types: List[str] = None
dynamic_axes: Dict[str, Dict[int, str]] = None
# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
# NOTE: JSON does not support torch.export.Dim, so we use List[str, int, int] here.
# for example, {"input_ids": {0: torch.export.Dim("batch", min=2, max=1024)}}
# -> {"input_ids": {0: ["batch", 2, 1024]}}
# dynamic_shapes is different from dynamic_axes, it is nested.
# We need to post-process its keys to int under onnx/conversion.py
# for example, {"input_ids": {"0": "batch"}}
dynamic_shapes: Union[List[Any], Dict[str, Any]] = None
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference
# even though we want the dimension to be a constant int.
Expand Down Expand Up @@ -88,6 +86,20 @@ def convert_dynamic_axes(cls, v):
dynamic_axes[k] = {int(kk): vv for kk, vv in value.items()}
return dynamic_axes

@validator("dynamic_shapes")
def convert_dynamic_shapes(cls, v):
if not v:
return v

flattened, tree_spec = get_the_flattened_and_tree_spec(v, leave_is_str=True)
new_flattened = []
for axes in flattened:
if isinstance(axes, dict):
new_flattened.append({int(kk): vv for kk, vv in axes.items()})
else:
new_flattened.append(axes)
return tree_spec.unflatten(new_flattened)

@validator("string_to_int_dim_params")
def check_string_to_int_dim_params(cls, v):
if not v:
Expand Down Expand Up @@ -172,6 +184,8 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig):
output_names = kv_cache_config.get_output_names()
dynamic_axes = deepcopy(io_config.dynamic_axes or {})
dynamic_axes.update(kv_cache_config.get_dynamic_axes())
dynamic_shapes = deepcopy(io_config.dynamic_shapes or {})
dynamic_shapes.update(kv_cache_config.get_dynamic_shapes())
return IoConfig(
input_names=(io_config.input_names or []) + kv_names,
input_shapes=(io_config.input_shapes or []) + kv_shapes,
Expand All @@ -180,6 +194,7 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig):
output_shapes=io_config.output_shapes, # ignore kv_cache output shapes
output_types=io_config.output_types, # ignore kv_cache output types
dynamic_axes=dynamic_axes,
dynamic_shapes=dynamic_shapes,
kv_cache=kv_cache_config,
)

Expand Down
8 changes: 8 additions & 0 deletions olive/model/config/kv_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,11 @@ def get_dynamic_axes(self):
for present_name in self.get_ort_present_kv_names():
dynamic_axis[present_name] = self.present_kv_dynamic_axis
return dynamic_axis

def get_dynamic_shapes(self):
dynamic_shapes = {}
past_kv_names = self.get_ort_past_kv_names()
dynamic_shapes["past_key_values"] = [
[self.past_kv_dynamic_axis, self.past_kv_dynamic_axis] for _ in range(0, len(past_kv_names), 2)
]
return dynamic_shapes
132 changes: 41 additions & 91 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import collections
import inspect
import logging
import multiprocessing
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Dict, Optional, Tuple, Type, Union

import onnx
import torch
from packaging import version

from olive.common.config_utils import validate_config
from olive.common.config_utils import get_the_flattened_and_tree_spec, validate_config
from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device, tensor_data_to_dtype
from olive.hardware import AcceleratorSpec
from olive.model import (
Expand Down Expand Up @@ -214,15 +216,20 @@ def _export_pytorch_model(
if config.use_dynamo_exporter:
# Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0
torch_version = version.parse(torch.__version__).release
if torch_version < version.parse("2.7.0").release and io_config.dynamic_shapes is not None:
logger.warning(
"Dynamic shape support in torch.onnx.export(..., dynamo=True) requires "
"PyTorch version 2.7.0 or later. "
"Please upgrade to PyTorch 2.7.0 or newer if you need dynamic shapes.",
)
# The "legacy dynamo" is the torch.onnx_dynamo_export API
legacy_dynamo_supported_version = version.parse("2.2.0").release
# The new "dynamo" api is torch.onnx.export with dynamo=True
# TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive
dynamo_supported_version = version.parse("2.6.0").release
dynamo_supported_version = version.parse("2.7.0").release
if torch_version < legacy_dynamo_supported_version:
raise ImportError(
f"torch.onnx.dynamo_export is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.5.0 or above."
f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.7.0 or above."
)
from torch._dynamo import config as dynamo_config

Expand All @@ -246,8 +253,9 @@ def _export_pytorch_model(
# NOTE: Usually validation is done in io_config.py, but because
# dynamic_shapes has nested complexity, and it can't be validated multiple
# times like others, we validate it here.
io_config.dynamic_shapes = _validate_dynamic_shapes(io_config.dynamic_shapes, the_input)
io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes)
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
io_config.dynamic_shapes, the_input, pytorch_model
)

# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
Expand Down Expand Up @@ -582,7 +590,7 @@ def _run_for_config(
return model_proto_to_olive_model(converted_model_proto, output_model_path, config)


def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):
def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model):
"""Validate dynamic_shapes.
This function validates two things:
Expand All @@ -601,88 +609,30 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):

from torch.utils import _pytree

def is_dict_axes(x) -> bool:
return isinstance(x, dict) and all(
isinstance(key, str)
and len(key) == 1
and isinstance(value, list)
and len(value) == 3
and isinstance(value[0], str)
and isinstance(value[1], int)
and isinstance(value[2], int)
for key, value in x.items()
)

flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes)
new_dynamic_shapes = []
for axes in flat_dynamic_shapes:
if axes is None:
new_dynamic_shapes.append(axes)
continue
new_axes = {}
for axis, dynamic_shape in axes.items():
new_axes[int(axis)] = dynamic_shape
new_dynamic_shapes.append(new_axes)

_, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes)
return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)


def _convert_dynamic_shapes_to_torch_export_dims(
dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]]
) -> Dict[str, Dict[int, torch.export.Dim]]:
"""Convert dynamic_shapes to torch export dims.
flat_dynamic_shapes, _ = get_the_flattened_and_tree_spec(dynamic_shapes)

torch.onnx.export takes the exported program (fx graph) from
torch.export.export, which requires the dynamic_shapes to be in the format
of using torch.export.Dim(name, min=min, max=max). This function converts
the dynamic_shapes to the format that torch.export.export requires.
# dict: {axis: axis_name} -> {int(axis): axis_name}
# list/tuple: [axis_name] -> [axis_name]
new_dynamic_shapes = [
{int(k): v for k, v in axes.items()} if isinstance(axes, dict) else axes for axes in flat_dynamic_shapes
]

For a single axis:
# reconstruct the dynamic_shapes to the same tree structure as dummy_inputs
_, tree_structure = get_the_flattened_and_tree_spec(dummy_inputs, leave_is_str=False)
unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)

before: ["axis_name", min_value, max_value]
after: torch.export.Dim("axis_name", min=min_value, max=max_value)
# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
:param dynamic_shapes: the dynamic_shapes to convert
:return: the converted dynamic_shapes
"""
if dynamic_shapes is None:
return None

# If the axes has the same name, they should be the same torch.export.Dim
torch_export_dim_farm: Dict[str, torch.export.Dim] = {}

# dynamic_shapes follows input format, which could be nested
def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]:
if isinstance(data, dict):
for key, value in data.items():
data[key] = _from_tuple_to_dim(value)
# TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format?
# JSON foramt does not accept tuple.
elif isinstance(data, (tuple, list)):
# We assume the tuple/list is in the format of (name, min, max)
# TODO(titaiwang): This format could potentially be used as model
# inputs (would string be used as model input?)
if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int):
if data[0] in torch_export_dim_farm:
if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]:
return torch_export_dim_farm[data[0]]
raise ValueError(
f"Found different boundary for the same axis name {data[0]}. "
f"Previous min: {torch_export_dim_farm[data[0]].min} and "
f"max: {torch_export_dim_farm[data[0]].max}. "
f"Current min: {data[1]} and max: {data[2]}."
)
dim = torch.export.Dim(data[0], min=data[1], max=data[2])
torch_export_dim_farm[data[0]] = dim
return dim
if isinstance(data, tuple):
return tuple(_from_tuple_to_dim(item) for item in data)
if isinstance(data, list):
return [_from_tuple_to_dim(item) for item in data]
return data

return _from_tuple_to_dim(dynamic_shapes)
# NOTE: dynamic_shapes need to follow the same model.forward signature when it's referring to kwargs.
if isinstance(unflatten_dynamic_shapes, dict):
param_order = list(inspect.signature(model.forward).parameters)
# Sort io_config.dynamic_shapes based on this order
unflatten_dynamic_shapes = collections.OrderedDict(
sorted(unflatten_dynamic_shapes.items(), key=lambda item: param_order.index(item[0]))
)
dummy_inputs = collections.OrderedDict(
sorted(dummy_inputs.items(), key=lambda item: param_order.index(item[0]))
)
# dummy_inputs is kwargs
return unflatten_dynamic_shapes, (), dummy_inputs
# If dynamic_shapes and dummy_inputs are both list/tuple, we don't need to do anything.
# dummy_inputs is args
return unflatten_dynamic_shapes, dummy_inputs, {}
Loading

0 comments on commit ba7e187

Please sign in to comment.