Skip to content

Commit

Permalink
Speed increase for exp
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Jun 16, 2024
1 parent 65250ea commit 5232464
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/app/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn main() {
println!("HSL {:?}", hsl);
println!("Back RGB {:?}", hsl.to_rgb8());

let img = ImageReader::open("./assets/asset.jpg")
let img = ImageReader::open("./assets/beach_horizon.jpg")
.unwrap()
.decode()
.unwrap();
Expand Down
27 changes: 20 additions & 7 deletions src/avx/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ pub unsafe fn _mm256_select_si256(

#[inline(always)]
pub unsafe fn _mm256_exp_ps(x: __m256) -> __m256 {
_mm256_exp_ps_impl::<false>(x)
}

#[inline(always)]
unsafe fn _mm256_exp_ps_impl<const PROCESS_NAN: bool>(x: __m256) -> __m256 {
let l2e = _mm256_set1_ps(std::f32::consts::LOG2_E); /* log2(e) */
let c0 = _mm256_set1_ps(0.3371894346f32);
let c1 = _mm256_set1_ps(0.657636276f32);
Expand All @@ -117,15 +122,23 @@ pub unsafe fn _mm256_exp_ps(x: __m256) -> __m256 {
let f = _mm256_sub_ps(t, e); /* f = t - floor(t) */
let mut p = c0; /* c0 */
p = _mm256_prefer_fma_ps(c1, p, f); /* c0 * f + c1 */
p = _mm256_prefer_fma_ps(c2, p ,f); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
p = _mm256_prefer_fma_ps(c2, p, f); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
let j = _mm256_slli_epi32::<23>(i); /* i << 23 */
let r = _mm256_castsi256_ps(_mm256_add_epi32(j, _mm256_castps_si256(p))); /* r = p * 2^i*/
let inf = _mm256_set1_ps(f32::INFINITY);
let max_input = _mm256_set1_ps(88.72283f32); // Approximately ln(2^127.5)
let min_input = _mm256_set1_ps(-87.33654f32); // Approximately ln(2^-125)
let poly = _mm256_select_ps(_mm256_cmp_ps::<_CMP_LT_OS>(x, min_input), _mm256_setzero_ps(), r);
let poly = _mm256_select_ps(_mm256_cmp_ps::<_CMP_GT_OS>(x, max_input), inf, poly);
return poly;
if PROCESS_NAN {
let inf = _mm256_set1_ps(f32::INFINITY);
let max_input = _mm256_set1_ps(88.72283f32); // Approximately ln(2^127.5)
let min_input = _mm256_set1_ps(-87.33654f32); // Approximately ln(2^-125)
let poly = _mm256_select_ps(
_mm256_cmp_ps::<_CMP_LT_OS>(x, min_input),
_mm256_setzero_ps(),
r,
);
let poly = _mm256_select_ps(_mm256_cmp_ps::<_CMP_GT_OS>(x, max_input), inf, poly);
return poly;
} else {
return r;
}
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
Expand Down
18 changes: 14 additions & 4 deletions src/neon/gamma_curves.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::arch::aarch64::*;

#[inline(always)]
pub unsafe fn neon_srgb_from_linear(linear: float32x4_t) -> float32x4_t {
let linear = vmaxq_f32(linear, vdupq_n_f32(0f32));
let linear = vminq_f32(linear, vdupq_n_f32(1f32));
let low_cut_off = vdupq_n_f32(0.0030412825601275209f32);
let mask = vcgeq_f32(linear, low_cut_off);

Expand All @@ -20,6 +22,8 @@ pub unsafe fn neon_srgb_from_linear(linear: float32x4_t) -> float32x4_t {

#[inline(always)]
pub unsafe fn neon_srgb_to_linear(gamma: float32x4_t) -> float32x4_t {
let gamma = vmaxq_f32(gamma, vdupq_n_f32(0f32));
let gamma = vminq_f32(gamma, vdupq_n_f32(1f32));
let low_cut_off = vdupq_n_f32(12.92f32 * 0.0030412825601275209f32);
let mask = vcgeq_f32(gamma, low_cut_off);

Expand All @@ -37,6 +41,8 @@ pub unsafe fn neon_srgb_to_linear(gamma: float32x4_t) -> float32x4_t {

#[inline(always)]
pub unsafe fn neon_rec709_from_linear(linear: float32x4_t) -> float32x4_t {
let linear = vmaxq_f32(linear, vdupq_n_f32(0f32));
let linear = vminq_f32(linear, vdupq_n_f32(1f32));
let low_cut_off = vdupq_n_f32(0.018053968510807f32);
let mask = vcgeq_f32(linear, low_cut_off);

Expand All @@ -52,14 +58,16 @@ pub unsafe fn neon_rec709_from_linear(linear: float32x4_t) -> float32x4_t {
}

#[inline(always)]
pub unsafe fn neon_rec709_to_linear(linear: float32x4_t) -> float32x4_t {
pub unsafe fn neon_rec709_to_linear(gamma: float32x4_t) -> float32x4_t {
let gamma = vmaxq_f32(gamma, vdupq_n_f32(0f32));
let gamma = vminq_f32(gamma, vdupq_n_f32(1f32));
let low_cut_off = vdupq_n_f32(4.5f32 * 0.018053968510807f32);
let mask = vcgeq_f32(linear, low_cut_off);
let mask = vcgeq_f32(gamma, low_cut_off);

let mut low = linear;
let mut low = gamma;
let high = vpowq_n_f32(
vmulq_n_f32(
vaddq_f32(linear, vdupq_n_f32(0.09929682680944f32)),
vaddq_f32(gamma, vdupq_n_f32(0.09929682680944f32)),
1f32 / 1.09929682680944f32,
),
1.0f32 / 0.45f32,
Expand All @@ -70,6 +78,8 @@ pub unsafe fn neon_rec709_to_linear(linear: float32x4_t) -> float32x4_t {

#[inline(always)]
pub unsafe fn neon_pure_gamma_function(gamma: float32x4_t, gamma_constant: f32) -> float32x4_t {
let gamma = vmaxq_f32(gamma, vdupq_n_f32(0f32));
let gamma = vminq_f32(gamma, vdupq_n_f32(1f32));
vpowq_n_f32(gamma, gamma_constant)
}

Expand Down
9 changes: 1 addition & 8 deletions src/neon/linear_to_image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,14 @@ unsafe fn neon_gamma_vld<const CHANNELS_CONFIGURATION: u8, const USE_ALPHA: bool
}
}

let zeros = vdupq_n_f32(0f32);
let ones = vdupq_n_f32(1f32);

r_f32 = vmaxq_f32(vminq_f32(r_f32, ones), zeros);
g_f32 = vmaxq_f32(vminq_f32(g_f32, ones), zeros);
b_f32 = vmaxq_f32(vminq_f32(b_f32, ones), zeros);

r_f32 = transfer(r_f32);
g_f32 = transfer(g_f32);
b_f32 = transfer(b_f32);
r_f32 = vmulq_f32(r_f32, v_scale_alpha);
g_f32 = vmulq_f32(g_f32, v_scale_alpha);
b_f32 = vmulq_f32(b_f32, v_scale_alpha);
if USE_ALPHA {
a_f32 = vminq_f32(vmulq_f32(a_f32, v_scale_alpha), ones);
a_f32 = vmulq_f32(a_f32, v_scale_alpha);
}
(
vcvtaq_u32_f32(r_f32),
Expand Down
64 changes: 45 additions & 19 deletions src/neon/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,53 @@ pub unsafe fn vrintq_s32(d: float32x4_t) -> int32x4_t {
));
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vfloorq_f32(x: float32x4_t) -> float32x4_t {
let const_1 = vdupq_n_f32(1f32);

let z = vcvtq_s32_f32(x);
let r = vcvtq_f32_s32(z);

return vbslq_f32(vcgtq_f32(r, x), vsubq_f32(r, const_1), r);
}

#[inline(always)]
pub unsafe fn vexpq_f32(x: float32x4_t) -> float32x4_t {
vexpq_f32_impl::<false>(x)
}

#[inline(always)]
unsafe fn vexpq_f32_impl<const PROCESS_NAN: bool>(x: float32x4_t) -> float32x4_t {
let l2e = vdupq_n_f32(std::f32::consts::LOG2_E); /* log2(e) */
let c0 = vdupq_n_f32(0.3371894346f32);
let c1 = vdupq_n_f32(0.657636276f32);
let c2 = vdupq_n_f32(1.00172476f32);

/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
let t = vmulq_f32(x, l2e); /* t = log2(e) * x */
let e = vfloorq_f32(t); /* floor(t) */
let i = vcvtq_s32_f32(e); /* (int)floor(t) */
let f = vsubq_f32(t, e); /* f = t - floor(t) */
let mut p = c0; /* c0 */
p = prefer_vfmaq_f32(c1, p, f); /* c0 * f + c1 */
p = prefer_vfmaq_f32(c2, p, f); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
let j = vshlq_n_s32::<23>(i); /* i << 23 */
let r = vreinterpretq_f32_s32(vaddq_s32(j, vreinterpretq_s32_f32(p))); /* r = p * 2^i*/
if PROCESS_NAN {
let inf = vdupq_n_f32(f32::INFINITY);
let max_input = vdupq_n_f32(88.72283f32); // Approximately ln(2^127.5)
let min_input = vdupq_n_f32(-87.33654f32); // Approximately ln(2^-125)
let poly = vbslq_f32(vcltq_f32(x, min_input), vdupq_n_f32(0f32), r);
let poly = vbslq_f32(vcgtq_f32(x, max_input), inf, poly);
return poly;
} else {
return r;
}
}

#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vexpq_f32_ulp3(x: float32x4_t) -> float32x4_t {
let c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
let c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
let c3 = vreinterpretq_f32_u32(vdupq_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
Expand Down Expand Up @@ -149,10 +189,6 @@ pub unsafe fn vexpq_f32(x: float32x4_t) -> float32x4_t {
return poly;
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
/// High precision exp. ULP = 1.0
Expand Down Expand Up @@ -189,12 +225,7 @@ pub unsafe fn vexpq_f32_ulp1(d: float32x4_t) -> float32x4_t {
u
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vlogq_f32(x: float32x4_t) -> float32x4_t {
let const_ln127 = vdupq_n_s32(127); // 127
let const_ln2 = vdupq_n_f32(std::f32::consts::LN_2); // ln(2)
Expand Down Expand Up @@ -484,12 +515,7 @@ pub unsafe fn vcbrtq_f32_ulp35(d: float32x4_t) -> float32x4_t {
return y;
}

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm"),
target_feature = "neon"
))]
#[inline(always)]
#[allow(dead_code)]
pub unsafe fn vcolorq_matrix_f32(
r: float32x4_t,
g: float32x4_t,
Expand Down
21 changes: 15 additions & 6 deletions src/sse/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ pub unsafe fn _mm_select_si128(mask: __m128i, true_vals: __m128i, false_vals: __

#[inline(always)]
pub unsafe fn _mm_exp_ps(x: __m128) -> __m128 {
_mm_exp_ps_impl::<false>(x)
}

#[inline(always)]
unsafe fn _mm_exp_ps_impl<const PROCESS_NAN: bool>(x: __m128) -> __m128 {
let l2e = _mm_set1_ps(std::f32::consts::LOG2_E); /* log2(e) */
let c0 = _mm_set1_ps(0.3371894346f32);
let c1 = _mm_set1_ps(0.657636276f32);
Expand All @@ -107,12 +112,16 @@ pub unsafe fn _mm_exp_ps(x: __m128) -> __m128 {
p = _mm_prefer_fma_ps(c2, p ,f); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
let j = _mm_slli_epi32::<23>(i); /* i << 23 */
let r = _mm_castsi128_ps(_mm_add_epi32(j, _mm_castps_si128(p))); /* r = p * 2^i*/
let inf = _mm_set1_ps(f32::INFINITY);
let max_input = _mm_set1_ps(88.72283f32); // Approximately ln(2^127.5)
let min_input = _mm_set1_ps(-87.33654f32); // Approximately ln(2^-125)
let poly = _mm_select_ps(_mm_cmplt_ps(x, min_input), _mm_setzero_ps(), r);
let poly = _mm_select_ps(_mm_cmpgt_ps(x, max_input), inf, poly);
return poly;
if PROCESS_NAN {
let inf = _mm_set1_ps(f32::INFINITY);
let max_input = _mm_set1_ps(88.72283f32); // Approximately ln(2^127.5)
let min_input = _mm_set1_ps(-87.33654f32); // Approximately ln(2^-125)
let poly = _mm_select_ps(_mm_cmplt_ps(x, min_input), _mm_setzero_ps(), r);
let poly = _mm_select_ps(_mm_cmpgt_ps(x, max_input), inf, poly);
return poly;
} else {
return r;
}
}

#[inline(always)]
Expand Down

0 comments on commit 5232464

Please sign in to comment.