Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missed optimization: lshr(smax(x, 0), y) can be transformed to smax(ashr(x, y), 0) #109471

Open
okaneco opened this issue Sep 20, 2024 · 0 comments

Comments

@okaneco
Copy link

okaneco commented Sep 20, 2024

While refactoring Rust code, I ended up with the src loop body which produces extra max/min instructions compared to tgt.
Reordering the right shift to be before the signed max with 0 produces better auto-vectorized assembly.

The arithmetic shift right preserves the sign, so the saturating truncation instructions can handle clamping to 0.
https://rust.godbolt.org/z/WW4sGoaf4

pub fn src(input: &[i32], output: &mut [u8]) {
    const N: usize = 2;
    for (&i, o) in input.iter().zip(output.iter_mut()) {
        *o = (i.max(0) >> N).min(255) as u8;
    }
}

pub fn tgt(input: &[i32], output: &mut [u8]) {
    const N: usize = 2;
    for (&i, o) in input.iter().zip(output.iter_mut()) {
        *o = (i >> N).max(0).min(255) as u8;
    }
}
Assembly instructions
; src                                          ; tgt
.LBB0_6:                                       .LBB0_6:
movdqu  xmm2, xmmword ptr [rdi + 4*r8]         movdqu  xmm0, xmmword ptr [rdi + 4*r8]
movdqu  xmm3, xmmword ptr [rdi + 4*r8 + 16]    movdqu  xmm1, xmmword ptr [rdi + 4*r8 + 16]
movdqa  xmm4, xmm2                             psrad   xmm0, 2
pcmpgtd xmm4, xmm0                             packssdw        xmm0, xmm0
pand    xmm4, xmm2                             packuswb        xmm0, xmm0
movdqa  xmm2, xmm3                             psrad   xmm1, 2
pcmpgtd xmm2, xmm0                             packssdw        xmm1, xmm1
pand    xmm2, xmm3                             packuswb        xmm1, xmm1
psrld   xmm4, 2                                movd    dword ptr [rdx + r8], xmm0
psrld   xmm2, 2                                movd    dword ptr [rdx + r8 + 4], xmm1
movdqa  xmm3, xmm1                             add     r8, 8
pcmpgtd xmm3, xmm4                             cmp     rsi, r8
pand    xmm4, xmm3                             jne     .LBB0_6
pandn   xmm3, xmm1                             cmp     rcx, rsi
por     xmm3, xmm4                             je      .LBB0_8
packuswb        xmm3, xmm3
packuswb        xmm3, xmm3
movdqa  xmm4, xmm1
pcmpgtd xmm4, xmm2
pand    xmm2, xmm4
pandn   xmm4, xmm1
por     xmm4, xmm2
packuswb        xmm4, xmm4
packuswb        xmm4, xmm4
movd    dword ptr [rdx + r8], xmm3
movd    dword ptr [rdx + r8 + 4], xmm4
add     r8, 8
cmp     rsi, r8
jne     .LBB0_6
cmp     rcx, rsi
je      .LBB0_8

; `src` with rustc flags `-Copt-level=3 -Ctarget-cpu=x86-64-v2`
.LBB0_6:
movdqu  xmm2, xmmword ptr [rdi + 4*r8]
movdqu  xmm3, xmmword ptr [rdi + 4*r8 + 16]
pmaxsd  xmm2, xmm0
pmaxsd  xmm3, xmm0
psrld   xmm2, 2
psrld   xmm3, 2
pminud  xmm2, xmm1
packusdw        xmm2, xmm2
packuswb        xmm2, xmm2
pminud  xmm3, xmm1
packusdw        xmm3, xmm3
packuswb        xmm3, xmm3
movd    dword ptr [rdx + r8], xmm2
movd    dword ptr [rdx + r8 + 4], xmm3
add     r8, 8
cmp     rsi, r8
jne     .LBB0_6
cmp     rcx, rsi
je      .LBB0_8

Emitted IR - https://alive2.llvm.org/ce/z/fa8cRT

src body

vector.body:
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %0 = getelementptr inbounds i32, ptr %input.0, i64 %index
  %1 = getelementptr inbounds i8, ptr %output.0, i64 %index
  %2 = getelementptr inbounds i8, ptr %0, i64 16
  %wide.load = load <4 x i32>, ptr %0, align 4
  %wide.load8 = load <4 x i32>, ptr %2, align 4
  %3 = tail call <4 x i32> @llvm.smax.v4i32(<4 x i32> %wide.load, <4 x i32> zeroinitializer)
  %4 = tail call <4 x i32> @llvm.smax.v4i32(<4 x i32> %wide.load8, <4 x i32> zeroinitializer)
  %5 = lshr <4 x i32> %3, <i32 2, i32 2, i32 2, i32 2>
  %6 = lshr <4 x i32> %4, <i32 2, i32 2, i32 2, i32 2>
  %7 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %5, <4 x i32> <i32 255, i32 255, i32 255, i32 255>)
  %8 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %6, <4 x i32> <i32 255, i32 255, i32 255, i32 255>)
  %9 = trunc nuw <4 x i32> %7 to <4 x i8>
  %10 = trunc nuw <4 x i32> %8 to <4 x i8>
  %11 = getelementptr inbounds i8, ptr %1, i64 4
  store <4 x i8> %9, ptr %1, align 1
  store <4 x i8> %10, ptr %11, align 1
  %index.next = add nuw i64 %index, 8
  %12 = icmp eq i64 %index.next, %n.vec
  br i1 %12, label %middle.block, label %vector.body

tgt body

vector.body:
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %0 = getelementptr inbounds i32, ptr %input.0, i64 %index
  %1 = getelementptr inbounds i8, ptr %output.0, i64 %index
  %2 = getelementptr inbounds i8, ptr %0, i64 16
  %wide.load = load <4 x i32>, ptr %0, align 4
  %wide.load9 = load <4 x i32>, ptr %2, align 4
  %3 = ashr <4 x i32> %wide.load, <i32 2, i32 2, i32 2, i32 2>
  %4 = ashr <4 x i32> %wide.load9, <i32 2, i32 2, i32 2, i32 2>
  %5 = tail call <4 x i32> @llvm.smax.v4i32(<4 x i32> %3, <4 x i32> zeroinitializer)
  %6 = tail call <4 x i32> @llvm.smax.v4i32(<4 x i32> %4, <4 x i32> zeroinitializer)
  %7 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %5, <4 x i32> <i32 255, i32 255, i32 255, i32 255>)
  %8 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %6, <4 x i32> <i32 255, i32 255, i32 255, i32 255>)
  %9 = trunc nuw <4 x i32> %7 to <4 x i8>
  %10 = trunc nuw <4 x i32> %8 to <4 x i8>
  %11 = getelementptr inbounds i8, ptr %1, i64 4
  store <4 x i8> %9, ptr %1, align 1
  store <4 x i8> %10, ptr %11, align 1
  %index.next = add nuw i64 %index, 8
  %12 = icmp eq i64 %index.next, %n.vec
  br i1 %12, label %middle.block, label %vector.body

alive2 proof - https://alive2.llvm.org/ce/z/iUbk-i

define i8 @src(i32 %input, i32 %shift) {
  %1 = tail call i32 @llvm.smax.i32(i32 %input, i32 0)
  %2 = lshr i32 %1, %shift
  %3 = tail call i32 @llvm.umin.i32(i32 %2, i32 255)
  %4 = trunc nuw i32 %3 to i8
  ret i8 %4
}

define i8 @tgt(i32 %input, i32 %shift) {
  %1 = ashr i32 %input, %shift
  %2 = tail call i32 @llvm.smax.i32(i32 %1, i32 0)
  %3 = tail call i32 @llvm.umin.i32(i32 %2, i32 255)
  %4 = trunc nuw i32 %3 to i8
  ret i8 %4
}

A real world case of this was from the Rust image-webp crate.
image-rs/image-webp#72

You can see some of the effect if you right-click on L11 in both editors and reveal linked code. It results in about 40% less instructions in the main loop at label .LBB0_9, L20 in the editors with reveal linked code.
https://rust.godbolt.org/z/h1nM6cG4K

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants