From fb660c23f550d9fe744a1688fc0a8c9378f6db4e Mon Sep 17 00:00:00 2001 From: Jonathan Behrens Date: Thu, 5 Dec 2024 18:46:40 -0800 Subject: [PATCH] Small refactoring of paeth filter logic (#544) --- src/filter.rs | 293 ++++++++++++++++++++++++-------------------------- 1 file changed, 140 insertions(+), 153 deletions(-) diff --git a/src/filter.rs b/src/filter.rs index aa8c4aa4..da1e2966 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -38,7 +38,7 @@ mod simd { { let mut out = [0; N]; for i in 0..N { - out[i] = super::filter_paeth_decode_i16(a[i].into(), b[i].into(), c[i].into()); + out[i] = super::filter_paeth_stbi_i16(a[i].into(), b[i].into(), c[i].into()); } out.into() } @@ -55,7 +55,7 @@ mod simd { { let mut out = [0; N]; for i in 0..N { - out[i] = super::filter_paeth_decode(a[i].into(), b[i].into(), c[i].into()); + out[i] = super::filter_paeth_stbi(a[i].into(), b[i].into(), c[i].into()); } out.into() } @@ -277,9 +277,30 @@ impl Default for AdaptiveFilterType { } } -#[cfg(target_arch = "x86_64")] -fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 { - // Decoding optimizes better with this algorithm than with `filter_paeth()` +fn filter_paeth(a: u8, b: u8, c: u8) -> u8 { + // On ARM this algorithm performs much better than the one above adapted from stb, + // and this is the better-studied algorithm we've always used here, + // so we default to it on all non-x86 platforms. + let pa = (i16::from(b) - i16::from(c)).abs(); + let pb = (i16::from(a) - i16::from(c)).abs(); + let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs(); + + let mut out = a; + let mut min = pa; + + if pb < min { + min = pb; + out = b; + } + if pc < min { + out = c; + } + + out +} + +fn filter_paeth_stbi(a: u8, b: u8, c: u8) -> u8 { + // Decoding optimizes better with this algorithm than with `filter_paeth` // // This formulation looks very different from the reference in the PNG spec, but is // actually equivalent and has favorable data dependencies and admits straightforward @@ -295,9 +316,9 @@ fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 { return t1; } -#[cfg(all(feature = "unstable", target_arch = "x86_64"))] -fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 { - // Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD types. +#[cfg(any(test, all(feature = "unstable", target_arch = "x86_64")))] +fn filter_paeth_stbi_i16(a: i16, b: i16, c: i16) -> i16 { + // Like `filter_paeth_stbi` but vectorizes better when wrapped in SIMD types. // Used for bpp=3 and bpp=6 let thresh = c * 3 - (a + b); let lo = a.min(b); @@ -307,30 +328,7 @@ fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 { return t1; } -#[cfg(not(target_arch = "x86_64"))] -fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 { - // On ARM this algorithm performs much better than the one above adapted from stb, - // and this is the better-studied algorithm we've always used here, - // so we default to it on all non-x86 platforms. - let pa = (i16::from(b) - i16::from(c)).abs(); - let pb = (i16::from(a) - i16::from(c)).abs(); - let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs(); - - let mut out = a; - let mut min = pa; - - if pb < min { - min = pb; - out = b; - } - if pc < min { - out = c; - } - - out -} - -fn filter_paeth(a: u8, b: u8, c: u8) -> u8 { +fn filter_paeth_fpnge(a: u8, b: u8, c: u8) -> u8 { // This is an optimized version of the paeth filter from the PNG specification, proposed by // Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates // entirely on unsigned 8-bit quantities, making it more conducive to vectorization. @@ -706,7 +704,15 @@ pub(crate) fn unfilter( } } }, + #[allow(unreachable_code)] Paeth => { + // Select the fastest Paeth filter implementation based on the target architecture. + let filter_paeth_decode = if cfg!(target_arch = "x86_64") { + filter_paeth_stbi + } else { + filter_paeth + }; + // Paeth filter pixels: // C B D // A X @@ -742,141 +748,116 @@ pub(crate) fn unfilter( BytesPerPixel::Three => { // Do not enable this algorithm on ARM, that would be a big performance hit #[cfg(all(feature = "unstable", target_arch = "x86_64"))] - simd::unfilter_paeth3(previous, current); + { + simd::unfilter_paeth3(previous, current); + return; + } - #[cfg(not(feature = "unstable"))] + let mut a_bpp = [0; 3]; + let mut c_bpp = [0; 3]; + for (chunk, b_bpp) in current.chunks_exact_mut(3).zip(previous.chunks_exact(3)) { - let mut a_bpp = [0; 3]; - let mut c_bpp = [0; 3]; - for (chunk, b_bpp) in - current.chunks_exact_mut(3).zip(previous.chunks_exact(3)) - { - let new_chunk = [ - chunk[0].wrapping_add(filter_paeth_decode( - a_bpp[0], b_bpp[0], c_bpp[0], - )), - chunk[1].wrapping_add(filter_paeth_decode( - a_bpp[1], b_bpp[1], c_bpp[1], - )), - chunk[2].wrapping_add(filter_paeth_decode( - a_bpp[2], b_bpp[2], c_bpp[2], - )), - ]; - *TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk; - a_bpp = new_chunk; - c_bpp = b_bpp.try_into().unwrap(); - } + let new_chunk = [ + chunk[0] + .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])), + chunk[1] + .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])), + chunk[2] + .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])), + ]; + *TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk; + a_bpp = new_chunk; + c_bpp = b_bpp.try_into().unwrap(); } } BytesPerPixel::Four => { #[cfg(all(feature = "unstable", target_arch = "x86_64"))] - simd::unfilter_paeth_u8::<4>(previous, current); + { + simd::unfilter_paeth_u8::<4>(previous, current); + return; + } - #[cfg(not(feature = "unstable"))] + let mut a_bpp = [0; 4]; + let mut c_bpp = [0; 4]; + for (chunk, b_bpp) in current.chunks_exact_mut(4).zip(previous.chunks_exact(4)) { - let mut a_bpp = [0; 4]; - let mut c_bpp = [0; 4]; - for (chunk, b_bpp) in - current.chunks_exact_mut(4).zip(previous.chunks_exact(4)) - { - let new_chunk = [ - chunk[0].wrapping_add(filter_paeth_decode( - a_bpp[0], b_bpp[0], c_bpp[0], - )), - chunk[1].wrapping_add(filter_paeth_decode( - a_bpp[1], b_bpp[1], c_bpp[1], - )), - chunk[2].wrapping_add(filter_paeth_decode( - a_bpp[2], b_bpp[2], c_bpp[2], - )), - chunk[3].wrapping_add(filter_paeth_decode( - a_bpp[3], b_bpp[3], c_bpp[3], - )), - ]; - *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk; - a_bpp = new_chunk; - c_bpp = b_bpp.try_into().unwrap(); - } + let new_chunk = [ + chunk[0] + .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])), + chunk[1] + .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])), + chunk[2] + .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])), + chunk[3] + .wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])), + ]; + *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk; + a_bpp = new_chunk; + c_bpp = b_bpp.try_into().unwrap(); } } BytesPerPixel::Six => { #[cfg(all(feature = "unstable", target_arch = "x86_64"))] - simd::unfilter_paeth6(previous, current); + { + simd::unfilter_paeth6(previous, current); + return; + } - #[cfg(not(feature = "unstable"))] + let mut a_bpp = [0; 6]; + let mut c_bpp = [0; 6]; + for (chunk, b_bpp) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6)) { - let mut a_bpp = [0; 6]; - let mut c_bpp = [0; 6]; - for (chunk, b_bpp) in - current.chunks_exact_mut(6).zip(previous.chunks_exact(6)) - { - let new_chunk = [ - chunk[0].wrapping_add(filter_paeth_decode( - a_bpp[0], b_bpp[0], c_bpp[0], - )), - chunk[1].wrapping_add(filter_paeth_decode( - a_bpp[1], b_bpp[1], c_bpp[1], - )), - chunk[2].wrapping_add(filter_paeth_decode( - a_bpp[2], b_bpp[2], c_bpp[2], - )), - chunk[3].wrapping_add(filter_paeth_decode( - a_bpp[3], b_bpp[3], c_bpp[3], - )), - chunk[4].wrapping_add(filter_paeth_decode( - a_bpp[4], b_bpp[4], c_bpp[4], - )), - chunk[5].wrapping_add(filter_paeth_decode( - a_bpp[5], b_bpp[5], c_bpp[5], - )), - ]; - *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk; - a_bpp = new_chunk; - c_bpp = b_bpp.try_into().unwrap(); - } + let new_chunk = [ + chunk[0] + .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])), + chunk[1] + .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])), + chunk[2] + .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])), + chunk[3] + .wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])), + chunk[4] + .wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])), + chunk[5] + .wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])), + ]; + *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk; + a_bpp = new_chunk; + c_bpp = b_bpp.try_into().unwrap(); } } BytesPerPixel::Eight => { #[cfg(all(feature = "unstable", target_arch = "x86_64"))] - simd::unfilter_paeth_u8::<8>(previous, current); + { + simd::unfilter_paeth_u8::<8>(previous, current); + return; + } - #[cfg(not(feature = "unstable"))] + let mut a_bpp = [0; 8]; + let mut c_bpp = [0; 8]; + for (chunk, b_bpp) in current.chunks_exact_mut(8).zip(previous.chunks_exact(8)) { - let mut a_bpp = [0; 8]; - let mut c_bpp = [0; 8]; - for (chunk, b_bpp) in - current.chunks_exact_mut(8).zip(previous.chunks_exact(8)) - { - let new_chunk = [ - chunk[0].wrapping_add(filter_paeth_decode( - a_bpp[0], b_bpp[0], c_bpp[0], - )), - chunk[1].wrapping_add(filter_paeth_decode( - a_bpp[1], b_bpp[1], c_bpp[1], - )), - chunk[2].wrapping_add(filter_paeth_decode( - a_bpp[2], b_bpp[2], c_bpp[2], - )), - chunk[3].wrapping_add(filter_paeth_decode( - a_bpp[3], b_bpp[3], c_bpp[3], - )), - chunk[4].wrapping_add(filter_paeth_decode( - a_bpp[4], b_bpp[4], c_bpp[4], - )), - chunk[5].wrapping_add(filter_paeth_decode( - a_bpp[5], b_bpp[5], c_bpp[5], - )), - chunk[6].wrapping_add(filter_paeth_decode( - a_bpp[6], b_bpp[6], c_bpp[6], - )), - chunk[7].wrapping_add(filter_paeth_decode( - a_bpp[7], b_bpp[7], c_bpp[7], - )), - ]; - *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk; - a_bpp = new_chunk; - c_bpp = b_bpp.try_into().unwrap(); - } + let new_chunk = [ + chunk[0] + .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])), + chunk[1] + .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])), + chunk[2] + .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])), + chunk[3] + .wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])), + chunk[4] + .wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])), + chunk[5] + .wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])), + chunk[6] + .wrapping_add(filter_paeth_decode(a_bpp[6], b_bpp[6], c_bpp[6])), + chunk[7] + .wrapping_add(filter_paeth_decode(a_bpp[7], b_bpp[7], c_bpp[7])), + ]; + *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk; + a_bpp = new_chunk; + c_bpp = b_bpp.try_into().unwrap(); } } } @@ -1000,7 +981,7 @@ fn filter_internal( .zip(&mut c_chunks) { for i in 0..CHUNK_SIZE { - out[i] = cur[i].wrapping_sub(filter_paeth(a[i], b[i], c[i])); + out[i] = cur[i].wrapping_sub(filter_paeth_fpnge(a[i], b[i], c[i])); } } @@ -1012,11 +993,11 @@ fn filter_internal( .zip(b_chunks.remainder()) .zip(c_chunks.remainder()) { - *out = cur.wrapping_sub(filter_paeth(a, b, c)); + *out = cur.wrapping_sub(filter_paeth_fpnge(a, b, c)); } for i in 0..bpp { - output[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0)); + output[i] = current[i].wrapping_sub(filter_paeth_fpnge(0, previous[i], 0)); } Paeth } @@ -1085,7 +1066,7 @@ fn sum_buffer(buf: &[u8]) -> u64 { #[cfg(test)] mod test { - use super::{filter, unfilter, AdaptiveFilterType, BytesPerPixel, FilterType}; + use super::*; use core::iter; #[test] @@ -1135,11 +1116,17 @@ mod test { #[test] #[ignore] // takes ~20s without optimizations fn paeth_impls_are_equivalent() { - use super::{filter_paeth, filter_paeth_decode}; for a in 0..=255 { for b in 0..=255 { for c in 0..=255 { - assert_eq!(filter_paeth(a, b, c), filter_paeth_decode(a, b, c)); + let baseline = filter_paeth(a, b, c); + let fpnge = filter_paeth_fpnge(a, b, c); + let stbi = filter_paeth_stbi(a, b, c); + let stbi_i16 = filter_paeth_stbi_i16(a as i16, b as i16, c as i16); + + assert_eq!(baseline, fpnge); + assert_eq!(baseline, stbi); + assert_eq!(baseline as i16, stbi_i16); } } }