From e6b9a36e8efd793eb5b1878c79c203f09977eab8 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 19 Feb 2025 20:50:45 +0000 Subject: [PATCH 1/7] draft --- olive/passes/onnx/conversion.py | 42 ++++++++----------- requirements.txt | 1 + test/requirements-test.txt | 2 - test/unit_test/model/test_hf_model.py | 6 +-- test/unit_test/model/test_pytorch_model.py | 6 +-- test/unit_test/passes/onnx/test_conversion.py | 30 ++++++------- 6 files changed, 40 insertions(+), 47 deletions(-) diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index b766539f4..93bae4e52 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -247,7 +247,7 @@ def _export_pytorch_model( # dynamic_shapes has nested complexity, and it can't be validated multiple # times like others, we validate it here. io_config.dynamic_shapes = _validate_dynamic_shapes(io_config.dynamic_shapes, the_input) - io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes) + # io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes) # there might be multiple files created during export, so we need to track the dir # if there are other processes writing to the same dir, we might end up deleting files created by @@ -601,30 +601,24 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs): from torch.utils import _pytree - def is_dict_axes(x) -> bool: - return isinstance(x, dict) and all( - isinstance(key, str) - and len(key) == 1 - and isinstance(value, list) - and len(value) == 3 - and isinstance(value[0], str) - and isinstance(value[1], int) - and isinstance(value[2], int) - for key, value in x.items() - ) + def is_axes(x) -> bool: + # axes can be either a dict or a list/tuple + # dict: {int: str} + # list/tuple: [str] + return ( + isinstance(x, dict) + and all(isinstance(k, str) and (v is None or isinstance(v, (str, int))) for k, v in x.items()) + ) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x)) + + flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes) + + # dict: {axis: axis_name} -> {int(axis): axis_name} + # list/tuple: [axis_name] -> [axis_name] + new_dynamic_shapes = [ + {int(k): v for k, v in axes.items()} if isinstance(axes, dict) else axes for axes in flat_dynamic_shapes + ] - flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes) - new_dynamic_shapes = [] - for axes in flat_dynamic_shapes: - if axes is None: - new_dynamic_shapes.append(axes) - continue - new_axes = {} - for axis, dynamic_shape in axes.items(): - new_axes[int(axis)] = dynamic_shape - new_dynamic_shapes.append(new_axes) - - _, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes) + _, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_axes) return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure) diff --git a/requirements.txt b/requirements.txt index 786904514..54b97185f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy onnx +onnxscript optuna pandas protobuf<4.0.0 diff --git a/test/requirements-test.txt b/test/requirements-test.txt index a0d24429f..e05fe8368 100644 --- a/test/requirements-test.txt +++ b/test/requirements-test.txt @@ -30,8 +30,6 @@ onnxconverter_common onnxmltools onnxoptimizer onnxruntime_extensions -# TODO(titaiwai): Add onnxscript to requirements.txt once it's released -onnxscript openvino==2023.2.0 optimum>=1.17.0 pandas diff --git a/test/unit_test/model/test_hf_model.py b/test/unit_test/model/test_hf_model.py index bd8ad17b6..05fa3cba1 100644 --- a/test/unit_test/model/test_hf_model.py +++ b/test/unit_test/model/test_hf_model.py @@ -145,9 +145,9 @@ def setup(self): "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, "dynamic_shapes": { - "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, + "input_ids": {"0": "batch_size", "1": ["seq_length", 1, 256]}, + "attention_mask": {"0": "batch_size", "1": "seq_length"}, + "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, } diff --git a/test/unit_test/model/test_pytorch_model.py b/test/unit_test/model/test_pytorch_model.py index bee910886..26cf7acb6 100644 --- a/test/unit_test/model/test_pytorch_model.py +++ b/test/unit_test/model/test_pytorch_model.py @@ -23,9 +23,9 @@ def io_config_fixture(): "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, "dynamic_shapes": { - "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, + "input_ids": {"0": "batch_size", "1": ["seq_length", 1, 256]}, + "attention_mask": {"0": "batch_size", "1": "seq_length"}, + "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, } diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index 92a384415..e47e475f1 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -189,10 +189,10 @@ def mock_onnx_export_func(*args, **kwargs): @pytest.mark.parametrize( "dynamic_shapes", [ - [{"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}], + [{"0": "axis_batch", "1": "x_axis"}, {"0": "axis_batch", "1": "y_axis"}], { - "input_x": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - "input_y": {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}, + "input_x": {"0": "axis_batch", "1": "x_axis"}, + "input_y": {"0": "axis_batch", "1": "y_axis"}, }, ], ) @@ -224,30 +224,30 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False): [ ( [ - {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}], - {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}}, + {"0": "axis_batch", "1": "x_axis"}, + [{"1": "x_axis"}, {"0": "axis_batch"}], + {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, None, ], ( - {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]}, - ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}), - {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}}, + {0: "axis_batch", 1: "x_axis"}, + ({1: "x_axis"}, {0: "axis_batch"}), + {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, None, ), _get_simulate_torch_float_tensor_inputs(return_tuple=True), ), ( { - "w": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - "x": [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}], - "y": {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}}, + "w": {"0": "axis_batch", "1": "x_axis"}, + "x": [{"1": "x_axis"}, {"0": "axis_batch"}], + "y": {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, "z": None, }, { - "w": {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]}, - "x": ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}), - "y": {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}}, + "w": {0: "axis_batch", 1: "x_axis"}, + "x": ({1: "x_axis"}, {0: "axis_batch"}), + "y": {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, "z": None, }, _get_simulate_torch_float_tensor_inputs(return_tuple=False), From 16fe4f0589091b450b414596fa9cd5cec0db0ef7 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 20 Feb 2025 00:50:21 +0000 Subject: [PATCH 2/7] fully support dynamic_shapes str mode --- olive/common/config_utils.py | 30 ++++- olive/common/hf/model_io.py | 60 ++++++++- olive/model/config/io_config.py | 35 ++++-- olive/model/config/kv_cache_config.py | 8 ++ olive/passes/onnx/conversion.py | 118 ++++++------------ olive/passes/pytorch/common.py | 16 ++- test/unit_test/common/test_hf.py | 10 +- test/unit_test/model/test_hf_model.py | 11 ++ test/unit_test/passes/onnx/test_conversion.py | 6 +- 9 files changed, 189 insertions(+), 105 deletions(-) diff --git a/olive/common/config_utils.py b/olive/common/config_utils.py index 21001c3c1..3695ca1a8 100644 --- a/olive/common/config_utils.py +++ b/olive/common/config_utils.py @@ -8,7 +8,7 @@ from functools import partial from pathlib import Path from types import FunctionType, MethodType -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import yaml @@ -358,3 +358,31 @@ def convert_configs_to_dicts(config: Any) -> Any: if isinstance(config, list): return [convert_configs_to_dicts(v) for v in config] return config + + +def get_the_flattened_and_tree_spec( + dynamic_shapes: Union[Dict[str, Any], List[Any]], leave_is_str: bool = False +) -> Tuple[List[Any], Any]: + """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree.""" + # More info: https://github.com/pytorch/pytorch/blob/48203bec636692e1a9140fe7f23ba1323b19550d/torch/utils/_pytree.py#L985 + from torch.utils import _pytree + + def is_axes_with_str_key(x) -> bool: + # axes can be either a dict or a list/tuple + # dict: {str: str} + # list/tuple: [str] + return ( + isinstance(x, dict) + and all(isinstance(k, str) and (v is None or isinstance(v, (str, int))) for k, v in x.items()) + ) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x)) + + def is_axes_with_int_key(x) -> bool: + # axes can be either a dict or a list/tuple + # dict: {int: str} + # list/tuple: [str] + return ( + isinstance(x, dict) + and all(isinstance(k, int) and (v is None or isinstance(v, (str, int))) for k, v in x.items()) + ) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x)) + + return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes_with_str_key if leave_is_str else is_axes_with_int_key) diff --git a/olive/common/hf/model_io.py b/olive/common/hf/model_io.py index 5a786458a..39c87ffdb 100644 --- a/olive/common/hf/model_io.py +++ b/olive/common/hf/model_io.py @@ -3,8 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging +import warnings from itertools import chain -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from olive.common.hf.mlflow import get_pretrained_name_or_path from olive.common.hf.peft import is_peft_model @@ -124,10 +125,59 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", ** for axis, axis_name in value.items(): if axis_name == "past_sequence_length + 1": value[axis] = "past_sequence_length + sequence_length" - # NOTE: Due to the complexity of dynamic_shapes, we don't provide it here. - # torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and - # dynamic_axes, so we don't need to provide dynamic shapes here. - return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes} + # dynamic_shapes should follow input order and format + dynamic_shapes = _unflatten_past_key_values_with_check(inputs) + return { + "input_names": input_names, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "dynamic_shapes": dynamic_shapes, + } + + +def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]: + max_idx = -1 + past_key_value_count = 0 # Track number of key-value pairs + + # Find the max index for generating unflatten past_key_values later + # and record the total number of past_key_values entries for validation + for input_name in flattened_inputs: + if input_name.startswith("past_key_values"): + # From Optimum: past_key_values.0.key, past_key_values.0.value, + # past_key_values.1.key, past_key_values.1.value, ... + idx = int(input_name.split(".")[1]) + max_idx = max(max_idx, idx) + past_key_value_count += 1 + + # Check if we have exactly 2 * (max_idx + 1) key-value pairs + expected_count = 2 * (max_idx + 1) + if past_key_value_count != expected_count or past_key_value_count % 2 != 0: + warnings.warn( + f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs." + "Giving up generating dynamic_shapes from Optimum inputs." + "Olive will use dynamic_axes instead.", + UserWarning, + stacklevel=3, + ) + return {} + # No past_key_values found + if max_idx == -1: + return flattened_inputs + # Keep all inputs except past_key_values + unflattened = { + input_name: dynamic_shapes + for input_name, dynamic_shapes in flattened_inputs.items() + if not input_name.startswith("past_key_values") + } + # Based on Optimum's implementation: + # https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629 + # past_key_values is a list of lists, and it locates at the end of the input list/dict + # Generate the past_key_values list using the max index + unflattened["past_key_values"] = [ + [flattened_inputs[f"past_key_values.{idx}.key"], flattened_inputs[f"past_key_values.{idx}.value"]] + for idx in range(max_idx + 1) + ] + return unflattened def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]: diff --git a/olive/model/config/io_config.py b/olive/model/config/io_config.py index 925609d4b..852115240 100644 --- a/olive/model/config/io_config.py +++ b/olive/model/config/io_config.py @@ -5,7 +5,7 @@ from copy import deepcopy from typing import Any, Dict, List, Union -from olive.common.config_utils import ConfigBase +from olive.common.config_utils import ConfigBase, get_the_flattened_and_tree_spec from olive.common.hf.wrapper import ModelWrapper from olive.common.pydantic_v1 import validator from olive.model.config.kv_cache_config import KVCacheConfig @@ -23,10 +23,10 @@ class IoConfig(ConfigBase): "images": { "0": "batch", "1": "height", "2": "width", "3": "channels" } }, "dynamic_shapes": { - "clip_input": { "0": ["batch", 1, 512], "1": ["channels", 0, 3], - "2": ["height", 0, 512], "3": ["width", 0, 512] }, - "images": { "0": ["batch", 1, 512], "1": ["height", 0, 512], - "2": ["width", 0, 512], "3": ["channels", 0, 3] } + "clip_input": { "0": "batch", "1": "channels", + "2": "height", "3": "width" }, + "images": { "0": "batch", "1": "height", + "2": "width", "3": "channels" } }, "kv_cache": None } @@ -40,11 +40,9 @@ class IoConfig(ConfigBase): output_shapes: List[List[int]] = None output_types: List[str] = None dynamic_axes: Dict[str, Dict[int, str]] = None - # Please check `dynamic_shapes` in torch.export.export - # https://pytorch.org/docs/stable/export.html#torch.export.export - # NOTE: JSON does not support torch.export.Dim, so we use List[str, int, int] here. - # for example, {"input_ids": {0: torch.export.Dim("batch", min=2, max=1024)}} - # -> {"input_ids": {0: ["batch", 2, 1024]}} + # dynamic_shapes is different from dynamic_axes, it is nested. + # We need to post-process its keys to int under onnx/conversion.py + # for example, {"input_ids": {"0": "batch"}} dynamic_shapes: Union[List[Any], Dict[str, Any]] = None # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference # even though we want the dimension to be a constant int. @@ -88,6 +86,20 @@ def convert_dynamic_axes(cls, v): dynamic_axes[k] = {int(kk): vv for kk, vv in value.items()} return dynamic_axes + @validator("dynamic_shapes") + def convert_dynamic_shapes(cls, v): + if not v: + return v + + flattened, tree_spec = get_the_flattened_and_tree_spec(v, leave_is_str=True) + new_flattened = [] + for axes in flattened: + if isinstance(axes, dict): + new_flattened.append({int(kk): vv for kk, vv in axes.items()}) + else: + new_flattened.append(axes) + return tree_spec.unflatten(new_flattened) + @validator("string_to_int_dim_params") def check_string_to_int_dim_params(cls, v): if not v: @@ -172,6 +184,8 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig): output_names = kv_cache_config.get_output_names() dynamic_axes = deepcopy(io_config.dynamic_axes or {}) dynamic_axes.update(kv_cache_config.get_dynamic_axes()) + dynamic_shapes = deepcopy(io_config.dynamic_shapes or {}) + dynamic_shapes.update(kv_cache_config.get_dynamic_shapes()) return IoConfig( input_names=(io_config.input_names or []) + kv_names, input_shapes=(io_config.input_shapes or []) + kv_shapes, @@ -180,6 +194,7 @@ def extend_io_config_with_kv_cache(io_config, kv_cache_config: KVCacheConfig): output_shapes=io_config.output_shapes, # ignore kv_cache output shapes output_types=io_config.output_types, # ignore kv_cache output types dynamic_axes=dynamic_axes, + dynamic_shapes=dynamic_shapes, kv_cache=kv_cache_config, ) diff --git a/olive/model/config/kv_cache_config.py b/olive/model/config/kv_cache_config.py index 98eb2f693..87af5e892 100644 --- a/olive/model/config/kv_cache_config.py +++ b/olive/model/config/kv_cache_config.py @@ -106,3 +106,11 @@ def get_dynamic_axes(self): for present_name in self.get_ort_present_kv_names(): dynamic_axis[present_name] = self.present_kv_dynamic_axis return dynamic_axis + + def get_dynamic_shapes(self): + dynamic_shapes = {} + past_kv_names = self.get_ort_past_kv_names() + dynamic_shapes["past_key_values"] = [ + [self.past_kv_dynamic_axis, self.past_kv_dynamic_axis] for _ in range(0, len(past_kv_names), 2) + ] + return dynamic_shapes diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index 93bae4e52..5a8d03767 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -2,18 +2,21 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import collections +import inspect import logging import multiprocessing import tempfile +import warnings from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type, Union import onnx import torch from packaging import version -from olive.common.config_utils import validate_config +from olive.common.config_utils import get_the_flattened_and_tree_spec, validate_config from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device, tensor_data_to_dtype from olive.hardware import AcceleratorSpec from olive.model import ( @@ -214,15 +217,22 @@ def _export_pytorch_model( if config.use_dynamo_exporter: # Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0 torch_version = version.parse(torch.__version__).release + if torch_version < version.parse("2.7.0").release: + warnings.warn( + "Dynamic shape support in torch.onnx.export(..., dynamo=True) requires " + "PyTorch version 2.7.0 or later. " + "Please upgrade to PyTorch 2.7.0 or newer if you need dynamic shapes.", + UserWarning, + stacklevel=3, + ) # The "legacy dynamo" is the torch.onnx_dynamo_export API legacy_dynamo_supported_version = version.parse("2.2.0").release # The new "dynamo" api is torch.onnx.export with dynamo=True - # TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive dynamo_supported_version = version.parse("2.6.0").release if torch_version < legacy_dynamo_supported_version: raise ImportError( - f"torch.onnx.dynamo_export is not available for torch version {torch_version}. " - "Please upgrade your torch version to 2.5.0 or above." + f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. " + "Please upgrade your torch version to 2.6.0 or above." ) from torch._dynamo import config as dynamo_config @@ -246,8 +256,9 @@ def _export_pytorch_model( # NOTE: Usually validation is done in io_config.py, but because # dynamic_shapes has nested complexity, and it can't be validated multiple # times like others, we validate it here. - io_config.dynamic_shapes = _validate_dynamic_shapes(io_config.dynamic_shapes, the_input) - # io_config.dynamic_shapes = _convert_dynamic_shapes_to_torch_export_dims(io_config.dynamic_shapes) + io_config.dynamic_shapes, dummy_kwargs = _validate_dynamic_shapes( + io_config.dynamic_shapes, the_input, pytorch_model + ) # there might be multiple files created during export, so we need to track the dir # if there are other processes writing to the same dir, we might end up deleting files created by @@ -269,7 +280,7 @@ def _export_pytorch_model( dynamo=True, fallback=True, optimize=config.optimize, - report=logger.isEnabledFor(logging.DEBUG), + report=True, ) assert onnx_program is not None onnx_model = onnx_program.model_proto @@ -582,7 +593,7 @@ def _run_for_config( return model_proto_to_olive_model(converted_model_proto, output_model_path, config) -def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs): +def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model): """Validate dynamic_shapes. This function validates two things: @@ -601,16 +612,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs): from torch.utils import _pytree - def is_axes(x) -> bool: - # axes can be either a dict or a list/tuple - # dict: {int: str} - # list/tuple: [str] - return ( - isinstance(x, dict) - and all(isinstance(k, str) and (v is None or isinstance(v, (str, int))) for k, v in x.items()) - ) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x)) - - flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes) + flat_dynamic_shapes, _ = get_the_flattened_and_tree_spec(dynamic_shapes) # dict: {axis: axis_name} -> {int(axis): axis_name} # list/tuple: [axis_name] -> [axis_name] @@ -618,65 +620,19 @@ def is_axes(x) -> bool: {int(k): v for k, v in axes.items()} if isinstance(axes, dict) else axes for axes in flat_dynamic_shapes ] - _, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_axes) - return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure) - - -def _convert_dynamic_shapes_to_torch_export_dims( - dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]] -) -> Dict[str, Dict[int, torch.export.Dim]]: - """Convert dynamic_shapes to torch export dims. - - torch.onnx.export takes the exported program (fx graph) from - torch.export.export, which requires the dynamic_shapes to be in the format - of using torch.export.Dim(name, min=min, max=max). This function converts - the dynamic_shapes to the format that torch.export.export requires. - - For a single axis: - - before: ["axis_name", min_value, max_value] - after: torch.export.Dim("axis_name", min=min_value, max=max_value) - - # Please check `dynamic_shapes` in torch.export.export - # https://pytorch.org/docs/stable/export.html#torch.export.export - - :param dynamic_shapes: the dynamic_shapes to convert - :return: the converted dynamic_shapes - """ - if dynamic_shapes is None: - return None - - # If the axes has the same name, they should be the same torch.export.Dim - torch_export_dim_farm: Dict[str, torch.export.Dim] = {} - - # dynamic_shapes follows input format, which could be nested - def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]: - if isinstance(data, dict): - for key, value in data.items(): - data[key] = _from_tuple_to_dim(value) - # TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format? - # JSON foramt does not accept tuple. - elif isinstance(data, (tuple, list)): - # We assume the tuple/list is in the format of (name, min, max) - # TODO(titaiwang): This format could potentially be used as model - # inputs (would string be used as model input?) - if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int): - if data[0] in torch_export_dim_farm: - if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]: - return torch_export_dim_farm[data[0]] - raise ValueError( - f"Found different boundary for the same axis name {data[0]}. " - f"Previous min: {torch_export_dim_farm[data[0]].min} and " - f"max: {torch_export_dim_farm[data[0]].max}. " - f"Current min: {data[1]} and max: {data[2]}." - ) - dim = torch.export.Dim(data[0], min=data[1], max=data[2]) - torch_export_dim_farm[data[0]] = dim - return dim - if isinstance(data, tuple): - return tuple(_from_tuple_to_dim(item) for item in data) - if isinstance(data, list): - return [_from_tuple_to_dim(item) for item in data] - return data - - return _from_tuple_to_dim(dynamic_shapes) + # reconstruct the dynamic_shapes to the same tree structure as dummy_inputs + _, tree_structure = get_the_flattened_and_tree_spec(dummy_inputs, leave_is_str=False) + unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure) + + # NOTE: dynamic_shapes need to follow the same model.forward signature when it's referring to kwargs. + # TODO(titaiwang): How to fix this ordering!? + if isinstance(unflatten_dynamic_shapes, dict): + param_order = list(inspect.signature(model.forward).parameters) + # Sort io_config.dynamic_shapes based on this order + unflatten_dynamic_shapes = collections.OrderedDict( + sorted(unflatten_dynamic_shapes.items(), key=lambda item: param_order.index(item[0])) + ) + dummy_inputs = collections.OrderedDict( + sorted(dummy_inputs.items(), key=lambda item: param_order.index(item[0])) + ) + return unflatten_dynamic_shapes, dummy_inputs diff --git a/olive/passes/pytorch/common.py b/olive/passes/pytorch/common.py index 246635c98..e9f860a2c 100644 --- a/olive/passes/pytorch/common.py +++ b/olive/passes/pytorch/common.py @@ -58,11 +58,7 @@ def inherit_pytorch_from_hf( for k, v in hf_io_config.get("dynamic_axes", {}).items() if not k.startswith(("present", "past_key_values")) } - else: - # TODO(titaiwang): fix this when we have a better way to handle dynamic_shapes - # If the dynamic_shapes is a list, we don't inherit it since - # we do not know the exact index of the past_key_values in the list - dynamic_shapes = {} + # kv cache will be handled by the kv_cache flag in io_config io_config = { "input_names": [i for i in hf_io_config.get("input_names", []) if not i.startswith("past_key_values")], @@ -84,6 +80,16 @@ def inherit_pytorch_from_hf( if io_config and not io_config.get("kv_cache") and model.task.endswith("-with-past"): io_config["kv_cache"] = True + # dynamic_shapes deals with kv_cache here. If kv_cache is False, + # we remove past_key_values from dynamic_shapes + if not io_config.get("kv_cache", False): + dynamic_shapes = { + k: v for k, v in hf_io_config.get("dynamic_shapes", {}).items() if not k.startswith("past_key_values") + } + else: + dynamic_shapes = hf_io_config.get("dynamic_shapes", {}) + io_config["dynamic_shapes"] = dynamic_shapes + return PyTorchModelHandler( model_path=model_path, model_file_format=model_file_format, diff --git a/test/unit_test/common/test_hf.py b/test/unit_test/common/test_hf.py index 64346ea80..0139d4d16 100644 --- a/test/unit_test/common/test_hf.py +++ b/test/unit_test/common/test_hf.py @@ -98,7 +98,7 @@ def test_get_model_io_config(with_past): model_name, task = get_model_name_task(with_past) model = load_model_from_task(task, model_name) io_config = get_model_io_config(model_name, task, model) - expected_keys = ["input_names", "output_names", "dynamic_axes"] + expected_keys = ["input_names", "output_names", "dynamic_axes", "dynamic_shapes"] assert set(io_config.keys()) == set(expected_keys) expected_input_names = ["input_ids", "attention_mask", "position_ids"] expected_output_names = ["logits"] @@ -109,3 +109,11 @@ def test_get_model_io_config(with_past): assert io_config["input_names"] == expected_input_names assert io_config["output_names"] == expected_output_names assert set(io_config["dynamic_axes"].keys()) == set(expected_input_names + expected_output_names) + # dynamic_shapes has nested past_key_values and only includes input names + if with_past: + assert ( + len(expected_input_names) + == len(io_config["dynamic_shapes"]) - 1 + len(io_config["dynamic_shapes"]["past_key_values"]) * 2 + ) + else: + assert len(expected_input_names) == len(io_config["dynamic_shapes"]) diff --git a/test/unit_test/model/test_hf_model.py b/test/unit_test/model/test_hf_model.py index 05fa3cba1..fd359f0e9 100644 --- a/test/unit_test/model/test_hf_model.py +++ b/test/unit_test/model/test_hf_model.py @@ -169,6 +169,17 @@ def test_dummy_input_with_kv_cache_dict(self): assert len(dummy_inputs) == 3 + 5 * 2 assert list(dummy_inputs["past_key_values.0.key"].shape) == [1, 4, 0, 8] + def test_dynamic_shapes_is_generated_when_kv_cache_is_true(self): + io_config = self.io_config + io_config["kv_cache"] = True + olive_model = HfModelHandler(model_path=self.model_name, task=self.task, io_config=io_config) + io_config = olive_model.io_config + assert "dynamic_shapes" in io_config + assert "past_key_values" in io_config["dynamic_shapes"] + assert len(io_config["dynamic_shapes"]["past_key_values"]) == 5 + assert len(io_config["dynamic_shapes"]["past_key_values"][0]) == 2 + assert io_config["dynamic_shapes"]["past_key_values"][0][0] == {0: "batch_size", 2: "past_sequence_length"} + def test_dict_io_config(self): olive_model = HfModelHandler(model_path=self.model_name, task=self.task, io_config=self.io_config) # get io config diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index e47e475f1..0a201c182 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -21,10 +21,12 @@ from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion -@pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") +# @pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") @pytest.mark.parametrize( ("input_model", "use_dynamo_exporter"), - [(get_pytorch_model(), True), (get_hf_model(), True), (get_pytorch_model(), False), (get_hf_model(), False)], + [ + (get_hf_model(), True), + ], ) def test_onnx_conversion_pass_with_exporters(input_model, use_dynamo_exporter, tmp_path): # setup From 2b688c63bbae8bdf9e01dd016ff5d1076bca96a2 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 20 Feb 2025 19:18:40 +0000 Subject: [PATCH 3/7] fix tests --- olive/passes/onnx/conversion.py | 13 +++--- test/unit_test/passes/onnx/test_conversion.py | 44 +++++++++++++------ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index 5a8d03767..2538840e6 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -228,11 +228,11 @@ def _export_pytorch_model( # The "legacy dynamo" is the torch.onnx_dynamo_export API legacy_dynamo_supported_version = version.parse("2.2.0").release # The new "dynamo" api is torch.onnx.export with dynamo=True - dynamo_supported_version = version.parse("2.6.0").release + dynamo_supported_version = version.parse("2.7.0").release if torch_version < legacy_dynamo_supported_version: raise ImportError( f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. " - "Please upgrade your torch version to 2.6.0 or above." + "Please upgrade your torch version to 2.7.0 or above." ) from torch._dynamo import config as dynamo_config @@ -256,7 +256,7 @@ def _export_pytorch_model( # NOTE: Usually validation is done in io_config.py, but because # dynamic_shapes has nested complexity, and it can't be validated multiple # times like others, we validate it here. - io_config.dynamic_shapes, dummy_kwargs = _validate_dynamic_shapes( + io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes( io_config.dynamic_shapes, the_input, pytorch_model ) @@ -625,7 +625,6 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model): unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure) # NOTE: dynamic_shapes need to follow the same model.forward signature when it's referring to kwargs. - # TODO(titaiwang): How to fix this ordering!? if isinstance(unflatten_dynamic_shapes, dict): param_order = list(inspect.signature(model.forward).parameters) # Sort io_config.dynamic_shapes based on this order @@ -635,4 +634,8 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model): dummy_inputs = collections.OrderedDict( sorted(dummy_inputs.items(), key=lambda item: param_order.index(item[0])) ) - return unflatten_dynamic_shapes, dummy_inputs + # dummy_inputs is kwargs + return unflatten_dynamic_shapes, (), dummy_inputs + # If dynami_shapes and dummy_inputs are both list/tuple, we don't need to do anything. + # dummy_inputs is args + return unflatten_dynamic_shapes, dummy_inputs, {} diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index 0a201c182..d3d21d8ec 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -21,11 +21,14 @@ from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion -# @pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") +@pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") @pytest.mark.parametrize( ("input_model", "use_dynamo_exporter"), [ (get_hf_model(), True), + (get_hf_model(), False), + (get_pytorch_model(), True), + (get_pytorch_model(), False), ], ) def test_onnx_conversion_pass_with_exporters(input_model, use_dynamo_exporter, tmp_path): @@ -191,10 +194,10 @@ def mock_onnx_export_func(*args, **kwargs): @pytest.mark.parametrize( "dynamic_shapes", [ - [{"0": "axis_batch", "1": "x_axis"}, {"0": "axis_batch", "1": "y_axis"}], + [{0: "axis_batch", 1: "x_axis"}, {0: "axis_batch", 1: "y_axis"}], { - "input_x": {"0": "axis_batch", "1": "x_axis"}, - "input_y": {"0": "axis_batch", "1": "y_axis"}, + "input_x": {0: "axis_batch", 1: "x_axis"}, + "input_y": {0: "axis_batch", 1: "y_axis"}, }, ], ) @@ -214,21 +217,32 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False): torch.ones(4), ) return { + "y": {"a": torch.zeros(5), "b": torch.ones(5)}, "w": torch.ones(5), "x": (torch.zeros(5), torch.ones(5)), - "y": {"a": torch.zeros(5), "b": torch.ones(5)}, "z": torch.ones(4), } +class SingnatureOnlyModel(torch.nn.Module): + def forward( + self, + w: torch.Tensor, + x: tuple[torch.Tensor, torch.Tensor], + y: dict[str, torch.Tensor], + z: torch.Tensor, + ): + pass + + @pytest.mark.parametrize( ("dynamic_shapes", "expected_dynamic_shapes", "inputs"), [ ( [ - {"0": "axis_batch", "1": "x_axis"}, - [{"1": "x_axis"}, {"0": "axis_batch"}], - {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, + {0: "axis_batch", 1: "x_axis"}, + [{1: "x_axis"}, {0: "axis_batch"}], + {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, None, ], ( @@ -240,10 +254,12 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False): _get_simulate_torch_float_tensor_inputs(return_tuple=True), ), ( + # We mess up the order of inputs and dynamic shapes from the model signature + # to test that the validation can order it back. { - "w": {"0": "axis_batch", "1": "x_axis"}, - "x": [{"1": "x_axis"}, {"0": "axis_batch"}], - "y": {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, + "y": {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, + "w": {0: "axis_batch", 1: "x_axis"}, + "x": [{1: "x_axis"}, {0: "axis_batch"}], "z": None, }, { @@ -257,8 +273,10 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False): ], ids=["in_nested_tuple_inputs", "in_nested_dict_format"], ) -def test___validate_dynamic_shapes_follow_input_format(dynamic_shapes, expected_dynamic_shapes, inputs): +def test___validate_dynamic_shapes_follow_input_format_and_follow_order_of_model_sig( + dynamic_shapes, expected_dynamic_shapes, inputs +): from olive.passes.onnx.conversion import _validate_dynamic_shapes - converted_dynamic_shapes = _validate_dynamic_shapes(dynamic_shapes, inputs) + converted_dynamic_shapes, _, _ = _validate_dynamic_shapes(dynamic_shapes, inputs, SingnatureOnlyModel()) assert converted_dynamic_shapes == expected_dynamic_shapes From 5a4e062aec69524d821eb37b1d327795d30b896b Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 20 Feb 2025 19:30:07 +0000 Subject: [PATCH 4/7] debug to report --- olive/passes/onnx/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index 2538840e6..f4e0ec71a 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -280,7 +280,7 @@ def _export_pytorch_model( dynamo=True, fallback=True, optimize=config.optimize, - report=True, + report=logger.isEnabledFor(logging.DEBUG), ) assert onnx_program is not None onnx_model = onnx_program.model_proto From 97649c8a78874ff61210db626b150ec9f912aa30 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 20 Feb 2025 19:43:01 +0000 Subject: [PATCH 5/7] fix typing --- test/unit_test/passes/onnx/test_conversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index d3d21d8ec..8270ba212 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -8,6 +8,7 @@ from itertools import chain from pathlib import Path from test.unit_test.utils import ONNX_MODEL_PATH, get_hf_model, get_onnx_model, get_pytorch_model, pytorch_model_loader +from types import Dict, Tuple from unittest.mock import patch import pytest @@ -228,8 +229,8 @@ class SingnatureOnlyModel(torch.nn.Module): def forward( self, w: torch.Tensor, - x: tuple[torch.Tensor, torch.Tensor], - y: dict[str, torch.Tensor], + x: Tuple[torch.Tensor, torch.Tensor], + y: Dict[str, torch.Tensor], z: torch.Tensor, ): pass From a8230a4ad9fd1cd6bad586f4ec5cc87eb3f3dec8 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 20 Feb 2025 22:13:21 +0000 Subject: [PATCH 6/7] fix ci - typing import --- test/unit_test/passes/onnx/test_conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index 8270ba212..84ccc3446 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -8,7 +8,7 @@ from itertools import chain from pathlib import Path from test.unit_test.utils import ONNX_MODEL_PATH, get_hf_model, get_onnx_model, get_pytorch_model, pytorch_model_loader -from types import Dict, Tuple +from typing import Dict, Tuple from unittest.mock import patch import pytest From d58a008e19ff1a168cd2eb5df75b2d110c4f8d46 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 21 Feb 2025 18:39:11 +0000 Subject: [PATCH 7/7] address reviews --- olive/common/hf/model_io.py | 9 ++++----- olive/passes/onnx/conversion.py | 9 +++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/olive/common/hf/model_io.py b/olive/common/hf/model_io.py index 39c87ffdb..b9fdcd990 100644 --- a/olive/common/hf/model_io.py +++ b/olive/common/hf/model_io.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -import warnings from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Optional @@ -152,12 +151,12 @@ def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> D # Check if we have exactly 2 * (max_idx + 1) key-value pairs expected_count = 2 * (max_idx + 1) if past_key_value_count != expected_count or past_key_value_count % 2 != 0: - warnings.warn( - f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs." + logger.warning( + "Expected %d past_key_values entries, but found %d from Optimum inputs." "Giving up generating dynamic_shapes from Optimum inputs." "Olive will use dynamic_axes instead.", - UserWarning, - stacklevel=3, + expected_count, + past_key_value_count, ) return {} # No past_key_values found diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index f4e0ec71a..98d12d3a3 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -7,7 +7,6 @@ import logging import multiprocessing import tempfile -import warnings from copy import deepcopy from pathlib import Path from typing import Dict, Optional, Tuple, Type, Union @@ -217,13 +216,11 @@ def _export_pytorch_model( if config.use_dynamo_exporter: # Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0 torch_version = version.parse(torch.__version__).release - if torch_version < version.parse("2.7.0").release: - warnings.warn( + if torch_version < version.parse("2.7.0").release and io_config.dynamic_shapes is not None: + logger.warning( "Dynamic shape support in torch.onnx.export(..., dynamo=True) requires " "PyTorch version 2.7.0 or later. " "Please upgrade to PyTorch 2.7.0 or newer if you need dynamic shapes.", - UserWarning, - stacklevel=3, ) # The "legacy dynamo" is the torch.onnx_dynamo_export API legacy_dynamo_supported_version = version.parse("2.2.0").release @@ -636,6 +633,6 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, model): ) # dummy_inputs is kwargs return unflatten_dynamic_shapes, (), dummy_inputs - # If dynami_shapes and dummy_inputs are both list/tuple, we don't need to do anything. + # If dynamic_shapes and dummy_inputs are both list/tuple, we don't need to do anything. # dummy_inputs is args return unflatten_dynamic_shapes, dummy_inputs, {}