Skip to content

Commit

Permalink
Update PyExecutionSessionBase.cpp to get right result
Browse files Browse the repository at this point in the history
When I use pyruntime, the output tensor always contains the same element. I believe this is because the pointer is not modified by the stride when getting the address. The pointer should move by the stride each time an element is read.

Signed-off-by: Haoxiang Meng <[email protected]>
  • Loading branch information
fairytale0828 authored Jan 10, 2025
1 parent 0183ad9 commit c19d72c
Showing 1 changed file with 112 additions and 45 deletions.
157 changes: 112 additions & 45 deletions src/Runtime/python/PyExecutionSessionBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,52 +232,113 @@ std::vector<py::array> PyExecutionSessionBase::pyRun(

// https://numpy.org/devdocs/user/basics.types.html
py::dtype dtype;
// switch (omTensorGetDataType(omt)) {
// case (OM_DATA_TYPE)onnx::TensorProto::FLOAT:
// dtype = py::dtype("float32");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::UINT8:
// dtype = py::dtype("uint8");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::INT8:
// dtype = py::dtype("int8");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::UINT16:
// dtype = py::dtype("uint16");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::INT16:
// dtype = py::dtype("int16");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::INT32:
// dtype = py::dtype("int32");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::INT64:
// dtype = py::dtype("int64");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::STRING:
// dtype = py::dtype("str");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::BOOL:
// dtype = py::dtype("bool_");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16:
// dtype = py::dtype("float16");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE:
// dtype = py::dtype("float64");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::UINT32:
// dtype = py::dtype("uint32");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::UINT64:
// dtype = py::dtype("uint64");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64:
// dtype = py::dtype("csingle");
// break;
// case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128:
// dtype = py::dtype("cdouble");
// break;
switch (omTensorGetDataType(omt)) {
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT:
dtype = py::dtype("float32");
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT8:
dtype = py::dtype("uint8");
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT8:
dtype = py::dtype("int8");
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT16:
dtype = py::dtype("uint16");
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT16:
dtype = py::dtype("int16");
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT32:
dtype = py::dtype("int32");
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT64:
dtype = py::dtype("int64");
break;
case (OM_DATA_TYPE)onnx::TensorProto::STRING:
dtype = py::dtype("str");
break;
case (OM_DATA_TYPE)onnx::TensorProto::BOOL:
dtype = py::dtype("bool_");
break;
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16:
dtype = py::dtype("float16");
break;
case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE:
dtype = py::dtype("float64");
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT32:
dtype = py::dtype("uint32");
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT64:
dtype = py::dtype("uint64");
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64:
dtype = py::dtype("csingle");
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128:
dtype = py::dtype("cdouble");
break;
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT:
dtype = py::dtype("float32");
typesize = 4;
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT8:
dtype = py::dtype("uint8");
typesize = 1;
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT8:
dtype = py::dtype("int8");
typesize = 1;
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT16:
dtype = py::dtype("uint16");
typesize = 2;
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT16:
dtype = py::dtype("int16");
typesize = 2;
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT32:
dtype = py::dtype("int32");
typesize = 4;
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT64:
dtype = py::dtype("int64");
typesize = 8;
break;
case (OM_DATA_TYPE)onnx::TensorProto::STRING:
dtype = py::dtype("str");
typesize = sizeof(char*);
break;
case (OM_DATA_TYPE)onnx::TensorProto::BOOL:
dtype = py::dtype("bool_");
typesize = 1;
break;
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16:
dtype = py::dtype("float16");
typesize = 2;
break;
case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE:
dtype = py::dtype("float64");
typesize = 8;
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT32:
dtype = py::dtype("uint32");
typesize = 4;
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT64:
dtype = py::dtype("uint64");
typesize = 8;
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64:
dtype = py::dtype("csingle");
typesize = 8;
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128:
dtype = py::dtype("cdouble");
typesize = 16;
break;
default: {
std::stringstream errStr;
errStr << "Unsupported ONNX type in OMTensor: "
Expand All @@ -286,6 +347,12 @@ std::vector<py::array> PyExecutionSessionBase::pyRun(
throw std::runtime_error(reportPythonError(errStr.str()));
}
}
// Convert tensor strides from element count to byte offset by multiplying with typesize
auto strides = std::vector<int64_t>(
omTensorGetStrides(omt), (omTensorGetStrides(omt) + omTensorGetRank(omt)));
for (auto& stride : strides) {
stride *= typesize;
}

outputPyArrays.emplace_back(
py::array(dtype, shape, omTensorGetDataPtr(omt)));
Expand Down

0 comments on commit c19d72c

Please sign in to comment.