diff --git a/include/matx/core/half.h b/include/matx/core/half.h index 9fbd9e03..adc4a1e4 100644 --- a/include/matx/core/half.h +++ b/include/matx/core/half.h @@ -701,7 +701,11 @@ __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf rsqrt(const matxHalf(::rsqrt(static_cast(x.x))); + #ifdef __CUDACC__ + return static_cast<__nv_bfloat16>(::rsqrt(static_cast(x.x))); + #else + return static_cast<__nv_bfloat16>(1.f / cuda::std::sqrt(static_cast(x.x))); + #endif #endif } @@ -719,7 +723,11 @@ rsqrt(const matxHalf<__nv_bfloat16> &x) #if __CUDA_ARCH__ >= 800 return hrsqrt(x.x); #else - return static_cast<__nv_bfloat16>(::rsqrt(static_cast(x.x))); + #ifdef __CUDACC__ + return static_cast<__nv_bfloat16>(::rsqrt(static_cast(x.x))); + #else + return static_cast<__nv_bfloat16>(1.f / cuda::std::sqrt(static_cast(x.x))); + #endif #endif } diff --git a/include/matx/core/storage.h b/include/matx/core/storage.h index 112b5d8f..06e6f3c6 100644 --- a/include/matx/core/storage.h +++ b/include/matx/core/storage.h @@ -282,7 +282,7 @@ namespace matx * @brief Default construct a smart_pointer_buffer. This should only be used when temporarily * creating an empty tensor for construction later. */ - smart_pointer_buffer() {}; + smart_pointer_buffer() {}; /** * @brief Construct a new smart pointer buffer from an existing object @@ -290,7 +290,7 @@ namespace matx * @param ptr Smart poiner object * @param size Size of allocation */ - smart_pointer_buffer(T &&ptr, size_t size) : data_(std::forward(ptr)), size_(size) { + smart_pointer_buffer(T &&ptr, size_t size) : data_(std::forward(ptr)), size_(size) { static_assert(is_smart_ptr_v); } @@ -328,7 +328,7 @@ namespace matx * * @param rhs Object to move from */ - smart_pointer_buffer(smart_pointer_buffer &&rhs) { + smart_pointer_buffer(smart_pointer_buffer &&rhs) { size_ = rhs.size_; data_ = std::move(rhs.data_); } @@ -337,7 +337,7 @@ namespace matx * @brief Default destructor * */ - ~smart_pointer_buffer() = default; + ~smart_pointer_buffer() = default; /** * @brief Get underlying data pointer diff --git a/include/matx/operators/scalar_ops.h b/include/matx/operators/scalar_ops.h index fa7e2716..ac743094 100644 --- a/include/matx/operators/scalar_ops.h +++ b/include/matx/operators/scalar_ops.h @@ -162,7 +162,11 @@ static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_rsqrt(T v1) return rsqrt(v1); } else { +#ifdef __CUDACC__ return ::rsqrt(v1); +#else + return static_cast(1) / sqrt(v1); +#endif } } template struct RSqrtF {