From da1ffceadb5bfa1a9c101ac802a0d422ddb8979d Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 20 Sep 2024 17:56:24 -0500 Subject: [PATCH] Update ROCM version detection code This fixes SWDEV-486455 --- cmake/CMakeLists.txt | 63 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0d582533faf21..d47b0bd66f396 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -329,20 +329,55 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # https://github.com/pytorch/pytorch/blob/1a10751731784942dcbb9c0524c1369a29d45244/cmake/public/LoadHIP.cmake#L45-L109 # with modification - if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") - file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - endif() - - if (ROCM_VERSION_MATCH) + set(ROCM_INCLUDE_DIRS "${onnxruntime_ROCM_HOME}/include") + set(PROJECT_RANDOM_BINARY_DIR "${CMAKE_BINARY_DIR}") + set(file "${CMAKE_BINARY_DIR}/detect_rocm_version.cc") + + # Find ROCM version for checks + # ROCM 5.0 and later will have header api for version management + if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h) + file(WRITE ${file} "" + "#include \n" + ) + elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) + file(WRITE ${file} "" + "#include \n" + ) + else() + message(FATAL_ERROR "********************* rocm_version.h couldnt be found ******************\n") + endif() + + file(APPEND ${file} "" + "#include \n" + + "#ifndef ROCM_VERSION_PATCH\n" + "#define ROCM_VERSION_PATCH 0\n" + "#endif\n" + "#define STRINGIFYHELPER(x) #x\n" + "#define STRINGIFY(x) STRINGIFYHELPER(x)\n" + "int main() {\n" + " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n" + " return 0;\n" + "}\n" + ) + + try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + RUN_OUTPUT_VARIABLE rocm_version_from_header + COMPILE_OUTPUT_VARIABLE output_var + ) + # We expect the compile to be successful if the include directory exists. + if(NOT compile_result) + message(FATAL_ERROR "ROCM: Couldn't determine version from header: " ${output_var}) + endif() + message(STATUS "ROCM: Header version is: " ${rocm_version_from_header}) + set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header}) + + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) + + if (ROCM_VERSION_DEV_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) @@ -351,7 +386,7 @@ if (onnxruntime_USE_ROCM) else() message(FATAL_ERROR "Cannot determine ROCm version string") endif() - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") + message("\n***** ROCm version from rocm_version.h ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")