Skip to content

Commit

Permalink
Cbrt ulp2 for x86
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Jun 16, 2024
1 parent 329cc05 commit b26bda0
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 79 deletions.
86 changes: 70 additions & 16 deletions src/avx/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@ use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
pub unsafe fn _mm256_cube_ps(x: __m256) -> __m256 {
_mm256_mul_ps(_mm256_mul_ps(x, x), x)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[cfg(not(target_feature = "fma"))]
#[inline(always)]
pub unsafe fn _mm256_prefer_fma_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
return _mm256_add_ps(_mm256_mul_ps(b, c), a);
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn _mm256_prefer_fma_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
Expand Down Expand Up @@ -82,14 +79,12 @@ pub unsafe fn _mm256_select_ps(mask: __m256, true_vals: __m256, false_vals: __m2
_mm256_blendv_ps(false_vals, true_vals, mask)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_selecti_ps(mask: __m256i, true_vals: __m256, false_vals: __m256) -> __m256 {
_mm256_blendv_ps(false_vals, true_vals, _mm256_castsi256_ps(mask))
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_select_si256(
Expand Down Expand Up @@ -202,21 +197,18 @@ unsafe fn _mm256_exp_ps_impl<const PROCESS_NAN: bool>(x: __m256) -> __m256 {
}
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_pow_ps(x: __m256, n: __m256) -> __m256 {
_mm256_exp_ps(_mm256_mul_ps(n, _mm256_log_ps(x)))
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_pow_n_ps(x: __m256, n: f32) -> __m256 {
_mm256_exp_ps(_mm256_mul_ps(_mm256_set1_ps(n), _mm256_log_ps(x)))
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_signbit_ps(f: __m256) -> __m256i {
Expand All @@ -226,7 +218,6 @@ pub unsafe fn _mm256_signbit_ps(f: __m256) -> __m256i {
);
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_mulsign_ps(x: __m256, y: __m256) -> __m256 {
Expand All @@ -236,7 +227,6 @@ pub unsafe fn _mm256_mulsign_ps(x: __m256, y: __m256) -> __m256 {
));
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_pow2i_ps(q: __m256i) -> __m256 {
Expand All @@ -246,7 +236,6 @@ pub unsafe fn _mm256_pow2i_ps(q: __m256i) -> __m256 {
)));
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_vldexp2_ps(d: __m256, e: __m256i) -> __m256 {
Expand All @@ -256,7 +245,6 @@ pub unsafe fn _mm256_vldexp2_ps(d: __m256, e: __m256i) -> __m256 {
);
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn _mm256_vilogbk_ps(d: __m256) -> __m256i {
Expand All @@ -281,22 +269,19 @@ pub unsafe fn _mm256_vilogbk_ps(d: __m256) -> __m256i {
return q;
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn _mm256_fmaf_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_prefer_fma_ps(c, b, a)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn _mm256_abs_ps(x: __m256) -> __m256 {
let sign_mask = _mm256_set1_ps(-0f32);
return _mm256_andnot_ps(sign_mask, x);
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn _mm256_neg_epi32(x: __m256i) -> __m256i {
Expand All @@ -309,7 +294,76 @@ pub(crate) unsafe fn _mm256_neg_epi32(x: __m256i) -> __m256i {
/// it is also precise however due to of inexact nature of power 1/3 result slightly differ
/// from real cbrt with about ULP 3-4, but this is almost 2 times faster than cbrt with real ULP 3.5
pub unsafe fn _mm256_cbrt_ps(d: __m256) -> __m256 {
_mm256_pow_n_ps(d, 1f32 / 3f32)
_mm_cbrtq_f32_ulp2::<false>(d)
}

#[inline(always)]
pub unsafe fn _mm256_cmpge_epi32(a: __m256i, b: __m256i) -> __m256i {
let gt = _mm256_cmpgt_epi32(a, b);
let eq = _mm256_cmpeq_epi32(a, b);
return _mm256_or_si256(gt, eq);
}

#[inline(always)]
pub unsafe fn _mm256_cmplt_epi32(a: __m256i, b: __m256i) -> __m256i {
return _mm256_cmpgt_epi32(b, a);
}

#[inline(always)]
/// Precise version of Cube Root with ULP 2
pub unsafe fn _mm_cbrtq_f32_ulp2<const HANDLE_NAN: bool>(x: __m256) -> __m256 {
let x1p24 = _mm256_castsi256_ps(_mm256_set1_epi32(0x4b800000)); // 0x1p24f === 2 ^ 24

let mut ui = _mm256_cvtps_epi32(x);
let hx = _mm256_and_si256(ui, _mm256_set1_epi32(0x7fffffff));

let nan_mask = _mm256_cmpge_epi32(hx, _mm256_set1_epi32(0x7f800000));
let is_zero_mask = _mm256_cmpeq_epi32(hx, _mm256_setzero_si256());

let lo_mask = _mm256_cmplt_epi32(hx, _mm256_set1_epi32(0x00800000));
let hi_ui_f = _mm256_castps_si256(_mm256_mul_ps(x, x1p24));
let mut lo_hx = _mm256_and_si256(hi_ui_f, _mm256_set1_epi32(0x7fffffff));
let recpreq_3 = _mm256_set1_ps(1f32 / 3f32);
lo_hx = _mm256_add_epi32(
_mm256_cvtps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(lo_hx), recpreq_3)),
_mm256_set1_epi32(642849266),
);
let hi_hx = _mm256_add_epi32(
_mm256_cvtps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(hx), recpreq_3)),
_mm256_set1_epi32(709958130),
);
let hx = _mm256_select_si256(lo_mask, lo_hx, hi_hx);

ui = _mm256_select_si256(lo_mask, hi_ui_f, ui);
ui = _mm256_and_si256(ui, _mm256_set1_epi32(-2147483648i32));
ui = _mm256_or_si256(ui, hx);

let mut t = _mm256_castsi256_ps(ui);
let mut r = _mm256_mul_ps(_mm256_mul_ps(t, t), t);

let sum_x = _mm256_add_ps(x, x);

t = _mm256_mul_ps(
_mm256_div_ps(
_mm256_add_ps(sum_x, r),
_mm256_add_ps(_mm256_add_ps(r, r), x),
),
t,
);

r = _mm256_mul_ps(_mm256_mul_ps(t, t), t);
t = _mm256_mul_ps(
_mm256_div_ps(
_mm256_add_ps(sum_x, r),
_mm256_add_ps(_mm256_add_ps(r, r), x),
),
t,
);
if HANDLE_NAN {
t = _mm256_selecti_ps(nan_mask, _mm256_set1_ps(f32::NAN), t);
t = _mm256_selecti_ps(is_zero_mask, _mm256_setzero_ps(), t);
}
t
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
Expand Down
49 changes: 0 additions & 49 deletions src/neon/math.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
use std::arch::aarch64::*;

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn vfmodq_f32(a: float32x4_t, b: float32x4_t) -> float32x4_t {
Expand All @@ -17,10 +13,6 @@ pub(crate) unsafe fn vfmodq_f32(a: float32x4_t, b: float32x4_t) -> float32x4_t {
remainder
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn prefer_vfmaq_f32(
Expand All @@ -38,10 +30,6 @@ pub(crate) unsafe fn prefer_vfmaq_f32(
}
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
unsafe fn vtaylor_polyq_f32(
Expand All @@ -65,10 +53,6 @@ unsafe fn vtaylor_polyq_f32(
return res;
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vrintq_s32(d: float32x4_t) -> int32x4_t {
Expand Down Expand Up @@ -260,20 +244,12 @@ pub unsafe fn vlogq_f32(x: float32x4_t) -> float32x4_t {
return poly;
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn visnanq_f32(x: float32x4_t) -> uint32x4_t {
return vmvnq_u32(vceqq_f32(x, x));
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vispinfq_f32(d: float32x4_t) -> uint32x4_t {
Expand Down Expand Up @@ -328,10 +304,6 @@ pub unsafe fn vpowq_n_f32(t: float32x4_t, power: f32) -> float32x4_t {
return vpowq_f32(t, vdupq_n_f32(power));
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vilogbk_vi2_vf(d: float32x4_t) -> int32x4_t {
Expand All @@ -345,20 +317,12 @@ pub unsafe fn vilogbk_vi2_vf(d: float32x4_t) -> int32x4_t {
return q;
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vpow2i(q: int32x4_t) -> float32x4_t {
return vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(q, vdupq_n_s32(0x7f))));
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vldexp2q_f32(d: float32x4_t, e: int32x4_t) -> float32x4_t {
Expand All @@ -368,10 +332,6 @@ pub unsafe fn vldexp2q_f32(d: float32x4_t, e: int32x4_t) -> float32x4_t {
);
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vsignbit_vm_vf(f: float32x4_t) -> uint32x4_t {
Expand All @@ -381,28 +341,19 @@ pub unsafe fn vsignbit_vm_vf(f: float32x4_t) -> uint32x4_t {
);
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vmulsignq_f32(x: float32x4_t, y: float32x4_t) -> float32x4_t {
return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(x), vsignbit_vm_vf(y)));
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn vmlafq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
prefer_vfmaq_f32(c, b, a)
}

#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vcbrtq_f32(d: float32x4_t) -> float32x4_t {
vcbrtq_f32_ulp2::<false>(d)
}
Expand Down
Loading

0 comments on commit b26bda0

Please sign in to comment.