diff --git a/docs/data/doxygenInputs/lens_img640x480.png b/docs/data/doxygenInputs/lens_img640x480.png new file mode 100644 index 000000000..897955d77 Binary files /dev/null and b/docs/data/doxygenInputs/lens_img640x480.png differ diff --git a/docs/data/doxygenOutputs/geometric_augmentations_lens_correction_img_640x480.png b/docs/data/doxygenOutputs/geometric_augmentations_lens_correction_img_640x480.png new file mode 100644 index 000000000..63a52819d Binary files /dev/null and b/docs/data/doxygenOutputs/geometric_augmentations_lens_correction_img_640x480.png differ diff --git a/include/rppt_tensor_geometric_augmentations.h b/include/rppt_tensor_geometric_augmentations.h index aa369111e..884127a71 100644 --- a/include/rppt_tensor_geometric_augmentations.h +++ b/include/rppt_tensor_geometric_augmentations.h @@ -634,6 +634,60 @@ RppStatus rppt_remap_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstP RppStatus rppt_remap_gpu(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32f *rowRemapTable, Rpp32f *colRemapTable, RpptDescPtr tableDescPtr, RpptInterpolationType interpolationType, RpptROIPtr roiTensorPtrSrc, RpptRoiType roiType, rppHandle_t rppHandle); #endif // GPU_SUPPORT +/*! \brief Lens correction transformation on HOST backend for a NCHW/NHWC layout tensor + * \details Performs lens correction transforms on an image to compensate barrel lens distortion of RGB(3 channel) / greyscale(1 channel) images with an NHWC/NCHW tensor layout.
+ * - srcPtr depth ranges - Rpp8u (0 to 255), Rpp16f (0 to 1), Rpp32f (0 to 1), Rpp8s (-128 to 127). + * - dstPtr depth ranges - Will be same depth as srcPtr. + * Note: Returns a black image if the passed camera matrix has a 0 determinant + * \image html lens_img640x480.png Sample Input + * \image html geometric_augmentations_lens_correction_img_640x480.png Sample Output + * \param [in] srcPtr source tensor in HOST memory + * \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = U8/F16/F32/I8, layout = NCHW/NHWC, c = 1/3) + * \param [out] dstPtr destination tensor in HOST memory + * \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = U8/F16/F32/I8, layout = NCHW/NHWC, c = same as that of srcDescPtr) + * \param [in] rowRemapTable Rpp32f row numbers in HOST memory for every pixel in the input batch of images (1D tensor of size width * height * batchSize) + * \param [in] colRemapTable Rpp32f column numbers in HOST memory for every pixel in the input batch of images (1D tensor of size width * height * batchSize) + * \param [in] tableDescPtr table tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = F32, layout = NHWC, c = 1) + * \param [in] cameraMatrixTensor contains camera intrinsic parameters required to compute lens corrected image. (1D tensor of size 9 * batchSize) + * \param [in] distortionCoeffsTensor contains distortion coefficients required to compute lens corrected image. (1D tensor of size 8 * batchSize) + * \param [in] roiTensorSrc ROI data in HOST memory, for each image in source tensor (2D tensor of size batchSize * 4, in either format - XYWH(xy.x, xy.y, roiWidth, roiHeight) or LTRB(lt.x, lt.y, rb.x, rb.y)) + * \param [in] roiType ROI type used (RpptRoiType::XYWH or RpptRoiType::LTRB) + * \param [in] rppHandle RPP HOST handle created with \ref rppCreateWithBatchSize() + * \return A \ref RppStatus enumeration. + * \retval RPP_SUCCESS Successful completion. + * \retval RPP_ERROR* Unsuccessful completion. + * \ingroup group_tensor_geometric + */ +RppStatus rppt_lens_correction_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32f *rowRemapTable, Rpp32f *colRemapTable, RpptDescPtr tableDescPtr, Rpp32f *cameraMatrixTensor, Rpp32f *distortionCoeffsTensor, RpptROIPtr roiTensorPtrSrc, RpptRoiType roiType, rppHandle_t rppHandle); + +#ifdef GPU_SUPPORT +/*! \brief Lens correction transformation on HIP backend for a NCHW/NHWC layout tensor + * \details Performs lens correction transforms on an image to compensate barrel lens distortion of RGB(3 channel) / greyscale(1 channel) images with an NHWC/NCHW tensor layout.
+ * - srcPtr depth ranges - Rpp8u (0 to 255), Rpp16f (0 to 1), Rpp32f (0 to 1), Rpp8s (-128 to 127). + * - dstPtr depth ranges - Will be same depth as srcPtr. + * Note: Returns a black image if the passed camera matrix has a 0 determinant + * \image html lens_img640x480.png Sample Input + * \image html geometric_augmentations_lens_correction_img_640x480.png Sample Output + * \param [in] srcPtr source tensor in HIP memory + * \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = U8/F16/F32/I8, layout = NCHW/NHWC, c = 1/3) + * \param [out] dstPtr destination tensor in HIP memory + * \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = U8/F16/F32/I8, layout = NCHW/NHWC, c = same as that of srcDescPtr) + * \param [in] rowRemapTable Rpp32f row numbers in HIP memory for every pixel in the input batch of images (1D tensor of size width * height * batchSize) + * \param [in] colRemapTable Rpp32f column numbers in HIP memory for every pixel in the input batch of images (1D tensor of size width * height * batchSize) + * \param [in] tableDescPtr table tensor descriptor (Restrictions - numDims = 4, offsetInBytes >= 0, dataType = F32, layout = NHWC, c = 1) + * \param [in] cameraMatrixTensor contains camera intrinsic parameters required to compute lens corrected image. (1D tensor of size 9 * batchSize) + * \param [in] distortionCoeffsTensor contains distortion coefficients required to compute lens corrected image. (1D tensor of size 8 * batchSize) + * \param [in] roiTensorSrc ROI data in HIP memory, for each image in source tensor (2D tensor of size batchSize * 4, in either format - XYWH(xy.x, xy.y, roiWidth, roiHeight) or LTRB(lt.x, lt.y, rb.x, rb.y)) + * \param [in] roiType ROI type used (RpptRoiType::XYWH or RpptRoiType::LTRB) + * \param [in] rppHandle RPP HIP handle created with \ref rppCreateWithStreamAndBatchSize() + * \return A \ref RppStatus enumeration. + * \retval RPP_SUCCESS Successful completion. + * \retval RPP_ERROR* Unsuccessful completion. + * \ingroup group_tensor_geometric + */ +RppStatus rppt_lens_correction_gpu(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32f *rowRemapTable, Rpp32f *colRemapTable, RpptDescPtr tableDescPtr, Rpp32f *cameraMatrixTensor, Rpp32f *distortionCoeffsTensor, RpptROIPtr roiTensorPtrSrc, RpptRoiType roiType, rppHandle_t rppHandle); +#endif // GPU_SUPPORT + /*! \brief Transpose Generic augmentation on HOST backend * \details The transpose augmentation performs an input-permutation based transpose on a generic ND Tensor. * \param [in] srcPtr source tensor in HOST memory @@ -674,4 +728,4 @@ RppStatus rppt_transpose_gpu(RppPtr_t srcPtr, RpptGenericDescPtr srcGenericDescP #ifdef __cplusplus } #endif -#endif // RPPT_TENSOR_GEOMETRIC_AUGMENTATIONS_H +#endif // RPPT_TENSOR_GEOMETRIC_AUGMENTATIONS_H \ No newline at end of file diff --git a/src/include/hip/rpp_hip_common.hpp b/src/include/hip/rpp_hip_common.hpp index e6e2ab986..16e3a2765 100644 --- a/src/include/hip/rpp_hip_common.hpp +++ b/src/include/hip/rpp_hip_common.hpp @@ -55,7 +55,7 @@ typedef union { float f1[5]; typedef union { float f1[6]; float2 f2[3]; } d_float6; typedef union { float f1[7]; } d_float7; typedef union { float f1[8]; float2 f2[4]; float4 f4[2]; } d_float8; -typedef union { float f1[9]; } d_float9; +typedef union { float f1[9]; float3 f3[3]; } d_float9; typedef union { float f1[12]; float4 f4[3]; } d_float12; typedef union { float f1[16]; float4 f4[4]; d_float8 f8[2]; } d_float16; typedef union { float f1[24]; float2 f2[12]; float3 f3[8]; float4 f4[6]; d_float8 f8[3]; } d_float24; @@ -1776,6 +1776,22 @@ __device__ __forceinline__ void rpp_hip_math_multiply24_const(d_float24 *src_f24 dst_f24->f4[5] = src_f24->f4[5] * multiplier_f4; } +// d_float8 divide + +__device__ __forceinline__ void rpp_hip_math_divide8(d_float8 *src1Ptr_f8, d_float8 *src2Ptr_f8, d_float8 *dstPtr_f8) +{ + dstPtr_f8->f4[0] = src1Ptr_f8->f4[0] / src2Ptr_f8->f4[0]; + dstPtr_f8->f4[1] = src1Ptr_f8->f4[1] / src2Ptr_f8->f4[1]; +} + +// d_float8 divide with constant + +__device__ __forceinline__ void rpp_hip_math_divide8_const(d_float8 *src_f8, d_float8 *dst_f8, float4 divisor_f4) +{ + dst_f8->f4[0] = divisor_f4 / src_f8->f4[0]; + dst_f8->f4[1] = divisor_f4 / src_f8->f4[1]; +} + // d_float8 bitwiseAND __device__ __forceinline__ void rpp_hip_math_bitwiseAnd8(d_float8 *src1_f8, d_float8 *src2_f8, d_float8 *dst_f8) diff --git a/src/modules/cpu/host_tensor_geometric_augmentations.hpp b/src/modules/cpu/host_tensor_geometric_augmentations.hpp index c29d56c0d..9facb0d78 100644 --- a/src/modules/cpu/host_tensor_geometric_augmentations.hpp +++ b/src/modules/cpu/host_tensor_geometric_augmentations.hpp @@ -35,6 +35,7 @@ SOFTWARE. #include "kernel/warp_affine.hpp" #include "kernel/phase.hpp" #include "kernel/slice.hpp" +#include "kernel/lens_correction.hpp" #include "kernel/transpose.hpp" #include "kernel/crop_and_patch.hpp" #include "kernel/flip_voxel.hpp" diff --git a/src/modules/cpu/kernel/lens_correction.hpp b/src/modules/cpu/kernel/lens_correction.hpp new file mode 100644 index 000000000..1632568a5 --- /dev/null +++ b/src/modules/cpu/kernel/lens_correction.hpp @@ -0,0 +1,178 @@ +/* +MIT License + +Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#include "rppdefs.h" +#include "rpp_cpu_simd.hpp" +#include "rpp_cpu_common.hpp" +#include + +// Compute Inverse matrix (3x3) +inline void get_inverse(float *mat, float *invMat) +{ + float det = mat[0] * (mat[4] * mat[8] - mat[7] * mat[5]) - mat[1] * (mat[3] * mat[8] - mat[5] * mat[6]) + mat[2] * (mat[3] * mat[7] - mat[4] * mat[6]); + if(det != 0) + { + float invDet = 1 / det; + invMat[0] = (mat[4] * mat[8] - mat[7] * mat[5]) * invDet; + invMat[1] = (mat[2] * mat[7] - mat[1] * mat[8]) * invDet; + invMat[2] = (mat[1] * mat[5] - mat[2] * mat[4]) * invDet; + invMat[3] = (mat[5] * mat[6] - mat[3] * mat[8]) * invDet; + invMat[4] = (mat[0] * mat[8] - mat[2] * mat[6]) * invDet; + invMat[5] = (mat[3] * mat[2] - mat[0] * mat[5]) * invDet; + invMat[6] = (mat[3] * mat[7] - mat[6] * mat[4]) * invDet; + invMat[7] = (mat[6] * mat[1] - mat[0] * mat[7]) * invDet; + invMat[8] = (mat[0] * mat[4] - mat[3] * mat[1]) * invDet; + } +} + +inline void compute_lens_correction_remap_tables_host_tensor(RpptDescPtr srcDescPtr, + Rpp32f *rowRemapTable, + Rpp32f *colRemapTable, + RpptDescPtr tableDescPtr, + Rpp32f *cameraMatrixTensor, + Rpp32f *distortionCoeffsTensor, + RpptROIPtr roiTensorPtrSrc, + rpp::Handle& handle) +{ + Rpp32u numThreads = handle.GetNumThreads(); + omp_set_dynamic(0); +#pragma omp parallel for num_threads(numThreads) + for(int batchCount = 0; batchCount < srcDescPtr->n; batchCount++) + { + Rpp32f *rowRemapTableTemp, *colRemapTableTemp; + rowRemapTableTemp = rowRemapTable + batchCount * tableDescPtr->strides.nStride; + colRemapTableTemp = colRemapTable + batchCount * tableDescPtr->strides.nStride; + + // cameraMatrix is a 3x3 matrix thus increment by 9 to iterate from one tensor in a batch to another + Rpp32f *cameraMatrix = cameraMatrixTensor + batchCount * 9; + Rpp32f *distortionCoeffs = distortionCoeffsTensor + batchCount * 8; + Rpp32s height = roiTensorPtrSrc[batchCount].xywhROI.roiHeight; + Rpp32s width = roiTensorPtrSrc[batchCount].xywhROI.roiWidth; + Rpp32u alignedLength = width & ~7; + Rpp32s vectorIncrement = 8; + + Rpp32f invCameraMatrix[9]; + std::fill(invCameraMatrix, invCameraMatrix + 9, 0.0f); // initialize all values in invCameraMatrix to zero + get_inverse(cameraMatrix, invCameraMatrix); + Rpp32f *invMat = &invCameraMatrix[0]; + + // Get radial and tangential distortion coefficients + Rpp32f rCoeff[6] = { distortionCoeffs[0], distortionCoeffs[1], distortionCoeffs[4], distortionCoeffs[5], distortionCoeffs[6], distortionCoeffs[7] }; + Rpp32f tCoeff[2] = { distortionCoeffs[2], distortionCoeffs[3] }; + + __m256 pRCoeff[6], pTCoeff[2]; + pRCoeff[0] = _mm256_set1_ps(rCoeff[0]); + pRCoeff[1] = _mm256_set1_ps(rCoeff[1]); + pRCoeff[2] = _mm256_set1_ps(rCoeff[2]); + pRCoeff[3] = _mm256_set1_ps(rCoeff[3]); + pRCoeff[4] = _mm256_set1_ps(rCoeff[4]); + pRCoeff[5] = _mm256_set1_ps(rCoeff[5]); + pTCoeff[0] = _mm256_set1_ps(tCoeff[0]); + pTCoeff[1] = _mm256_set1_ps(tCoeff[1]); + + Rpp32f u0 = cameraMatrix[2], v0 = cameraMatrix[5]; + Rpp32f fx = cameraMatrix[0], fy = cameraMatrix[4]; + __m256 pFx, pFy, pU0, pV0; + pFx = _mm256_set1_ps(fx); + pFy = _mm256_set1_ps(fy); + pU0 = _mm256_set1_ps(u0); + pV0 = _mm256_set1_ps(v0); + + __m256 pInvMat0, pInvMat3, pInvMat6; + pInvMat0 = _mm256_set1_ps(invMat[0]); + pInvMat3 = _mm256_set1_ps(invMat[3]); + pInvMat6 = _mm256_set1_ps(invMat[6]); + + __m256 pXCameraInit, pYCameraInit, pZCameraInit; + __m256 pXCameraIncrement, pYCameraIncrement, pZCameraIncrement; + pXCameraInit = _mm256_mul_ps(avx_pDstLocInit, pInvMat0); + pYCameraInit = _mm256_mul_ps(avx_pDstLocInit, pInvMat3); + pZCameraInit = _mm256_mul_ps(avx_pDstLocInit, pInvMat6); + pXCameraIncrement = _mm256_mul_ps(pInvMat0, avx_p8); + pYCameraIncrement = _mm256_mul_ps(pInvMat3, avx_p8); + pZCameraIncrement = _mm256_mul_ps(pInvMat6, avx_p8); + for(int i = 0; i < height; i++) + { + Rpp32f *rowRemapTableRow = rowRemapTableTemp + i * tableDescPtr->strides.hStride; + Rpp32f *colRemapTableRow = colRemapTableTemp + i * tableDescPtr->strides.hStride; + Rpp32f xCamera = i * invMat[1] + invMat[2]; + Rpp32f yCamera = i * invMat[4] + invMat[5]; + Rpp32f zCamera = i * invMat[7] + invMat[8]; + __m256 pXCamera = _mm256_add_ps(_mm256_set1_ps(xCamera), pXCameraInit); + __m256 pYCamera = _mm256_add_ps(_mm256_set1_ps(yCamera), pYCameraInit); + __m256 pZCamera = _mm256_add_ps(_mm256_set1_ps(zCamera), pZCameraInit); + int vectorLoopCount = 0; + for(; vectorLoopCount < alignedLength; vectorLoopCount += vectorIncrement) + { + // float z = 1./zCamera, x = xCamera*z, y = yCamera*z; + __m256 pZ = _mm256_div_ps(avx_p1, pZCamera); + __m256 pX = _mm256_mul_ps(pXCamera, pZ); + __m256 pY = _mm256_mul_ps(pYCamera, pZ); + + // float xSquare = x*x, ySquare = y*y, r2 = xSquare + ySquare; + __m256 pXSquare = _mm256_mul_ps(pX, pX); + __m256 pYSquare = _mm256_mul_ps(pY, pY); + __m256 pR2 = _mm256_add_ps(pXSquare, pYSquare); + + // float xyMul2 = 2*x*y; + __m256 p2xy = _mm256_mul_ps(avx_p2, _mm256_mul_ps(pX, pY)); + + // float kr = std::fmaf(std::fmaf(std::fmaf(rCoeff[2], r2, rCoeff[1]), r2, rCoeff[0]), r2, 1) / std::fmaf(std::fmaf(std::fmaf(rCoeff[5], r2, rCoeff[4]), r2, rCoeff[3]), r2, 1); + __m256 pNum = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(pRCoeff[2], pR2, pRCoeff[1]), pR2, pRCoeff[0]), pR2, avx_p1); + __m256 pDen = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(pRCoeff[5], pR2, pRCoeff[4]), pR2, pRCoeff[3]), pR2, avx_p1); + __m256 pKR = _mm256_div_ps(pNum, pDen); + + // float colLoc = std::fmaf(fx, (std::fmaf(tCoeff[1], (std::fmaf(2, xSquare, r2)), std::fmaf(x, kr, (tCoeff[0] * xyMul2)))), u0); + __m256 pColLoc = _mm256_fmadd_ps(pFx, _mm256_fmadd_ps(pTCoeff[1], _mm256_fmadd_ps(avx_p2, pXSquare, pR2), _mm256_fmadd_ps(pX, pKR, _mm256_mul_ps(pTCoeff[0], p2xy))), pU0); + + // float rowLoc = std::fmaf(fy, (std::fmaf(tCoeff[0], (std::fmaf(2, ySquare, r2)), std::fmaf(y, kr, (tCoeff[1] * xyMul2)))), v0); + __m256 pRowLoc = _mm256_fmadd_ps(pFy, _mm256_fmadd_ps(pTCoeff[0], _mm256_fmadd_ps(avx_p2, pYSquare, pR2), _mm256_fmadd_ps(pY, pKR, _mm256_mul_ps(pTCoeff[1], p2xy))), pV0); + + _mm256_storeu_ps(rowRemapTableRow, pRowLoc); + _mm256_storeu_ps(colRemapTableRow, pColLoc); + rowRemapTableRow += vectorIncrement; + colRemapTableRow += vectorIncrement; + + // xCamera += invMat[0], yCamera += invMat[3], zCamera += invMat[6] + pXCamera = _mm256_add_ps(pXCamera, pXCameraIncrement); + pYCamera = _mm256_add_ps(pYCamera, pYCameraIncrement); + pZCamera = _mm256_add_ps(pZCamera, pZCameraIncrement); + } + for(; vectorLoopCount < width; vectorLoopCount++) + { + Rpp32f z = 1./zCamera, x = xCamera * z, y = yCamera * z; + Rpp32f xSquare = x * x, ySquare = y * y, r2 = xSquare + ySquare; + Rpp32f xyMul2 = 2 * x * y; + Rpp32f kr = std::fmaf(std::fmaf(std::fmaf(rCoeff[2], r2, rCoeff[1]), r2, rCoeff[0]), r2, 1) / std::fmaf(std::fmaf(std::fmaf(rCoeff[5], r2, rCoeff[4]), r2, rCoeff[3]), r2, 1); + Rpp32f colLoc = std::fmaf(fx, (std::fmaf(tCoeff[1], (std::fmaf(2, xSquare, r2)), std::fmaf(x, kr, (tCoeff[0] * xyMul2)))), u0); + Rpp32f rowLoc = std::fmaf(fy, (std::fmaf(tCoeff[0], (std::fmaf(2, ySquare, r2)), std::fmaf(y, kr, (tCoeff[1] * xyMul2)))), v0); + *rowRemapTableRow++ = rowLoc; + *colRemapTableRow++ = colLoc; + xCamera += invMat[0]; + yCamera += invMat[3]; + zCamera += invMat[6]; + } + } + } +} \ No newline at end of file diff --git a/src/modules/hip/hip_tensor_geometric_augmentations.hpp b/src/modules/hip/hip_tensor_geometric_augmentations.hpp index 12cc5592d..102e7d686 100644 --- a/src/modules/hip/hip_tensor_geometric_augmentations.hpp +++ b/src/modules/hip/hip_tensor_geometric_augmentations.hpp @@ -35,6 +35,7 @@ SOFTWARE. #include "kernel/resize_crop_mirror.hpp" #include "kernel/phase.hpp" #include "kernel/slice.hpp" +#include "kernel/lens_correction.hpp" #include "kernel/transpose.hpp" #include "kernel/crop_and_patch.hpp" #include "kernel/flip_voxel.hpp" diff --git a/src/modules/hip/kernel/lens_correction.hpp b/src/modules/hip/kernel/lens_correction.hpp new file mode 100644 index 000000000..0d53db7e1 --- /dev/null +++ b/src/modules/hip/kernel/lens_correction.hpp @@ -0,0 +1,183 @@ +#include +#include "rpp_hip_common.hpp" + +// -------------------- Set 0 - lens_correction device helpers -------------------- + +__device__ __forceinline__ void camera_coordinates_hip_compute(d_float8 *cameraCoords_f8, int id_y, d_float8 *locDst_f8x, float3 *inverseMatrix) +{ + float4 inverseCoord1_f4 = static_cast(id_y * inverseMatrix->y + inverseMatrix->z); + float4 inverseCoord2_f4 = static_cast(inverseMatrix->x); + cameraCoords_f8->f4[0] = inverseCoord1_f4 + locDst_f8x->f4[0] * inverseCoord2_f4; + cameraCoords_f8->f4[1] = inverseCoord1_f4 + locDst_f8x->f4[1] * inverseCoord2_f4; +} + +// -------------------- Set 1 - lens_correction kernels -------------------- + +// compute inverse of 3x3 camera matrix +__global__ void compute_inverse_matrix_hip_tensor(d_float9 *matTensor, d_float9 *invMatTensor) +{ + int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z; + d_float9 *mat_f9 = &matTensor[id_z]; + d_float9 *invMat_f9 = &invMatTensor[id_z]; + + // initialize all values in invMat_f9 to zero + invMat_f9->f3[0] = static_cast(0.0f); + invMat_f9->f3[1] = invMat_f9->f3[0]; + invMat_f9->f3[2] = invMat_f9->f3[0]; + + // compute determinant mat_f9 + float det = (mat_f9->f1[0] * ((mat_f9->f1[4] * mat_f9->f1[8]) - (mat_f9->f1[7] * mat_f9->f1[5]))) + - (mat_f9->f1[1] * ((mat_f9->f1[3] * mat_f9->f1[8]) - (mat_f9->f1[5] * mat_f9->f1[6]))) + + (mat_f9->f1[2] * ((mat_f9->f1[3] * mat_f9->f1[7]) - (mat_f9->f1[4] * mat_f9->f1[6]))); + if(det != 0) + { + float invDet = 1 / det; + invMat_f9->f1[0] = (mat_f9->f1[4] * mat_f9->f1[8] - mat_f9->f1[7] * mat_f9->f1[5]) * invDet; + invMat_f9->f1[1] = (mat_f9->f1[2] * mat_f9->f1[7] - mat_f9->f1[1] * mat_f9->f1[8]) * invDet; + invMat_f9->f1[2] = (mat_f9->f1[1] * mat_f9->f1[5] - mat_f9->f1[2] * mat_f9->f1[4]) * invDet; + invMat_f9->f1[3] = (mat_f9->f1[5] * mat_f9->f1[6] - mat_f9->f1[3] * mat_f9->f1[8]) * invDet; + invMat_f9->f1[4] = (mat_f9->f1[0] * mat_f9->f1[8] - mat_f9->f1[2] * mat_f9->f1[6]) * invDet; + invMat_f9->f1[5] = (mat_f9->f1[3] * mat_f9->f1[2] - mat_f9->f1[0] * mat_f9->f1[5]) * invDet; + invMat_f9->f1[6] = (mat_f9->f1[3] * mat_f9->f1[7] - mat_f9->f1[6] * mat_f9->f1[4]) * invDet; + invMat_f9->f1[7] = (mat_f9->f1[6] * mat_f9->f1[1] - mat_f9->f1[0] * mat_f9->f1[7]) * invDet; + invMat_f9->f1[8] = (mat_f9->f1[0] * mat_f9->f1[4] - mat_f9->f1[3] * mat_f9->f1[1]) * invDet; + } +} + +// compute remap tables from the camera matrix and distortion coefficients +__global__ void compute_remap_tables_hip_tensor(float *rowRemapTable, + float *colRemapTable, + d_float9 *cameraMatrixTensor, + d_float9 *inverseMatrixTensor, + d_float8 *distortionCoeffsTensor, + uint2 remapTableStridesNH, + RpptROIPtr roiTensorPtrSrc) +{ + int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8; + int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; + int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z; + + if ((id_y >= roiTensorPtrSrc[id_z].xywhROI.roiHeight) || (id_x >= roiTensorPtrSrc[id_z].xywhROI.roiWidth)) + return; + + d_float9 cameraMatrix_f9 = cameraMatrixTensor[id_z]; + d_float9 inverseMatrix_f9 = inverseMatrixTensor[id_z]; + d_float8 distortionCoeffs_f8 = distortionCoeffsTensor[id_z]; + + // Get radial and tangential distortion coefficients + float radialCoeff[6] = {distortionCoeffs_f8.f1[0], distortionCoeffs_f8.f1[1], distortionCoeffs_f8.f1[4], distortionCoeffs_f8.f1[5], distortionCoeffs_f8.f1[6], distortionCoeffs_f8.f1[7]}; + float tangentialCoeff[2] = {distortionCoeffs_f8.f1[2], distortionCoeffs_f8.f1[3]}; + + uint dstIdx = id_z * remapTableStridesNH.x + id_y * remapTableStridesNH.y + id_x; + d_float8 locDst_f8x; + locDst_f8x.f4[0] = static_cast(id_x) + make_float4(0, 1, 2, 3); + locDst_f8x.f4[1] = static_cast(id_x) + make_float4(4, 5, 6, 7); + + float4 one_f4 = static_cast(1.0f); + float4 two_f4 = static_cast(2.0f); + d_float8 z_f8, y_f8, x_f8; + camera_coordinates_hip_compute(&z_f8, id_y, &locDst_f8x, &inverseMatrix_f9.f3[2]); // float zCamera = id_y * inverseMatrix.f1[7] + inverseMatrix.f1[8] + id_x * inverseMatrix.f1[6] + camera_coordinates_hip_compute(&y_f8, id_y, &locDst_f8x, &inverseMatrix_f9.f3[1]); // float yCamera = id_y * inverseMatrix.f1[4] + inverseMatrix.f1[5] + id_x * inverseMatrix.f1[3] + camera_coordinates_hip_compute(&x_f8, id_y, &locDst_f8x, &inverseMatrix_f9.f3[0]); // float xCamera = id_y * inverseMatrix.f1[1] + inverseMatrix.f1[2] + id_x * inverseMatrix.f1[0] + rpp_hip_math_divide8_const(&z_f8, &z_f8, one_f4); // float z = 1./zCamera + rpp_hip_math_multiply8(&y_f8, &z_f8, &y_f8); // float y = yCamera * z; + rpp_hip_math_multiply8(&x_f8, &z_f8, &x_f8); // float x = xCamera * z; + + d_float8 ySquare_f8, xSquare_f8; + rpp_hip_math_multiply8(&y_f8, &y_f8, &ySquare_f8); // float ySquare = x * x + rpp_hip_math_multiply8(&x_f8, &x_f8, &xSquare_f8); // float xSquare = x * x + + d_float8 r2_f8, kr_f8, kr1_f8, kr2_f8; + rpp_hip_math_add8(&xSquare_f8, &ySquare_f8, &r2_f8); // float r2 = xSquare + ySquare + + d_float8 r2Cube_f8, r2Square_f8; + rpp_hip_math_multiply8(&r2_f8, &r2_f8, &r2Square_f8); // float r2Square = r2 * r2; + rpp_hip_math_multiply8(&r2Square_f8, &r2_f8, &r2Cube_f8); // float r2Cube = r2Square * r2; + + d_float24 radialCoeff_f24; + radialCoeff_f24.f4[0] = static_cast(radialCoeff[0]); + radialCoeff_f24.f4[1] = static_cast(radialCoeff[1]); + radialCoeff_f24.f4[2] = static_cast(radialCoeff[2]); + radialCoeff_f24.f4[3] = static_cast(radialCoeff[3]); + radialCoeff_f24.f4[4] = static_cast(radialCoeff[4]); + radialCoeff_f24.f4[5] = static_cast(radialCoeff[5]); + + // float kr = (1 + (radialCoeff[2] * r2Cube) + (radialCoeff[1] * r2Square) + (radialCoeff[0]) * r2)) / (1 + (radialCoeff[5] * r2Cube) + (radialCoeff[4] * r2Square) + (radialCoeff[3]) *r2)) + kr1_f8.f4[0] = (one_f4 + (radialCoeff_f24.f4[2] * r2Cube_f8.f4[0]) + (radialCoeff_f24.f4[1] * r2Square_f8.f4[0]) + (radialCoeff_f24.f4[0] * r2_f8.f4[0])); + kr1_f8.f4[1] = (one_f4 + (radialCoeff_f24.f4[2] * r2Cube_f8.f4[1]) + (radialCoeff_f24.f4[1] * r2Square_f8.f4[1]) + (radialCoeff_f24.f4[0] * r2_f8.f4[1])); + kr2_f8.f4[0] = (one_f4 + (radialCoeff_f24.f4[5] * r2Cube_f8.f4[0]) + (radialCoeff_f24.f4[4] * r2Square_f8.f4[0]) + (radialCoeff_f24.f4[3] * r2_f8.f4[0])); + kr2_f8.f4[1] = (one_f4 + (radialCoeff_f24.f4[5] * r2Cube_f8.f4[1]) + (radialCoeff_f24.f4[4] * r2Square_f8.f4[1]) + (radialCoeff_f24.f4[3] * r2_f8.f4[1])); + rpp_hip_math_divide8(&kr1_f8, &kr2_f8, &kr_f8); + + d_float8 xyMul2_f8; + rpp_hip_math_multiply8(&x_f8, &y_f8, &xyMul2_f8); + rpp_hip_math_multiply8_const(&xyMul2_f8, &xyMul2_f8, two_f4); // float xyMul2 = 2 * x * y + + d_float8 colLoc_f8, rowLoc_f8; + rpp_hip_math_multiply8_const(&xSquare_f8, &xSquare_f8, two_f4); // xSquare = xSquare * 2; + rpp_hip_math_multiply8_const(&ySquare_f8, &ySquare_f8, two_f4); // ySquare = ySquare * 2; + + d_float16 cameraMatrix_f16; + cameraMatrix_f16.f4[0] = static_cast(cameraMatrix_f9.f1[0]); + cameraMatrix_f16.f4[1] = static_cast(cameraMatrix_f9.f1[2]); + cameraMatrix_f16.f4[2] = static_cast(cameraMatrix_f9.f1[4]); + cameraMatrix_f16.f4[3] = static_cast(cameraMatrix_f9.f1[5]); + + d_float8 tangentialCoeff_f8; + tangentialCoeff_f8.f4[0] = static_cast(tangentialCoeff[0]); + tangentialCoeff_f8.f4[1] = static_cast(tangentialCoeff[1]); + + // float colLoc = cameraMatrix[0] * (x * kr + tangentialCoeff[0] * xyMul2 + tangentialCoeff[1] * (r2 + 2 * xSquare)) + cameraMatrix[2]; + colLoc_f8.f4[0] = cameraMatrix_f16.f4[0] * ((x_f8.f4[0] * kr_f8.f4[0]) + (tangentialCoeff_f8.f4[0] * xyMul2_f8.f4[0]) + (tangentialCoeff_f8.f4[1] * (r2_f8.f4[0] + xSquare_f8.f4[0]))) + cameraMatrix_f16.f4[1]; + colLoc_f8.f4[1] = cameraMatrix_f16.f4[0] * ((x_f8.f4[1] * kr_f8.f4[1]) + (tangentialCoeff_f8.f4[0] * xyMul2_f8.f4[1]) + (tangentialCoeff_f8.f4[1] * (r2_f8.f4[1] + xSquare_f8.f4[1]))) + cameraMatrix_f16.f4[1]; + + // float rowLoc = cameraMatrix[4] * (y * kr + tangentialCoeff[1] * xyMul2 + tangentialCoeff[0] * (r2 + 2 * ySquare)) + cameraMatrix[4]; + rowLoc_f8.f4[0] = cameraMatrix_f16.f4[2] * ((y_f8.f4[0] * kr_f8.f4[0]) + (tangentialCoeff_f8.f4[1] * xyMul2_f8.f4[0]) + (tangentialCoeff_f8.f4[0] * (r2_f8.f4[0] + ySquare_f8.f4[0]))) + cameraMatrix_f16.f4[3]; + rowLoc_f8.f4[1] = cameraMatrix_f16.f4[2] * ((y_f8.f4[1] * kr_f8.f4[1]) + (tangentialCoeff_f8.f4[1] * xyMul2_f8.f4[1]) + (tangentialCoeff_f8.f4[0] * (r2_f8.f4[1] + ySquare_f8.f4[1]))) + cameraMatrix_f16.f4[3]; + + rpp_hip_pack_float8_and_store8(colRemapTable + dstIdx, &colLoc_f8); + rpp_hip_pack_float8_and_store8(rowRemapTable + dstIdx, &rowLoc_f8); +} + +// -------------------- Set 2 - Kernel Executors -------------------- + +RppStatus hip_exec_lens_correction_tensor(RpptDescPtr dstDescPtr, + Rpp32f *rowRemapTable, + Rpp32f *colRemapTable, + RpptDescPtr remapTableDescPtr, + Rpp32f *cameraMatrix, + Rpp32f *distanceCoeffs, + RpptROIPtr roiTensorPtrSrc, + RpptRoiType roiType, + rpp::Handle& handle) +{ + if (roiType == RpptRoiType::LTRB) + hip_exec_roi_converison_ltrb_to_xywh(roiTensorPtrSrc, handle); + + int globalThreads_x = (dstDescPtr->w + 7) >> 3; + int globalThreads_y = dstDescPtr->h; + int globalThreads_z = dstDescPtr->n; + + float *inverseMatrix = handle.GetInitHandle()->mem.mgpu.scratchBufferHip.floatmem; + hipLaunchKernelGGL(compute_inverse_matrix_hip_tensor, + dim3(1, 1, ceil((float)globalThreads_z/LOCAL_THREADS_Z)), + dim3(1, 1, LOCAL_THREADS_Z), + 0, + handle.GetStream(), + reinterpret_cast(cameraMatrix), + reinterpret_cast(inverseMatrix)); + hipLaunchKernelGGL(compute_remap_tables_hip_tensor, + dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), + dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), + 0, + handle.GetStream(), + rowRemapTable, + colRemapTable, + reinterpret_cast(cameraMatrix), + reinterpret_cast(inverseMatrix), + reinterpret_cast(distanceCoeffs), + make_uint2(remapTableDescPtr->strides.nStride, remapTableDescPtr->strides.hStride), + roiTensorPtrSrc); + + return RPP_SUCCESS; +} \ No newline at end of file diff --git a/src/modules/rppt_tensor_geometric_augmentations.cpp b/src/modules/rppt_tensor_geometric_augmentations.cpp index 32bfea84f..325881c54 100644 --- a/src/modules/rppt_tensor_geometric_augmentations.cpp +++ b/src/modules/rppt_tensor_geometric_augmentations.cpp @@ -1300,6 +1300,91 @@ RppStatus rppt_remap_host(RppPtr_t srcPtr, return RPP_SUCCESS; } +/******************** lens_correction ********************/ + +RppStatus rppt_lens_correction_host(RppPtr_t srcPtr, + RpptDescPtr srcDescPtr, + RppPtr_t dstPtr, + RpptDescPtr dstDescPtr, + Rpp32f *rowRemapTable, + Rpp32f *colRemapTable, + RpptDescPtr tableDescPtr, + Rpp32f *cameraMatrixTensor, + Rpp32f *distortionCoeffsTensor, + RpptROIPtr roiTensorPtrSrc, + RpptRoiType roiType, + rppHandle_t rppHandle) +{ + RppLayoutParams layoutParams = get_layout_params(srcDescPtr->layout, srcDescPtr->c); + compute_lens_correction_remap_tables_host_tensor(srcDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + cameraMatrixTensor, + distortionCoeffsTensor, + roiTensorPtrSrc, + rpp::deref(rppHandle)); + + if ((srcDescPtr->dataType == RpptDataType::U8) && (dstDescPtr->dataType == RpptDataType::U8)) + { + remap_bilinear_u8_u8_host_tensor(static_cast(srcPtr) + srcDescPtr->offsetInBytes, + srcDescPtr, + static_cast(dstPtr) + dstDescPtr->offsetInBytes, + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + roiTensorPtrSrc, + roiType, + layoutParams, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::F16) && (dstDescPtr->dataType == RpptDataType::F16)) + { + remap_bilinear_f16_f16_host_tensor(reinterpret_cast(static_cast(srcPtr) + srcDescPtr->offsetInBytes), + srcDescPtr, + reinterpret_cast(static_cast(dstPtr) + dstDescPtr->offsetInBytes), + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + roiTensorPtrSrc, + roiType, + layoutParams, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32)) + { + remap_bilinear_f32_f32_host_tensor(reinterpret_cast(static_cast(srcPtr) + srcDescPtr->offsetInBytes), + srcDescPtr, + reinterpret_cast(static_cast(dstPtr) + dstDescPtr->offsetInBytes), + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + roiTensorPtrSrc, + roiType, + layoutParams, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::I8) && (dstDescPtr->dataType == RpptDataType::I8)) + { + remap_bilinear_i8_i8_host_tensor(static_cast(srcPtr) + srcDescPtr->offsetInBytes, + srcDescPtr, + static_cast(dstPtr) + dstDescPtr->offsetInBytes, + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + roiTensorPtrSrc, + roiType, + layoutParams, + rpp::deref(rppHandle)); + } + + return RPP_SUCCESS; +} + /******************** transpose ********************/ RppStatus rppt_transpose_host(RppPtr_t srcPtr, @@ -2303,6 +2388,94 @@ RppStatus rppt_remap_gpu(RppPtr_t srcPtr, #endif // backend } +/******************** lens_correction ********************/ + +RppStatus rppt_lens_correction_gpu(RppPtr_t srcPtr, + RpptDescPtr srcDescPtr, + RppPtr_t dstPtr, + RpptDescPtr dstDescPtr, + Rpp32f *rowRemapTable, + Rpp32f *colRemapTable, + RpptDescPtr tableDescPtr, + Rpp32f *cameraMatrixTensor, + Rpp32f *distortionCoeffsTensor, + RpptROIPtr roiTensorPtrSrc, + RpptRoiType roiType, + rppHandle_t rppHandle) +{ +#ifdef HIP_COMPILE + hip_exec_lens_correction_tensor(dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + cameraMatrixTensor, + distortionCoeffsTensor, + roiTensorPtrSrc, + roiType, + rpp::deref(rppHandle)); + + if ((srcDescPtr->dataType == RpptDataType::U8) && (dstDescPtr->dataType == RpptDataType::U8)) + { + hip_exec_remap_tensor(static_cast(srcPtr) + srcDescPtr->offsetInBytes, + srcDescPtr, + static_cast(dstPtr) + dstDescPtr->offsetInBytes, + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + RpptInterpolationType::BILINEAR, + roiTensorPtrSrc, + roiType, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::F16) && (dstDescPtr->dataType == RpptDataType::F16)) + { + hip_exec_remap_tensor(reinterpret_cast(static_cast(srcPtr) + srcDescPtr->offsetInBytes), + srcDescPtr, + reinterpret_cast(static_cast(dstPtr) + dstDescPtr->offsetInBytes), + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + RpptInterpolationType::BILINEAR, + roiTensorPtrSrc, + roiType, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32)) + { + hip_exec_remap_tensor(reinterpret_cast(static_cast(srcPtr) + srcDescPtr->offsetInBytes), + srcDescPtr, + reinterpret_cast(static_cast(dstPtr) + dstDescPtr->offsetInBytes), + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + RpptInterpolationType::BILINEAR, + roiTensorPtrSrc, + roiType, + rpp::deref(rppHandle)); + } + else if ((srcDescPtr->dataType == RpptDataType::I8) && (dstDescPtr->dataType == RpptDataType::I8)) + { + hip_exec_remap_tensor(static_cast(srcPtr) + srcDescPtr->offsetInBytes, + srcDescPtr, + static_cast(dstPtr) + dstDescPtr->offsetInBytes, + dstDescPtr, + rowRemapTable, + colRemapTable, + tableDescPtr, + RpptInterpolationType::BILINEAR, + roiTensorPtrSrc, + roiType, + rpp::deref(rppHandle)); + } + return RPP_SUCCESS; +#elif defined(OCL_COMPILE) + return RPP_ERROR_NOT_IMPLEMENTED; +#endif // backend +} + /******************** transpose ********************/ RppStatus rppt_transpose_gpu(RppPtr_t srcPtr, diff --git a/utilities/test_suite/HIP/Tensor_hip.cpp b/utilities/test_suite/HIP/Tensor_hip.cpp index 025dd712c..aad78241e 100644 --- a/utilities/test_suite/HIP/Tensor_hip.cpp +++ b/utilities/test_suite/HIP/Tensor_hip.cpp @@ -366,10 +366,19 @@ int main(int argc, char **argv) CHECK_RETURN_STATUS(hipHostMalloc(&roiPtrInputCropRegion, 4 * sizeof(RpptROI))); void *d_rowRemapTable, *d_colRemapTable; - if(testCase == 79) + if(testCase == 26 || testCase == 79) { CHECK_RETURN_STATUS(hipMalloc(&d_rowRemapTable, ioBufferSize * sizeof(Rpp32u))); CHECK_RETURN_STATUS(hipMalloc(&d_colRemapTable, ioBufferSize * sizeof(Rpp32u))); + CHECK_RETURN_STATUS(hipMemset(d_rowRemapTable, 0, ioBufferSize * sizeof(Rpp32u))); + CHECK_RETURN_STATUS(hipMemset(d_colRemapTable, 0, ioBufferSize * sizeof(Rpp32u))); + } + + Rpp32f *cameraMatrix, *distortionCoeffs; + if(testCase == 26) + { + CHECK_RETURN_STATUS(hipHostMalloc(&cameraMatrix, batchSize * 9 * sizeof(Rpp32f))); + CHECK_RETURN_STATUS(hipHostMalloc(&distortionCoeffs, batchSize * 8 * sizeof(Rpp32f))); } Rpp32u boxesInEachImage = 3; @@ -700,6 +709,22 @@ int main(int argc, char **argv) break; } + case 26: + { + testCaseName = "lens_correction"; + + RpptDesc tableDesc = srcDesc; + RpptDescPtr tableDescPtr = &tableDesc; + init_lens_correction(batchSize, srcDescPtr, cameraMatrix, distortionCoeffs, tableDescPtr); + + startWallTime = omp_get_wtime(); + if (inputBitDepth == 0 || inputBitDepth == 1 || inputBitDepth == 2 || inputBitDepth == 5) + rppt_lens_correction_gpu(d_input, srcDescPtr, d_output, dstDescPtr, static_cast(d_rowRemapTable), static_cast(d_colRemapTable), tableDescPtr, cameraMatrix, distortionCoeffs, roiTensorPtrSrc, roiTypeSrc, handle); + else + missingFuncFlag = 1; + + break; + } case 29: { testCaseName = "water"; @@ -1546,6 +1571,18 @@ int main(int argc, char **argv) CHECK_RETURN_STATUS(hipHostFree(cropRoi)); CHECK_RETURN_STATUS(hipHostFree(patchRoi)); } + if(testCase == 26) + { + CHECK_RETURN_STATUS(hipHostFree(cameraMatrix)); + CHECK_RETURN_STATUS(hipHostFree(distortionCoeffs)); + } + if(testCase == 79) + { + free(rowRemapTable); + free(colRemapTable); + CHECK_RETURN_STATUS(hipFree(d_rowRemapTable)); + CHECK_RETURN_STATUS(hipFree(d_colRemapTable)); + } if(testCase == 35) CHECK_RETURN_STATUS(hipHostFree(rgbOffsets)); if (reductionTypeCase) @@ -1572,13 +1609,6 @@ int main(int argc, char **argv) free(inputu8); free(inputu8Second); free(outputu8); - if(testCase == 79) - { - free(rowRemapTable); - free(colRemapTable); - CHECK_RETURN_STATUS(hipFree(d_rowRemapTable)); - CHECK_RETURN_STATUS(hipFree(d_colRemapTable)); - } CHECK_RETURN_STATUS(hipFree(d_input)); if(dualInputCase) CHECK_RETURN_STATUS(hipFree(d_input_second)); diff --git a/utilities/test_suite/HIP/runTests.py b/utilities/test_suite/HIP/runTests.py index 629da6ae1..01da79c8d 100644 --- a/utilities/test_suite/HIP/runTests.py +++ b/utilities/test_suite/HIP/runTests.py @@ -35,6 +35,7 @@ inFilePath1 = scriptPath + "/../TEST_IMAGES/three_images_mixed_src1" inFilePath2 = scriptPath + "/../TEST_IMAGES/three_images_mixed_src2" ricapInFilePath = scriptPath + "/../TEST_IMAGES/three_images_150x150_src1" +lensCorrectionInFilePath = scriptPath + "/../TEST_IMAGES/lens_distortion" qaInputFile = scriptPath + "/../TEST_IMAGES/three_images_mixed_src1" outFolderPath = os.getcwd() buildFolderPath = os.getcwd() @@ -275,7 +276,7 @@ def rpp_test_suite_parser_and_validator(): subprocess.run(["make", "-j16"], cwd=".") # nosec # List of cases supported -supportedCaseList = ['0', '1', '2', '4', '8', '13', '20', '21', '23', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '45', '46', '54', '61', '63', '65', '68', '70', '79', '80', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92'] +supportedCaseList = ['0', '1', '2', '4', '8', '13', '20', '21', '23', '26', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '45', '46', '54', '61', '63', '65', '68', '70', '79', '80', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92'] # Create folders based on testType and profilingOption if testType == 1 and profilingOption == "YES": @@ -295,8 +296,11 @@ def rpp_test_suite_parser_and_validator(): if case == "82" and (("--input_path1" not in sys.argv and "--input_path2" not in sys.argv) or qaMode == 1): srcPath1 = ricapInFilePath srcPath2 = ricapInFilePath + if case == "26" and (("--input_path1" not in sys.argv and "--input_path2" not in sys.argv) or qaMode == 1): + srcPath1 = lensCorrectionInFilePath + srcPath2 = lensCorrectionInFilePath # if QA mode is enabled overwrite the input folders with the folders used for generating golden outputs - if qaMode == 1 and case != "82": + if qaMode == 1 and (case != "82" and case != "26"): srcPath1 = inFilePath1 srcPath2 = inFilePath2 for layout in range(3): @@ -319,6 +323,9 @@ def rpp_test_suite_parser_and_validator(): if case == "82" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: srcPath1 = ricapInFilePath srcPath2 = ricapInFilePath + if case == "26" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: + srcPath1 = lensCorrectionInFilePath + srcPath2 = lensCorrectionInFilePath for layout in range(3): dstPathTemp, logFileLayout = process_layout(layout, qaMode, case, dstPath, "hip", func_group_finder) @@ -333,6 +340,9 @@ def rpp_test_suite_parser_and_validator(): if case == "82" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: srcPath1 = ricapInFilePath srcPath2 = ricapInFilePath + if case == "26" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: + srcPath1 = lensCorrectionInFilePath + srcPath2 = lensCorrectionInFilePath for layout in range(3): dstPathTemp, logFileLayout = process_layout(layout, qaMode, case, dstPath, "hip", func_group_finder) diff --git a/utilities/test_suite/HOST/Tensor_host.cpp b/utilities/test_suite/HOST/Tensor_host.cpp index 8f484c863..4c3d4f0e8 100644 --- a/utilities/test_suite/HOST/Tensor_host.cpp +++ b/utilities/test_suite/HOST/Tensor_host.cpp @@ -317,7 +317,7 @@ int main(int argc, char **argv) output = static_cast(calloc(outputBufferSize, 1)); Rpp32f *rowRemapTable, *colRemapTable; - if(testCase == 79) + if(testCase == 79 || testCase == 26) { rowRemapTable = static_cast(calloc(ioBufferSize, sizeof(Rpp32f))); colRemapTable = static_cast(calloc(ioBufferSize, sizeof(Rpp32f))); @@ -672,6 +672,25 @@ int main(int argc, char **argv) break; } + case 26: + { + testCaseName = "lens_correction"; + + Rpp32f cameraMatrix[9 * batchSize]; + Rpp32f distortionCoeffs[8 * batchSize]; + RpptDesc tableDesc = srcDesc; + RpptDescPtr tableDescPtr = &tableDesc; + init_lens_correction(batchSize, srcDescPtr, cameraMatrix, distortionCoeffs, tableDescPtr); + + startWallTime = omp_get_wtime(); + startCpuTime = clock(); + if (inputBitDepth == 0 || inputBitDepth == 1 || inputBitDepth == 2 || inputBitDepth == 5) + rppt_lens_correction_host(input, srcDescPtr, output, dstDescPtr, rowRemapTable, colRemapTable, tableDescPtr, cameraMatrix, distortionCoeffs, roiTensorPtrSrc, roiTypeSrc, handle); + else + missingFuncFlag = 1; + + break; + } case 29: { testCaseName = "water"; diff --git a/utilities/test_suite/HOST/runTests.py b/utilities/test_suite/HOST/runTests.py index b38a08757..93cd64713 100644 --- a/utilities/test_suite/HOST/runTests.py +++ b/utilities/test_suite/HOST/runTests.py @@ -34,6 +34,7 @@ inFilePath1 = scriptPath + "/../TEST_IMAGES/three_images_mixed_src1" inFilePath2 = scriptPath + "/../TEST_IMAGES/three_images_mixed_src2" ricapInFilePath = scriptPath + "/../TEST_IMAGES/three_images_150x150_src1" +lensCorrectionInFilePath = scriptPath + "/../TEST_IMAGES/lens_distortion" qaInputFile = scriptPath + "/../TEST_IMAGES/three_images_mixed_src1" perfQaInputFile = scriptPath + "/../TEST_IMAGES/eight_images_mixed_src1" outFolderPath = os.getcwd() @@ -257,7 +258,7 @@ def rpp_test_suite_parser_and_validator(): subprocess.run(["make", "-j16"], cwd=".") # nosec # List of cases supported -supportedCaseList = ['0', '1', '2', '4', '8', '13', '20', '21', '23', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '45', '46', '54', '61', '63', '65', '68', '70', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92'] +supportedCaseList = ['0', '1', '2', '4', '8', '13', '20', '21', '23', '26', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '45', '46', '54', '61', '63', '65', '68', '70', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92'] print("\n\n\n\n\n") print("##########################################################################################") @@ -271,8 +272,11 @@ def rpp_test_suite_parser_and_validator(): if case == "82" and (("--input_path1" not in sys.argv and "--input_path2" not in sys.argv) or qaMode == 1): srcPath1 = ricapInFilePath srcPath2 = ricapInFilePath + if case == "26" and (("--input_path1" not in sys.argv and "--input_path2" not in sys.argv) or qaMode == 1): + srcPath1 = lensCorrectionInFilePath + srcPath2 = lensCorrectionInFilePath # if QA mode is enabled overwrite the input folders with the folders used for generating golden outputs - if qaMode == 1 and case != "82": + if qaMode == 1 and (case != "82" and case != "26"): srcPath1 = inFilePath1 srcPath2 = inFilePath2 for layout in range(3): @@ -297,6 +301,9 @@ def rpp_test_suite_parser_and_validator(): if case == "82" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: srcPath1 = ricapInFilePath srcPath2 = ricapInFilePath + if case == "26" and "--input_path1" not in sys.argv and "--input_path2" not in sys.argv: + srcPath1 = lensCorrectionInFilePath + srcPath2 = lensCorrectionInFilePath for layout in range(3): dstPathTemp, logFileLayout = process_layout(layout, qaMode, case, dstPath, "host", func_group_finder) run_performance_test(loggingFolder, logFileLayout, srcPath1, srcPath2, dstPath, case, numRuns, testType, layout, qaMode, decoderType, batchSize, roiList) diff --git a/utilities/test_suite/REFERENCE_OUTPUT/lens_correction/lens_correction_u8_Tensor.bin b/utilities/test_suite/REFERENCE_OUTPUT/lens_correction/lens_correction_u8_Tensor.bin new file mode 100644 index 000000000..e79550932 Binary files /dev/null and b/utilities/test_suite/REFERENCE_OUTPUT/lens_correction/lens_correction_u8_Tensor.bin differ diff --git a/utilities/test_suite/TEST_IMAGES/lens_distortion/sample1.jpg b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample1.jpg new file mode 100644 index 000000000..0fe764603 Binary files /dev/null and b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample1.jpg differ diff --git a/utilities/test_suite/TEST_IMAGES/lens_distortion/sample2.jpg b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample2.jpg new file mode 100644 index 000000000..5d17e9572 Binary files /dev/null and b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample2.jpg differ diff --git a/utilities/test_suite/TEST_IMAGES/lens_distortion/sample3.jpg b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample3.jpg new file mode 100644 index 000000000..897955d77 Binary files /dev/null and b/utilities/test_suite/TEST_IMAGES/lens_distortion/sample3.jpg differ diff --git a/utilities/test_suite/common.py b/utilities/test_suite/common.py index e32aac98f..527b40ead 100644 --- a/utilities/test_suite/common.py +++ b/utilities/test_suite/common.py @@ -332,7 +332,7 @@ def get_layout_name(layout): # Prints entire case list if user asks for help def print_case_list(imageAugmentationMap, backendType, parser): - if '--help' or '-h' in sys.argv: + if '--help' in sys.argv or '-h' in sys.argv: parser.print_help() print("\n" + "="*30) print("Functionality Reference List") diff --git a/utilities/test_suite/rpp_test_suite_common.h b/utilities/test_suite/rpp_test_suite_common.h index 46965e9c3..71ca9fb34 100644 --- a/utilities/test_suite/rpp_test_suite_common.h +++ b/utilities/test_suite/rpp_test_suite_common.h @@ -58,6 +58,8 @@ using namespace std; #define MAX_BATCH_SIZE 512 #define GOLDEN_OUTPUT_MAX_HEIGHT 150 // Golden outputs are generated with MAX_HEIGHT set to 150. Changing this constant will result in QA test failures #define GOLDEN_OUTPUT_MAX_WIDTH 150 // Golden outputs are generated with MAX_WIDTH set to 150. Changing this constant will result in QA test failures +#define LENS_CORRECTION_GOLDEN_OUTPUT_MAX_HEIGHT 480 // Lens correction golden outputs are generated with MAX_HEIGHT set to 480. Changing this constant will result in QA test failures +#define LENS_CORRECTION_GOLDEN_OUTPUT_MAX_WIDTH 640 // Lens correction golden outputs are generated with MAX_WIDTH set to 640. Changing this constant will result in QA test failures #define CHECK_RETURN_STATUS(x) do { \ int retval = (x); \ @@ -78,6 +80,7 @@ std::map augmentationMap = {20, "flip"}, {21, "resize"}, {23, "rotate"}, + {26, "lens_correction"}, {29, "water"}, {30, "non_linear_blend"}, {31, "color_cast"}, @@ -1093,8 +1096,17 @@ inline void compare_output(T* output, string funcName, RpptDescPtr srcDescPtr, R { string func = funcName; string refFile = ""; - int refOutputWidth = ((GOLDEN_OUTPUT_MAX_WIDTH / 8) * 8) + 8; // obtain next multiple of 8 after GOLDEN_OUTPUT_MAX_WIDTH - int refOutputHeight = GOLDEN_OUTPUT_MAX_HEIGHT; + int refOutputWidth, refOutputHeight; + if(testCase == 26) + { + refOutputWidth = ((LENS_CORRECTION_GOLDEN_OUTPUT_MAX_WIDTH / 8) * 8) + 8; // obtain next multiple of 8 after GOLDEN_OUTPUT_MAX_WIDTH + refOutputHeight = LENS_CORRECTION_GOLDEN_OUTPUT_MAX_HEIGHT; + } + else + { + refOutputWidth = ((GOLDEN_OUTPUT_MAX_WIDTH / 8) * 8) + 8; // obtain next multiple of 8 after GOLDEN_OUTPUT_MAX_WIDTH + refOutputHeight = GOLDEN_OUTPUT_MAX_HEIGHT; + } int refOutputSize = refOutputHeight * refOutputWidth * dstDescPtr->c; Rpp64u binOutputSize = refOutputHeight * refOutputWidth * dstDescPtr->n * 4; int pln1RefStride = dstDescPtr->strides.nStride * dstDescPtr->n * 3; @@ -1527,3 +1539,24 @@ void inline init_erase(int batchSize, int boxesInEachImage, Rpp32u* numOfBoxes, } } } + +// Lens correction initializer for unit and performance testing +void inline init_lens_correction(int batchSize, RpptDescPtr srcDescPtr, Rpp32f *cameraMatrix, Rpp32f *distortionCoeffs, RpptDescPtr tableDescPtr) +{ + typedef struct { Rpp32f data[9]; } Rpp32f9; + typedef struct { Rpp32f data[8]; } Rpp32f8; + Rpp32f9 *cameraMatrix_f9 = reinterpret_cast(cameraMatrix); + Rpp32f8 *distortionCoeffs_f8 = reinterpret_cast(distortionCoeffs); + Rpp32f9 sampleCameraMatrix = {534.07088364, 0, 341.53407554, 0, 534.11914595, 232.94565259, 0, 0, 1}; + Rpp32f8 sampleDistortionCoeffs = {-0.29297164, 0.10770696, 0.00131038, -0.0000311, 0.0434798, 0, 0, 0}; + for (int i = 0; i < batchSize; i++) + { + cameraMatrix_f9[i] = sampleCameraMatrix; + distortionCoeffs_f8[i] = sampleDistortionCoeffs; + } + + tableDescPtr->c = 1; + tableDescPtr->strides.nStride = srcDescPtr->h * srcDescPtr->w; + tableDescPtr->strides.hStride = srcDescPtr->w; + tableDescPtr->strides.wStride = tableDescPtr->strides.cStride = 1; +}