Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why the output of the ONNX MatMul node never be the same as what PyTorch gives? #23792

Open
JNR000 opened this issue Feb 23, 2025 · 2 comments
Open
Labels
performance issues related to performance regressions

Comments

@JNR000
Copy link

JNR000 commented Feb 23, 2025

Why the output of the ONNX MatMul node never be the same as what PyTorch gives?

Question

The output of the ONNX MatMul node is never the same as what PyTorch gives, whether during CPU inference or GPU inference. I've tried a lot of different approaches to test the situation.

Environment & Packages

Ubuntu 	22.04.4
Python	3.12.3
numpy     2.1.3
onnx        1.17.0
onnxruntime-gpu    1.20.1
onnxscript     0.2.0
openpyxl       3.1.5
torch              2.5.1+cu124
cuda	12.4

Further information

I write a simple linear layer by torch and export it as onnx model. After that, I could never get the same output from onnx as what my torch model give.

import torch
import torch.nn as nn
import torch.onnx
import numpy as np
import onnx
import onnxruntime as ort


def comp(a, b):
    return ~((~np.equal(np.array(a), np.array(b))).any())


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(5, 7, bias=False)

    def forward(self, x):
        return self.fc(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for type in [torch.float32, torch.float64]:
    model = SimpleModel().to(type).to(device)

    # torch output
    x = torch.randn(2, 3, 5).to(type).to(device)
    model.eval()
    with torch.no_grad():
        outtorch = model(x)

    # numpy matmul
    with torch.no_grad():
        for name, param in model.named_parameters():
            numpy_weight = param.data
    out_numpy = np.matmul(np.array(x.cpu()), np.array(numpy_weight.cpu()).T)

    # export onnx model
    onnx_model_path = "simple_model.onnx"
    torch.onnx.export(
        model,
        x,
        onnx_model_path,
        verbose=False,
        input_names=["input"],
        output_names=["output"],
        opset_version=13,
    )
    # load onnx model
    onnx_model = onnx.load(onnx_model_path)
    # check
    for initializer in onnx_model.graph.initializer:
        onnx_dtype = onnx.TensorProto.DataType.Name(initializer.data_type)
        print(f"Dtype: {onnx_dtype}")

    onnx_session_cpu = ort.InferenceSession(onnx_model_path)

    # format
    input_data = np.array(x.cpu())
    print(input_data.dtype)

    input_name = onnx_session_cpu.get_inputs()[0].name
    output_name = onnx_session_cpu.get_outputs()[0].name
    # onnx cpy inference
    onnx_output_cpu = onnx_session_cpu.run([output_name], {input_name: input_data})

    # onnx gpu inference
    options = ort.SessionOptions()
    options.enable_cpu_mem_arena = False
    options.enable_mem_pattern = False
    options.enable_mem_reuse = False
    options.intra_op_num_threads = 1

    cuda_provider_options = {
        "arena_extend_strategy": "kSameAsRequested",
    }

    ort_session_gpu = ort.InferenceSession(
        onnx_model_path,
        sess_options=options,
        providers=[("CUDAExecutionProvider", cuda_provider_options)],
    )

    onnx_output_gpu = ort_session_gpu.run([output_name], {input_name: input_data})

    print("onnx cpu vs onnx gpu: ", comp(onnx_output_cpu, onnx_output_gpu))
    print("onnx cpu vs torch (gpu): ", comp(onnx_output_cpu, outtorch.cpu()))
    print("onnx gpu vs torch (gpu): ", comp(outtorch.cpu(), onnx_output_gpu))
    print("numpy matmul vs torch (gpu)", comp(outtorch.cpu(), out_numpy))

The results are:

Dtype: FLOAT
float32
onnx cpu vs onnx gpu:  False
onnx cpu vs torch (gpu):  True
onnx gpu vs torch (gpu):  False
numpy matmul vs torch (gpu) True
Dtype: DOUBLE
float64
onnx cpu vs onnx gpu:  True
onnx cpu vs torch (gpu):  True
onnx gpu vs torch (gpu):  True
numpy matmul vs torch (gpu) True

What's more, the output of an nn.Conv layer is the same as the corresponding ONNX model's output. The problem only occurs with the Linear layer till now. It's quite strange because the two operations, Conv and Linear, aren't that different.

To reproduce

Run the code above

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04.4

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

X86

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.4

Model File

No response

Is this a quantized model?

No

@JNR000 JNR000 added the performance issues related to performance regressions label Feb 23, 2025
@tianleiwu
Copy link
Contributor

tianleiwu commented Feb 23, 2025

Floating-point arithmetic is not associative, meaning that the order of operations affects the result. Different device or software may execute operations in varying sequences, especially when parallelizing tasks, leading to slight discrepancies in outcomes. That could be one of the reasons.

If your NVIDIA GPU is Ampere or later, try the following:

providers = [("CUDAExecutionProvider", {"use_tf32": 0})]
sess_options = ort.SessionOptions()
sess = ort.InferenceSession("model.onnx", sess_options=sess_options, providers=providers)

By default, PyTorch disabled TF32 in MatMul. However, ORT enables TF32 by default. See https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#use_tf32

@JNR000
Copy link
Author

JNR000 commented Feb 24, 2025

@tianleiwu Thank you very much for your kind help!
I corrected the comparison of the results, and with my original settings, the comparison with float32 results in:

Dtype: FLOAT
float32
onnx cpu vs onnx gpu:  False
onnx cpu vs torch (gpu):  True
onnx gpu vs torch (gpu):  False
numpy matmul vs torch (gpu) True

Setting “use_tf32”: 0 as you mentioned seems to have no change to the result. However, keeping use_tf32 on default but torch.backends.cuda.matmul.allow_tf32 = True as you said earlier helps. The output now is:

Dtype: FLOAT
float32
onnx cpu vs onnx gpu:  False
onnx cpu vs torch (gpu):  False
onnx gpu vs torch (gpu):  True
numpy matmul vs torch (gpu): False

which seems more reasonable.

I will explore more about how pytorch and onnxruntime do with matmul node. Setting allow_tf32 from False to True in Pytorch changes the result from what equal to onnx_cpu (and numpy) to onnx_gpu, which seems reasonable. However, why setting ort's “use_tf32” from 1 to 0 doesn't give me a result equal to onnx_cpu (or numpy)? What may onnxruntime do in Matmul node with default use_tf32 ?

I would be very thankful if you have any further suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues related to performance regressions
Projects
None yet
Development

No branches or pull requests

2 participants