diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fb464b673b7..177999a50cd5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -224,24 +224,6 @@ if(USE_CUDA) add_subdirectory(${PROJECT_SOURCE_DIR}/gputreeshap) find_package(CUDAToolkit REQUIRED) - find_package(CCCL CONFIG) - if(NOT CCCL_FOUND) - message(STATUS "Standalone CCCL not found. Attempting to use CCCL from CUDA Toolkit...") - find_package(CCCL CONFIG - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - if(NOT CCCL_FOUND) - message(STATUS "Could not locate CCCL from CUDA Toolkit. Using Thrust and CUB from CUDA Toolkit...") - find_package(libcudacxx CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - find_package(CUB CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - find_package(Thrust CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - thrust_create_target(Thrust HOST CPP DEVICE CUDA) - add_library(CCCL::CCCL INTERFACE IMPORTED GLOBAL) - target_link_libraries(CCCL::CCCL INTERFACE libcudacxx::libcudacxx CUB::CUB Thrust) - endif() - endif() endif() if(FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND @@ -327,6 +309,28 @@ if(PLUGIN_RMM) list(REMOVE_ITEM rmm_link_libs CUDA::cudart) list(APPEND rmm_link_libs CUDA::cudart_static) set_target_properties(rmm::rmm PROPERTIES INTERFACE_LINK_LIBRARIES "${rmm_link_libs}") + + # Pick up patched CCCL from RMM +elseif(USE_CUDA) + # If using CUDA and not RMM, search for CCCL. + find_package(CCCL CONFIG) + if(NOT CCCL_FOUND) + message(STATUS "Standalone CCCL not found. Attempting to use CCCL from CUDA Toolkit...") + find_package(CCCL CONFIG + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + if(NOT CCCL_FOUND) + message(STATUS "Could not locate CCCL from CUDA Toolkit. Using Thrust and CUB from CUDA Toolkit...") + find_package(libcudacxx CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + find_package(CUB CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + find_package(Thrust CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + thrust_create_target(Thrust HOST CPP DEVICE CUDA) + add_library(CCCL::CCCL INTERFACE IMPORTED GLOBAL) + target_link_libraries(CCCL::CCCL INTERFACE libcudacxx::libcudacxx CUB::CUB Thrust) + endif() + endif() endif() if(PLUGIN_SYCL)