Skip to content

Commit

Permalink
Math bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Jun 21, 2024
1 parent 8203988 commit acdf13d
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 41 deletions.
18 changes: 12 additions & 6 deletions src/app/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::arch::aarch64::{vdupq_n_f32, vgetq_lane_f32};
use std::time::Instant;

use image::io::Reader as ImageReader;
Expand All @@ -14,6 +15,13 @@ pub const fn shuffle(z: u32, y: u32, x: u32, w: u32) -> i32 {
}

fn main() {
unsafe {
let y = vdupq_n_f32(0f32);
let x = vdupq_n_f32(0f32);
let v = vatan2q_f32(y, x);
let val = vgetq_lane_f32::<0>(v);
print!("{}", val);
}
let r = 140;
let g = 164;
let b = 177;
Expand All @@ -23,7 +31,7 @@ fn main() {
println!("HSL {:?}", hsl);
println!("Back RGB {:?}", hsl.to_rgb8());

let img = ImageReader::open("./assets/beach_horizon.jpg")
let img = ImageReader::open("./assets/horse.png")
.unwrap()
.decode()
.unwrap();
Expand All @@ -34,7 +42,7 @@ fn main() {
let mut src_bytes = img.as_bytes();
let width = dimensions.0;
let height = dimensions.1;
let components = 3;
let components = 4;
//
// let mut dst_rgba = vec![];
// dst_rgba.resize(4usize * width as usize * height as usize, 0u8);
Expand All @@ -58,14 +66,13 @@ fn main() {
lab_store.resize(width as usize * components * height as usize, 0f32);
let src_stride = width * components as u32;
let start_time = Instant::now();
rgb_to_linear(
rgba_to_lch_with_alpha(
src_bytes,
src_stride,
&mut lab_store,
store_stride as u32,
width,
height,
TransferFunction::Srgb,
);
let elapsed_time = start_time.elapsed();
// Print the elapsed time in milliseconds
Expand Down Expand Up @@ -93,14 +100,13 @@ fn main() {
// }

let start_time = Instant::now();
linear_to_rgb(
lch_with_alpha_to_rgba(
&lab_store,
store_stride as u32,
&mut dst_slice,
src_stride,
width,
height,
TransferFunction::Srgb,
);

let elapsed_time = start_time.elapsed();
Expand Down
10 changes: 7 additions & 3 deletions src/avx/gamma_curves.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ pub unsafe fn avx2_rec709_to_linear(gamma: __m256) -> __m256 {

#[inline(always)]
pub unsafe fn avx2_pure_gamma(x: __m256, value: f32) -> __m256 {
let x = _mm256_max_ps(x, _mm256_setzero_ps());
let x = _mm256_min_ps(x, _mm256_set1_ps(1f32));
_mm256_pow_n_ps(x, value)
let zeros = _mm256_setzero_ps();
let ones = _mm256_set1_ps(1f32);
let zero_mask = _mm256_cmp_ps::<_CMP_LE_OS>(x, zeros);
let ones_mask = _mm256_cmp_ps::<_CMP_GE_OS>(x, ones);
let mut rs = _mm256_pow_n_ps(x, value);
rs = crate::avx::math::_mm256_select_ps(zero_mask, zeros, rs);
crate::avx::math::_mm256_select_ps(ones_mask, ones, rs)
}

#[inline(always)]
Expand Down
22 changes: 20 additions & 2 deletions src/avx/math.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::sse::{_mm_mulsign_ps, _mm_select_ps};
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -47,7 +48,9 @@ unsafe fn _mm256_taylorpoly_ps(
}

#[inline(always)]
pub unsafe fn _mm256_log_ps(v: __m256) -> __m256 {
pub unsafe fn _mm256_log_ps<const HANDLE_NAN: bool>(v: __m256) -> __m256 {
let zeros = _mm256_setzero_ps();
let nan_mask = _mm256_cmp_ps::<_CMP_LE_OS>(v, zeros);
let const_ln127 = _mm256_set1_epi32(127); // 127
let const_ln2 = _mm256_set1_ps(std::f32::consts::LN_2); // ln(2)

Expand All @@ -71,6 +74,13 @@ pub unsafe fn _mm256_log_ps(v: __m256) -> __m256 {
);

poly = _mm256_prefer_fma_ps(poly, _mm256_cvtepi32_ps(m), const_ln2);

if HANDLE_NAN {
poly = _mm256_select_ps(nan_mask, _mm256_set1_ps(-f32::INFINITY), poly);
} else {
poly = _mm256_select_ps(nan_mask, zeros, poly);
}

poly
}

Expand Down Expand Up @@ -576,7 +586,15 @@ unsafe fn _mm256_atan2q_ps_impl(y: __m256, x: __m256) -> __m256 {
#[inline(always)]
pub unsafe fn _mm256_atan2_ps(y: __m256, x: __m256) -> __m256 {
let r = _mm256_atan2q_ps_impl(_mm256_abs_ps(y), x);
let r = _mm256_mulsign_ps(r, x);
let mut r = _mm256_mulsign_ps(r, x);
let zeros = _mm256_setzero_ps();
let y_zero_mask = _mm256_cmp_ps::<_CMP_EQ_OS>(y, zeros);
r = _mm256_select_ps(
_mm256_cmp_ps::<_CMP_EQ_OS>(x, zeros),
_mm256_set1_ps(std::f32::consts::FRAC_PI_2),
r,
);
r = _mm256_select_ps(y_zero_mask, zeros, r);
_mm256_mulsign_ps(r, y)
}

Expand Down
8 changes: 0 additions & 8 deletions src/hsl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::rgb::Rgb;

#[allow(dead_code)]
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq)]
pub struct Hsl {
pub h: f32,
Expand All @@ -9,7 +8,6 @@ pub struct Hsl {
}

impl Hsl {
#[allow(dead_code)]
pub fn new(h: u16, s: u16, l: u16) -> Hsl {
Hsl {
h: h as f32,
Expand All @@ -22,12 +20,10 @@ impl Hsl {
Hsl { h, s, l }
}

#[allow(dead_code)]
pub fn from_rgb(rgb: &Rgb<u8>) -> Hsl {
rgb2hsl(rgb.r, rgb.g, rgb.b)
}

#[allow(dead_code)]
pub fn to_rgb8(&self) -> Rgb<u8> {
let c = (1f32 - (2f32 * self.l - 1f32).abs()) * self.s;
let x = c * (1f32 - ((self.h / 60f32) % 2f32 - 1f32).abs());
Expand All @@ -54,22 +50,18 @@ impl Hsl {
}
}

#[allow(dead_code)]
pub fn to_rgb(&self) -> Rgb<u8> {
self.to_rgb8()
}

#[allow(dead_code)]
pub fn get_saturation(&self) -> u16 {
((self.s * 100f32) as u16).min(100u16)
}

#[allow(dead_code)]
pub fn get_lightness(&self) -> u16 {
((self.l * 100f32) as u16).min(100u16)
}

#[allow(dead_code)]
pub fn get_hue(&self) -> u16 {
(self.h as u16).min(360)
}
Expand Down
10 changes: 0 additions & 10 deletions src/hsv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ static HSV_U8_SCALE: f32 = 1f32 / 255f32;
static HSV_PERCENTAGE_SCALE: f32 = 1f32 / 100f32;

impl Hsv {
#[allow(dead_code)]
pub fn new(h: u16, s: u16, l: u16) -> Hsv {
Hsv {
h: h as f32,
Expand All @@ -20,12 +19,10 @@ impl Hsv {
}
}

#[allow(dead_code)]
pub fn from_components(h: f32, s: f32, v: f32) -> Hsv {
Hsv { h, s, v }
}

#[allow(dead_code)]
pub fn from(rgb: &Rgb<u8>) -> Hsv {
let (h, s, v) = rgb_to_hsv(
rgb.r as f32 * HSV_U8_SCALE,
Expand All @@ -35,7 +32,6 @@ impl Hsv {
return Hsv { h, s, v };
}

#[allow(dead_code)]
pub fn to_rgb8(&self) -> Rgb<u8> {
let (rf, gf, bf) = hsv_to_rgb(self.h, self.s, self.v);
return Rgb {
Expand All @@ -45,32 +41,26 @@ impl Hsv {
};
}

#[allow(dead_code)]
pub fn get_hue(&self) -> f32 {
self.h
}

#[allow(dead_code)]
pub fn get_saturation(&self) -> f32 {
self.s
}

#[allow(dead_code)]
pub fn get_value(&self) -> f32 {
self.v
}

#[allow(dead_code)]
pub fn get_hue_p(&self) -> u16 {
self.h.max(0f32).min(360f32) as u16
}

#[allow(dead_code)]
pub fn get_saturation_p(&self) -> u16 {
(self.s * 100f32).max(0f32).min(100f32) as u16
}

#[allow(dead_code)]
pub fn get_value_p(&self) -> u16 {
(self.v * 100f32).max(0f32).min(100f32) as u16
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub use xyza_laba_to_image::xyz_with_alpha_to_rgba;
pub use image_to_sigmoidal::bgra_to_sigmoidal;
pub use image_to_sigmoidal::rgb_to_sigmoidal;
pub use image_to_sigmoidal::rgba_to_sigmoidal;
pub use neon::*;

Check failure on line 112 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Build

unresolved import `neon`
pub use rgb_expand::*;
pub use sigmoidal::Sigmoidal;
pub use sigmoidal_to_image::sigmoidal_to_bgra;
Expand Down
10 changes: 7 additions & 3 deletions src/neon/gamma_curves.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ pub unsafe fn neon_rec709_to_linear(gamma: float32x4_t) -> float32x4_t {

#[inline(always)]
pub unsafe fn neon_pure_gamma_function(gamma: float32x4_t, gamma_constant: f32) -> float32x4_t {
let gamma = vmaxq_f32(gamma, vdupq_n_f32(0f32));
let gamma = vminq_f32(gamma, vdupq_n_f32(1f32));
vpowq_n_f32(gamma, gamma_constant)
let zero_mask = vclezq_f32(gamma);
let ones = vdupq_n_f32(1f32);
let zeros = vdupq_n_f32(0f32);
let ones_mask = vcgeq_f32(gamma, ones);
let mut rs = vpowq_n_f32(gamma, gamma_constant);
rs = vbslq_f32(zero_mask, zeros, rs);
vbslq_f32(ones_mask, ones, rs)
}

#[inline(always)]
Expand Down
29 changes: 26 additions & 3 deletions src/neon/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ pub unsafe fn vexpq_f32_ulp1(d: float32x4_t) -> float32x4_t {
}

#[inline(always)]
pub unsafe fn vlogq_f32(x: float32x4_t) -> float32x4_t {
pub unsafe fn vlogq_f32<const HANDLE_NAN: bool>(x: float32x4_t) -> float32x4_t {
let nan_mask = vclezq_f32(x);
let const_ln127 = vdupq_n_s32(127); // 127
let const_ln2 = vdupq_n_f32(std::f32::consts::LN_2); // ln(2)

Expand All @@ -240,6 +241,12 @@ pub unsafe fn vlogq_f32(x: float32x4_t) -> float32x4_t {
// Reconstruct
poly = prefer_vfmaq_f32(poly, vcvtq_f32_s32(m), const_ln2);

if HANDLE_NAN {
poly = vbslq_f32(nan_mask, vdupq_n_f32(-f32::INFINITY), poly);
} else {
poly = vbslq_f32(nan_mask, vdupq_n_f32(0f32), poly);
}

return poly;
}

Expand Down Expand Up @@ -295,7 +302,7 @@ pub unsafe fn vlogq_f32_ulp35(d: float32x4_t) -> float32x4_t {

#[inline(always)]
pub unsafe fn vpowq_f32(val: float32x4_t, n: float32x4_t) -> float32x4_t {
return vexpq_f32(vmulq_f32(n, vlogq_f32(val)));
return vexpq_f32(vmulq_f32(n, vlogq_f32::<false>(val)));
}

#[inline(always)]
Expand Down Expand Up @@ -552,10 +559,26 @@ unsafe fn vatan2q_f32_impl(y: float32x4_t, x: float32x4_t) -> float32x4_t {
t
}

#[inline(always)]
pub unsafe fn visnegq_f32(x: float32x4_t) -> uint32x4_t {
vcltzq_f32(x)
}

#[inline(always)]
pub unsafe fn vatan2q_f32(y: float32x4_t, x: float32x4_t) -> float32x4_t {
let r = vatan2q_f32_impl(vabsq_f32(y), x);
let r = vmulsignq_f32(r, x);
let mut r = vmulsignq_f32(r, x);
let y_zero_mask = vceqzq_f32(y);
r = vbslq_f32(vceqzq_f32(x), vdupq_n_f32(std::f32::consts::FRAC_PI_2), r);
r = vbslq_f32(
y_zero_mask,
vbslq_f32(
visnegq_f32(x),
vdupq_n_f32(std::f32::consts::PI),
vdupq_n_f32(0f32),
),
r,
);
vmulsignq_f32(r, y)
}

Expand Down
1 change: 1 addition & 0 deletions src/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub use gamma_curves::*;
pub use hsv_to_image::*;
pub use image_to_hsv::*;
pub use linear_to_image::*;
pub use math::vatan2q_f32;
pub use to_linear::*;
pub use to_linear_u8::*;
pub use to_sigmoidal::neon_image_to_sigmoidal;
Expand Down
2 changes: 1 addition & 1 deletion src/neon/sigmoidal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub(crate) unsafe fn neon_sigmoidal_to_color(x: float32x4_t) -> float32x4_t {
let k = vmulq_f32(x, vrecpeq_f32(den));
let zeros = vdupq_n_f32(0f32);
let zero_mask_2 = vcleq_f32(k, zeros);
let ln = vlogq_f32(k);
let ln = vlogq_f32::<false>(k);
let rs = vbslq_f32(vandq_u32(zero_mask_1, zero_mask_2), zeros, ln);
return rs;
}
Expand Down
10 changes: 7 additions & 3 deletions src/sse/gamma_curves.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ pub unsafe fn sse_rec709_to_linear(gamma: __m128) -> __m128 {

#[inline(always)]
pub unsafe fn sse_pure_gamma(gamma: __m128, value: f32) -> __m128 {
let gamma = _mm_max_ps(gamma, _mm_setzero_ps());
let gamma = _mm_min_ps(gamma, _mm_set1_ps(1f32));
_mm_pow_n_ps(gamma, value)
let zeros = _mm_setzero_ps();
let zero_mask = _mm_cmple_ps(gamma, zeros);
let ones = _mm_set1_ps(1f32);
let ones_mask = _mm_cmpge_ps(gamma, ones);
let mut rs = _mm_pow_n_ps(gamma, value);
rs = _mm_select_ps(zero_mask, zeros, rs);
_mm_select_ps(ones_mask, ones, rs)
}

#[inline(always)]
Expand Down
Loading

0 comments on commit acdf13d

Please sign in to comment.