Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Results for Output v1_0 After ONNX Runtime Optimization (Flaky Test) #23143

Open
Thrsu opened this issue Dec 18, 2024 · 3 comments
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@Thrsu
Copy link

Thrsu commented Dec 18, 2024

Describe the issue

I encountered an issue where the results from the optimized ONNX model are inconsistent with the original unoptimized model. Specifically, the output v1_0 is inconsistent, and the discrepancy occurs intermittently (flaky test). This issue arises after applying different optimization levels (opt_level=0, 1, 2, 99), though it does not occur consistently in every test.

The following error is reported when comparing the results from the optimized and original models:

AssertionError:
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 5 / 5 (100%)
Max absolute difference: 0.79532735
Max relative difference: 1.
 x: array([6.910885e-310, 6.910885e-310, 6.910885e-310, 6.910885e-310,
           6.910885e-310])
 y: array([0.795327, 0.75308 , 0.59723 , 0.711406, 0.667502])

Could anyone tell me why the v1_0 output is inconsistent after applying optimizations? Specifically, I'd like to understand the cause of this intermittent discrepancy and whether there are optimizations that could be adjusted to improve the consistency of the results.

To reproduce

  1. Download the ONNX model.
  2. Run the below script:
import onnx
import onnxruntime as ort
import numpy as np
from onnxruntime.transformers import optimizer

model_path = "inconsis3.onnx"
optimized_model_path = f"./opt.onnx"
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
this_provider_list = ort.get_available_providers()

original_session = ort.InferenceSession(model_path, sess_options, providers=this_provider_list)
input_data = {"v0_0": np.random.rand(5).astype(np.float64)}
output_names = [output.name for output in original_session.get_outputs()]
original_result = original_session.run(output_names, input_data)

optimized_model = optimizer.optimize_model(model_path, opt_level=1, use_gpu=True)
optimized_model.save_model_to_file(optimized_model_path)
optimized_session = ort.InferenceSession(optimized_model_path, providers=this_provider_list)
optimized_model = onnx.load(optimized_model_path)
optimized_result = optimized_session.run(output_names, input_data)
for r1, r2 in zip(original_result, optimized_result):
    np.testing.assert_allclose(r1, r2, atol=1e-3, rtol=1e-3)

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

5c1b7cc

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Dec 18, 2024
@xadupre
Copy link
Member

xadupre commented Dec 19, 2024

It looks like a bug. It tried to simplify your model to only keep the failing part but it works fine.

Image

And with yours:

Image

However, optimization do not seem to be involved. Your model fails for me with everything disabled with the model below. The optimizer insert memcpy from/to host but it should not really needed. However, because some constants are very small, the optimizer may choose to process them on CPU and starts to switch many time. That does not change the fact it should work. I recommend using CPU only for this one (providers=["CPUExecutionProvider"]). It should be faster anyway on such small model.

Image

My code:

        import onnx
        import onnx.helper as oh
        import onnx.numpy_helper as onh
        import onnxruntime as ort
        import numpy as np
        from onnx.reference import ReferenceEvaluator

        proto_simple = oh.make_model(
            oh.make_graph(
                [
                    oh.make_node("Cast", ["v0_0"], ["x1"], to=onnx.TensorProto.FLOAT),
                    oh.make_node("Cast", ["v0_0"], ["x2"], to=onnx.TensorProto.FLOAT),
                    oh.make_node("Flatten", ["x1"], ["f1"], axis=0),
                    oh.make_node("Flatten", ["x2"], ["f2"], axis=0),
                    oh.make_node("Concat", ["f1", "i1"], ["c1"], axis=1),
                    oh.make_node("Concat", ["f2", "i2"], ["c2"], axis=1),
                    oh.make_node("Reshape", ["c1", "s1"], ["m1"]),
                    oh.make_node("Reshape", ["c2", "s2"], ["m2"]),
                    oh.make_node("MatMul", ["m1", "m2"], ["mm"]),
                    oh.make_node("Identity", ["mm"], ["output"]),
                ],
                "nd",
                [oh.make_tensor_value_info("v0_0", onnx.TensorProto.DOUBLE, [5])],
                [oh.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [2, 3, 3, 3])],
                [
                    onh.from_array(np.zeros((1, 49)).astype(np.float32), name="i1"),
                    onh.from_array(np.zeros((1, 4)).astype(np.float32), name="i2"),
                    onh.from_array(np.array([2, 3, 3, 3], dtype=np.int64), name="s1"),
                    onh.from_array(np.array([3, 3], dtype=np.int64), name="s2"),
                ],
            ),
            opset_imports=[oh.make_opsetid("", 18)],
            ir_version=9,
        )

        # not optimized
        input_data = {"v0_0": np.arange(5).astype(np.float64)}
        proto_issue = onnx.load(
            os.path.join(os.path.dirname(__file__), "data", "inconsis3.onnx")
        )
        for i, proto in enumerate([proto_simple, proto_issue]):
            sessopts = ort.SessionOptions()
            sessopts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
            sessopts.optimized_model_filepath = f"test_ort_optimization_disabled_{i}.onnx"
            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
            original_session = ort.InferenceSession(
                proto.SerializeToString(), sessopts, providers=providers
            )
            output_names = ["output"]
            original_result = original_session.run(output_names, input_data)

            # optimized
            sessopts2 = ort.SessionOptions()
            sessopts2.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
            sessopts2.optimized_model_filepath = f"test_ort_optimization_enabled_{i}.onnx"
            original_session2 = ort.InferenceSession(
                proto.SerializeToString(), sessopts2, providers=providers
            )
            original_result2 = original_session2.run(output_names, input_data)

            ref = ReferenceEvaluator(proto, verbose=10)
            onnx_results = ref.run(output_names, input_data)
            # fails here
            np.testing.assert_allclose(onnx_results[0], original_result[0])
            np.testing.assert_allclose(onnx_results[0], original_result2[0])

@yuslepukhin
Copy link
Member

Reshape and Squeeze should not be wrapped into Memcpy nodes.

@Thrsu
Copy link
Author

Thrsu commented Dec 20, 2024

Thank you for your analysis and explanation. I will try using only the CPU and look forward to the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

3 participants