Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of stb_image optimized paeth unfiltering #539

Merged
merged 16 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ jobs:
feature_check:
strategy:
matrix:
features: ["", "benchmarks"]
runs-on: ubuntu-latest
features: ["", "unstable", "benchmarks"]
os: [ubuntu-latest, macos-latest] # macos-latest is ARM
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
Expand All @@ -54,7 +55,10 @@ jobs:
rustup target add powerpc-unknown-linux-gnu
cargo build --target powerpc-unknown-linux-gnu
test_all:
runs-on: ubuntu-latest
strategy:
matrix:
os: [ubuntu-latest, macos-latest] # macos-latest is ARM
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- run: rustup default stable
Expand Down
139 changes: 62 additions & 77 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ use crate::common::BytesPerPixel;
/// TODO(https://github.com/rust-lang/rust/issues/86656): Stop gating this module behind the
/// "unstable" feature of the `png` crate. This should be possible once the "portable_simd"
/// feature of Rust gets stabilized.
#[cfg(feature = "unstable")]
///
/// This is only known to help on x86, with no change measured on most benchmarks on ARM,
/// and even severely regressing some of them.
/// So despite the code being portable, we only enable this for x86.
/// We can add more platforms once this code is proven to be beneficial for them.
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
mod simd {
use std::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd};
use std::simd::num::{SimdInt, SimdUint};
use std::simd::{u8x4, u8x8, LaneCount, Simd, SimdElement, SupportedLaneCount};

Expand Down Expand Up @@ -39,18 +43,6 @@ mod simd {
out.into()
}

/// This is an equivalent of the `PaethPredictor` function from
/// [the spec](http://www.libpng.org/pub/png/spec/1.2/PNG-Filters.html#Filter-type-4-Paeth)
/// except that it simultaneously calculates the predictor for all SIMD lanes.
/// Mapping between parameter names and pixel positions can be found in
/// [a diagram here](https://www.w3.org/TR/png/#filter-byte-positions).
///
/// Examples of how different pixel types may be represented as multiple SIMD lanes:
/// - RGBA => 4 lanes of `i16x4` contain R, G, B, A
/// - RGB => 4 lanes of `i16x4` contain R, G, B, and a ignored 4th value
///
/// The SIMD algorithm below is based on [`libpng`](https://github.com/glennrp/libpng/blob/f8e5fa92b0e37ab597616f554bee254157998227/intel/filter_sse2_intrinsics.c#L261-L280).
///
/// Functionally equivalent to `simd::paeth_predictor` but does not temporarily convert
/// the SIMD elements to `i16`.
fn paeth_predictor_u8<const N: usize>(
Expand All @@ -61,44 +53,11 @@ mod simd {
where
LaneCount<N>: SupportedLaneCount,
{
// Calculates the absolute difference between `a` and `b`.
fn abs_diff_simd<const N: usize>(a: Simd<u8, N>, b: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
a.simd_max(b) - b.simd_min(a)
let mut out = [0; N];
for i in 0..N {
out[i] = super::filter_paeth_decode(a[i].into(), b[i].into(), c[i].into());
}

// Uses logic from `filter::filter_paeth` to calculate absolute values
// entirely in `Simd<u8, N>`. This method avoids unpacking and packing
// penalties resulting from conversion to and from `Simd<i16, N>`.
// ```
// let pa = b.max(c) - c.min(b);
// let pb = a.max(c) - c.min(a);
// let pc = if (a < c) == (c < b) {
// pa.max(pb) - pa.min(pb)
// } else {
// 255
// };
// ```
let pa = abs_diff_simd(b, c);
let pb = abs_diff_simd(a, c);
let pc = a
.simd_lt(c)
.simd_eq(c.simd_lt(b))
.select(abs_diff_simd(pa, pb), Simd::splat(255));

let smallest = pc.simd_min(pa.simd_min(pb));

// Paeth algorithm breaks ties favoring a over b over c, so we execute the following
// lane-wise selection:
//
// if smalest == pa
// then select a
// else select (if smallest == pb then select b else select c)
smallest
.simd_eq(pa)
.select(a, smallest.simd_eq(pb).select(b, c))
out.into()
}

/// Memory of previous pixels (as needed to unfilter `FilterType::Paeth`).
Expand Down Expand Up @@ -318,32 +277,44 @@ impl Default for AdaptiveFilterType {
}
}

#[cfg(target_arch = "x86_64")]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// Decoding seems to optimize better with this algorithm
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;

if pb < min {
min = pb;
out = b;
}
if pc < min {
out = c;
}

out
// Decoding optimizes better with this algorithm than with `filter_paeth()`
//
// This formulation looks very different from the reference in the PNG spec, but is
// actually equivalent and has favorable data dependencies and admits straightforward
// generation of branch-free code, which helps performance significantly.
//
// Adapted from public domain PNG implementation:
// https://github.com/nothings/stb/blob/5c205738c191bcb0abc65c4febfa9bd25ff35234/stb_image.h#L4657-L4668
let thresh = i16::from(c) * 3 - (i16::from(a) + i16::from(b));
let lo = a.min(b);
let hi = a.max(b);
let t0 = if hi as i16 <= thresh { lo } else { c };
let t1 = if thresh <= lo as i16 { hi } else { t0 };
return t1;
}

#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 {
// Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD
let pa = (b - c).abs();
let pb = (a - c).abs();
let pc = ((a - c) + (b - c)).abs();
// Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD types.
// Used for bpp=3 and bpp=6
let thresh = c * 3 - (a + b);
let lo = a.min(b);
let hi = a.max(b);
let t0 = if hi <= thresh { lo } else { c };
let t1 = if thresh <= lo { hi } else { t0 };
return t1;
}

#[cfg(not(target_arch = "x86_64"))]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// On ARM this algorithm performs much better than the one above adapted from stb,
// and this is the better-studied algorithm we've always used here,
// so we default to it on all non-x86 platforms.
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;
Expand Down Expand Up @@ -769,7 +740,8 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Three => {
#[cfg(feature = "unstable")]
// Do not enable this algorithm on ARM, that would be a big performance hit
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth3(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -797,7 +769,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Four => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<4>(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -828,7 +800,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Six => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth6(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -865,7 +837,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Eight => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<8>(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -1160,6 +1132,19 @@ mod test {
}
}

#[test]
#[ignore] // takes ~20s without optimizations
fn paeth_impls_are_equivalent() {
use super::{filter_paeth, filter_paeth_decode};
for a in 0..=255 {
for b in 0..=255 {
for c in 0..=255 {
assert_eq!(filter_paeth(a, b, c), filter_paeth_decode(a, b, c));
}
}
}
}

#[test]
fn roundtrip_ascending_previous_line() {
// A multiple of 8, 6, 4, 3, 2, 1
Expand Down
Loading