You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Why the output of the ONNX MatMul node never be the same as what PyTorch gives?
Question
The output of the ONNX MatMul node is never the same as what PyTorch gives, whether during CPU inference or GPU inference. I've tried a lot of different approaches to test the situation.
I write a simple linear layer by torch and export it as onnx model. After that, I could never get the same output from onnx as what my torch model give.
import torch
import torch.nn as nn
import torch.onnx
import numpy as np
import onnx
import onnxruntime as ort
def comp(a, b):
return ~((~np.equal(np.array(a), np.array(b))).any())
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(5, 7, bias=False)
def forward(self, x):
return self.fc(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for type in [torch.float32, torch.float64]:
model = SimpleModel().to(type).to(device)
# torch output
x = torch.randn(2, 3, 5).to(type).to(device)
model.eval()
with torch.no_grad():
outtorch = model(x)
# numpy matmul
with torch.no_grad():
for name, param in model.named_parameters():
numpy_weight = param.data
out_numpy = np.matmul(np.array(x.cpu()), np.array(numpy_weight.cpu()).T)
# export onnx model
onnx_model_path = "simple_model.onnx"
torch.onnx.export(
model,
x,
onnx_model_path,
verbose=False,
input_names=["input"],
output_names=["output"],
opset_version=13,
)
# load onnx model
onnx_model = onnx.load(onnx_model_path)
# check
for initializer in onnx_model.graph.initializer:
onnx_dtype = onnx.TensorProto.DataType.Name(initializer.data_type)
print(f"Dtype: {onnx_dtype}")
onnx_session_cpu = ort.InferenceSession(onnx_model_path)
# format
input_data = np.array(x.cpu())
print(input_data.dtype)
input_name = onnx_session_cpu.get_inputs()[0].name
output_name = onnx_session_cpu.get_outputs()[0].name
# onnx cpy inference
onnx_output_cpu = onnx_session_cpu.run([output_name], {input_name: input_data})
# onnx gpu inference
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
options.intra_op_num_threads = 1
cuda_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
ort_session_gpu = ort.InferenceSession(
onnx_model_path,
sess_options=options,
providers=[("CUDAExecutionProvider", cuda_provider_options)],
)
onnx_output_gpu = ort_session_gpu.run([output_name], {input_name: input_data})
print("onnx cpu vs onnx gpu: ", comp(onnx_output_cpu, onnx_output_gpu))
print("onnx cpu vs torch (gpu): ", comp(onnx_output_cpu, outtorch.cpu()))
print("onnx gpu vs torch (gpu): ", comp(outtorch.cpu(), onnx_output_gpu))
print("numpy matmul vs torch (gpu)", comp(outtorch.cpu(), out_numpy))
The results are:
Dtype: FLOAT
float32
onnx cpu vs onnx gpu: False
onnx cpu vs torch (gpu): True
onnx gpu vs torch (gpu): False
numpy matmul vs torch (gpu) True
Dtype: DOUBLE
float64
onnx cpu vs onnx gpu: True
onnx cpu vs torch (gpu): True
onnx gpu vs torch (gpu): True
numpy matmul vs torch (gpu) True
What's more, the output of an nn.Conv layer is the same as the corresponding ONNX model's output. The problem only occurs with the Linear layer till now. It's quite strange because the two operations, Conv and Linear, aren't that different.
To reproduce
Run the code above
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.04.4
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.20.1
ONNX Runtime API
Python
Architecture
X86
Execution Provider
CUDA
Execution Provider Library Version
CUDA 12.4
Model File
No response
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered:
Floating-point arithmetic is not associative, meaning that the order of operations affects the result. Different device or software may execute operations in varying sequences, especially when parallelizing tasks, leading to slight discrepancies in outcomes. That could be one of the reasons.
If your NVIDIA GPU is Ampere or later, try the following:
@tianleiwu Thank you very much for your kind help!
I corrected the comparison of the results, and with my original settings, the comparison with float32 results in:
Dtype: FLOAT
float32
onnx cpu vs onnx gpu: False
onnx cpu vs torch (gpu): True
onnx gpu vs torch (gpu): False
numpy matmul vs torch (gpu) True
Setting “use_tf32”: 0 as you mentioned seems to have no change to the result. However, keeping use_tf32 on default but torch.backends.cuda.matmul.allow_tf32 = True as you said earlier helps. The output now is:
Dtype: FLOAT
float32
onnx cpu vs onnx gpu: False
onnx cpu vs torch (gpu): False
onnx gpu vs torch (gpu): True
numpy matmul vs torch (gpu): False
which seems more reasonable.
I will explore more about how pytorch and onnxruntime do with matmul node. Setting allow_tf32 from False to True in Pytorch changes the result from what equal to onnx_cpu (and numpy) to onnx_gpu, which seems reasonable. However, why setting ort's “use_tf32” from 1 to 0 doesn't give me a result equal to onnx_cpu (or numpy)? What may onnxruntime do in Matmul node with default use_tf32 ?
I would be very thankful if you have any further suggestions.
Why the output of the ONNX MatMul node never be the same as what PyTorch gives?
Question
The output of the ONNX MatMul node is never the same as what PyTorch gives, whether during CPU inference or GPU inference. I've tried a lot of different approaches to test the situation.
Environment & Packages
Further information
I write a simple linear layer by torch and export it as onnx model. After that, I could never get the same output from onnx as what my torch model give.
The results are:
What's more, the output of an nn.Conv layer is the same as the corresponding ONNX model's output. The problem only occurs with the Linear layer till now. It's quite strange because the two operations, Conv and Linear, aren't that different.
To reproduce
Run the code above
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.04.4
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.20.1
ONNX Runtime API
Python
Architecture
X86
Execution Provider
CUDA
Execution Provider Library Version
CUDA 12.4
Model File
No response
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered: