diff --git a/src/avx/math.rs b/src/avx/math.rs index 627d3f6..6c27de4 100644 --- a/src/avx/math.rs +++ b/src/avx/math.rs @@ -3,20 +3,17 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] pub unsafe fn _mm256_cube_ps(x: __m256) -> __m256 { _mm256_mul_ps(_mm256_mul_ps(x, x), x) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[cfg(not(target_feature = "fma"))] #[inline(always)] pub unsafe fn _mm256_prefer_fma_ps(a: __m256, b: __m256, c: __m256) -> __m256 { return _mm256_add_ps(_mm256_mul_ps(b, c), a); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[cfg(target_feature = "fma")] #[inline(always)] pub unsafe fn _mm256_prefer_fma_ps(a: __m256, b: __m256, c: __m256) -> __m256 { @@ -82,14 +79,12 @@ pub unsafe fn _mm256_select_ps(mask: __m256, true_vals: __m256, false_vals: __m2 _mm256_blendv_ps(false_vals, true_vals, mask) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_selecti_ps(mask: __m256i, true_vals: __m256, false_vals: __m256) -> __m256 { _mm256_blendv_ps(false_vals, true_vals, _mm256_castsi256_ps(mask)) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_select_si256( @@ -202,21 +197,18 @@ unsafe fn _mm256_exp_ps_impl(x: __m256) -> __m256 { } } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_pow_ps(x: __m256, n: __m256) -> __m256 { _mm256_exp_ps(_mm256_mul_ps(n, _mm256_log_ps(x))) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_pow_n_ps(x: __m256, n: f32) -> __m256 { _mm256_exp_ps(_mm256_mul_ps(_mm256_set1_ps(n), _mm256_log_ps(x))) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_signbit_ps(f: __m256) -> __m256i { @@ -226,7 +218,6 @@ pub unsafe fn _mm256_signbit_ps(f: __m256) -> __m256i { ); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_mulsign_ps(x: __m256, y: __m256) -> __m256 { @@ -236,7 +227,6 @@ pub unsafe fn _mm256_mulsign_ps(x: __m256, y: __m256) -> __m256 { )); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_pow2i_ps(q: __m256i) -> __m256 { @@ -246,7 +236,6 @@ pub unsafe fn _mm256_pow2i_ps(q: __m256i) -> __m256 { ))); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_vldexp2_ps(d: __m256, e: __m256i) -> __m256 { @@ -256,7 +245,6 @@ pub unsafe fn _mm256_vldexp2_ps(d: __m256, e: __m256i) -> __m256 { ); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm256_vilogbk_ps(d: __m256) -> __m256i { @@ -281,14 +269,12 @@ pub unsafe fn _mm256_vilogbk_ps(d: __m256) -> __m256i { return q; } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm256_fmaf_ps(a: __m256, b: __m256, c: __m256) -> __m256 { _mm256_prefer_fma_ps(c, b, a) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm256_abs_ps(x: __m256) -> __m256 { @@ -296,7 +282,6 @@ pub(crate) unsafe fn _mm256_abs_ps(x: __m256) -> __m256 { return _mm256_andnot_ps(sign_mask, x); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm256_neg_epi32(x: __m256i) -> __m256i { @@ -309,7 +294,76 @@ pub(crate) unsafe fn _mm256_neg_epi32(x: __m256i) -> __m256i { /// it is also precise however due to of inexact nature of power 1/3 result slightly differ /// from real cbrt with about ULP 3-4, but this is almost 2 times faster than cbrt with real ULP 3.5 pub unsafe fn _mm256_cbrt_ps(d: __m256) -> __m256 { - _mm256_pow_n_ps(d, 1f32 / 3f32) + _mm_cbrtq_f32_ulp2::(d) +} + +#[inline(always)] +pub unsafe fn _mm256_cmpge_epi32(a: __m256i, b: __m256i) -> __m256i { + let gt = _mm256_cmpgt_epi32(a, b); + let eq = _mm256_cmpeq_epi32(a, b); + return _mm256_or_si256(gt, eq); +} + +#[inline(always)] +pub unsafe fn _mm256_cmplt_epi32(a: __m256i, b: __m256i) -> __m256i { + return _mm256_cmpgt_epi32(b, a); +} + +#[inline(always)] +/// Precise version of Cube Root with ULP 2 +pub unsafe fn _mm_cbrtq_f32_ulp2(x: __m256) -> __m256 { + let x1p24 = _mm256_castsi256_ps(_mm256_set1_epi32(0x4b800000)); // 0x1p24f === 2 ^ 24 + + let mut ui = _mm256_cvtps_epi32(x); + let hx = _mm256_and_si256(ui, _mm256_set1_epi32(0x7fffffff)); + + let nan_mask = _mm256_cmpge_epi32(hx, _mm256_set1_epi32(0x7f800000)); + let is_zero_mask = _mm256_cmpeq_epi32(hx, _mm256_setzero_si256()); + + let lo_mask = _mm256_cmplt_epi32(hx, _mm256_set1_epi32(0x00800000)); + let hi_ui_f = _mm256_castps_si256(_mm256_mul_ps(x, x1p24)); + let mut lo_hx = _mm256_and_si256(hi_ui_f, _mm256_set1_epi32(0x7fffffff)); + let recpreq_3 = _mm256_set1_ps(1f32 / 3f32); + lo_hx = _mm256_add_epi32( + _mm256_cvtps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(lo_hx), recpreq_3)), + _mm256_set1_epi32(642849266), + ); + let hi_hx = _mm256_add_epi32( + _mm256_cvtps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(hx), recpreq_3)), + _mm256_set1_epi32(709958130), + ); + let hx = _mm256_select_si256(lo_mask, lo_hx, hi_hx); + + ui = _mm256_select_si256(lo_mask, hi_ui_f, ui); + ui = _mm256_and_si256(ui, _mm256_set1_epi32(-2147483648i32)); + ui = _mm256_or_si256(ui, hx); + + let mut t = _mm256_castsi256_ps(ui); + let mut r = _mm256_mul_ps(_mm256_mul_ps(t, t), t); + + let sum_x = _mm256_add_ps(x, x); + + t = _mm256_mul_ps( + _mm256_div_ps( + _mm256_add_ps(sum_x, r), + _mm256_add_ps(_mm256_add_ps(r, r), x), + ), + t, + ); + + r = _mm256_mul_ps(_mm256_mul_ps(t, t), t); + t = _mm256_mul_ps( + _mm256_div_ps( + _mm256_add_ps(sum_x, r), + _mm256_add_ps(_mm256_add_ps(r, r), x), + ), + t, + ); + if HANDLE_NAN { + t = _mm256_selecti_ps(nan_mask, _mm256_set1_ps(f32::NAN), t); + t = _mm256_selecti_ps(is_zero_mask, _mm256_setzero_ps(), t); + } + t } #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] diff --git a/src/neon/math.rs b/src/neon/math.rs index 48522d7..abd73a1 100644 --- a/src/neon/math.rs +++ b/src/neon/math.rs @@ -1,10 +1,6 @@ #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] use std::arch::aarch64::*; -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn vfmodq_f32(a: float32x4_t, b: float32x4_t) -> float32x4_t { @@ -17,10 +13,6 @@ pub(crate) unsafe fn vfmodq_f32(a: float32x4_t, b: float32x4_t) -> float32x4_t { remainder } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn prefer_vfmaq_f32( @@ -38,10 +30,6 @@ pub(crate) unsafe fn prefer_vfmaq_f32( } } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] unsafe fn vtaylor_polyq_f32( @@ -65,10 +53,6 @@ unsafe fn vtaylor_polyq_f32( return res; } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vrintq_s32(d: float32x4_t) -> int32x4_t { @@ -260,20 +244,12 @@ pub unsafe fn vlogq_f32(x: float32x4_t) -> float32x4_t { return poly; } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn visnanq_f32(x: float32x4_t) -> uint32x4_t { return vmvnq_u32(vceqq_f32(x, x)); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vispinfq_f32(d: float32x4_t) -> uint32x4_t { @@ -328,10 +304,6 @@ pub unsafe fn vpowq_n_f32(t: float32x4_t, power: f32) -> float32x4_t { return vpowq_f32(t, vdupq_n_f32(power)); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vilogbk_vi2_vf(d: float32x4_t) -> int32x4_t { @@ -345,20 +317,12 @@ pub unsafe fn vilogbk_vi2_vf(d: float32x4_t) -> int32x4_t { return q; } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vpow2i(q: int32x4_t) -> float32x4_t { return vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(q, vdupq_n_s32(0x7f)))); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vldexp2q_f32(d: float32x4_t, e: int32x4_t) -> float32x4_t { @@ -368,10 +332,6 @@ pub unsafe fn vldexp2q_f32(d: float32x4_t, e: int32x4_t) -> float32x4_t { ); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vsignbit_vm_vf(f: float32x4_t) -> uint32x4_t { @@ -381,20 +341,12 @@ pub unsafe fn vsignbit_vm_vf(f: float32x4_t) -> uint32x4_t { ); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub unsafe fn vmulsignq_f32(x: float32x4_t, y: float32x4_t) -> float32x4_t { return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(x), vsignbit_vm_vf(y))); } -#[cfg(all( - any(target_arch = "aarch64", target_arch = "arm"), - target_feature = "neon" -))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn vmlafq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t { @@ -402,7 +354,6 @@ pub(crate) unsafe fn vmlafq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) } #[inline(always)] -#[allow(dead_code)] pub unsafe fn vcbrtq_f32(d: float32x4_t) -> float32x4_t { vcbrtq_f32_ulp2::(d) } diff --git a/src/sse/math.rs b/src/sse/math.rs index 2a7121e..2aaeea6 100644 --- a/src/sse/math.rs +++ b/src/sse/math.rs @@ -3,7 +3,6 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] pub unsafe fn _mm_cube_ps(x: __m128) -> __m128 { _mm_mul_ps(_mm_mul_ps(x, x), x) @@ -73,14 +72,12 @@ pub unsafe fn _mm_select_ps(mask: __m128, true_vals: __m128, false_vals: __m128) _mm_blendv_ps(false_vals, true_vals, mask) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_selecti_ps(mask: __m128i, true_vals: __m128, false_vals: __m128) -> __m128 { _mm_blendv_ps(false_vals, true_vals, _mm_castsi128_ps(mask)) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_select_si128(mask: __m128i, true_vals: __m128i, false_vals: __m128i) -> __m128i { @@ -195,28 +192,24 @@ pub unsafe fn _mm_pow_n_ps(x: __m128, n: f32) -> __m128 { _mm_exp_ps(_mm_mul_ps(_mm_set1_ps(n), _mm_log_ps(x))) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_signbit_ps(f: __m128) -> __m128i { return _mm_and_si128(_mm_castps_si128(f), _mm_castps_si128(_mm_set1_ps(-0.0f32))); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_mulsign_ps(x: __m128, y: __m128) -> __m128 { return _mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(x), _mm_signbit_ps(y))); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_pow2i_ps(q: __m128i) -> __m128 { return _mm_castsi128_ps(_mm_slli_epi32::<23>(_mm_add_epi32(q, _mm_set1_epi32(0x7f)))); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_vldexp2_ps(d: __m128, e: __m128i) -> __m128 { @@ -226,7 +219,6 @@ pub unsafe fn _mm_vldexp2_ps(d: __m128, e: __m128i) -> __m128 { ); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_vilogbk_ps(d: __m128) -> __m128i { @@ -247,14 +239,12 @@ pub unsafe fn _mm_vilogbk_ps(d: __m128) -> __m128i { return q; } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm_fmaf_ps(a: __m128, b: __m128, c: __m128) -> __m128 { _mm_prefer_fma_ps(c, b, a) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm_abs_ps(x: __m128) -> __m128 { @@ -262,7 +252,6 @@ pub(crate) unsafe fn _mm_abs_ps(x: __m128) -> __m128 { return _mm_andnot_ps(sign_mask, x); } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm_neg_epi32(x: __m128i) -> __m128i { @@ -281,7 +270,7 @@ pub(crate) unsafe fn _mm_neg_ps(x: __m128) -> __m128 { /// it is also precise however due to of inexact nature of power 1/3 result slightly differ /// from real cbrt with about ULP 3-4, but this is almost 2 times faster than cbrt with real ULP 3.5 pub unsafe fn _mm_cbrt_ps(d: __m128) -> __m128 { - _mm_pow_n_ps(d, 1f32 / 3f32) + _mm_cbrt_ulp2_ps::(d) } #[inline(always)] @@ -334,7 +323,69 @@ pub unsafe fn _mm_cbrt_ps_ulp35(d: __m128) -> __m128 { return y; } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[inline(always)] +pub unsafe fn _mm_cmpge_epi32(a: __m128i, b: __m128i) -> __m128i { + let gt = _mm_cmpgt_epi32(a, b); + let eq = _mm_cmpeq_epi32(a, b); + return _mm_or_si128(gt, eq); +} + +#[inline(always)] +pub unsafe fn _mm_cmplt_epi32(a: __m128i, b: __m128i) -> __m128i { + return _mm_cmpgt_epi32(b, a); +} + +#[inline(always)] +/// Precise version of Cube Root with ULP 2 +pub unsafe fn _mm_cbrt_ulp2_ps(x: __m128) -> __m128 { + let x1p24 = _mm_castsi128_ps(_mm_set1_epi32(0x4b800000)); // 0x1p24f === 2 ^ 24 + + let mut ui = _mm_cvtps_epi32(x); + let hx = _mm_and_si128(ui, _mm_set1_epi32(0x7fffffff)); + + let nan_mask = _mm_cmpge_epi32(hx, _mm_set1_epi32(0x7f800000)); + let is_zero_mask = _mm_cmpeq_epi32(hx, _mm_setzero_si128()); + + let lo_mask = _mm_cmplt_epi32(hx, _mm_set1_epi32(0x00800000)); + let hi_ui_f = _mm_castps_si128(_mm_mul_ps(x, x1p24)); + let mut lo_hx = _mm_and_si128(hi_ui_f, _mm_set1_epi32(0x7fffffff)); + let recpeq_3 = _mm_set1_ps(1f32 / 3f32); + lo_hx = _mm_add_epi32( + _mm_cvtps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(lo_hx), recpeq_3)), + _mm_set1_epi32(642849266), + ); + let hi_hx = _mm_add_epi32( + _mm_cvtps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(hx),recpeq_3)), + _mm_set1_epi32(709958130), + ); + let hx = _mm_select_si128(lo_mask, lo_hx, hi_hx); + + ui = _mm_select_si128(lo_mask, hi_ui_f, ui); + ui = _mm_and_si128(ui, _mm_set1_epi32(-2147483648i32)); + ui = _mm_or_si128(ui, hx); + + let mut t = _mm_castsi128_ps(ui); + let mut r = _mm_mul_ps(_mm_mul_ps(t, t), t); + + let sum_x = _mm_add_ps(x, x); + + t = _mm_mul_ps( + _mm_div_ps(_mm_add_ps(sum_x, r), _mm_add_ps(_mm_add_ps(r, r), x)), + t, + ); + + r = _mm_mul_ps(_mm_mul_ps(t, t), t); + t = _mm_mul_ps( + _mm_div_ps(_mm_add_ps(sum_x, r), _mm_add_ps(_mm_add_ps(r, r), x)), + t, + ); + if HANDLE_NAN { + t = _mm_selecti_ps(nan_mask, _mm_set1_ps(f32::NAN), t); + t = _mm_selecti_ps(is_zero_mask, _mm_setzero_ps(), t); + } + t +} + #[inline(always)] #[allow(dead_code)] pub unsafe fn _mm_color_matrix_ps( @@ -357,7 +408,6 @@ pub unsafe fn _mm_color_matrix_ps( (new_r, new_g, new_b) } -#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] #[inline(always)] #[allow(dead_code)] pub(crate) unsafe fn _mm_fmod_ps(a: __m128, b: __m128) -> __m128 {