From c19d72cd28d32cf0567ada4715d1aa5b6e2e6a3b Mon Sep 17 00:00:00 2001 From: Haoxiang Meng <52441206+fairytale0828@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:46:18 +0800 Subject: [PATCH] Update PyExecutionSessionBase.cpp to get right result When I use pyruntime, the output tensor always contains the same element. I believe this is because the pointer is not modified by the stride when getting the address. The pointer should move by the stride each time an element is read. Signed-off-by: Haoxiang Meng <52441206+fairytale0828@users.noreply.github.com> --- src/Runtime/python/PyExecutionSessionBase.cpp | 157 +++++++++++++----- 1 file changed, 112 insertions(+), 45 deletions(-) diff --git a/src/Runtime/python/PyExecutionSessionBase.cpp b/src/Runtime/python/PyExecutionSessionBase.cpp index b98fb0de23..7cd9c66840 100644 --- a/src/Runtime/python/PyExecutionSessionBase.cpp +++ b/src/Runtime/python/PyExecutionSessionBase.cpp @@ -232,52 +232,113 @@ std::vector PyExecutionSessionBase::pyRun( // https://numpy.org/devdocs/user/basics.types.html py::dtype dtype; + // switch (omTensorGetDataType(omt)) { + // case (OM_DATA_TYPE)onnx::TensorProto::FLOAT: + // dtype = py::dtype("float32"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::UINT8: + // dtype = py::dtype("uint8"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::INT8: + // dtype = py::dtype("int8"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::UINT16: + // dtype = py::dtype("uint16"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::INT16: + // dtype = py::dtype("int16"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::INT32: + // dtype = py::dtype("int32"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::INT64: + // dtype = py::dtype("int64"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::STRING: + // dtype = py::dtype("str"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::BOOL: + // dtype = py::dtype("bool_"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16: + // dtype = py::dtype("float16"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE: + // dtype = py::dtype("float64"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::UINT32: + // dtype = py::dtype("uint32"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::UINT64: + // dtype = py::dtype("uint64"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64: + // dtype = py::dtype("csingle"); + // break; + // case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128: + // dtype = py::dtype("cdouble"); + // break; switch (omTensorGetDataType(omt)) { - case (OM_DATA_TYPE)onnx::TensorProto::FLOAT: - dtype = py::dtype("float32"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::UINT8: - dtype = py::dtype("uint8"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::INT8: - dtype = py::dtype("int8"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::UINT16: - dtype = py::dtype("uint16"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::INT16: - dtype = py::dtype("int16"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::INT32: - dtype = py::dtype("int32"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::INT64: - dtype = py::dtype("int64"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::STRING: - dtype = py::dtype("str"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::BOOL: - dtype = py::dtype("bool_"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16: - dtype = py::dtype("float16"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE: - dtype = py::dtype("float64"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::UINT32: - dtype = py::dtype("uint32"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::UINT64: - dtype = py::dtype("uint64"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64: - dtype = py::dtype("csingle"); - break; - case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128: - dtype = py::dtype("cdouble"); - break; + case (OM_DATA_TYPE)onnx::TensorProto::FLOAT: + dtype = py::dtype("float32"); + typesize = 4; + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT8: + dtype = py::dtype("uint8"); + typesize = 1; + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT8: + dtype = py::dtype("int8"); + typesize = 1; + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT16: + dtype = py::dtype("uint16"); + typesize = 2; + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT16: + dtype = py::dtype("int16"); + typesize = 2; + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT32: + dtype = py::dtype("int32"); + typesize = 4; + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT64: + dtype = py::dtype("int64"); + typesize = 8; + break; + case (OM_DATA_TYPE)onnx::TensorProto::STRING: + dtype = py::dtype("str"); + typesize = sizeof(char*); + break; + case (OM_DATA_TYPE)onnx::TensorProto::BOOL: + dtype = py::dtype("bool_"); + typesize = 1; + break; + case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16: + dtype = py::dtype("float16"); + typesize = 2; + break; + case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE: + dtype = py::dtype("float64"); + typesize = 8; + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT32: + dtype = py::dtype("uint32"); + typesize = 4; + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT64: + dtype = py::dtype("uint64"); + typesize = 8; + break; + case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64: + dtype = py::dtype("csingle"); + typesize = 8; + break; + case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128: + dtype = py::dtype("cdouble"); + typesize = 16; + break; default: { std::stringstream errStr; errStr << "Unsupported ONNX type in OMTensor: " @@ -286,6 +347,12 @@ std::vector PyExecutionSessionBase::pyRun( throw std::runtime_error(reportPythonError(errStr.str())); } } + // Convert tensor strides from element count to byte offset by multiplying with typesize + auto strides = std::vector( + omTensorGetStrides(omt), (omTensorGetStrides(omt) + omTensorGetRank(omt))); + for (auto& stride : strides) { + stride *= typesize; + } outputPyArrays.emplace_back( py::array(dtype, shape, omTensorGetDataPtr(omt)));