Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPP Audio Support HIP - Non silent region #395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
c33af22
Bump rocm-docs-core[api_reference] from 0.35.0 to 0.35.1 in /docs/sph…
dependabot[bot] Mar 6, 2024
7adaa5f
added initial skeleton code for the NSR HIP kernel
sampath1117 Mar 7, 2024
8e13762
added test suite support for audio in HIP
sampath1117 Mar 8, 2024
1348130
initial commit for working NSR kernel with batch size 1
sampath1117 Mar 11, 2024
3c0c3eb
added max reduction kernel for finding max value in MMS buffer
sampath1117 Mar 12, 2024
14f6334
Bump rocm-docs-core[api_reference] from 0.35.1 to 0.36.0 in /docs/sph…
dependabot[bot] Mar 12, 2024
95c3272
Merge branch 'master' into develop
kiritigowda Mar 12, 2024
edf0e99
optimized find region kernel
sampath1117 Mar 17, 2024
f3d105f
added profiler support for hip test suite
sampath1117 Mar 18, 2024
701b9fc
modified kernel launch configuration for moving_mean_square_hip_tenso…
sampath1117 Mar 19, 2024
3102d38
changed the pinned memory for mmsArr to HIP memoryy
sampath1117 Mar 19, 2024
e5d469c
modified the datatype for NSR HIP kernel outputs from float to int
sampath1117 Mar 19, 2024
980dade
modify NSR HOST kernel outputs to int
sampath1117 Mar 14, 2024
0c4c787
change shm_pos to smem_pos
sampath1117 Mar 19, 2024
99ccead
minor change
sampath1117 Mar 20, 2024
641f653
Docs - Bump rocm-docs-core[api_reference] from 0.36.0 to 0.37.0 in /d…
dependabot[bot] Mar 20, 2024
5568573
Link cleanup (#326)
LisaDelaney Mar 20, 2024
a6749ba
Update notes
LisaDelaney Mar 20, 2024
a255906
Docs - Bump rocm-docs-core[api_reference] from 0.37.0 to 0.37.1 in /d…
dependabot[bot] Mar 22, 2024
d3df761
RPP Voxel Flip on HIP and HOST (#285)
r-abishek Mar 23, 2024
ebecb42
RPP Vignette Tensor on HOST and HIP (#311)
r-abishek Mar 23, 2024
fc1410b
Bump rocm-docs-core[api_reference] from 0.37.1 to 0.38.0 in /docs/sph…
dependabot[bot] Mar 27, 2024
cf91791
Merge branch 'develop' into sr/nsr_hip
sampath1117 Mar 29, 2024
64ee621
minor code cleanup
sampath1117 Mar 29, 2024
bc66023
changed the declaration of shared memory
sampath1117 Mar 29, 2024
3ebd7c3
RPP Tensor Audio Support - Resample (#310)
r-abishek Apr 3, 2024
76f31df
Docs - Missing input and output images for Doxygen (#331)
r-abishek Apr 3, 2024
b83f910
Scratch buffers rename for HOST and HIP (#324)
r-abishek Apr 3, 2024
ebeb131
Update CMakeLists.txt
kiritigowda Apr 3, 2024
2e40516
removed f16 includes since not needed for audio
sampath1117 Apr 4, 2024
fbd85c6
Merge remote-tracking branch 'develop' into sr/nsr_hip
sampath1117 Apr 4, 2024
46299d5
modified scratch buffer name used in NSR hip kernel
sampath1117 Apr 4, 2024
29efdb7
moved gridStride as param from kernel launch
sampath1117 Apr 4, 2024
3f5319b
restructured python test suite
sampath1117 Apr 4, 2024
44c29b1
minor change
sampath1117 Apr 4, 2024
6bdc00e
removed gridStride based processing in moving_mean_square_hip_tensor …
sampath1117 Apr 10, 2024
39fb985
build fix
sampath1117 Apr 10, 2024
bd28ffd
add comment for smem_pos function
sampath1117 Apr 10, 2024
cf85008
fixed spacing in Doxygen
sampath1117 Apr 11, 2024
1147bfe
Update CMakeLists.txt
kiritigowda Apr 12, 2024
201c5d6
Merge remote-tracking branch 'develop' into sr/nsr_hip
sampath1117 Apr 16, 2024
5e3fc7a
Bump rocm-docs-core[api_reference] from 0.38.1 to 1.0.0 in /docs/sphi…
dependabot[bot] Apr 18, 2024
b6b7cc5
Bump rocm-docs-core[api_reference] from 1.0.0 to 1.1.0 in /docs/sphin…
dependabot[bot] Apr 25, 2024
e16ad7a
RPP Gaussian Noise Voxel Tensor on HOST and HIP (#323)
r-abishek Apr 26, 2024
b5568c4
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 2, 2024
6f16fc3
modify CHECK to CHECK_RETURN_STATUS in hip audio test suite
sampath1117 May 2, 2024
1d7377d
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 2, 2024
c44a014
removed additional line added in merge
sampath1117 May 2, 2024
77e14ef
Minor common-fixes for HIP (#345)
r-abishek May 7, 2024
34f3f6d
Readme Updates: --usecase=rocm (#349)
kiritigowda May 8, 2024
ab52683
RPP Tensor Audio Support - Spectrogram (#312)
r-abishek May 8, 2024
ee0d6fe
Update CHANGELOG.md (#352)
r-abishek May 8, 2024
2decd32
RPP Tensor Audio Support - Slice (#325)
r-abishek May 8, 2024
30ce1d6
RPP Tensor Audio Support - MelFilterBank (#332)
r-abishek May 8, 2024
64ae74f
RPP Tensor Normalize ND on HOST and HIP (#335)
r-abishek May 9, 2024
1a3015c
SWDEV-459739 - Remove the package obsolete setting (#353)
raramakr May 9, 2024
7d233bc
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 9, 2024
550e981
modified code to use globalThreads_z from description pointer instead…
sampath1117 May 9, 2024
4cb8d4b
Audio support merge commit fixes (#354)
r-abishek May 9, 2024
9bd6566
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 17, 2024
69e985c
renamed instances of tensor_hip_audio to tensor_audio_hip
sampath1117 May 17, 2024
5498e3a
made helper functions inline
sampath1117 May 17, 2024
bde44ba
added comments for prev_pow2 and next_pow2 functions
sampath1117 May 17, 2024
067a425
change max reduction kernel block size from 512 to 256
sampath1117 May 17, 2024
995bde5
change base types to RPP types in hip executor
sampath1117 May 17, 2024
07126db
added error codes for tile length and shared memory
sampath1117 May 17, 2024
13efda2
minor change
sampath1117 May 17, 2024
d4219f8
Bump rocm-docs-core[api_reference] from 1.1.1 to 1.1.2 in /docs/sphin…
dependabot[bot] May 17, 2024
264392b
Docker updates (#356)
LakshmiKumar23 May 17, 2024
9907a49
Version Updates (#359)
LakshmiKumar23 May 17, 2024
d705d1b
remove square helper function in HIP kernel
sampath1117 May 20, 2024
04e1c75
optimize index computation in prefix_sim function
sampath1117 May 20, 2024
2256d19
add more comments in code for better readability
sampath1117 May 20, 2024
faa34e6
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 24, 2024
98c8412
Merge branch 'develop' into sr/nsr_hip
sampath1117 May 29, 2024
8aa674c
Merge branch 'develop' into sr/nsr_hip
sampath1117 Jun 18, 2024
63a655d
removed print statement for invalid cases in hip executor
sampath1117 Jun 18, 2024
8acb1ad
vectorized computations in compute_prefix_sum function
sampath1117 Jun 18, 2024
fed7828
increased scratch memory allocation and removed the hipMalloc and hip…
sampath1117 Jun 20, 2024
a47f70d
Merge branch 'develop' of https://github.com/r-abishek/rpp into sr/ns…
r-abishek Jun 20, 2024
d7333e8
added explanation for the scratch memory allocated in HIP backend
sampath1117 Jun 25, 2024
cf6eb41
added comments for compute_prefix_sum helper function
sampath1117 Jun 25, 2024
ca66f60
Merge branch 'develop' into sr/nsr_hip
sampath1117 Jun 25, 2024
3ed52ea
Merge branch 'develop' into sr/nsr_hip
sampath1117 Jun 26, 2024
9bcee81
added audio flag changes in HIP test suite cmake
sampath1117 Jun 26, 2024
0684b56
removed empty blank line
sampath1117 Jun 26, 2024
76fd42e
Merge pull request #250 from sampath1117/sr/nsr_hip
r-abishek Jun 28, 2024
7c7324a
added const variable for maximum MMS buffer size
sampath1117 Jul 12, 2024
8924c31
added few more comments for moving mean sqaure hip kernel
sampath1117 Jul 12, 2024
0789ed3
Merge pull request #289 from sampath1117/sr/nsr_hip_pr_changes
r-abishek Jul 12, 2024
5ef9dc5
minor change to comment
sampath1117 Jul 16, 2024
ea1b3eb
Merge pull request #292 from sampath1117/sr/nsr_hip_pr_changes
r-abishek Jul 16, 2024
61b45e8
Merge branch 'develop' into ar/audio_support_1_non_silent_region_hip
r-abishek Jul 16, 2024
5d4bafa
audio test suite changes for python 2 compatibility
sampath1117 Jul 17, 2024
60b6966
Merge pull request #293 from sampath1117/sr/nsr_hip_pr_changes
r-abishek Jul 17, 2024
761a49b
Merge branch 'develop' into ar/audio_support_1_non_silent_region_hip
r-abishek Jul 23, 2024
cd8ae4d
Merge branch 'develop' into ar/audio_support_1_non_silent_region_hip
kiritigowda Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/rppdefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ SOFTWARE.
const float ONE_OVER_6 = 1.0f / 6;
const float ONE_OVER_3 = 1.0f / 3;
const float ONE_OVER_255 = 1.0f / 255;
const uint MMS_MAX_SCRATCH_MEMORY = 76800000; // maximum scratch memory size (number of floats) needed for MMS buffer in RNNT training

/******************** RPP typedefs ********************/

Expand Down Expand Up @@ -136,7 +137,13 @@ typedef enum
/*! \brief src and dst layout mismatch \ingroup group_rppdefs */
RPP_ERROR_LAYOUT_MISMATCH = -18,
/*! \brief Number of channels is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_CHANNELS = -19
RPP_ERROR_INVALID_CHANNELS = -19,
/*! \brief Invalid output tile length (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_OUTPUT_TILE_LENGTH = -20,
/*! \brief Shared memory size needed is beyond the bounds (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_OUT_OF_BOUND_SHARED_MEMORY_SIZE = -21,
/*! \brief Scratch memory size needed is beyond the bounds (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_OUT_OF_BOUND_SCRATCH_MEMORY_SIZE = -22,
} RppStatus;

/*! \brief RPP rppStatus_t type enums
Expand Down
108 changes: 65 additions & 43 deletions include/rppt_tensor_audio_augmentations.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,55 @@ extern "C" {
* \details Non Silent Region Detection augmentation for 1D audio buffer
\n Finds the starting index and length of non silent region in the audio buffer by comparing the
calculated short-term power with cutoff value passed
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param[out] detectedIndexTensor beginning index of non silent region (1D tensor in HOST memory, of size batchSize)
* \param[out] detectionLengthTensor length of non silent region (1D tensor in HOST memory, of size batchSize)
* \param[in] cutOffDB cutOff in dB below which the signal is considered silent
* \param[in] windowLength window length used for computing short-term power of the signal
* \param[in] referencePower reference power that is used to convert the signal to dB
* \param[in] resetInterval number of samples after which the moving mean average is recalculated to avoid precision loss
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param [out] detectedIndexTensor beginning index of non silent region (1D tensor in HOST memory, of size batchSize)
* \param [out] detectionLengthTensor length of non silent region (1D tensor in HOST memory, of size batchSize)
* \param [in] cutOffDB cutOff in dB below which the signal is considered silent
* \param [in] windowLength window length used for computing short-term power of the signal
* \param [in] referencePower reference power that is used to convert the signal to dB
* \param [in] resetInterval number of samples after which the moving mean average is recalculated to avoid precision loss
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_non_silent_region_detection_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, Rpp32s *srcLengthTensor, Rpp32s *detectedIndexTensor, Rpp32s *detectionLengthTensor, Rpp32f cutOffDB, Rpp32s windowLength, Rpp32f referencePower, Rpp32s resetInterval, rppHandle_t rppHandle);

#ifdef GPU_SUPPORT
/*! \brief Non Silent Region Detection augmentation on HIP backend
* \details Non Silent Region Detection augmentation for 1D audio buffer
\n Finds the starting index and length of non silent region in the audio buffer by comparing the
calculated short-term power with cutoff value passed
* \param [in] srcPtr source tensor in HIP memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcLengthTensor source audio buffer length (1D tensor in Pinned/HIP memory, of size batchSize)
* \param [out] detectedIndexTensor beginning index of non silent region (1D tensor in Pinned/HIP memory, of size batchSize)
* \param [out] detectionLengthTensor length of non silent region (1D tensor in Pinned/HIP memory, of size batchSize)
* \param [in] cutOffDB cutOff in dB below which the signal is considered silent
* \param [in] windowLength window length used for computing short-term power of the signal
* \param [in] referencePower reference power that is used to convert the signal to dB
* \param [in] resetInterval number of samples after which the moving mean average is recalculated to avoid precision loss
* \param [in] rppHandle RPP HIP handle created with <tt>\ref rppCreateWithStreamAndBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_non_silent_region_detection_gpu(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, Rpp32s *srcLengthTensor, Rpp32s *detectedIndexTensor, Rpp32s *detectionLengthTensor, Rpp32f cutOffDB, Rpp32s windowLength, Rpp32f referencePower, Rpp32s resetInterval, rppHandle_t rppHandle);
#endif // GPU_SUPPORT

/*! \brief To Decibels augmentation on HOST backend
* \details To Decibels augmentation for 1D audio buffer converts magnitude values to decibel values
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcDims source tensor sizes for each element in batch (2D tensor in HOST memory, of size batchSize * 2)
* \param[in] cutOffDB minimum or cut-off ratio in dB
* \param[in] multiplier factor by which the logarithm is multiplied
* \param[in] referenceMagnitude Reference magnitude if not provided maximum value of input used as reference
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcDims source tensor sizes for each element in batch (2D tensor in HOST memory, of size batchSize * 2)
* \param [in] cutOffDB minimum or cut-off ratio in dB
* \param [in] multiplier factor by which the logarithm is multiplied
* \param [in] referenceMagnitude Reference magnitude if not provided maximum value of input used as reference
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand All @@ -83,14 +105,14 @@ RppStatus rppt_to_decibels_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_

/*! \brief Pre Emphasis Filter augmentation on HOST backend
* \details Pre Emphasis Filter augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param[in] coeffTensor preemphasis coefficient (1D tensor in HOST memory, of size batchSize)
* \param[in] borderType border value policy
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param [in] coeffTensor preemphasis coefficient (1D tensor in HOST memory, of size batchSize)
* \param [in] borderType border value policy
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand All @@ -99,13 +121,13 @@ RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr,

/*! \brief Down Mixing augmentation on HOST backend
* \details Down Mixing augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param [in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand Down Expand Up @@ -155,15 +177,15 @@ RppStatus rppt_mel_filter_bank_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, Rpp

/*! \brief Resample augmentation on HOST backend
* \details Resample augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] inRate Input sampling rate (1D tensor in HOST memory, of size batchSize)
* \param[in] outRate Output sampling rate (1D tensor in HOST memory, of size batchSize)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] window Resampling window (struct of type RpptRpptResamplingWindow)
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] inRate Input sampling rate (1D tensor in HOST memory, of size batchSize)
* \param [in] outRate Output sampling rate (1D tensor in HOST memory, of size batchSize)
* \param [in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param [in] window Resampling window (struct of type RpptRpptResamplingWindow)
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand Down
7 changes: 6 additions & 1 deletion src/modules/hip/handlehip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,12 @@ struct HandleImpl
}

hipMalloc(&(this->initHandle->mem.mgpu.rgbArr.rgbmem), sizeof(RpptRGB) * this->nBatchSize);
hipMalloc(&(this->initHandle->mem.mgpu.scratchBufferHip.floatmem), sizeof(Rpp32f) * 8294400); // 3840 x 2160

/* (600000 + 293 + 128) * 128 - Maximum scratch memory required for Non Silent Region Detection HIP kernel used in RNNT training (uses a batchsize 128)
- 600000 is the maximum size that will be required for MMS buffer based on Librispeech dataset
- 293 is the size required for storing reduction outputs for 600000 size sample
- 128 is the size required for storing cutOffDB values for batch size 128 */
hipMalloc(&(this->initHandle->mem.mgpu.scratchBufferHip.floatmem), sizeof(Rpp32f) * 76853888);
}
};

Expand Down
30 changes: 30 additions & 0 deletions src/modules/hip/hip_tensor_audio_augmentations.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
MIT License

Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#ifndef HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP
#define HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP

#include "kernel/non_silent_region_detection.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is only including one header file, you don't really need a separate header file. You can directly include kernel/non_silent_region_detection.hpp in where it is referrenced

Copy link
Contributor

@sampath1117 sampath1117 Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Rajy,
This header file is used for including all audio HIP kernels
Currently since this is the first audio HIP PR, this has only included only 1 HIP kernel
Once few more audio HIP kernels are merged, this header will have more included like it is for HOST
https://github.com/ROCm/rpp/blob/develop/src/modules/cpu/host_tensor_audio_augmentations.hpp


#endif // HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP
Loading