Skip to content

[rocm] Add rnnt loss feature for rocm #3938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/hipify_torch"]
path = third_party/hipify_torch
url = https://github.com/ROCmSoftwarePlatform/hipify_torch
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ if(USE_CUDA AND USE_ROCM)
message(FATAL "CUDA and ROCm are mutually exclusive")
endif()

find_package(Torch REQUIRED)

if(USE_ROCM)

enable_language(HIP)

# Find the HIP package, set the HIP paths, load the HIP CMake.
include(cmake/LoadHIP.cmake)
if(NOT PYTORCH_FOUND_HIP)
set(USE_ROCM OFF)
#set(USE_ROCM OFF)
endif()
endif()

Expand Down
57 changes: 54 additions & 3 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
################################################################################
# libtorchaudio
################################################################################

if(USE_ROCM)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
FIND_PACKAGE(HIP REQUIRED)
MESSAGE(STATUS "hip found ${ROCM_FOUND}")

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake")
include(Hipify)

set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" )

endif()


set(
sources
lfilter.cpp
Expand Down Expand Up @@ -39,6 +56,23 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif()

if (USE_ROCM)
hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/gpu HIP_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/hip)
if ( NOT HIP_ADD_LIBRARY_FOUND )
list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake)
find_package(HIP REQUIRED)
endif()

list(
APPEND
sources
rnnt/hip/compute_alphas.hip
rnnt/hip/compute_betas.hip
rnnt/hip/compute.hip
)
endif()

endif()

if(BUILD_RIR)
Expand Down Expand Up @@ -76,12 +110,29 @@ if(USE_CUDA)
)
endif()

if(OpenMP_CXX_FOUND)
if(USE_ROCM)
list(
APPEND
additional_libs
OpenMP::OpenMP_CXX
additional_libs
hip::host
hip::device
)
list(
APPEND
compile_definitions
USE_ROCM
)
endif()


if(USE_CUDA)
if(OpenMP_CXX_FOUND)
list(
APPEND
additional_libs
OpenMP::OpenMP_CXX
)
endif()
endif()

#------------------------------------------------------------------------------#
Expand Down
6 changes: 5 additions & 1 deletion src/libtorchaudio/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/gpu_transducer_hip.h>
#else
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
6 changes: 5 additions & 1 deletion src/libtorchaudio/rnnt/gpu/compute_alphas.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/gpu_transducer_hip.h>
#else
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
6 changes: 5 additions & 1 deletion src/libtorchaudio/rnnt/gpu/compute_betas.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/gpu_transducer_hip.h>
#else
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
4 changes: 4 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

#ifdef USE_CUDA

#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/math_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
6 changes: 6 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/kernel_utils.h>
#include <libtorchaudio/rnnt/hip/kernels.h>
#include <libtorchaudio/rnnt/hip/math_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/kernel_utils.h>
#include <libtorchaudio/rnnt/gpu/kernels.h>
#include <libtorchaudio/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
5 changes: 5 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_transducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
#ifdef USE_CUDA

#include <libtorchaudio/rnnt/workspace.h>
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/gpu_kernel_utils_hip.cuh>
#include <libtorchaudio/rnnt/hip/gpu_kernels_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh>
#include <libtorchaudio/rnnt/gpu/gpu_kernels.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
4 changes: 4 additions & 0 deletions src/libtorchaudio/rnnt/gpu/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/math_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
5 changes: 5 additions & 0 deletions src/libtorchaudio/rnnt/gpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/hip/kernel_utils.h>
#include <libtorchaudio/rnnt/hip/math_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/kernel_utils.h>
#include <libtorchaudio/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
8 changes: 8 additions & 0 deletions src/libtorchaudio/rnnt/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif USE_ROCM
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
Expand Down
9 changes: 7 additions & 2 deletions src/libtorchaudio/rnnt/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

#ifdef USE_CUDA
#include <cuda_runtime.h>
typedef cudaStream_t gpuStream_t;
#endif // USE_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
typedef hipStream_t gpuStream_t;
#endif // USE_ROCM

#include <libtorchaudio/rnnt/types.h>
#include <ostream>
Expand All @@ -13,9 +18,9 @@ namespace rnnt {
struct Options {
// the device to compute transducer loss.
device_t device_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
gpuStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
Expand Down
16 changes: 14 additions & 2 deletions src/libtorchaudio/rnnt/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,22 @@ class IntWorkspace {
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
#ifdef USE_ROCM
if (data_ != nullptr && options_.device_ == GPU) {
hipMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
hipMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_ROCM
}

static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
Expand All @@ -147,7 +159,7 @@ class IntWorkspace {
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
Expand Down
1 change: 1 addition & 0 deletions third_party/hipify_torch
Submodule hipify_torch added at a4337c