diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index d1919039..045d473c 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Level, Simd, SimdFrom, SimdInto, arch_types::ArchTypes, seal::Seal}; +use crate::{Level, arch_types::ArchTypes, prelude::*, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -14,8 +14,7 @@ use crate::{ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use core::ops::*; -#[doc = r#" The SIMD token for the "AVX2" and "FMA" level."#] +#[doc = "The SIMD token for the \"AVX2\" and \"FMA\" level."] #[derive(Clone, Copy, Debug)] pub struct Avx2 { pub avx2: crate::core_arch::x86::Avx2, @@ -25,10 +24,10 @@ impl Avx2 { #[doc = r""] #[doc = r" # Safety"] #[doc = r""] - #[doc = r" The AVX2 and FMA CPU feature must be available."] + #[doc = r" The AVX2 and FMA CPU features must be available."] #[inline] pub const unsafe fn new_unchecked() -> Self { - Avx2 { + Self { avx2: unsafe { crate::core_arch::x86::Avx2::new_unchecked() }, } } diff --git a/fearless_simd/src/generated/fallback.rs b/fearless_simd/src/generated/fallback.rs index 7766bfec..1701a140 100644 --- a/fearless_simd/src/generated/fallback.rs +++ b/fearless_simd/src/generated/fallback.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Bytes, Level, Simd, SimdInto, arch_types::ArchTypes, seal::Seal}; +use crate::{Level, arch_types::ArchTypes, prelude::*, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -74,7 +74,7 @@ impl FloatExt for f64 { libm::trunc(self) } } -#[doc = r#" The SIMD token for the "fallback" level."#] +#[doc = "The SIMD token for the \"fallback\" level."] #[derive(Clone, Copy, Debug)] pub struct Fallback { pub fallback: crate::core_arch::fallback::Fallback, @@ -82,7 +82,7 @@ pub struct Fallback { impl Fallback { #[inline] pub const fn new() -> Self { - Fallback { + Self { fallback: crate::core_arch::fallback::Fallback::new(), } } diff --git a/fearless_simd/src/generated/neon.rs b/fearless_simd/src/generated/neon.rs index dc462973..62f7c69d 100644 --- a/fearless_simd/src/generated/neon.rs +++ b/fearless_simd/src/generated/neon.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Level, Simd, SimdFrom, SimdInto, arch_types::ArchTypes, seal::Seal}; +use crate::{Level, arch_types::ArchTypes, prelude::*, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -11,7 +11,7 @@ use crate::{ u32x4, u32x8, u32x16, }; use core::arch::aarch64::*; -#[doc = r#" The SIMD token for the "neon" level."#] +#[doc = "The SIMD token for the \"neon\" level."] #[derive(Clone, Copy, Debug)] pub struct Neon { pub neon: crate::core_arch::aarch64::Neon, @@ -25,6 +25,44 @@ impl Neon { } } impl Seal for Neon {} +impl ArchTypes for Neon { + type f32x4 = crate::support::Aligned128; + type i8x16 = crate::support::Aligned128; + type u8x16 = crate::support::Aligned128; + type mask8x16 = crate::support::Aligned128; + type i16x8 = crate::support::Aligned128; + type u16x8 = crate::support::Aligned128; + type mask16x8 = crate::support::Aligned128; + type i32x4 = crate::support::Aligned128; + type u32x4 = crate::support::Aligned128; + type mask32x4 = crate::support::Aligned128; + type f64x2 = crate::support::Aligned128; + type mask64x2 = crate::support::Aligned128; + type f32x8 = crate::support::Aligned256; + type i8x32 = crate::support::Aligned256; + type u8x32 = crate::support::Aligned256; + type mask8x32 = crate::support::Aligned256; + type i16x16 = crate::support::Aligned256; + type u16x16 = crate::support::Aligned256; + type mask16x16 = crate::support::Aligned256; + type i32x8 = crate::support::Aligned256; + type u32x8 = crate::support::Aligned256; + type mask32x8 = crate::support::Aligned256; + type f64x4 = crate::support::Aligned256; + type mask64x4 = crate::support::Aligned256; + type f32x16 = crate::support::Aligned512; + type i8x64 = crate::support::Aligned512; + type u8x64 = crate::support::Aligned512; + type mask8x64 = crate::support::Aligned512; + type i16x32 = crate::support::Aligned512; + type u16x32 = crate::support::Aligned512; + type mask16x32 = crate::support::Aligned512; + type i32x16 = crate::support::Aligned512; + type u32x16 = crate::support::Aligned512; + type mask32x16 = crate::support::Aligned512; + type f64x8 = crate::support::Aligned512; + type mask64x8 = crate::support::Aligned512; +} impl Simd for Neon { type f32s = f32x4; type f64s = f64x2; @@ -6712,44 +6750,6 @@ impl Simd for Neon { ) } } -impl ArchTypes for Neon { - type f32x4 = crate::support::Aligned128; - type i8x16 = crate::support::Aligned128; - type u8x16 = crate::support::Aligned128; - type mask8x16 = crate::support::Aligned128; - type i16x8 = crate::support::Aligned128; - type u16x8 = crate::support::Aligned128; - type mask16x8 = crate::support::Aligned128; - type i32x4 = crate::support::Aligned128; - type u32x4 = crate::support::Aligned128; - type mask32x4 = crate::support::Aligned128; - type f64x2 = crate::support::Aligned128; - type mask64x2 = crate::support::Aligned128; - type f32x8 = crate::support::Aligned256; - type i8x32 = crate::support::Aligned256; - type u8x32 = crate::support::Aligned256; - type mask8x32 = crate::support::Aligned256; - type i16x16 = crate::support::Aligned256; - type u16x16 = crate::support::Aligned256; - type mask16x16 = crate::support::Aligned256; - type i32x8 = crate::support::Aligned256; - type u32x8 = crate::support::Aligned256; - type mask32x8 = crate::support::Aligned256; - type f64x4 = crate::support::Aligned256; - type mask64x4 = crate::support::Aligned256; - type f32x16 = crate::support::Aligned512; - type i8x64 = crate::support::Aligned512; - type u8x64 = crate::support::Aligned512; - type mask8x64 = crate::support::Aligned512; - type i16x32 = crate::support::Aligned512; - type u16x32 = crate::support::Aligned512; - type mask16x32 = crate::support::Aligned512; - type i32x16 = crate::support::Aligned512; - type u32x16 = crate::support::Aligned512; - type mask32x16 = crate::support::Aligned512; - type f64x8 = crate::support::Aligned512; - type mask64x8 = crate::support::Aligned512; -} impl SimdFrom for f32x4 { #[inline(always)] fn simd_from(arch: float32x4_t, simd: S) -> Self { diff --git a/fearless_simd/src/generated/sse4_2.rs b/fearless_simd/src/generated/sse4_2.rs index 68167e3a..8f0ebda3 100644 --- a/fearless_simd/src/generated/sse4_2.rs +++ b/fearless_simd/src/generated/sse4_2.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Level, Simd, SimdFrom, SimdInto, arch_types::ArchTypes, seal::Seal}; +use crate::{Level, arch_types::ArchTypes, prelude::*, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -14,8 +14,7 @@ use crate::{ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use core::ops::*; -#[doc = r#" The SIMD token for the "SSE 4.2" level."#] +#[doc = "The SIMD token for the \"SSE4.2\" level."] #[derive(Clone, Copy, Debug)] pub struct Sse4_2 { pub sse4_2: crate::core_arch::x86::Sse4_2, diff --git a/fearless_simd/src/generated/wasm.rs b/fearless_simd/src/generated/wasm.rs index 63f06e58..b8550602 100644 --- a/fearless_simd/src/generated/wasm.rs +++ b/fearless_simd/src/generated/wasm.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Level, Simd, SimdFrom, SimdInto, arch_types::ArchTypes, seal::Seal}; +use crate::{Level, arch_types::ArchTypes, prelude::*, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -11,7 +11,7 @@ use crate::{ u32x4, u32x8, u32x16, }; use core::arch::wasm32::*; -#[doc = r#" The SIMD token for the "wasm128" level."#] +#[doc = "The SIMD token for the \"wasm128\" level."] #[derive(Clone, Copy, Debug)] pub struct WasmSimd128 { pub wasmsimd128: crate::core_arch::wasm32::WasmSimd128, @@ -82,11 +82,7 @@ impl Simd for WasmSimd128 { } #[inline] fn vectorize R, R>(self, f: F) -> R { - #[inline] - unsafe fn vectorize_simd128 R, R>(f: F) -> R { - f() - } - unsafe { vectorize_simd128(f) } + f() } #[inline(always)] fn splat_f32x4(self, val: f32) -> f32x4 { diff --git a/fearless_simd_gen/src/arch/fallback.rs b/fearless_simd_gen/src/arch/fallback.rs index e31788cd..9aab5f9f 100644 --- a/fearless_simd_gen/src/arch/fallback.rs +++ b/fearless_simd_gen/src/arch/fallback.rs @@ -60,22 +60,12 @@ pub(crate) fn translate_op(op: &str, is_float: bool) -> Option<&'static str> { } pub(crate) fn simple_intrinsic(name: &str, ty: &VecType) -> TokenStream { - let ty_prefix = arch_ty(ty); + let ty_prefix = ty.scalar.rust(ty.scalar_bits); let ident = Ident::new(name, Span::call_site()); quote! {#ty_prefix::#ident} } -pub(crate) fn arch_ty(ty: &VecType) -> Ident { - let scalar = match ty.scalar { - ScalarType::Float => "f", - ScalarType::Unsigned => "u", - ScalarType::Int | ScalarType::Mask => "i", - }; - let name = format!("{}{}", scalar, ty.scalar_bits); - Ident::new(&name, Span::call_site()) -} - pub(crate) fn expr(op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { let Some(translated) = translate_op(op, ty.scalar == ScalarType::Float) else { unimplemented!("missing {op}"); diff --git a/fearless_simd_gen/src/arch/neon.rs b/fearless_simd_gen/src/arch/neon.rs index 6221e012..c5626278 100644 --- a/fearless_simd_gen/src/arch/neon.rs +++ b/fearless_simd_gen/src/arch/neon.rs @@ -41,22 +41,6 @@ fn translate_op(op: &str) -> Option<&'static str> { }) } -pub(crate) fn arch_ty(ty: &VecType) -> Ident { - let scalar = match ty.scalar { - ScalarType::Float => "float", - ScalarType::Unsigned => "uint", - ScalarType::Int | ScalarType::Mask => "int", - }; - let name = if ty.n_bits() == 256 { - format!("{}{}x{}x2_t", scalar, ty.scalar_bits, ty.len / 2) - } else if ty.n_bits() == 512 { - format!("{}{}x{}x4_t", scalar, ty.scalar_bits, ty.len / 4) - } else { - format!("{}{}x{}_t", scalar, ty.scalar_bits, ty.len) - }; - Ident::new(&name, Span::call_site()) -} - // expects args and return value in arch dialect pub(crate) fn expr(op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { // There is no logical NOT for 64-bit, so we need this workaround. diff --git a/fearless_simd_gen/src/arch/wasm.rs b/fearless_simd_gen/src/arch/wasm.rs index 4c20b05c..781ee132 100644 --- a/fearless_simd_gen/src/arch/wasm.rs +++ b/fearless_simd_gen/src/arch/wasm.rs @@ -37,7 +37,7 @@ fn translate_op(op: &str) -> Option<&'static str> { } pub(crate) fn simple_intrinsic(name: &str, ty: &VecType) -> Ident { - let ty_prefix = arch_ty(ty); + let ty_prefix = arch_prefix(ty); let ident = Ident::new(name, Span::call_site()); Ident::new(&format!("{}_{}", ty_prefix, ident), Span::call_site()) } @@ -48,7 +48,7 @@ pub(crate) fn v128_intrinsic(name: &str) -> Ident { Ident::new(&format!("{}_{}", ty_prefix, ident), Span::call_site()) } -pub(crate) fn arch_ty(ty: &VecType) -> Ident { +pub(crate) fn arch_prefix(ty: &VecType) -> Ident { let scalar = match ty.scalar { ScalarType::Float => "f", ScalarType::Unsigned => "u", diff --git a/fearless_simd_gen/src/generic.rs b/fearless_simd_gen/src/generic.rs index d110ff81..1900e171 100644 --- a/fearless_simd_gen/src/generic.rs +++ b/fearless_simd_gen/src/generic.rs @@ -2,11 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{ToTokens, quote}; use crate::{ ops::{Op, OpSig, RefKind}, - types::{SIMD_TYPES, ScalarType, VecType}, + types::{ScalarType, VecType}, }; pub(crate) fn generic_op_name(op: &str, ty: &VecType) -> Ident { @@ -22,10 +22,6 @@ pub(crate) fn generic_op(op: &Op, ty: &VecType) -> TokenStream { let combine = generic_op_name("combine", &half); let do_half = generic_op_name(op.method, &half); let method_sig = op.simd_trait_method_sig(ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; match op.sig { OpSig::Splat => { quote! { @@ -306,18 +302,19 @@ pub(crate) fn generic_from_array( } } -pub(crate) fn generic_as_array( +pub(crate) fn generic_as_array( method_sig: TokenStream, vec_ty: &VecType, kind: RefKind, max_block_size: usize, - arch_ty: impl Fn(&VecType) -> Ident, + arch_ty: impl Fn(&VecType) -> T, ) -> TokenStream { let rust_scalar = vec_ty.scalar.rust(vec_ty.scalar_bits); let num_scalars = vec_ty.len; let ref_tok = kind.token(); - let native_ty = vec_ty.wrapped_native_ty(arch_ty, max_block_size); + let native_ty = + vec_ty.wrapped_native_ty(|vec_ty| arch_ty(vec_ty).into_token_stream(), max_block_size); quote! { #method_sig { @@ -358,25 +355,3 @@ pub(crate) fn generic_from_bytes(method_sig: TokenStream, vec_ty: &VecType) -> T } } } - -pub(crate) fn impl_arch_types( - level_name: &str, - max_block_size: usize, - arch_ty: impl Fn(&VecType) -> Ident, -) -> TokenStream { - let mut assoc_types = vec![]; - for vec_ty in SIMD_TYPES { - let ty_ident = vec_ty.rust(); - let wrapper_ty = vec_ty.aligned_wrapper_ty(&arch_ty, max_block_size); - assoc_types.push(quote! { - type #ty_ident = #wrapper_ty; - }); - } - let level_tok = Ident::new(level_name, Span::call_site()); - - quote! { - impl ArchTypes for #level_tok { - #( #assoc_types )* - } - } -} diff --git a/fearless_simd_gen/src/level.rs b/fearless_simd_gen/src/level.rs new file mode 100644 index 00000000..e867e448 --- /dev/null +++ b/fearless_simd_gen/src/level.rs @@ -0,0 +1,239 @@ +// Copyright 2025 the Fearless_SIMD Authors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{format_ident, quote}; + +use crate::{ + generic::generic_op, + ops::{Op, ops_for_type}, + types::{SIMD_TYPES, ScalarType, VecType, type_imports}, +}; + +/// Trait implemented by each SIMD level code generator. The methods on top must be provided by each code generator; the +/// others are provided as default trait methods that call into the non-default ones. +pub(crate) trait Level { + /// The name of this SIMD level token (e.g. `Neon` or `Sse4_2`). + fn name(&self) -> &'static str; + /// The highest vector width, in bits, that SIMD instructions can directly operate on. Operations above this width + /// will be implemented via split/combine. + fn native_width(&self) -> usize; + /// The highest bit width available to *store* vector types in. This is usually the same as [`Level::native_width`], + /// but may differ if the implementation provides wider vector types than most instructions actually operate on. For + /// instance, NEON provides tuples of vectors like `int32x4x4_t` up to 512 bits, and the fallback implementation + /// stores everything as arrays but only operates on 128-bit chunks. + fn max_block_size(&self) -> usize; + /// The names of the target features to enable within vectorized code. This goes in the + /// `#[target_feature(enable = "...")]` attribute. + /// + /// If this SIMD level is not runtime-toggleable (for instance, the fallback implementation or WASM SIMD128), + /// returns `None`. + fn enabled_target_features(&self) -> Option<&'static str>; + /// A function that takes a given vector type and returns the corresponding native vector type. For instance, + /// `f32x8` would map to `__m256` on `Avx2`, and to `[f32; 8]` on `Fallback`. This will never be passed a vector + /// type *larger* than [`Level::max_block_size`], since [`VecType::aligned_wrapper_ty`] will split those up into + /// smaller blocks. + fn arch_ty(&self, vec_ty: &VecType) -> TokenStream; + /// The docstring for this SIMD level token. + fn token_doc(&self) -> &'static str; + /// The full path to the `core_arch` token wrapped by this SIMD level token. + fn token_inner(&self) -> TokenStream; + + /// Any additional imports or supporting code necessary for the module (for instance, importing + /// implementation-specific functions from `core::arch`). + fn make_module_prelude(&self) -> TokenStream; + /// The body of the SIMD token's inherent `impl` block. By convention, this contains an unsafe `new_unchecked` + /// method for constructing a SIMD token that may not be supported on current hardware, or a safe `new` method for + /// constructing a SIMD token that is statically known to be supported. + fn make_impl_body(&self) -> TokenStream; + /// Generate a single operation's method on the `Simd` implementation. + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream; + + fn token(&self) -> Ident { + Ident::new(self.name(), Span::call_site()) + } + + fn impl_arch_types(&self) -> TokenStream { + let mut assoc_types = vec![]; + for vec_ty in SIMD_TYPES { + let ty_ident = vec_ty.rust(); + let wrapper_ty = + vec_ty.aligned_wrapper_ty(|vec_ty| self.arch_ty(vec_ty), self.max_block_size()); + assoc_types.push(quote! { + type #ty_ident = #wrapper_ty; + }); + } + let level_tok = self.token(); + + quote! { + impl ArchTypes for #level_tok { + #( #assoc_types )* + } + } + } + + /// The body of the `Simd::level` function. This can be overridden, e.g. to return `Level::baseline()` if we know a + /// higher SIMD level is statically enabled. + fn make_level_body(&self) -> TokenStream { + let level_tok = self.token(); + + quote! { + Level::#level_tok(self) + } + } + + fn make_simd_impl(&self) -> TokenStream { + let level_tok = self.token(); + let native_width = self.native_width(); + let mut methods = vec![]; + for vec_ty in SIMD_TYPES { + for op in ops_for_type(vec_ty) { + if op.sig.should_use_generic_op(vec_ty, native_width) { + methods.push(generic_op(&op, vec_ty)); + continue; + } + + let method = self.make_method(op, vec_ty); + methods.push(method); + } + } + + let vectorize_body = if let Some(target_features) = self.enabled_target_features() { + let vectorize = format_ident!("vectorize_{}", self.name().to_ascii_lowercase()); + quote! { + #[target_feature(enable = #target_features)] + #[inline] + unsafe fn #vectorize R, R>(f: F) -> R { + f() + } + unsafe { #vectorize(f) } + } + } else { + // If this SIMD level doesn't do runtime feature detection/enabling, just call the inner function as-is + quote! { + f() + } + }; + + let level_body = self.make_level_body(); + + let mut assoc_types = vec![]; + for (scalar, scalar_bits) in [ + (ScalarType::Float, 32), + (ScalarType::Float, 64), + (ScalarType::Unsigned, 8), + (ScalarType::Int, 8), + (ScalarType::Unsigned, 16), + (ScalarType::Int, 16), + (ScalarType::Unsigned, 32), + (ScalarType::Int, 32), + (ScalarType::Mask, 8), + (ScalarType::Mask, 16), + (ScalarType::Mask, 32), + (ScalarType::Mask, 64), + ] { + let native_width_ty = VecType::new(scalar, scalar_bits, native_width / scalar_bits); + let name = native_width_ty.rust(); + let native_width_name = scalar.native_width_name(scalar_bits); + assoc_types.push(quote! { + type #native_width_name = #name; + }); + } + + quote! { + impl Simd for #level_tok { + #( #assoc_types )* + + #[inline(always)] + fn level(self) -> Level { + #level_body + } + + #[inline] + fn vectorize R, R>(self, f: F) -> R { + #vectorize_body + } + + #( + #[inline(always)] + #methods + )* + } + } + } + + fn make_type_impl(&self) -> TokenStream { + let native_width = self.native_width(); + let max_block_size = self.max_block_size(); + let mut result = vec![]; + for ty in SIMD_TYPES { + let n_bits = ty.n_bits(); + // If n_bits is below our native width (e.g. 128 bits for AVX2), another module will have already + // implemented the conversion. + if n_bits > max_block_size || n_bits < native_width { + continue; + } + let simd = ty.rust(); + let arch = self.arch_ty(ty); + result.push(quote! { + impl SimdFrom<#arch, S> for #simd { + #[inline(always)] + fn simd_from(arch: #arch, simd: S) -> Self { + Self { + val: unsafe { core::mem::transmute_copy(&arch) }, + simd + } + } + } + impl From<#simd> for #arch { + #[inline(always)] + fn from(value: #simd) -> Self { + unsafe { core::mem::transmute_copy(&value.val) } + } + } + }); + } + quote! { + #( #result )* + } + } + + fn make_module(&self) -> TokenStream { + let level_tok = self.token(); + let token_doc = self.token_doc(); + let field_name = Ident::new(&self.name().to_ascii_lowercase(), Span::call_site()); + let token_inner = self.token_inner(); + let imports = type_imports(); + let module_prelude = self.make_module_prelude(); + let impl_body = self.make_impl_body(); + let arch_types_impl = self.impl_arch_types(); + let simd_impl = self.make_simd_impl(); + let ty_impl = self.make_type_impl(); + + quote! { + use crate::{prelude::*, seal::Seal, arch_types::ArchTypes, Level}; + + #imports + + #module_prelude + + #[doc = #token_doc] + #[derive(Clone, Copy, Debug)] + pub struct #level_tok { + pub #field_name: #token_inner, + } + + impl #level_tok { + #impl_body + } + + impl Seal for #level_tok {} + + #arch_types_impl + + #simd_impl + + #ty_impl + } + } +} diff --git a/fearless_simd_gen/src/main.rs b/fearless_simd_gen/src/main.rs index 96600f90..65e9ac0c 100644 --- a/fearless_simd_gen/src/main.rs +++ b/fearless_simd_gen/src/main.rs @@ -11,8 +11,11 @@ use std::{fs::File, io::Write, path::Path}; use clap::{Parser, ValueEnum}; use proc_macro2::TokenStream; +use crate::level::Level as _; + mod arch; mod generic; +mod level; mod mk_avx2; mod mk_fallback; mod mk_neon; @@ -59,11 +62,11 @@ impl Module { Self::SimdTypes => mk_simd_types::mk_simd_types(), Self::SimdTrait => mk_simd_trait::mk_simd_trait(), Self::Ops => mk_ops::mk_ops(), - Self::Neon => mk_neon::mk_neon_impl(mk_neon::Level::Neon), - Self::Wasm => mk_wasm::mk_wasm128_impl(mk_wasm::Level::WasmSimd128), - Self::Fallback => mk_fallback::mk_fallback_impl(), - Self::Sse4_2 => mk_sse4_2::mk_sse4_2_impl(), - Self::Avx2 => mk_avx2::mk_avx2_impl(), + Self::Neon => mk_neon::Neon.make_module(), + Self::Wasm => mk_wasm::WasmSimd128.make_module(), + Self::Fallback => mk_fallback::Fallback.make_module(), + Self::Sse4_2 => mk_sse4_2::Sse4_2.make_module(), + Self::Avx2 => mk_avx2::Avx2.make_module(), } } diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 42b92c10..3b73e741 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -2,239 +2,153 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::x86::{ - self, arch_ty, coarse_type, extend_intrinsic, intrinsic_ident, op_suffix, pack_intrinsic, + arch_ty, coarse_type, extend_intrinsic, intrinsic_ident, op_suffix, pack_intrinsic, set1_intrinsic, simple_intrinsic, }; use crate::generic::{ generic_as_array, generic_block_combine, generic_block_split, generic_from_array, - generic_from_bytes, generic_op, generic_op_name, generic_to_bytes, impl_arch_types, + generic_from_bytes, generic_op_name, generic_to_bytes, }; +use crate::level::Level; use crate::mk_sse4_2; -use crate::ops::{Op, OpSig, ops_for_type}; -use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; -use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use crate::ops::{Op, OpSig}; +use crate::types::{ScalarType, VecType}; +use proc_macro2::TokenStream; +use quote::{ToTokens as _, quote}; #[derive(Clone, Copy)] -pub(crate) struct Level; +pub(crate) struct Avx2; -impl Level { - fn name(self) -> &'static str { +impl Level for Avx2 { + fn name(&self) -> &'static str { "Avx2" } - fn token(self) -> TokenStream { - let ident = Ident::new(self.name(), Span::call_site()); - quote! { #ident } + fn native_width(&self) -> usize { + 256 } -} -pub(crate) fn mk_avx2_impl() -> TokenStream { - let imports = type_imports(); - let arch_types_impl = impl_arch_types(Level.name(), 256, arch_ty); - let simd_impl = mk_simd_impl(); - let ty_impl = mk_type_impl(); + fn max_block_size(&self) -> usize { + 256 + } - quote! { - #[cfg(target_arch = "x86")] - use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] - use core::arch::x86_64::*; + fn enabled_target_features(&self) -> Option<&'static str> { + Some("avx2,fma") + } - use core::ops::*; - use crate::{seal::Seal, arch_types::ArchTypes, Level, Simd, SimdFrom, SimdInto}; + fn arch_ty(&self, vec_ty: &VecType) -> TokenStream { + arch_ty(vec_ty).into_token_stream() + } - #imports + fn token_doc(&self) -> &'static str { + r#"The SIMD token for the "AVX2" and "FMA" level."# + } - /// The SIMD token for the "AVX2" and "FMA" level. - #[derive(Clone, Copy, Debug)] - pub struct Avx2 { - pub avx2: crate::core_arch::x86::Avx2, + fn token_inner(&self) -> TokenStream { + quote!(crate::core_arch::x86::Avx2) + } + + fn make_module_prelude(&self) -> TokenStream { + quote! { + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; } + } - impl Avx2 { + fn make_impl_body(&self) -> TokenStream { + quote! { /// Create a SIMD token. /// /// # Safety /// - /// The AVX2 and FMA CPU feature must be available. + /// The AVX2 and FMA CPU features must be available. #[inline] pub const unsafe fn new_unchecked() -> Self { - Avx2 { + Self { avx2: unsafe { crate::core_arch::x86::Avx2::new_unchecked() }, } } } - - impl Seal for Avx2 {} - - #arch_types_impl - - #simd_impl - - #ty_impl - } -} - -fn mk_simd_impl() -> TokenStream { - let level_tok = Level.token(); - let mut methods = vec![]; - for vec_ty in SIMD_TYPES { - for op in ops_for_type(vec_ty) { - if op.sig.should_use_generic_op(vec_ty, 256) { - methods.push(generic_op(&op, vec_ty)); - continue; - } - - let method = make_method(op, vec_ty); - - methods.push(method); - } } - // Note: the `vectorize` implementation is pretty boilerplate and should probably - // be factored out for DRY. - quote! { - impl Simd for #level_tok { - type f32s = f32x8; - type f64s = f64x4; - type u8s = u8x32; - type i8s = i8x32; - type u16s = u16x16; - type i16s = i16x16; - type u32s = u32x8; - type i32s = i32x8; - type mask8s = mask8x32; - type mask16s = mask16x16; - type mask32s = mask32x8; - type mask64s = mask64x4; - #[inline(always)] - fn level(self) -> Level { - Level::#level_tok(self) - } + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream { + let scalar_bits = vec_ty.scalar_bits; + let Op { sig, method, .. } = op; + let method_sig = op.simd_trait_method_sig(vec_ty); - #[inline] - fn vectorize R, R>(self, f: F) -> R { - #[target_feature(enable = "avx2,fma")] - #[inline] - unsafe fn vectorize_avx2 R, R>(f: F) -> R { - f() - } - unsafe { vectorize_avx2(f) } + match sig { + OpSig::Splat => mk_sse4_2::handle_splat(method_sig, vec_ty), + OpSig::Compare => mk_sse4_2::handle_compare(method_sig, method, vec_ty), + OpSig::Unary => mk_sse4_2::handle_unary(method_sig, method, vec_ty), + OpSig::WidenNarrow { target_ty } => { + handle_widen_narrow(method_sig, method, vec_ty, target_ty) } - - #( #methods )* - } - } -} - -fn mk_type_impl() -> TokenStream { - let mut result = vec![]; - for ty in SIMD_TYPES { - let n_bits = ty.n_bits(); - if n_bits != 256 { - continue; - } - let simd = ty.rust(); - let arch = x86::arch_ty(ty); - result.push(quote! { - impl SimdFrom<#arch, S> for #simd { - #[inline(always)] - fn simd_from(arch: #arch, simd: S) -> Self { - Self { - val: unsafe { core::mem::transmute_copy(&arch) }, - simd + OpSig::Binary => match method { + "shlv" if scalar_bits >= 32 => handle_shift_vectored(method_sig, method, vec_ty), + "shrv" if scalar_bits >= 32 => handle_shift_vectored(method_sig, method, vec_ty), + _ => mk_sse4_2::handle_binary(method_sig, method, vec_ty), + }, + OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty), + OpSig::Ternary => match method { + "mul_add" => { + let intrinsic = simple_intrinsic("fmadd", vec_ty); + quote! { + #method_sig { + unsafe { #intrinsic(a.into(), b.into(), c.into()).simd_into(self) } + } } } - } - impl From<#simd> for #arch { - #[inline(always)] - fn from(value: #simd) -> Self { - unsafe { core::mem::transmute_copy(&value.val) } - } - } - }); - } - quote! { - #( #result )* - } -} - -fn make_method(op: Op, vec_ty: &VecType) -> TokenStream { - let scalar_bits = vec_ty.scalar_bits; - let Op { sig, method, .. } = op; - let method_sig = op.simd_trait_method_sig(vec_ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; - - match sig { - OpSig::Splat => mk_sse4_2::handle_splat(method_sig, vec_ty), - OpSig::Compare => mk_sse4_2::handle_compare(method_sig, method, vec_ty), - OpSig::Unary => mk_sse4_2::handle_unary(method_sig, method, vec_ty), - OpSig::WidenNarrow { target_ty } => { - handle_widen_narrow(method_sig, method, vec_ty, target_ty) - } - OpSig::Binary => match method { - "shlv" if scalar_bits >= 32 => handle_shift_vectored(method_sig, method, vec_ty), - "shrv" if scalar_bits >= 32 => handle_shift_vectored(method_sig, method, vec_ty), - _ => mk_sse4_2::handle_binary(method_sig, method, vec_ty), - }, - OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty), - OpSig::Ternary => match method { - "mul_add" => { - let intrinsic = simple_intrinsic("fmadd", vec_ty); - quote! { - #method_sig { - unsafe { #intrinsic(a.into(), b.into(), c.into()).simd_into(self) } + "mul_sub" => { + let intrinsic = simple_intrinsic("fmsub", vec_ty); + quote! { + #method_sig { + unsafe { #intrinsic(a.into(), b.into(), c.into()).simd_into(self) } + } } } + _ => mk_sse4_2::handle_ternary(method_sig, method, vec_ty), + }, + OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty), + OpSig::Combine { combined_ty } => handle_combine(method_sig, vec_ty, &combined_ty), + OpSig::Split { half_ty } => handle_split(method_sig, vec_ty, &half_ty), + OpSig::Zip { select_low } => mk_sse4_2::handle_zip(method_sig, vec_ty, select_low), + OpSig::Unzip { select_even } => { + mk_sse4_2::handle_unzip(method_sig, vec_ty, select_even) } - "mul_sub" => { - let intrinsic = simple_intrinsic("fmsub", vec_ty); - quote! { - #method_sig { - unsafe { #intrinsic(a.into(), b.into(), c.into()).simd_into(self) } - } - } + OpSig::Cvt { + target_ty, + scalar_bits, + precise, + } => mk_sse4_2::handle_cvt(method_sig, vec_ty, target_ty, scalar_bits, precise), + OpSig::Reinterpret { + target_ty, + scalar_bits, + } => mk_sse4_2::handle_reinterpret(self, method_sig, vec_ty, target_ty, scalar_bits), + OpSig::MaskReduce { + quantifier, + condition, + } => mk_sse4_2::handle_mask_reduce(method_sig, vec_ty, quantifier, condition), + OpSig::LoadInterleaved { + block_size, + block_count, + } => mk_sse4_2::handle_load_interleaved(method_sig, vec_ty, block_size, block_count), + OpSig::StoreInterleaved { + block_size, + block_count, + } => mk_sse4_2::handle_store_interleaved(method_sig, vec_ty, block_size, block_count), + OpSig::FromArray { kind } => { + generic_from_array(method_sig, vec_ty, kind, 256, |block_ty| { + intrinsic_ident("loadu", coarse_type(block_ty), block_ty.n_bits()) + }) + } + OpSig::AsArray { kind } => { + generic_as_array(method_sig, vec_ty, kind, 256, |vec_ty| self.arch_ty(vec_ty)) } - _ => mk_sse4_2::handle_ternary(method_sig, method, vec_ty), - }, - OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty), - OpSig::Combine { combined_ty } => handle_combine(method_sig, vec_ty, &combined_ty), - OpSig::Split { half_ty } => handle_split(method_sig, vec_ty, &half_ty), - OpSig::Zip { select_low } => mk_sse4_2::handle_zip(method_sig, vec_ty, select_low), - OpSig::Unzip { select_even } => mk_sse4_2::handle_unzip(method_sig, vec_ty, select_even), - OpSig::Cvt { - target_ty, - scalar_bits, - precise, - } => mk_sse4_2::handle_cvt(method_sig, vec_ty, target_ty, scalar_bits, precise), - OpSig::Reinterpret { - target_ty, - scalar_bits, - } => mk_sse4_2::handle_reinterpret(method_sig, vec_ty, target_ty, scalar_bits), - OpSig::MaskReduce { - quantifier, - condition, - } => mk_sse4_2::handle_mask_reduce(method_sig, vec_ty, quantifier, condition), - OpSig::LoadInterleaved { - block_size, - block_count, - } => mk_sse4_2::handle_load_interleaved(method_sig, vec_ty, block_size, block_count), - OpSig::StoreInterleaved { - block_size, - block_count, - } => mk_sse4_2::handle_store_interleaved(method_sig, vec_ty, block_size, block_count), - OpSig::FromArray { kind } => { - generic_from_array(method_sig, vec_ty, kind, 256, |block_ty| { - intrinsic_ident("loadu", coarse_type(block_ty), block_ty.n_bits()) - }) + OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), + OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), } - OpSig::AsArray { kind } => generic_as_array(method_sig, vec_ty, kind, 256, arch_ty), - OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), - OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), } } diff --git a/fearless_simd_gen/src/mk_fallback.rs b/fearless_simd_gen/src/mk_fallback.rs index 2de2db72..cd253243 100644 --- a/fearless_simd_gen/src/mk_fallback.rs +++ b/fearless_simd_gen/src/mk_fallback.rs @@ -2,565 +2,499 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::fallback; -use crate::generic::{generic_from_bytes, generic_op, generic_op_name, generic_to_bytes}; -use crate::ops::{Op, OpSig, RefKind, ops_for_type, valid_reinterpret}; -use crate::types::{SIMD_TYPES, ScalarType, type_imports}; -use proc_macro2::{Ident, Span, TokenStream}; +use crate::generic::{generic_from_bytes, generic_op_name, generic_to_bytes}; +use crate::level::Level; +use crate::ops::{Op, OpSig, RefKind, valid_reinterpret}; +use crate::types::{ScalarType, VecType}; +use proc_macro2::TokenStream; use quote::quote; #[derive(Clone, Copy)] -pub(crate) struct Level; +pub(crate) struct Fallback; -impl Level { - fn name(self) -> &'static str { +impl Level for Fallback { + fn name(&self) -> &'static str { "Fallback" } - fn token(self) -> TokenStream { - let ident = Ident::new(self.name(), Span::call_site()); - quote! { #ident } + fn native_width(&self) -> usize { + 128 } -} -pub(crate) fn mk_fallback_impl() -> TokenStream { - let imports = type_imports(); - let arch_types_impl = mk_arch_types(); - let simd_impl = mk_simd_impl(); - - quote! { - use core::ops::*; - use crate::{seal::Seal, arch_types::ArchTypes, Bytes, Level, Simd, SimdInto}; - - #imports - - #[cfg(all(feature = "libm", not(feature = "std")))] - trait FloatExt { - fn floor(self) -> Self; - fn ceil(self) -> Self; - fn round_ties_even(self) -> Self; - fn fract(self) -> Self; - fn sqrt(self) -> Self; - fn trunc(self) -> Self; - } - #[cfg(all(feature = "libm", not(feature = "std")))] - impl FloatExt for f32 { - #[inline(always)] - fn floor(self) -> f32 { - libm::floorf(self) - } - #[inline(always)] - fn ceil(self) -> f32 { - libm::ceilf(self) - } - #[inline(always)] - fn round_ties_even(self) -> f32 { - libm::rintf(self) - } - #[inline(always)] - fn sqrt(self) -> f32 { - libm::sqrtf(self) - } - #[inline(always)] - fn fract(self) -> f32 { - self - self.trunc() - } - #[inline(always)] - fn trunc(self) -> f32 { - libm::truncf(self) - } - } + fn max_block_size(&self) -> usize { + 512 + } - #[cfg(all(feature = "libm", not(feature = "std")))] - impl FloatExt for f64 { - #[inline(always)] - fn floor(self) -> f64 { - libm::floor(self) - } - #[inline(always)] - fn ceil(self) -> f64 { - libm::ceil(self) - } - #[inline(always)] - fn round_ties_even(self) -> f64 { - libm::rint(self) - } - #[inline(always)] - fn sqrt(self) -> f64 { - libm::sqrt(self) + fn enabled_target_features(&self) -> Option<&'static str> { + None + } + + fn arch_ty(&self, vec_ty: &VecType) -> TokenStream { + let scalar_rust = vec_ty.scalar.rust(vec_ty.scalar_bits); + let len = vec_ty.len; + quote!([#scalar_rust; #len]) + } + + fn token_doc(&self) -> &'static str { + r#"The SIMD token for the "fallback" level."# + } + + fn token_inner(&self) -> TokenStream { + quote!(crate::core_arch::fallback::Fallback) + } + + fn make_module_prelude(&self) -> TokenStream { + quote! { + use core::ops::*; + + #[cfg(all(feature = "libm", not(feature = "std")))] + trait FloatExt { + fn floor(self) -> Self; + fn ceil(self) -> Self; + fn round_ties_even(self) -> Self; + fn fract(self) -> Self; + fn sqrt(self) -> Self; + fn trunc(self) -> Self; } - #[inline(always)] - fn fract(self) -> f64 { - self - self.trunc() + #[cfg(all(feature = "libm", not(feature = "std")))] + impl FloatExt for f32 { + #[inline(always)] + fn floor(self) -> f32 { + libm::floorf(self) + } + #[inline(always)] + fn ceil(self) -> f32 { + libm::ceilf(self) + } + #[inline(always)] + fn round_ties_even(self) -> f32 { + libm::rintf(self) + } + #[inline(always)] + fn sqrt(self) -> f32 { + libm::sqrtf(self) + } + #[inline(always)] + fn fract(self) -> f32 { + self - self.trunc() + } + #[inline(always)] + fn trunc(self) -> f32 { + libm::truncf(self) + } } - #[inline(always)] - fn trunc(self) -> f64 { - libm::trunc(self) + + #[cfg(all(feature = "libm", not(feature = "std")))] + impl FloatExt for f64 { + #[inline(always)] + fn floor(self) -> f64 { + libm::floor(self) + } + #[inline(always)] + fn ceil(self) -> f64 { + libm::ceil(self) + } + #[inline(always)] + fn round_ties_even(self) -> f64 { + libm::rint(self) + } + #[inline(always)] + fn sqrt(self) -> f64 { + libm::sqrt(self) + } + #[inline(always)] + fn fract(self) -> f64 { + self - self.trunc() + } + #[inline(always)] + fn trunc(self) -> f64 { + libm::trunc(self) + } } } + } - /// The SIMD token for the "fallback" level. - #[derive(Clone, Copy, Debug)] - pub struct Fallback { - pub fallback: crate::core_arch::fallback::Fallback, + fn make_level_body(&self) -> TokenStream { + let level_tok = Self.token(); + quote! { + #[cfg(feature = "force_support_fallback")] + return Level::#level_tok(self); + #[cfg(not(feature = "force_support_fallback"))] + Level::baseline() } + } - impl Fallback { + fn make_impl_body(&self) -> TokenStream { + quote! { #[inline] pub const fn new() -> Self { - Fallback { + Self { fallback: crate::core_arch::fallback::Fallback::new(), } } } - - impl Seal for Fallback {} - - #arch_types_impl - - #simd_impl } -} -fn mk_arch_types() -> TokenStream { - // We can't use the generic version, because the fallback implementation is the only one that doesn't provide native - // vector types and instead uses plain arrays - let mut arch_types = vec![]; - for vec_ty in SIMD_TYPES { - let ty_ident = vec_ty.rust(); - let scalar_rust = vec_ty.scalar.rust(vec_ty.scalar_bits); - let len = vec_ty.len; - let wrapper_name = vec_ty.aligned_wrapper(); - arch_types.push(quote! { - type #ty_ident = #wrapper_name<[#scalar_rust; #len]>; - }); - } + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream { + let Op { sig, method, .. } = op; + let method_sig = op.simd_trait_method_sig(vec_ty); - quote! { - impl ArchTypes for Fallback { - #( #arch_types )* - } - } -} - -fn mk_simd_impl() -> TokenStream { - let level_tok = Level.token(); - let mut methods = vec![]; - for vec_ty in SIMD_TYPES { - let scalar_bits = vec_ty.scalar_bits; - for op in ops_for_type(vec_ty) { - let Op { sig, method, .. } = op; - if sig.should_use_generic_op(vec_ty, 128) { - methods.push(generic_op(&op, vec_ty)); - continue; + match sig { + OpSig::Splat => { + let num_elements = vec_ty.len; + quote! { + #method_sig { + [val; #num_elements].simd_into(self) + } + } } - let method_sig = op.simd_trait_method_sig(vec_ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; - - let method = match sig { - OpSig::Splat => { - let num_elements = vec_ty.len; - quote! { - #method_sig { - [val; #num_elements].simd_into(self) - } + OpSig::Unary => { + let items = make_list( + (0..vec_ty.len) + .map(|idx| { + let args = [quote! { a[#idx] }]; + let expr = fallback::expr(method, vec_ty, &args); + quote! { #expr } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) } } - OpSig::Unary => { - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - let args = [quote! { a[#idx] }]; - let expr = fallback::expr(method, vec_ty, &args); - quote! { #expr } - }) - .collect::>(), - ); - - quote! { - #method_sig { - #items.simd_into(self) - } + } + OpSig::WidenNarrow { target_ty } => { + let items = make_list( + (0..vec_ty.len) + .map(|idx| { + let scalar_ty = target_ty.scalar.rust(target_ty.scalar_bits); + quote! { a[#idx] as #scalar_ty } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) } } - OpSig::WidenNarrow { target_ty } => { - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - let scalar_ty = target_ty.scalar.rust(target_ty.scalar_bits); - quote! { a[#idx] as #scalar_ty } - }) - .collect::>(), - ); - + } + OpSig::Binary => { + let items = make_list( + (0..vec_ty.len) + .map(|idx| { + let b = if fallback::translate_op( + method, + vec_ty.scalar == ScalarType::Float, + ) + .map(rhs_reference) + .unwrap_or(true) + { + quote! { &b[#idx] } + } else { + quote! { b[#idx] } + }; + + let args = [quote! { a[#idx] }, quote! { #b }]; + let expr = fallback::expr(method, vec_ty, &args); + quote! { #expr } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) + } + } + } + OpSig::Shift => { + let rust_scalar = vec_ty.scalar.rust(vec_ty.scalar_bits); + let items = make_list( + (0..vec_ty.len) + .map(|idx| { + let args = [quote! { a[#idx] }, quote! { shift as #rust_scalar }]; + let expr = fallback::expr(method, vec_ty, &args); + quote! { #expr } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) + } + } + } + OpSig::Ternary => { + if method == "mul_add" { quote! { #method_sig { - #items.simd_into(self) + a.mul(b).add(c) } } - } - OpSig::Binary => { - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - let b = if fallback::translate_op( - method, - vec_ty.scalar == ScalarType::Float, - ) - .map(rhs_reference) - .unwrap_or(true) - { - quote! { &b[#idx] } - } else { - quote! { b[#idx] } - }; - - let args = [quote! { a[#idx] }, quote! { #b }]; - let expr = fallback::expr(method, vec_ty, &args); - quote! { #expr } - }) - .collect::>(), - ); - + } else if method == "mul_sub" { quote! { #method_sig { - #items.simd_into(self) + a.mul(b).sub(c) } } - } - OpSig::Shift => { - let arch_ty = fallback::arch_ty(vec_ty); - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - let args = [quote! { a[#idx] }, quote! { shift as #arch_ty }]; - let expr = fallback::expr(method, vec_ty, &args); - quote! { #expr } - }) - .collect::>(), - ); - + } else { + let args = [ + quote! { a.into() }, + quote! { b.into() }, + quote! { c.into() }, + ]; + + let expr = fallback::expr(method, vec_ty, &args); quote! { #method_sig { - #items.simd_into(self) + #expr.simd_into(self) } } } - OpSig::Ternary => { - if method == "mul_add" { - quote! { - #method_sig { - a.mul(b).add(c) - } - } - } else if method == "mul_sub" { - quote! { - #method_sig { - a.mul(b).sub(c) - } - } - } else { - let args = [ - quote! { a.into() }, - quote! { b.into() }, - quote! { c.into() }, - ]; - - let expr = fallback::expr(method, vec_ty, &args); - quote! { - #method_sig { - #expr.simd_into(self) - } - } + } + OpSig::Compare => { + let mask_type = vec_ty.cast(ScalarType::Mask); + let items = make_list( + (0..vec_ty.len) + .map(|idx: usize| { + let args = [quote! { &a[#idx] }, quote! { &b[#idx] }]; + let expr = fallback::expr(method, vec_ty, &args); + let mask_ty = mask_type.scalar.rust(vec_ty.scalar_bits); + quote! { -(#expr as #mask_ty) } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) } } - OpSig::Compare => { - let mask_type = vec_ty.cast(ScalarType::Mask); - let items = make_list( - (0..vec_ty.len) - .map(|idx: usize| { - let args = [quote! { &a[#idx] }, quote! { &b[#idx] }]; - let expr = fallback::expr(method, vec_ty, &args); - let mask_ty = mask_type.scalar.rust(scalar_bits); - quote! { -(#expr as #mask_ty) } - }) - .collect::>(), - ); - - quote! { - #method_sig { - #items.simd_into(self) - } + } + OpSig::Select => { + let items = make_list( + (0..vec_ty.len) + .map(|idx| { + quote! { if a[#idx] != 0 { b[#idx] } else { c[#idx] } } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #items.simd_into(self) } } - OpSig::Select => { - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - quote! { if a[#idx] != 0 { b[#idx] } else { c[#idx] } } - }) - .collect::>(), - ); - - quote! { - #method_sig { - #items.simd_into(self) - } + } + OpSig::Combine { combined_ty } => { + let n = vec_ty.len; + let n2 = combined_ty.len; + let default = match vec_ty.scalar { + ScalarType::Float => quote! { 0.0 }, + _ => quote! { 0 }, + }; + quote! { + #method_sig { + let mut result = [#default; #n2]; + result[0..#n].copy_from_slice(&a.val.0); + result[#n..#n2].copy_from_slice(&b.val.0); + result.simd_into(self) } } - OpSig::Combine { combined_ty } => { - let n = vec_ty.len; - let n2 = combined_ty.len; - let ty_rust = vec_ty.rust(); - let result = combined_ty.rust(); - let name = Ident::new( - &format!("combine_{}", vec_ty.rust_name()), - Span::call_site(), - ); - let default = match vec_ty.scalar { - ScalarType::Float => quote! { 0.0 }, - _ => quote! { 0 }, - }; - quote! { - #[inline(always)] - fn #name(self, a: #ty_rust, b: #ty_rust) -> #result { - let mut result = [#default; #n2]; - result[0..#n].copy_from_slice(&a.val.0); - result[#n..#n2].copy_from_slice(&b.val.0); - result.simd_into(self) - } + } + OpSig::Split { half_ty } => { + let n = vec_ty.len; + let nhalf = half_ty.len; + let default = match vec_ty.scalar { + ScalarType::Float => quote! { 0.0 }, + _ => quote! { 0 }, + }; + quote! { + #method_sig { + let mut b0 = [#default; #nhalf]; + let mut b1 = [#default; #nhalf]; + b0.copy_from_slice(&a.val.0[0..#nhalf]); + b1.copy_from_slice(&a.val.0[#nhalf..#n]); + (b0.simd_into(self), b1.simd_into(self)) } } - OpSig::Split { half_ty } => { - let n = vec_ty.len; - let nhalf = half_ty.len; - let ty_rust = vec_ty.rust(); - let result = half_ty.rust(); - let name = - Ident::new(&format!("split_{}", vec_ty.rust_name()), Span::call_site()); - let default = match vec_ty.scalar { - ScalarType::Float => quote! { 0.0 }, - _ => quote! { 0 }, - }; - quote! { - #[inline(always)] - fn #name(self, a: #ty_rust) -> (#result, #result) { - let mut b0 = [#default; #nhalf]; - let mut b1 = [#default; #nhalf]; - b0.copy_from_slice(&a.val.0[0..#nhalf]); - b1.copy_from_slice(&a.val.0[#nhalf..#n]); - (b0.simd_into(self), b1.simd_into(self)) - } + } + OpSig::Zip { select_low } => { + let indices = if select_low { + 0..vec_ty.len / 2 + } else { + (vec_ty.len / 2)..vec_ty.len + }; + + let zip = make_list( + indices + .map(|idx| { + quote! {a[#idx], b[#idx] } + }) + .collect::>(), + ); + + quote! { + #method_sig { + #zip.simd_into(self) } } - OpSig::Zip { select_low } => { - let indices = if select_low { - 0..vec_ty.len / 2 - } else { - (vec_ty.len / 2)..vec_ty.len - }; - - let zip = make_list( - indices - .map(|idx| { - quote! {a[#idx], b[#idx] } - }) - .collect::>(), - ); - + } + OpSig::Unzip { select_even } => { + let indices = if select_even { + (0..vec_ty.len).step_by(2) + } else { + (1..vec_ty.len).step_by(2) + }; + + let unzip = make_list( + indices + .clone() + .map(|idx| { + quote! {a[#idx]} + }) + .chain(indices.map(|idx| { + quote! {b[#idx]} + })) + .collect::>(), + ); + + quote! { + #method_sig { + #unzip.simd_into(self) + } + } + } + OpSig::Cvt { + target_ty, + scalar_bits, + precise, + } => { + if precise { + let non_precise = + generic_op_name(method.strip_suffix("_precise").unwrap(), vec_ty); quote! { #method_sig { - #zip.simd_into(self) + self.#non_precise(a) } } - } - OpSig::Unzip { select_even } => { - let indices = if select_even { - (0..vec_ty.len).step_by(2) - } else { - (1..vec_ty.len).step_by(2) - }; - - let unzip = make_list( - indices - .clone() + } else { + let to_ty = vec_ty.reinterpret(target_ty, scalar_bits); + let scalar = to_ty.scalar.rust(scalar_bits); + let items = make_list( + (0..vec_ty.len) .map(|idx| { - quote! {a[#idx]} + quote! { a[#idx] as #scalar } }) - .chain(indices.map(|idx| { - quote! {b[#idx]} - })) .collect::>(), ); - quote! { #method_sig { - #unzip.simd_into(self) + #items.simd_into(self) } } } - OpSig::Cvt { - target_ty, - scalar_bits, - precise, - } => { - if precise { - let non_precise = - generic_op_name(method.strip_suffix("_precise").unwrap(), vec_ty); - quote! { - #method_sig { - self.#non_precise(a) - } - } - } else { - let to_ty = vec_ty.reinterpret(target_ty, scalar_bits); - let scalar = to_ty.scalar.rust(scalar_bits); - let items = make_list( - (0..vec_ty.len) - .map(|idx| { - quote! { a[#idx] as #scalar } - }) - .collect::>(), - ); - quote! { - #method_sig { - #items.simd_into(self) - } + } + OpSig::Reinterpret { + target_ty, + scalar_bits, + } => { + if valid_reinterpret(vec_ty, target_ty, scalar_bits) { + quote! { + #method_sig { + a.bitcast() } } + } else { + quote! {} } - OpSig::Reinterpret { - target_ty, - scalar_bits, - } => { - if valid_reinterpret(vec_ty, target_ty, scalar_bits) { - quote! { - #method_sig { - a.bitcast() - } - } - } else { - quote! {} + } + OpSig::MaskReduce { + quantifier, + condition, + } => { + let indices = (0..vec_ty.len).map(|idx| quote! { #idx }); + let check = if condition { + quote! { != } + } else { + quote! { == } + }; + + let expr = match quantifier { + crate::ops::Quantifier::Any => { + quote! { #(a[#indices] #check 0)||* } } - } - OpSig::MaskReduce { - quantifier, - condition, - } => { - let indices = (0..vec_ty.len).map(|idx| quote! { #idx }); - let check = if condition { - quote! { != } - } else { - quote! { == } - }; - - let expr = match quantifier { - crate::ops::Quantifier::Any => { - quote! { #(a[#indices] #check 0)||* } - } - crate::ops::Quantifier::All => { - quote! { #(a[#indices] #check 0)&&* } - } - }; + crate::ops::Quantifier::All => { + quote! { #(a[#indices] #check 0)&&* } + } + }; - quote! { - #method_sig { - #expr - } + quote! { + #method_sig { + #expr } } - OpSig::LoadInterleaved { - block_size, - block_count, - } => { - let len = (block_size * block_count) as usize / vec_ty.scalar_bits; - let items = - interleave_indices(len, block_count as usize, |idx| quote! { src[#idx] }); - - quote! { - #method_sig { - #items.simd_into(self) - } + } + OpSig::LoadInterleaved { + block_size, + block_count, + } => { + let len = (block_size * block_count) as usize / vec_ty.scalar_bits; + let items = + interleave_indices(len, block_count as usize, |idx| quote! { src[#idx] }); + + quote! { + #method_sig { + #items.simd_into(self) } } - OpSig::StoreInterleaved { - block_size, - block_count, - } => { - let len = (block_size * block_count) as usize / vec_ty.scalar_bits; - let items = interleave_indices( - len, - len / block_count as usize, - |idx| quote! { a[#idx] }, - ); - - quote! { - #method_sig { - *dest = #items; - } + } + OpSig::StoreInterleaved { + block_size, + block_count, + } => { + let len = (block_size * block_count) as usize / vec_ty.scalar_bits; + let items = + interleave_indices(len, len / block_count as usize, |idx| quote! { a[#idx] }); + + quote! { + #method_sig { + *dest = #items; } } - OpSig::FromArray { kind } => { - let vec_rust = vec_ty.rust(); - let wrapper = vec_ty.aligned_wrapper(); - let expr = match kind { - RefKind::Value => quote! { val }, - RefKind::Ref | RefKind::Mut => quote! { *val }, - }; - quote! { - #method_sig { - #vec_rust { val: #wrapper(#expr), simd: self } - } + } + OpSig::FromArray { kind } => { + let vec_rust = vec_ty.rust(); + let wrapper = vec_ty.aligned_wrapper(); + let expr = match kind { + RefKind::Value => quote! { val }, + RefKind::Ref | RefKind::Mut => quote! { *val }, + }; + quote! { + #method_sig { + #vec_rust { val: #wrapper(#expr), simd: self } } } - OpSig::AsArray { kind } => { - let ref_tok = kind.token(); - quote! { - #method_sig { - #ref_tok a.val.0 - } + } + OpSig::AsArray { kind } => { + let ref_tok = kind.token(); + quote! { + #method_sig { + #ref_tok a.val.0 } } - OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), - OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), - }; - methods.push(method); + } + OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), + OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), } } - // Note: the `vectorize` implementation is pretty boilerplate and should probably - // be factored out for DRY. - quote! { - impl Simd for #level_tok { - type f32s = f32x4; - type f64s = f64x2; - type u8s = u8x16; - type i8s = i8x16; - type u16s = u16x8; - type i16s = i16x8; - type u32s = u32x4; - type i32s = i32x4; - type mask8s = mask8x16; - type mask16s = mask16x8; - type mask32s = mask32x4; - type mask64s = mask64x2; - #[inline(always)] - fn level(self) -> Level { - #[cfg(feature = "force_support_fallback")] - return Level::#level_tok(self); - #[cfg(not(feature = "force_support_fallback"))] - Level::baseline() - } - - #[inline] - fn vectorize R, R>(self, f: F) -> R { - f() - } - - #( #methods )* - } + fn make_type_impl(&self) -> TokenStream { + TokenStream::new() } } diff --git a/fearless_simd_gen/src/mk_neon.rs b/fearless_simd_gen/src/mk_neon.rs index 2fed692a..2e4af67a 100644 --- a/fearless_simd_gen/src/mk_neon.rs +++ b/fearless_simd_gen/src/mk_neon.rs @@ -2,60 +2,72 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use proc_macro2::{Ident, Literal, Span, TokenStream}; -use quote::{format_ident, quote}; +use quote::{ToTokens as _, format_ident, quote}; use crate::arch::neon::load_intrinsic; use crate::generic::{ generic_as_array, generic_from_array, generic_from_bytes, generic_op_name, generic_to_bytes, - impl_arch_types, }; +use crate::level::Level; use crate::ops::{Op, valid_reinterpret}; use crate::{ - arch::neon::{self, arch_ty, cvt_intrinsic, simple_intrinsic, split_intrinsic}, - generic::generic_op, - ops::{OpSig, ops_for_type}, - types::{SIMD_TYPES, ScalarType, VecType, type_imports}, + arch::neon::{self, cvt_intrinsic, simple_intrinsic, split_intrinsic}, + ops::OpSig, + types::{ScalarType, VecType}, }; #[derive(Clone, Copy)] -pub(crate) enum Level { - Neon, - // TODO: Fp16, -} +pub(crate) struct Neon; -impl Level { - fn name(self) -> &'static str { - match self { - Self::Neon => "Neon", - } +impl Level for Neon { + fn name(&self) -> &'static str { + "Neon" } - fn token(self) -> TokenStream { - let ident = Ident::new(self.name(), Span::call_site()); - quote! { #ident } + fn native_width(&self) -> usize { + 128 } -} -pub(crate) fn mk_neon_impl(level: Level) -> TokenStream { - let imports = type_imports(); - let arch_types_impl = impl_arch_types(level.name(), 512, arch_ty); - let simd_impl = mk_simd_impl(level); - let ty_impl = mk_type_impl(); + fn max_block_size(&self) -> usize { + 512 + } - quote! { - use core::arch::aarch64::*; + fn enabled_target_features(&self) -> Option<&'static str> { + Some("neon") + } - use crate::{seal::Seal, arch_types::ArchTypes, Level, Simd, SimdFrom, SimdInto}; + fn arch_ty(&self, vec_ty: &VecType) -> TokenStream { + let scalar = match vec_ty.scalar { + ScalarType::Float => "float", + ScalarType::Unsigned => "uint", + ScalarType::Int | ScalarType::Mask => "int", + }; + let name = if vec_ty.n_bits() == 256 { + format!("{}{}x{}x2_t", scalar, vec_ty.scalar_bits, vec_ty.len / 2) + } else if vec_ty.n_bits() == 512 { + format!("{}{}x{}x4_t", scalar, vec_ty.scalar_bits, vec_ty.len / 4) + } else { + format!("{}{}x{}_t", scalar, vec_ty.scalar_bits, vec_ty.len) + }; + Ident::new(&name, Span::call_site()).into_token_stream() + } + + fn token_doc(&self) -> &'static str { + r#"The SIMD token for the "neon" level."# + } - #imports + fn token_inner(&self) -> TokenStream { + quote!(crate::core_arch::aarch64::Neon) + } - /// The SIMD token for the "neon" level. - #[derive(Clone, Copy, Debug)] - pub struct Neon { - pub neon: crate::core_arch::aarch64::Neon, + fn make_module_prelude(&self) -> TokenStream { + quote! { + use core::arch::aarch64::*; } + } - impl Neon { + fn make_impl_body(&self) -> TokenStream { + quote! { #[inline] pub const unsafe fn new_unchecked() -> Self { Neon { @@ -63,497 +75,406 @@ pub(crate) fn mk_neon_impl(level: Level) -> TokenStream { } } } - - impl Seal for Neon {} - - #simd_impl - - #arch_types_impl - - #ty_impl } -} - -fn mk_simd_impl(level: Level) -> TokenStream { - let level_tok = level.token(); - let mut methods = vec![]; - for vec_ty in SIMD_TYPES { - let scalar_bits = vec_ty.scalar_bits; - for op in ops_for_type(vec_ty) { - let Op { sig, method, .. } = op; - if sig.should_use_generic_op(vec_ty, 128) { - methods.push(generic_op(&op, vec_ty)); - continue; + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream { + let Op { sig, method, .. } = op; + let method_sig = op.simd_trait_method_sig(vec_ty); + + match sig { + OpSig::Splat => { + let expr = neon::expr(method, vec_ty, &[quote! { val }]); + quote! { + #method_sig { + unsafe { + #expr.simd_into(self) + } + } + } + } + OpSig::Shift => { + let dup_type = vec_ty.cast(ScalarType::Int); + let scalar = dup_type.scalar.rust(dup_type.scalar_bits); + let dup_intrinsic = split_intrinsic("vdup", "n", &dup_type); + let shift = if method == "shr" { + quote! { -(shift as #scalar) } + } else { + quote! { shift as #scalar } + }; + let expr = neon::expr( + method, + vec_ty, + &[quote! { a.into() }, quote! { #dup_intrinsic ( #shift ) }], + ); + quote! { + #method_sig { + unsafe { + #expr.simd_into(self) + } + } + } } + OpSig::Unary => { + let args = [quote! { a.into() }]; - let method_sig = op.simd_trait_method_sig(vec_ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; + let expr = neon::expr(method, vec_ty, &args); - let method = match sig { - OpSig::Splat => { - let expr = neon::expr(method, vec_ty, &[quote! { val }]); - quote! { - #method_sig { - unsafe { - #expr.simd_into(self) - } + quote! { + #method_sig { + unsafe { + #expr.simd_into(self) } } } - OpSig::Shift => { - let dup_type = vec_ty.cast(ScalarType::Int); - let scalar = dup_type.scalar.rust(dup_type.scalar_bits); - let dup_intrinsic = split_intrinsic("vdup", "n", &dup_type); - let shift = if method == "shr" { - quote! { -(shift as #scalar) } - } else { - quote! { shift as #scalar } - }; - let expr = neon::expr( - method, - vec_ty, - &[quote! { a.into() }, quote! { #dup_intrinsic ( #shift ) }], + } + OpSig::LoadInterleaved { + block_size, + block_count, + } => { + assert_eq!(block_count, 4, "only count of 4 is currently supported"); + let intrinsic = { + // The function expects 64-bit or 128-bit + let ty = VecType::new( + vec_ty.scalar, + vec_ty.scalar_bits, + block_size as usize / vec_ty.scalar_bits, ); - quote! { - #method_sig { - unsafe { - #expr.simd_into(self) - } + simple_intrinsic("vld4", &ty) + }; + + quote! { + #method_sig { + unsafe { + #intrinsic(src.as_ptr()).simd_into(self) } } } - OpSig::Unary => { - let args = [quote! { a.into() }]; - - let expr = neon::expr(method, vec_ty, &args); + } + OpSig::StoreInterleaved { + block_size, + block_count, + } => { + assert_eq!(block_count, 4, "only count of 4 is currently supported"); + let intrinsic = { + // The function expects 64-bit or 128-bit + let ty = VecType::new( + vec_ty.scalar, + vec_ty.scalar_bits, + block_size as usize / vec_ty.scalar_bits, + ); + simple_intrinsic("vst4", &ty) + }; - quote! { - #method_sig { - unsafe { - #expr.simd_into(self) - } + quote! { + #method_sig { + unsafe { + #intrinsic(dest.as_mut_ptr(), a.into()) } } } - OpSig::LoadInterleaved { - block_size, - block_count, - } => { - assert_eq!(block_count, 4, "only count of 4 is currently supported"); - let intrinsic = { - // The function expects 64-bit or 128-bit - let ty = VecType::new( - vec_ty.scalar, - vec_ty.scalar_bits, - block_size as usize / vec_ty.scalar_bits, - ); - simple_intrinsic("vld4", &ty) - }; + } + OpSig::WidenNarrow { target_ty } => { + let vec_scalar_ty = vec_ty.scalar.rust(vec_ty.scalar_bits); + let target_scalar_ty = target_ty.scalar.rust(target_ty.scalar_bits); + + if method == "narrow" { + let arch = self.arch_ty(vec_ty); + + let id1 = Ident::new(&format!("vmovn_{}", vec_scalar_ty), Span::call_site()); + let id2 = + Ident::new(&format!("vcombine_{}", target_scalar_ty), Span::call_site()); quote! { #method_sig { unsafe { - #intrinsic(src.as_ptr()).simd_into(self) + let converted: #arch = a.into(); + let low = #id1(converted.0); + let high = #id1(converted.1); + + #id2(low, high).simd_into(self) } } } - } - OpSig::StoreInterleaved { - block_size, - block_count, - } => { - assert_eq!(block_count, 4, "only count of 4 is currently supported"); - let intrinsic = { - // The function expects 64-bit or 128-bit - let ty = VecType::new( - vec_ty.scalar, - vec_ty.scalar_bits, - block_size as usize / vec_ty.scalar_bits, - ); - simple_intrinsic("vst4", &ty) - }; + } else { + let arch = self.arch_ty(&target_ty); + let id1 = Ident::new(&format!("vmovl_{}", vec_scalar_ty), Span::call_site()); + let id2 = Ident::new(&format!("vget_low_{}", vec_scalar_ty), Span::call_site()); + let id3 = + Ident::new(&format!("vget_high_{}", vec_scalar_ty), Span::call_site()); quote! { #method_sig { unsafe { - #intrinsic(dest.as_mut_ptr(), a.into()) + let low = #id1(#id2(a.into())); + let high = #id1(#id3(a.into())); + + #arch(low, high).simd_into(self) } } } } - OpSig::WidenNarrow { target_ty } => { - let vec_scalar_ty = vec_ty.scalar.rust(vec_ty.scalar_bits); - let target_scalar_ty = target_ty.scalar.rust(target_ty.scalar_bits); - - if method == "narrow" { - let arch = neon::arch_ty(vec_ty); - - let id1 = - Ident::new(&format!("vmovn_{}", vec_scalar_ty), Span::call_site()); - let id2 = Ident::new( - &format!("vcombine_{}", target_scalar_ty), - Span::call_site(), - ); - - quote! { - #method_sig { - unsafe { - let converted: #arch = a.into(); - let low = #id1(converted.0); - let high = #id1(converted.1); - - #id2(low, high).simd_into(self) - } - } + } + OpSig::Binary => { + let expr = match method { + "shlv" | "shrv" => { + let mut args = if vec_ty.scalar == ScalarType::Int { + // Signed case + [quote! { a.into() }, quote! { b.into() }] + } else { + // Unsigned case + let bits = vec_ty.scalar_bits; + let reinterpret = format_ident!("vreinterpretq_s{bits}_u{bits}"); + [quote! { a.into() }, quote! { #reinterpret(b.into()) }] + }; + + // For a right shift, we need to negate the shift amount + if method == "shrv" { + let neg = simple_intrinsic("vneg", &vec_ty.cast(ScalarType::Int)); + let arg1 = &args[1]; + args[1] = quote! { #neg(#arg1) }; } - } else { - let arch = neon::arch_ty(&target_ty); - let id1 = - Ident::new(&format!("vmovl_{}", vec_scalar_ty), Span::call_site()); - let id2 = - Ident::new(&format!("vget_low_{}", vec_scalar_ty), Span::call_site()); - let id3 = - Ident::new(&format!("vget_high_{}", vec_scalar_ty), Span::call_site()); + let expr = neon::expr(method, vec_ty, &args); quote! { - #method_sig { - unsafe { - let low = #id1(#id2(a.into())); - let high = #id1(#id3(a.into())); - - #arch(low, high).simd_into(self) - } - } + #expr.simd_into(self) } } - } - OpSig::Binary => { - let expr = match method { - "shlv" | "shrv" => { - let mut args = if vec_ty.scalar == ScalarType::Int { - // Signed case - [quote! { a.into() }, quote! { b.into() }] - } else { - // Unsigned case - let bits = vec_ty.scalar_bits; - let reinterpret = format_ident!("vreinterpretq_s{bits}_u{bits}"); - [quote! { a.into() }, quote! { #reinterpret(b.into()) }] - }; + "copysign" => { + let shift_amt = Literal::usize_unsuffixed(vec_ty.scalar_bits - 1); + let unsigned_ty = vec_ty.cast(ScalarType::Unsigned); + let sign_mask = + neon::expr("splat", &unsigned_ty, &[quote! { 1 << #shift_amt }]); + let vbsl = simple_intrinsic("vbsl", vec_ty); - // For a right shift, we need to negate the shift amount - if method == "shrv" { - let neg = simple_intrinsic("vneg", &vec_ty.cast(ScalarType::Int)); - let arg1 = &args[1]; - args[1] = quote! { #neg(#arg1) }; - } - - let expr = neon::expr(method, vec_ty, &args); - quote! { - #expr.simd_into(self) - } - } - "copysign" => { - let shift_amt = Literal::usize_unsuffixed(vec_ty.scalar_bits - 1); - let unsigned_ty = vec_ty.cast(ScalarType::Unsigned); - let sign_mask = - neon::expr("splat", &unsigned_ty, &[quote! { 1 << #shift_amt }]); - let vbsl = simple_intrinsic("vbsl", vec_ty); - - quote! { - let sign_mask = #sign_mask; - #vbsl(sign_mask, b.into(), a.into()).simd_into(self) - } + quote! { + let sign_mask = #sign_mask; + #vbsl(sign_mask, b.into(), a.into()).simd_into(self) } - _ => { - let args = [quote! { a.into() }, quote! { b.into() }]; - let expr = neon::expr(method, vec_ty, &args); - quote! { - #expr.simd_into(self) - } + } + _ => { + let args = [quote! { a.into() }, quote! { b.into() }]; + let expr = neon::expr(method, vec_ty, &args); + quote! { + #expr.simd_into(self) } - }; + } + }; - quote! { - #method_sig { - unsafe { - #expr - } + quote! { + #method_sig { + unsafe { + #expr } } } - OpSig::Ternary => { - let args = match method { - "mul_add" | "mul_sub" => [ - quote! { c.into() }, - quote! { b.into() }, - quote! { a.into() }, - ], - _ => [ - quote! { a.into() }, - quote! { b.into() }, - quote! { c.into() }, - ], - }; - - let mut expr = neon::expr(method, vec_ty, &args); - if method == "mul_sub" { - // -(c - a * b) = (a * b - c) - let neg = simple_intrinsic("vneg", vec_ty); - expr = quote! { #neg(#expr) }; - } - quote! { - #method_sig { - unsafe { - #expr.simd_into(self) - } + } + OpSig::Ternary => { + let args = match method { + "mul_add" | "mul_sub" => [ + quote! { c.into() }, + quote! { b.into() }, + quote! { a.into() }, + ], + _ => [ + quote! { a.into() }, + quote! { b.into() }, + quote! { c.into() }, + ], + }; + + let mut expr = neon::expr(method, vec_ty, &args); + if method == "mul_sub" { + // -(c - a * b) = (a * b - c) + let neg = simple_intrinsic("vneg", vec_ty); + expr = quote! { #neg(#expr) }; + } + quote! { + #method_sig { + unsafe { + #expr.simd_into(self) } } } - OpSig::Compare => { - let args = [quote! { a.into() }, quote! { b.into() }]; - let expr = neon::expr(method, vec_ty, &args); - let opt_q = crate::arch::neon::opt_q(vec_ty); - let reinterpret_str = - format!("vreinterpret{opt_q}_s{scalar_bits}_u{scalar_bits}"); - let reinterpret = Ident::new(&reinterpret_str, Span::call_site()); - quote! { - #method_sig { - unsafe { - #reinterpret(#expr).simd_into(self) - } + } + OpSig::Compare => { + let args = [quote! { a.into() }, quote! { b.into() }]; + let expr = neon::expr(method, vec_ty, &args); + let opt_q = crate::arch::neon::opt_q(vec_ty); + let scalar_bits = vec_ty.scalar_bits; + let reinterpret_str = format!("vreinterpret{opt_q}_s{scalar_bits}_u{scalar_bits}"); + let reinterpret = Ident::new(&reinterpret_str, Span::call_site()); + quote! { + #method_sig { + unsafe { + #reinterpret(#expr).simd_into(self) } } } - OpSig::Select => { - let opt_q = crate::arch::neon::opt_q(vec_ty); - let reinterpret_str = - format!("vreinterpret{opt_q}_u{scalar_bits}_s{scalar_bits}"); - let reinterpret = Ident::new(&reinterpret_str, Span::call_site()); - let vbsl = simple_intrinsic("vbsl", vec_ty); - quote! { - #method_sig { - unsafe { - #vbsl(#reinterpret(a.into()), b.into(), c.into()).simd_into(self) - } + } + OpSig::Select => { + let opt_q = crate::arch::neon::opt_q(vec_ty); + let scalar_bits = vec_ty.scalar_bits; + let reinterpret_str = format!("vreinterpret{opt_q}_u{scalar_bits}_s{scalar_bits}"); + let reinterpret = Ident::new(&reinterpret_str, Span::call_site()); + let vbsl = simple_intrinsic("vbsl", vec_ty); + quote! { + #method_sig { + unsafe { + #vbsl(#reinterpret(a.into()), b.into(), c.into()).simd_into(self) } } } - OpSig::Combine { combined_ty } => { - let combined_wrapper = combined_ty.aligned_wrapper(); - let combined_arch_ty = arch_ty(&combined_ty); - let combined_rust = combined_ty.rust(); - let expr = match combined_ty.n_bits() { - 512 => quote! { - #combined_rust {val: #combined_wrapper(#combined_arch_ty(a.val.0.0, a.val.0.1, b.val.0.0, b.val.0.1)), simd: self } - }, - 256 => quote! { - #combined_rust {val: #combined_wrapper(#combined_arch_ty(a.val.0, b.val.0)), simd: self } - }, - _ => unimplemented!(), - }; - quote! { - #method_sig { - #expr + } + OpSig::Combine { combined_ty } => { + let combined_wrapper = combined_ty.aligned_wrapper(); + let combined_arch_ty = self.arch_ty(&combined_ty); + let combined_rust = combined_ty.rust(); + let expr = match combined_ty.n_bits() { + 512 => quote! { + #combined_rust {val: #combined_wrapper(#combined_arch_ty(a.val.0.0, a.val.0.1, b.val.0.0, b.val.0.1)), simd: self } + }, + 256 => quote! { + #combined_rust {val: #combined_wrapper(#combined_arch_ty(a.val.0, b.val.0)), simd: self } + }, + _ => unimplemented!(), + }; + quote! { + #method_sig { + #expr + } + } + } + OpSig::Split { half_ty } => { + let split_wrapper = half_ty.aligned_wrapper(); + let split_arch_ty = self.arch_ty(&half_ty); + let half_rust = half_ty.rust(); + let expr = match half_ty.n_bits() { + 256 => quote! { + ( + #half_rust { val: #split_wrapper(#split_arch_ty(a.val.0.0, a.val.0.1)), simd: self }, + #half_rust { val: #split_wrapper(#split_arch_ty(a.val.0.2, a.val.0.3)), simd: self }, + ) + }, + 128 => quote! { + ( + #half_rust { val: #split_wrapper(a.val.0.0), simd: self }, + #half_rust { val: #split_wrapper(a.val.0.1), simd: self }, + ) + }, + _ => unimplemented!(), + }; + quote! { + #method_sig { + #expr + } + } + } + OpSig::Zip { select_low } => { + let neon = if select_low { "vzip1" } else { "vzip2" }; + let zip = simple_intrinsic(neon, vec_ty); + quote! { + #method_sig { + let x = a.into(); + let y = b.into(); + unsafe { + #zip(x, y).simd_into(self) } } } - OpSig::Split { half_ty } => { - let split_wrapper = half_ty.aligned_wrapper(); - let split_arch_ty = arch_ty(&half_ty); - let half_rust = half_ty.rust(); - let expr = match half_ty.n_bits() { - 256 => quote! { - ( - #half_rust { val: #split_wrapper(#split_arch_ty(a.val.0.0, a.val.0.1)), simd: self }, - #half_rust { val: #split_wrapper(#split_arch_ty(a.val.0.2, a.val.0.3)), simd: self }, - ) - }, - 128 => quote! { - ( - #half_rust { val: #split_wrapper(a.val.0.0), simd: self }, - #half_rust { val: #split_wrapper(a.val.0.1), simd: self }, - ) - }, - _ => unimplemented!(), - }; - quote! { - #method_sig { - #expr + } + OpSig::Unzip { select_even } => { + let neon = if select_even { "vuzp1" } else { "vuzp2" }; + let zip = simple_intrinsic(neon, vec_ty); + quote! { + #method_sig { + let x = a.into(); + let y = b.into(); + unsafe { + #zip(x, y).simd_into(self) } } } - OpSig::Zip { select_low } => { - let neon = if select_low { "vzip1" } else { "vzip2" }; - let zip = simple_intrinsic(neon, vec_ty); + } + OpSig::Cvt { + target_ty, + scalar_bits, + precise, + } => { + if precise { + let non_precise = + generic_op_name(method.strip_suffix("_precise").unwrap(), vec_ty); quote! { #method_sig { - let x = a.into(); - let y = b.into(); - unsafe { - #zip(x, y).simd_into(self) - } + self.#non_precise(a) } } - } - OpSig::Unzip { select_even } => { - let neon = if select_even { "vuzp1" } else { "vuzp2" }; - let zip = simple_intrinsic(neon, vec_ty); + } else { + let to_ty = &vec_ty.reinterpret(target_ty, scalar_bits); + let neon = cvt_intrinsic("vcvt", to_ty, vec_ty); quote! { #method_sig { - let x = a.into(); - let y = b.into(); unsafe { - #zip(x, y).simd_into(self) - } - } - } - } - OpSig::Cvt { - target_ty, - scalar_bits, - precise, - } => { - if precise { - let non_precise = - generic_op_name(method.strip_suffix("_precise").unwrap(), vec_ty); - quote! { - #method_sig { - self.#non_precise(a) - } - } - } else { - let to_ty = &vec_ty.reinterpret(target_ty, scalar_bits); - let neon = cvt_intrinsic("vcvt", to_ty, vec_ty); - quote! { - #method_sig { - unsafe { - #neon(a.into()).simd_into(self) - } + #neon(a.into()).simd_into(self) } } } } - OpSig::Reinterpret { - target_ty, - scalar_bits, - } => { - if valid_reinterpret(vec_ty, target_ty, scalar_bits) { - let to_ty = vec_ty.reinterpret(target_ty, scalar_bits); - let neon = cvt_intrinsic("vreinterpret", &to_ty, vec_ty); - - quote! { - #method_sig { - unsafe { - #neon(a.into()).simd_into(self) - } - } - } - } else { - quote! {} - } - } - OpSig::MaskReduce { - quantifier, - condition, - } => { - let (reduction, target) = match (quantifier, condition) { - (crate::ops::Quantifier::Any, true) => ("vmaxv", quote! { != 0 }), - (crate::ops::Quantifier::Any, false) => ("vminv", quote! { != 0xffffffff }), - (crate::ops::Quantifier::All, true) => ("vminv", quote! { == 0xffffffff }), - (crate::ops::Quantifier::All, false) => ("vmaxv", quote! { == 0 }), - }; + } + OpSig::Reinterpret { + target_ty, + scalar_bits, + } => { + if valid_reinterpret(vec_ty, target_ty, scalar_bits) { + let to_ty = vec_ty.reinterpret(target_ty, scalar_bits); + let neon = cvt_intrinsic("vreinterpret", &to_ty, vec_ty); - let u32_ty = vec_ty.reinterpret(ScalarType::Unsigned, 32); - let min_max = simple_intrinsic(reduction, &u32_ty); - let reinterpret = format_ident!("vreinterpretq_u32_s{}", vec_ty.scalar_bits); quote! { #method_sig { unsafe { - #min_max(#reinterpret(a.into())) #target + #neon(a.into()).simd_into(self) } } } + } else { + quote! {} } - OpSig::FromArray { kind } => { - generic_from_array(method_sig, vec_ty, kind, 512, load_intrinsic) - } - OpSig::AsArray { kind } => generic_as_array(method_sig, vec_ty, kind, 512, arch_ty), - OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), - OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), - }; - methods.push(method); - } - } - - // Note: the `vectorize` implementation is pretty boilerplate and should probably - // be factored out for DRY. - quote! { - impl Simd for #level_tok { - type f32s = f32x4; - type f64s = f64x2; - type u8s = u8x16; - type i8s = i8x16; - type u16s = u16x8; - type i16s = i16x8; - type u32s = u32x4; - type i32s = i32x4; - type mask8s = mask8x16; - type mask16s = mask16x8; - type mask32s = mask32x4; - type mask64s = mask64x2; - #[inline(always)] - fn level(self) -> Level { - Level::#level_tok(self) } - - #[inline] - fn vectorize R, R>(self, f: F) -> R { - #[target_feature(enable = "neon")] - #[inline] - // unsafe not needed here with tf11, but can be justified - unsafe fn vectorize_neon R, R>(f: F) -> R { - f() - } - unsafe { vectorize_neon(f) } - } - - #( #methods )* - } - } -} - -fn mk_type_impl() -> TokenStream { - let mut result = vec![]; - for ty in SIMD_TYPES { - let n_bits = ty.n_bits(); - if !(n_bits == 64 || n_bits == 128 || n_bits == 256 || n_bits == 512) { - continue; - } - let simd = ty.rust(); - let arch = neon::arch_ty(ty); - result.push(quote! { - impl SimdFrom<#arch, S> for #simd { - #[inline(always)] - fn simd_from(arch: #arch, simd: S) -> Self { - Self { - val: unsafe { core::mem::transmute_copy(&arch) }, - simd + OpSig::MaskReduce { + quantifier, + condition, + } => { + let (reduction, target) = match (quantifier, condition) { + (crate::ops::Quantifier::Any, true) => ("vmaxv", quote! { != 0 }), + (crate::ops::Quantifier::Any, false) => ("vminv", quote! { != 0xffffffff }), + (crate::ops::Quantifier::All, true) => ("vminv", quote! { == 0xffffffff }), + (crate::ops::Quantifier::All, false) => ("vmaxv", quote! { == 0 }), + }; + + let u32_ty = vec_ty.reinterpret(ScalarType::Unsigned, 32); + let min_max = simple_intrinsic(reduction, &u32_ty); + let reinterpret = format_ident!("vreinterpretq_u32_s{}", vec_ty.scalar_bits); + quote! { + #method_sig { + unsafe { + #min_max(#reinterpret(a.into())) #target + } } } } - impl From<#simd> for #arch { - #[inline(always)] - fn from(value: #simd) -> Self { - unsafe { core::mem::transmute_copy(&value.val) } - } + OpSig::FromArray { kind } => generic_from_array( + method_sig, + vec_ty, + kind, + self.max_block_size(), + load_intrinsic, + ), + OpSig::AsArray { kind } => { + generic_as_array(method_sig, vec_ty, kind, self.max_block_size(), |vec_ty| { + self.arch_ty(vec_ty) + }) } - }); - } - quote! { - #( #result )* + OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), + OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), + } } } diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index 5bae214d..485c66a3 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -8,52 +8,69 @@ use crate::arch::x86::{ }; use crate::generic::{ generic_as_array, generic_block_combine, generic_block_split, generic_from_array, - generic_from_bytes, generic_op, generic_op_name, generic_to_bytes, impl_arch_types, - scalar_binary, + generic_from_bytes, generic_op_name, generic_to_bytes, scalar_binary, }; -use crate::ops::{Op, OpSig, Quantifier, ops_for_type, valid_reinterpret}; -use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; +use crate::level::Level; +use crate::ops::{Op, OpSig, Quantifier, valid_reinterpret}; +use crate::types::{ScalarType, VecType}; use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{ToTokens as _, quote}; #[derive(Clone, Copy)] -pub(crate) struct Level; +pub(crate) struct Sse4_2; -impl Level { - fn name(self) -> &'static str { +impl Level for Sse4_2 { + fn name(&self) -> &'static str { "Sse4_2" } - fn token(self) -> TokenStream { - let ident = Ident::new(self.name(), Span::call_site()); - quote! { #ident } + fn native_width(&self) -> usize { + 128 } -} -pub(crate) fn mk_sse4_2_impl() -> TokenStream { - let imports = type_imports(); - let arch_types_impl = impl_arch_types(Level.name(), 128, arch_ty); - let simd_impl = mk_simd_impl(); - let ty_impl = mk_type_impl(); + fn max_block_size(&self) -> usize { + 128 + } - quote! { - #[cfg(target_arch = "x86")] - use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] - use core::arch::x86_64::*; + fn enabled_target_features(&self) -> Option<&'static str> { + Some("sse4.2") + } - use core::ops::*; - use crate::{seal::Seal, arch_types::ArchTypes, Level, Simd, SimdFrom, SimdInto}; + fn arch_ty(&self, vec_ty: &VecType) -> TokenStream { + arch_ty(vec_ty).into_token_stream() + } - #imports + fn token_doc(&self) -> &'static str { + r#"The SIMD token for the "SSE4.2" level."# + } - /// The SIMD token for the "SSE 4.2" level. - #[derive(Clone, Copy, Debug)] - pub struct Sse4_2 { - pub sse4_2: crate::core_arch::x86::Sse4_2, + fn token_inner(&self) -> TokenStream { + quote!(crate::core_arch::x86::Sse4_2) + } + + fn make_module_prelude(&self) -> TokenStream { + quote! { + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + } + } + + fn make_level_body(&self) -> TokenStream { + let level_tok = Self.token(); + quote! { + #[cfg(not(all(target_feature = "avx2", target_feature = "fma")))] + return Level::#level_tok(self); + #[cfg(all(target_feature = "avx2", target_feature = "fma"))] + { + Level::baseline() + } } + } - impl Sse4_2 { + fn make_impl_body(&self) -> TokenStream { + quote! { /// Create a SIMD token. /// /// # Safety @@ -66,157 +83,59 @@ pub(crate) fn mk_sse4_2_impl() -> TokenStream { } } } - - impl Seal for Sse4_2 {} - - #arch_types_impl - - #simd_impl - - #ty_impl } -} - -fn mk_simd_impl() -> TokenStream { - let level_tok = Level.token(); - let mut methods = vec![]; - for vec_ty in SIMD_TYPES { - for op in ops_for_type(vec_ty) { - if op.sig.should_use_generic_op(vec_ty, 128) { - methods.push(generic_op(&op, vec_ty)); - continue; - } - let method = make_method(op, vec_ty); - - methods.push(method); - } - } - // Note: the `vectorize` implementation is pretty boilerplate and should probably - // be factored out for DRY. - quote! { - impl Simd for #level_tok { - type f32s = f32x4; - type f64s = f64x2; - type u8s = u8x16; - type i8s = i8x16; - type u16s = u16x8; - type i16s = i16x8; - type u32s = u32x4; - type i32s = i32x4; - type mask8s = mask8x16; - type mask16s = mask16x8; - type mask32s = mask32x4; - type mask64s = mask64x2; - #[inline(always)] - fn level(self) -> Level { - #[cfg(not(all(target_feature = "avx2", target_feature = "fma")))] - return Level::#level_tok(self); - #[cfg(all(target_feature = "avx2", target_feature = "fma"))] - { - Level::baseline() - } - } + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream { + let Op { sig, method, .. } = op; + let method_sig = op.simd_trait_method_sig(vec_ty); - #[inline] - fn vectorize R, R>(self, f: F) -> R { - #[target_feature(enable = "sse4.2")] - #[inline] - unsafe fn vectorize_sse4_2 R, R>(f: F) -> R { - f() - } - unsafe { vectorize_sse4_2(f) } + match sig { + OpSig::Splat => handle_splat(method_sig, vec_ty), + OpSig::Compare => handle_compare(method_sig, method, vec_ty), + OpSig::Unary => handle_unary(method_sig, method, vec_ty), + OpSig::WidenNarrow { target_ty } => { + handle_widen_narrow(method_sig, method, vec_ty, target_ty) } - - #( #methods )* - } - } -} - -fn mk_type_impl() -> TokenStream { - let mut result = vec![]; - for ty in SIMD_TYPES { - let n_bits = ty.n_bits(); - if n_bits != 128 { - continue; - } - let simd = ty.rust(); - let arch = x86::arch_ty(ty); - result.push(quote! { - impl SimdFrom<#arch, S> for #simd { - #[inline(always)] - fn simd_from(arch: #arch, simd: S) -> Self { - Self { - val: unsafe { core::mem::transmute_copy(&arch) }, - simd - } - } + OpSig::Binary => handle_binary(method_sig, method, vec_ty), + OpSig::Shift => handle_shift(method_sig, method, vec_ty), + OpSig::Ternary => handle_ternary(method_sig, method, vec_ty), + OpSig::Select => handle_select(method_sig, vec_ty), + OpSig::Combine { combined_ty } => generic_block_combine(method_sig, &combined_ty, 128), + OpSig::Split { half_ty } => generic_block_split(method_sig, &half_ty, 128), + OpSig::Zip { select_low } => handle_zip(method_sig, vec_ty, select_low), + OpSig::Unzip { select_even } => handle_unzip(method_sig, vec_ty, select_even), + OpSig::Cvt { + target_ty, + scalar_bits, + precise, + } => handle_cvt(method_sig, vec_ty, target_ty, scalar_bits, precise), + OpSig::Reinterpret { + target_ty, + scalar_bits, + } => handle_reinterpret(self, method_sig, vec_ty, target_ty, scalar_bits), + OpSig::MaskReduce { + quantifier, + condition, + } => handle_mask_reduce(method_sig, vec_ty, quantifier, condition), + OpSig::LoadInterleaved { + block_size, + block_count, + } => handle_load_interleaved(method_sig, vec_ty, block_size, block_count), + OpSig::StoreInterleaved { + block_size, + block_count, + } => handle_store_interleaved(method_sig, vec_ty, block_size, block_count), + OpSig::FromArray { kind } => { + generic_from_array(method_sig, vec_ty, kind, 128, |block_ty| { + intrinsic_ident("loadu", coarse_type(block_ty), block_ty.n_bits()) + }) } - impl From<#simd> for #arch { - #[inline(always)] - fn from(value: #simd) -> Self { - unsafe { core::mem::transmute_copy(&value.val) } - } + OpSig::AsArray { kind } => { + generic_as_array(method_sig, vec_ty, kind, 128, |vec_ty| self.arch_ty(vec_ty)) } - }); - } - quote! { - #( #result )* - } -} - -fn make_method(op: Op, vec_ty: &VecType) -> TokenStream { - let Op { sig, method, .. } = op; - let method_sig = op.simd_trait_method_sig(vec_ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; - - match sig { - OpSig::Splat => handle_splat(method_sig, vec_ty), - OpSig::Compare => handle_compare(method_sig, method, vec_ty), - OpSig::Unary => handle_unary(method_sig, method, vec_ty), - OpSig::WidenNarrow { target_ty } => { - handle_widen_narrow(method_sig, method, vec_ty, target_ty) - } - OpSig::Binary => handle_binary(method_sig, method, vec_ty), - OpSig::Shift => handle_shift(method_sig, method, vec_ty), - OpSig::Ternary => handle_ternary(method_sig, method, vec_ty), - OpSig::Select => handle_select(method_sig, vec_ty), - OpSig::Combine { combined_ty } => generic_block_combine(method_sig, &combined_ty, 128), - OpSig::Split { half_ty } => generic_block_split(method_sig, &half_ty, 128), - OpSig::Zip { select_low } => handle_zip(method_sig, vec_ty, select_low), - OpSig::Unzip { select_even } => handle_unzip(method_sig, vec_ty, select_even), - OpSig::Cvt { - target_ty, - scalar_bits, - precise, - } => handle_cvt(method_sig, vec_ty, target_ty, scalar_bits, precise), - OpSig::Reinterpret { - target_ty, - scalar_bits, - } => handle_reinterpret(method_sig, vec_ty, target_ty, scalar_bits), - OpSig::MaskReduce { - quantifier, - condition, - } => handle_mask_reduce(method_sig, vec_ty, quantifier, condition), - OpSig::LoadInterleaved { - block_size, - block_count, - } => handle_load_interleaved(method_sig, vec_ty, block_size, block_count), - OpSig::StoreInterleaved { - block_size, - block_count, - } => handle_store_interleaved(method_sig, vec_ty, block_size, block_count), - OpSig::FromArray { kind } => { - generic_from_array(method_sig, vec_ty, kind, 128, |block_ty| { - intrinsic_ident("loadu", coarse_type(block_ty), block_ty.n_bits()) - }) + OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), + OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), } - OpSig::AsArray { kind } => generic_as_array(method_sig, vec_ty, kind, 128, arch_ty), - OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), - OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), } } @@ -968,6 +887,7 @@ pub(crate) fn handle_cvt( } pub(crate) fn handle_reinterpret( + level: &impl Level, method_sig: TokenStream, vec_ty: &VecType, target_ty: ScalarType, @@ -980,7 +900,7 @@ pub(crate) fn handle_reinterpret( ); if coarse_type(vec_ty) == coarse_type(&dst_ty) { - let arch_ty = x86::arch_ty(vec_ty); + let arch_ty = level.arch_ty(vec_ty); quote! { #method_sig { #arch_ty::from(a).simd_into(self) diff --git a/fearless_simd_gen/src/mk_wasm.rs b/fearless_simd_gen/src/mk_wasm.rs index 14078827..61da3d70 100644 --- a/fearless_simd_gen/src/mk_wasm.rs +++ b/fearless_simd_gen/src/mk_wasm.rs @@ -4,720 +4,627 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; -use crate::arch::wasm::{arch_ty, v128_intrinsic}; +use crate::arch::wasm::{arch_prefix, v128_intrinsic}; use crate::generic::{ generic_as_array, generic_block_combine, generic_block_split, generic_from_array, - generic_from_bytes, generic_op_name, generic_to_bytes, impl_arch_types, scalar_binary, + generic_from_bytes, generic_op_name, generic_to_bytes, scalar_binary, }; +use crate::level::Level; use crate::ops::{Op, Quantifier, valid_reinterpret}; use crate::{ arch::wasm::{self, simple_intrinsic}, - generic::generic_op, - ops::{OpSig, ops_for_type}, - types::{SIMD_TYPES, ScalarType, VecType, type_imports}, + ops::OpSig, + types::{ScalarType, VecType}, }; #[derive(Clone, Copy)] -pub(crate) enum Level { - WasmSimd128, -} +pub(crate) struct WasmSimd128; + +impl Level for WasmSimd128 { + fn name(&self) -> &'static str { + "WasmSimd128" + } + + fn native_width(&self) -> usize { + 128 + } + + fn max_block_size(&self) -> usize { + 128 + } + + fn enabled_target_features(&self) -> Option<&'static str> { + None + } + + fn arch_ty(&self, _vec_ty: &VecType) -> TokenStream { + quote! { v128 } + } + + fn token_doc(&self) -> &'static str { + r#"The SIMD token for the "wasm128" level."# + } + + fn token_inner(&self) -> TokenStream { + quote!(crate::core_arch::wasm32::WasmSimd128) + } -impl Level { - fn name(self) -> &'static str { - match self { - Self::WasmSimd128 => "WasmSimd128", + fn make_module_prelude(&self) -> TokenStream { + quote! { + use core::arch::wasm32::*; } } - fn token(self) -> TokenStream { - let ident = Ident::new(self.name(), Span::call_site()); - quote! { #ident } + fn make_impl_body(&self) -> TokenStream { + quote! { + #[inline] + pub const fn new_unchecked() -> Self { + Self { wasmsimd128: crate::core_arch::wasm32::WasmSimd128::new() } + } + } } -} -fn mk_simd_impl(level: Level) -> TokenStream { - let level_tok = level.token(); - let mut methods = vec![]; + fn make_method(&self, op: Op, vec_ty: &VecType) -> TokenStream { + let Op { sig, method, .. } = op; - for vec_ty in SIMD_TYPES { - for op in ops_for_type(vec_ty) { - let Op { sig, method, .. } = op; - if sig.should_use_generic_op(vec_ty, 128) { - methods.push(generic_op(&op, vec_ty)); - continue; + let method_sig = op.simd_trait_method_sig(vec_ty); + match sig { + OpSig::Splat => { + let expr = wasm::expr(method, vec_ty, &[quote! { val }]); + quote! { + #method_sig { + #expr.simd_into(self) + } + } } + OpSig::Unary => { + let args = [quote! { a.into() }]; + let expr = if matches!(method, "fract") { + assert_eq!( + vec_ty.scalar, + ScalarType::Float, + "only float supports fract" + ); - let method_sig = op.simd_trait_method_sig(vec_ty); - let method_sig = quote! { - #[inline(always)] - #method_sig - }; - let m = match sig { - OpSig::Splat => { - let expr = wasm::expr(method, vec_ty, &[quote! { val }]); + let trunc = generic_op_name("trunc", vec_ty); + let sub = generic_op_name("sub", vec_ty); quote! { - #method_sig { - #expr.simd_into(self) - } + self.#sub(a, self.#trunc(a)) + } + } else { + let expr = wasm::expr(method, vec_ty, &args); + quote! { #expr.simd_into(self) } + }; + + quote! { + #method_sig { + #expr } } - OpSig::Unary => { - let args = [quote! { a.into() }]; - let expr = if matches!(method, "fract") { - assert_eq!( - vec_ty.scalar, - ScalarType::Float, - "only float supports fract" - ); + } + OpSig::Binary => { + let args = [quote! { a.into() }, quote! { b.into() }]; + let expr = match method { + "mul" if vec_ty.scalar_bits == 8 && vec_ty.len == 16 => { + let (extmul_low, extmul_high) = match vec_ty.scalar { + ScalarType::Unsigned => ( + quote! { u16x8_extmul_low_u8x16 }, + quote! { u16x8_extmul_high_u8x16 }, + ), + ScalarType::Int => ( + quote! { i16x8_extmul_low_i8x16 }, + quote! { i16x8_extmul_high_i8x16 }, + ), + _ => unreachable!(), + }; - let trunc = generic_op_name("trunc", vec_ty); - let sub = generic_op_name("sub", vec_ty); quote! { - self.#sub(a, self.#trunc(a)) - } - } else { - let expr = wasm::expr(method, vec_ty, &args); - quote! { #expr.simd_into(self) } - }; - - quote! { - #method_sig { - #expr + let low = #extmul_low(a.into(), b.into()); + let high = #extmul_high(a.into(), b.into()); + u8x16_shuffle::<0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30>(low, high).simd_into(self) } } - } - OpSig::Binary => { - let args = [quote! { a.into() }, quote! { b.into() }]; - let expr = match method { - "mul" if vec_ty.scalar_bits == 8 && vec_ty.len == 16 => { - let (extmul_low, extmul_high) = match vec_ty.scalar { - ScalarType::Unsigned => ( - quote! { u16x8_extmul_low_u8x16 }, - quote! { u16x8_extmul_high_u8x16 }, - ), - ScalarType::Int => ( - quote! { i16x8_extmul_low_i8x16 }, - quote! { i16x8_extmul_high_i8x16 }, - ), - _ => unreachable!(), - }; - - quote! { - let low = #extmul_low(a.into(), b.into()); - let high = #extmul_high(a.into(), b.into()); - u8x16_shuffle::<0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30>(low, high).simd_into(self) - } - } - "max_precise" | "min_precise" => { - let intrinsic = simple_intrinsic( - if method == "max_precise" { - "pmax" - } else { - "pmin" - }, - vec_ty, - ); - let compare_ne = simple_intrinsic("ne", vec_ty); - quote! { - let intermediate = #intrinsic(b.into(), a.into()); - - // See the x86 min_precise/max_precise code in `arch::x86` for more info on how this - // works. - let b_is_nan = #compare_ne(b.into(), b.into()); - v128_bitselect(a.into(), intermediate, b_is_nan).simd_into(self) - } - } - "max" | "min" if vec_ty.scalar == ScalarType::Float => { - let expr = wasm::expr(method, vec_ty, &args); - let relaxed_intrinsic = simple_intrinsic( - if method == "max" { - "relaxed_max" - } else { - "relaxed_min" - }, - vec_ty, - ); - let relaxed_expr = quote! { #relaxed_intrinsic ( #( #args ),* ) }; - - quote! { - #[cfg(target_feature = "relaxed-simd")] - { #relaxed_expr.simd_into(self) } - - #[cfg(not(target_feature = "relaxed-simd"))] - { #expr.simd_into(self) } - } - } - "shlv" => scalar_binary(quote!(core::ops::Shl::shl)), - "shrv" => scalar_binary(quote!(core::ops::Shr::shr)), - "copysign" => { - let splat = simple_intrinsic("splat", vec_ty); - let sign_mask_literal = match vec_ty.scalar_bits { - 32 => quote! { -0.0_f32 }, - 64 => quote! { -0.0_f64 }, - _ => unimplemented!(), - }; - quote! { - let sign_mask = #splat(#sign_mask_literal); - let sign_bits = v128_and(b.into(), sign_mask.into()); - let magnitude = v128_andnot(a.into(), sign_mask.into()); - v128_or(magnitude, sign_bits).simd_into(self) - } - } - _ => { - let expr = wasm::expr(method, vec_ty, &args); - quote! { #expr.simd_into(self) } - } - }; + "max_precise" | "min_precise" => { + let intrinsic = simple_intrinsic( + if method == "max_precise" { + "pmax" + } else { + "pmin" + }, + vec_ty, + ); + let compare_ne = simple_intrinsic("ne", vec_ty); + quote! { + let intermediate = #intrinsic(b.into(), a.into()); - quote! { - #method_sig { - #expr + // See the x86 min_precise/max_precise code in `arch::x86` for more info on how this + // works. + let b_is_nan = #compare_ne(b.into(), b.into()); + v128_bitselect(a.into(), intermediate, b_is_nan).simd_into(self) } } - } - OpSig::Ternary => { - if matches!(method, "mul_add" | "mul_sub") { - let add_sub = generic_op_name( - if method == "mul_add" { "add" } else { "sub" }, + "max" | "min" if vec_ty.scalar == ScalarType::Float => { + let expr = wasm::expr(method, vec_ty, &args); + let relaxed_intrinsic = simple_intrinsic( + if method == "max" { + "relaxed_max" + } else { + "relaxed_min" + }, vec_ty, ); - let mul = generic_op_name("mul", vec_ty); - - let c = if method == "mul_sub" { - // WebAssembly just... forgot fused multiply-subtract? It seems the - // initial proposal - // (https://github.com/WebAssembly/relaxed-simd/issues/27) confused it - // with negate multiply-add, and nobody ever resolved the confusion. - let negate = simple_intrinsic("neg", vec_ty); - quote! { #negate(c.into()) } - } else { - quote! { c.into() } - }; - let relaxed_madd = simple_intrinsic("relaxed_madd", vec_ty); + let relaxed_expr = quote! { #relaxed_intrinsic ( #( #args ),* ) }; quote! { - #method_sig { - #[cfg(target_feature = "relaxed-simd")] - { #relaxed_madd(a.into(), b.into(), #c).simd_into(self) } + #[cfg(target_feature = "relaxed-simd")] + { #relaxed_expr.simd_into(self) } - #[cfg(not(target_feature = "relaxed-simd"))] - { self.#add_sub(self.#mul(a, b), c) } - } + #[cfg(not(target_feature = "relaxed-simd"))] + { #expr.simd_into(self) } } - } else { - unimplemented!() } - } - OpSig::Compare => { - let args = [quote! { a.into() }, quote! { b.into() }]; - let expr = wasm::expr(method, vec_ty, &args); - quote! { - #method_sig { - #expr.simd_into(self) + "shlv" => scalar_binary(quote!(core::ops::Shl::shl)), + "shrv" => scalar_binary(quote!(core::ops::Shr::shr)), + "copysign" => { + let splat = simple_intrinsic("splat", vec_ty); + let sign_mask_literal = match vec_ty.scalar_bits { + 32 => quote! { -0.0_f32 }, + 64 => quote! { -0.0_f64 }, + _ => unimplemented!(), + }; + quote! { + let sign_mask = #splat(#sign_mask_literal); + let sign_bits = v128_and(b.into(), sign_mask.into()); + let magnitude = v128_andnot(a.into(), sign_mask.into()); + v128_or(magnitude, sign_bits).simd_into(self) } } + _ => { + let expr = wasm::expr(method, vec_ty, &args); + quote! { #expr.simd_into(self) } + } + }; + + quote! { + #method_sig { + #expr + } } - OpSig::Select => { - // Rust includes unsigned versions of the lane select intrinsics, but they're - // just aliases for the signed ones - let lane_ty = vec_ty.cast(ScalarType::Int); - let lane_select = simple_intrinsic("relaxed_laneselect", &lane_ty); + } + OpSig::Ternary => { + if matches!(method, "mul_add" | "mul_sub") { + let add_sub = + generic_op_name(if method == "mul_add" { "add" } else { "sub" }, vec_ty); + let mul = generic_op_name("mul", vec_ty); + + let c = if method == "mul_sub" { + // WebAssembly just... forgot fused multiply-subtract? It seems the + // initial proposal + // (https://github.com/WebAssembly/relaxed-simd/issues/27) confused it + // with negate multiply-add, and nobody ever resolved the confusion. + let negate = simple_intrinsic("neg", vec_ty); + quote! { #negate(c.into()) } + } else { + quote! { c.into() } + }; + let relaxed_madd = simple_intrinsic("relaxed_madd", vec_ty); quote! { #method_sig { #[cfg(target_feature = "relaxed-simd")] - { #lane_select(b.into(), c.into(), a.into()).simd_into(self) } + { #relaxed_madd(a.into(), b.into(), #c).simd_into(self) } #[cfg(not(target_feature = "relaxed-simd"))] - { v128_bitselect(b.into(), c.into(), a.into()).simd_into(self) } + { self.#add_sub(self.#mul(a, b), c) } } } + } else { + unimplemented!() } - OpSig::Combine { combined_ty } => { - generic_block_combine(method_sig, &combined_ty, 128) + } + OpSig::Compare => { + let args = [quote! { a.into() }, quote! { b.into() }]; + let expr = wasm::expr(method, vec_ty, &args); + quote! { + #method_sig { + #expr.simd_into(self) + } } - OpSig::Split { half_ty } => generic_block_split(method_sig, &half_ty, 128), - OpSig::Zip { select_low } => { - let (indices, shuffle_fn) = match vec_ty.scalar_bits { - 8 => { - let indices = if select_low { - quote! { 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23 } - } else { - quote! { 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 } - }; - (indices, quote! { u8x16_shuffle }) - } - 16 => { - let indices = if select_low { - quote! { 0, 8, 1, 9, 2, 10, 3, 11 } - } else { - quote! { 4, 12, 5, 13, 6, 14, 7, 15 } - }; - (indices, quote! { u16x8_shuffle }) - } - 32 => { - let indices = if select_low { - quote! { 0, 4, 1, 5 } - } else { - quote! { 2, 6, 3, 7 } - }; - (indices, quote! { u32x4_shuffle }) - } - 64 => { - let indices = if select_low { - quote! { 0, 2 } - } else { - quote! { 1, 3 } - }; - (indices, quote! { u64x2_shuffle }) - } - _ => panic!("unsupported scalar_bits for zip operation"), - }; + } + OpSig::Select => { + // Rust includes unsigned versions of the lane select intrinsics, but they're + // just aliases for the signed ones + let lane_ty = vec_ty.cast(ScalarType::Int); + let lane_select = simple_intrinsic("relaxed_laneselect", &lane_ty); + + quote! { + #method_sig { + #[cfg(target_feature = "relaxed-simd")] + { #lane_select(b.into(), c.into(), a.into()).simd_into(self) } + + #[cfg(not(target_feature = "relaxed-simd"))] + { v128_bitselect(b.into(), c.into(), a.into()).simd_into(self) } + } + } + } + OpSig::Combine { combined_ty } => generic_block_combine(method_sig, &combined_ty, 128), + OpSig::Split { half_ty } => generic_block_split(method_sig, &half_ty, 128), + OpSig::Zip { select_low } => { + let (indices, shuffle_fn) = match vec_ty.scalar_bits { + 8 => { + let indices = if select_low { + quote! { 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23 } + } else { + quote! { 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 } + }; + (indices, quote! { u8x16_shuffle }) + } + 16 => { + let indices = if select_low { + quote! { 0, 8, 1, 9, 2, 10, 3, 11 } + } else { + quote! { 4, 12, 5, 13, 6, 14, 7, 15 } + }; + (indices, quote! { u16x8_shuffle }) + } + 32 => { + let indices = if select_low { + quote! { 0, 4, 1, 5 } + } else { + quote! { 2, 6, 3, 7 } + }; + (indices, quote! { u32x4_shuffle }) + } + 64 => { + let indices = if select_low { + quote! { 0, 2 } + } else { + quote! { 1, 3 } + }; + (indices, quote! { u64x2_shuffle }) + } + _ => panic!("unsupported scalar_bits for zip operation"), + }; - quote! { - #method_sig { - #shuffle_fn::<#indices>(a.into(), b.into()).simd_into(self) - } + quote! { + #method_sig { + #shuffle_fn::<#indices>(a.into(), b.into()).simd_into(self) } } - OpSig::Unzip { select_even } => { - let (indices, shuffle_fn) = match vec_ty.scalar_bits { - 8 => { - let indices = if select_even { - quote! { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 } - } else { - quote! { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 } - }; - (indices, quote! { u8x16_shuffle }) - } - 16 => { - let indices = if select_even { - quote! { 0, 2, 4, 6, 8, 10, 12, 14 } - } else { - quote! { 1, 3, 5, 7, 9, 11, 13, 15 } - }; - (indices, quote! { u16x8_shuffle }) - } - 32 => { - let indices = if select_even { - quote! { 0, 2, 4, 6 } - } else { - quote! { 1, 3, 5, 7 } - }; - (indices, quote! { u32x4_shuffle }) - } - 64 => { - let indices = if select_even { - quote! { 0, 2 } - } else { - quote! { 1, 3 } - }; - (indices, quote! { u64x2_shuffle }) - } - _ => panic!("unsupported scalar_bits for unzip operation"), - }; - quote! { - #method_sig { - #shuffle_fn::<#indices>(a.into(), b.into()).simd_into(self) - } + } + OpSig::Unzip { select_even } => { + let (indices, shuffle_fn) = match vec_ty.scalar_bits { + 8 => { + let indices = if select_even { + quote! { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 } + } else { + quote! { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 } + }; + (indices, quote! { u8x16_shuffle }) + } + 16 => { + let indices = if select_even { + quote! { 0, 2, 4, 6, 8, 10, 12, 14 } + } else { + quote! { 1, 3, 5, 7, 9, 11, 13, 15 } + }; + (indices, quote! { u16x8_shuffle }) + } + 32 => { + let indices = if select_even { + quote! { 0, 2, 4, 6 } + } else { + quote! { 1, 3, 5, 7 } + }; + (indices, quote! { u32x4_shuffle }) + } + 64 => { + let indices = if select_even { + quote! { 0, 2 } + } else { + quote! { 1, 3 } + }; + (indices, quote! { u64x2_shuffle }) + } + _ => panic!("unsupported scalar_bits for unzip operation"), + }; + quote! { + #method_sig { + #shuffle_fn::<#indices>(a.into(), b.into()).simd_into(self) } } - OpSig::Shift => { - let prefix = vec_ty.scalar.prefix(); - let shift_name = - format!("{prefix}{}x{}_{method}", vec_ty.scalar_bits, vec_ty.len); - let shift_fn = Ident::new(&shift_name, Span::call_site()); - + } + OpSig::Shift => { + let prefix = vec_ty.scalar.prefix(); + let shift_name = format!("{prefix}{}x{}_{method}", vec_ty.scalar_bits, vec_ty.len); + let shift_fn = Ident::new(&shift_name, Span::call_site()); + + quote! { + #method_sig { + #shift_fn(a.into(), shift).simd_into(self) + } + } + } + OpSig::Reinterpret { + target_ty, + scalar_bits, + } => { + assert!( + valid_reinterpret(vec_ty, target_ty, scalar_bits), + "The underlying data for WASM SIMD is a v128, so a reinterpret is just that, a reinterpretation of the v128." + ); + + quote! { + #method_sig { + ::from(a).simd_into(self) + } + } + } + OpSig::Cvt { + target_ty, + scalar_bits, + precise, + } => { + let (op, uses_relaxed) = match (vec_ty.scalar, target_ty, precise) { + (ScalarType::Float, ScalarType::Int | ScalarType::Unsigned, false) => { + ("relaxed_trunc", true) + } + (ScalarType::Float, ScalarType::Int | ScalarType::Unsigned, true) => { + ("trunc_sat", false) + } + (ScalarType::Int | ScalarType::Unsigned, ScalarType::Float, _) => { + ("convert", false) + } + _ => unimplemented!(), + }; + let dst_ty = arch_prefix(&vec_ty.reinterpret(target_ty, scalar_bits)); + let src_ty = arch_prefix(vec_ty); + let conversion_fn = format_ident!("{dst_ty}_{op}_{src_ty}"); + + if uses_relaxed { + let precise = generic_op_name(&[method, "_precise"].join(""), vec_ty); quote! { #method_sig { - #shift_fn(a.into(), shift).simd_into(self) + #[cfg(target_feature = "relaxed-simd")] + { #conversion_fn(a.into()).simd_into(self) } + + #[cfg(not(target_feature = "relaxed-simd"))] + { self.#precise(a) } } } - } - OpSig::Reinterpret { - target_ty, - scalar_bits, - } => { - assert!( - valid_reinterpret(vec_ty, target_ty, scalar_bits), - "The underlying data for WASM SIMD is a v128, so a reinterpret is just that, a reinterpretation of the v128." - ); - + } else { quote! { #method_sig { - ::from(a).simd_into(self) + #conversion_fn(a.into()).simd_into(self) } } } - OpSig::Cvt { - target_ty, - scalar_bits, - precise, - } => { - let (op, uses_relaxed) = match (vec_ty.scalar, target_ty, precise) { - (ScalarType::Float, ScalarType::Int | ScalarType::Unsigned, false) => { - ("relaxed_trunc", true) - } - (ScalarType::Float, ScalarType::Int | ScalarType::Unsigned, true) => { - ("trunc_sat", false) - } - (ScalarType::Int | ScalarType::Unsigned, ScalarType::Float, _) => { - ("convert", false) - } - _ => unimplemented!(), - }; - let dst_ty = arch_ty(&vec_ty.reinterpret(target_ty, scalar_bits)); - let src_ty = arch_ty(vec_ty); - let conversion_fn = format_ident!("{dst_ty}_{op}_{src_ty}"); - - if uses_relaxed { - let precise = generic_op_name(&[method, "_precise"].join(""), vec_ty); + } + OpSig::WidenNarrow { target_ty } => { + match method { + "widen" => { + assert_eq!( + vec_ty.rust_name(), + "u8x16", + "Currently only u8x16 -> u16x16 widening is supported" + ); + assert_eq!( + target_ty.rust_name(), + "u16x16", + "Currently only u8x16 -> u16x16 widening is supported" + ); quote! { #method_sig { - #[cfg(target_feature = "relaxed-simd")] - { #conversion_fn(a.into()).simd_into(self) } - - #[cfg(not(target_feature = "relaxed-simd"))] - { self.#precise(a) } + let low = u16x8_extend_low_u8x16(a.into()); + let high = u16x8_extend_high_u8x16(a.into()); + self.combine_u16x8(low.simd_into(self), high.simd_into(self)) } } - } else { + } + "narrow" => { + assert_eq!( + vec_ty.rust_name(), + "u16x16", + "Currently only u16x16 -> u8x16 narrowing is supported" + ); + assert_eq!( + target_ty.rust_name(), + "u8x16", + "Currently only u16x16 -> u8x16 narrowing is supported" + ); + // WASM SIMD only has saturating narrowing instructions, so we emulate + // truncated narrowing by masking out the quote! { #method_sig { - #conversion_fn(a.into()).simd_into(self) + let mask = u16x8_splat(0xFF); + let (low, high) = self.split_u16x16(a); + let low_masked = v128_and(low.into(), mask); + let high_masked = v128_and(high.into(), mask); + let result = u8x16_narrow_i16x8(low_masked, high_masked); + result.simd_into(self) } } } + _ => unimplemented!(), } - OpSig::WidenNarrow { target_ty } => { - match method { - "widen" => { - assert_eq!( - vec_ty.rust_name(), - "u8x16", - "Currently only u8x16 -> u16x16 widening is supported" - ); - assert_eq!( - target_ty.rust_name(), - "u16x16", - "Currently only u8x16 -> u16x16 widening is supported" - ); - quote! { - #method_sig { - let low = u16x8_extend_low_u8x16(a.into()); - let high = u16x8_extend_high_u8x16(a.into()); - self.combine_u16x8(low.simd_into(self), high.simd_into(self)) - } - } - } - "narrow" => { - assert_eq!( - vec_ty.rust_name(), - "u16x16", - "Currently only u16x16 -> u8x16 narrowing is supported" - ); - assert_eq!( - target_ty.rust_name(), - "u8x16", - "Currently only u16x16 -> u8x16 narrowing is supported" - ); - // WASM SIMD only has saturating narrowing instructions, so we emulate - // truncated narrowing by masking out the - quote! { - #method_sig { - let mask = u16x8_splat(0xFF); - let (low, high) = self.split_u16x16(a); - let low_masked = v128_and(low.into(), mask); - let high_masked = v128_and(high.into(), mask); - let result = u8x16_narrow_i16x8(low_masked, high_masked); - result.simd_into(self) - } - } - } - _ => unimplemented!(), + } + OpSig::MaskReduce { + quantifier, + condition, + } => { + let (intrinsic, negate) = match (quantifier, condition) { + (Quantifier::Any, true) => (v128_intrinsic("any_true"), None), + (Quantifier::Any, false) => { + (simple_intrinsic("all_true", vec_ty), Some(quote! { ! })) } - } - OpSig::MaskReduce { - quantifier, - condition, - } => { - let (intrinsic, negate) = match (quantifier, condition) { - (Quantifier::Any, true) => (v128_intrinsic("any_true"), None), - (Quantifier::Any, false) => { - (simple_intrinsic("all_true", vec_ty), Some(quote! { ! })) - } - (Quantifier::All, true) => (simple_intrinsic("all_true", vec_ty), None), - (Quantifier::All, false) => { - (v128_intrinsic("any_true"), Some(quote! { ! })) - } - }; + (Quantifier::All, true) => (simple_intrinsic("all_true", vec_ty), None), + (Quantifier::All, false) => (v128_intrinsic("any_true"), Some(quote! { ! })), + }; - quote! { - #method_sig { - #negate #intrinsic(a.into()) - } + quote! { + #method_sig { + #negate #intrinsic(a.into()) } } - OpSig::LoadInterleaved { - block_size, - block_count, - } => { - assert_eq!(block_count, 4, "only count of 4 is currently supported"); - let elems_per_vec = block_size as usize / vec_ty.scalar_bits; - - // For WASM we need to simulate interleaving with shuffle, and we only have - // access to 2, 4 and 16 lanes. So, for 64 u8's, we need to split and recombine - // the vectors. - let (i1, i2, i3, i4, shuffle_fn) = match vec_ty.scalar_bits { - 8 => ( - quote! { 0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29 }, - quote! { 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31 }, - quote! { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }, - quote! { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }, - quote! { u8x16_shuffle }, - ), - 16 => ( - quote! { 0, 4, 8, 12, 1, 5, 9, 13 }, - quote! { 2, 6, 10, 14, 3, 7, 11, 15 }, - quote! { 0, 1, 2, 3, 8, 9, 10, 11 }, - quote! { 4, 5, 6, 7, 12, 13, 14, 15 }, - quote! { u16x8_shuffle }, - ), - 32 => ( - quote! { 0, 4, 1, 5 }, - quote! { 2, 6, 3, 7 }, - quote! { 0, 1, 4, 5 }, - quote! { 2, 3, 6, 7 }, - quote! { u32x4_shuffle }, - ), - _ => panic!("unsupported scalar_bits"), - }; - - let block_ty = vec_ty.block_ty(); - let block_ty_2x = - VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 2); - - let combine_method = generic_op_name("combine", &block_ty); - let combine_method_2x = generic_op_name("combine", &block_ty_2x); - - let combine_code = quote! { - let combined_lower = self.#combine_method(out0.simd_into(self), out1.simd_into(self)); - let combined_upper = self.#combine_method(out2.simd_into(self), out3.simd_into(self)); - self.#combine_method_2x(combined_lower, combined_upper) - }; - - quote! { - #method_sig { - let v0: v128 = unsafe { v128_load(src[0 * #elems_per_vec..].as_ptr() as *const v128) }; - let v1: v128 = unsafe { v128_load(src[1 * #elems_per_vec..].as_ptr() as *const v128) }; - let v2: v128 = unsafe { v128_load(src[2 * #elems_per_vec..].as_ptr() as *const v128) }; - let v3: v128 = unsafe { v128_load(src[3 * #elems_per_vec..].as_ptr() as *const v128) }; - - // InterleaveLowerLanes(v0, v2) and InterleaveLowerLanes(v1, v3) - let v01_lower = #shuffle_fn::<#i1>(v0, v1); - let v23_lower = #shuffle_fn::<#i1>(v2, v3); - - // InterleaveUpperLanes(v0, v2) and InterleaveUpperLanes(v1, v3) - let v01_upper = #shuffle_fn::<#i2>(v0, v1); - let v23_upper = #shuffle_fn::<#i2>(v2, v3); - - // Interleave lower and upper to get final result - let out0 = #shuffle_fn::<#i3>(v01_lower, v23_lower); - let out1 = #shuffle_fn::<#i4>(v01_lower, v23_lower); - let out2 = #shuffle_fn::<#i3>(v01_upper, v23_upper); - let out3 = #shuffle_fn::<#i4>(v01_upper, v23_upper); - - #combine_code - } - } - } - OpSig::StoreInterleaved { - block_size, - block_count, - } => { - assert_eq!(block_count, 4, "only count of 4 is currently supported"); - let elems_per_vec = block_size as usize / vec_ty.scalar_bits; - - let (lower_indices, upper_indices, shuffle_fn) = match vec_ty.scalar_bits { - 8 => ( - quote! { 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23 }, - quote! { 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 }, - quote! { u8x16_shuffle }, - ), - 16 => ( - quote! { 0, 8, 1, 9, 2, 10, 3, 11 }, - quote! { 4, 12, 5, 13, 6, 14, 7, 15 }, - quote! { u16x8_shuffle }, - ), - 32 => ( - quote! { 0, 4, 1, 5 }, - quote! { 2, 6, 3, 7 }, - quote! { u32x4_shuffle }, - ), - _ => panic!("unsupported scalar_bits"), - }; - - let block_ty = vec_ty.block_ty(); - let block_ty_2x = - VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 2); - let block_ty_4x = - VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 4); - - let split_method = generic_op_name("split", &block_ty_2x); - let split_method_2x = generic_op_name("split", &block_ty_4x); - - let split_code = quote! { - let (lower, upper) = self.#split_method_2x(a); - let (v0_vec, v1_vec) = self.#split_method(lower); - let (v2_vec, v3_vec) = self.#split_method(upper); - - let v0: v128 = v0_vec.into(); - let v1: v128 = v1_vec.into(); - let v2: v128 = v2_vec.into(); - let v3: v128 = v3_vec.into(); - }; - - quote! { - #method_sig { - #split_code + } + OpSig::LoadInterleaved { + block_size, + block_count, + } => { + assert_eq!(block_count, 4, "only count of 4 is currently supported"); + let elems_per_vec = block_size as usize / vec_ty.scalar_bits; + + // For WASM we need to simulate interleaving with shuffle, and we only have + // access to 2, 4 and 16 lanes. So, for 64 u8's, we need to split and recombine + // the vectors. + let (i1, i2, i3, i4, shuffle_fn) = match vec_ty.scalar_bits { + 8 => ( + quote! { 0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29 }, + quote! { 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31 }, + quote! { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }, + quote! { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }, + quote! { u8x16_shuffle }, + ), + 16 => ( + quote! { 0, 4, 8, 12, 1, 5, 9, 13 }, + quote! { 2, 6, 10, 14, 3, 7, 11, 15 }, + quote! { 0, 1, 2, 3, 8, 9, 10, 11 }, + quote! { 4, 5, 6, 7, 12, 13, 14, 15 }, + quote! { u16x8_shuffle }, + ), + 32 => ( + quote! { 0, 4, 1, 5 }, + quote! { 2, 6, 3, 7 }, + quote! { 0, 1, 4, 5 }, + quote! { 2, 3, 6, 7 }, + quote! { u32x4_shuffle }, + ), + _ => panic!("unsupported scalar_bits"), + }; + + let block_ty = vec_ty.block_ty(); + let block_ty_2x = + VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 2); + + let combine_method = generic_op_name("combine", &block_ty); + let combine_method_2x = generic_op_name("combine", &block_ty_2x); + + let combine_code = quote! { + let combined_lower = self.#combine_method(out0.simd_into(self), out1.simd_into(self)); + let combined_upper = self.#combine_method(out2.simd_into(self), out3.simd_into(self)); + self.#combine_method_2x(combined_lower, combined_upper) + }; + + quote! { + #method_sig { + let v0: v128 = unsafe { v128_load(src[0 * #elems_per_vec..].as_ptr() as *const v128) }; + let v1: v128 = unsafe { v128_load(src[1 * #elems_per_vec..].as_ptr() as *const v128) }; + let v2: v128 = unsafe { v128_load(src[2 * #elems_per_vec..].as_ptr() as *const v128) }; + let v3: v128 = unsafe { v128_load(src[3 * #elems_per_vec..].as_ptr() as *const v128) }; // InterleaveLowerLanes(v0, v2) and InterleaveLowerLanes(v1, v3) - let v02_lower = #shuffle_fn::<#lower_indices>(v0, v2); - let v13_lower = #shuffle_fn::<#lower_indices>(v1, v3); + let v01_lower = #shuffle_fn::<#i1>(v0, v1); + let v23_lower = #shuffle_fn::<#i1>(v2, v3); // InterleaveUpperLanes(v0, v2) and InterleaveUpperLanes(v1, v3) - let v02_upper = #shuffle_fn::<#upper_indices>(v0, v2); - let v13_upper = #shuffle_fn::<#upper_indices>(v1, v3); + let v01_upper = #shuffle_fn::<#i2>(v0, v1); + let v23_upper = #shuffle_fn::<#i2>(v2, v3); // Interleave lower and upper to get final result - let out0 = #shuffle_fn::<#lower_indices>(v02_lower, v13_lower); - let out1 = #shuffle_fn::<#upper_indices>(v02_lower, v13_lower); - let out2 = #shuffle_fn::<#lower_indices>(v02_upper, v13_upper); - let out3 = #shuffle_fn::<#upper_indices>(v02_upper, v13_upper); - - unsafe { - v128_store(dest[0 * #elems_per_vec..].as_mut_ptr() as *mut v128, out0); - v128_store(dest[1 * #elems_per_vec..].as_mut_ptr() as *mut v128, out1); - v128_store(dest[2 * #elems_per_vec..].as_mut_ptr() as *mut v128, out2); - v128_store(dest[3 * #elems_per_vec..].as_mut_ptr() as *mut v128, out3); - } - } - } - } - OpSig::FromArray { kind } => { - generic_from_array(method_sig, vec_ty, kind, 128, |_| v128_intrinsic("load")) - } - OpSig::AsArray { kind } => generic_as_array(method_sig, vec_ty, kind, 128, |_| { - Ident::new("v128", Span::call_site()) - }), - OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), - OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), - }; + let out0 = #shuffle_fn::<#i3>(v01_lower, v23_lower); + let out1 = #shuffle_fn::<#i4>(v01_lower, v23_lower); + let out2 = #shuffle_fn::<#i3>(v01_upper, v23_upper); + let out3 = #shuffle_fn::<#i4>(v01_upper, v23_upper); - methods.push(m); - } - } - - quote! { - impl Simd for #level_tok { - type f32s = f32x4; - type f64s = f64x2; - type u8s = u8x16; - type i8s = i8x16; - type u16s = u16x8; - type i16s = i16x8; - type u32s = u32x4; - type i32s = i32x4; - type mask8s = mask8x16; - type mask16s = mask16x8; - type mask32s = mask32x4; - type mask64s = mask64x2; - - #[inline(always)] - fn level(self) -> Level { - Level::#level_tok(self) - } - - #[inline] - fn vectorize R, R>(self, f: F) -> R { - #[inline] - // unsafe not needed here with tf11, but can be justified - unsafe fn vectorize_simd128 R, R>(f: F) -> R { - f() + #combine_code + } } - unsafe { vectorize_simd128(f) } } - - #( #methods )* - } - } -} - -pub(crate) fn mk_wasm128_impl(level: Level) -> TokenStream { - let imports = type_imports(); - let arch_types_impl = - impl_arch_types(level.name(), 128, |_| Ident::new("v128", Span::call_site())); - let simd_impl = mk_simd_impl(level); - let ty_impl = mk_type_impl(); - let level_tok = level.token(); - - quote! { - use core::arch::wasm32::*; - - use crate::{seal::Seal, arch_types::ArchTypes, Level, Simd, SimdFrom, SimdInto}; - - #imports - - /// The SIMD token for the "wasm128" level. - #[derive(Clone, Copy, Debug)] - pub struct #level_tok { - pub wasmsimd128: crate::core_arch::wasm32::WasmSimd128, - } - - impl #level_tok { - // TODO: this can be renamed to `new` like with `Fallback`. - #[inline] - pub const fn new_unchecked() -> Self { - Self { wasmsimd128: crate::core_arch::wasm32::WasmSimd128::new() } - } - } - - impl Seal for #level_tok {} - - #arch_types_impl - - #simd_impl - - #ty_impl - } -} - -fn mk_type_impl() -> TokenStream { - let mut result = vec![]; - for ty in SIMD_TYPES { - if ty.n_bits() != 128 { - continue; - } - let simd = ty.rust(); - result.push(quote! { - impl SimdFrom for #simd { - #[inline(always)] - fn simd_from(arch: v128, simd: S) -> Self { - Self { - val: unsafe { core::mem::transmute_copy(&arch) }, - simd + OpSig::StoreInterleaved { + block_size, + block_count, + } => { + assert_eq!(block_count, 4, "only count of 4 is currently supported"); + let elems_per_vec = block_size as usize / vec_ty.scalar_bits; + + let (lower_indices, upper_indices, shuffle_fn) = match vec_ty.scalar_bits { + 8 => ( + quote! { 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23 }, + quote! { 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 }, + quote! { u8x16_shuffle }, + ), + 16 => ( + quote! { 0, 8, 1, 9, 2, 10, 3, 11 }, + quote! { 4, 12, 5, 13, 6, 14, 7, 15 }, + quote! { u16x8_shuffle }, + ), + 32 => ( + quote! { 0, 4, 1, 5 }, + quote! { 2, 6, 3, 7 }, + quote! { u32x4_shuffle }, + ), + _ => panic!("unsupported scalar_bits"), + }; + + let block_ty = vec_ty.block_ty(); + let block_ty_2x = + VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 2); + let block_ty_4x = + VecType::new(block_ty.scalar, block_ty.scalar_bits, block_ty.len * 4); + + let split_method = generic_op_name("split", &block_ty_2x); + let split_method_2x = generic_op_name("split", &block_ty_4x); + + let split_code = quote! { + let (lower, upper) = self.#split_method_2x(a); + let (v0_vec, v1_vec) = self.#split_method(lower); + let (v2_vec, v3_vec) = self.#split_method(upper); + + let v0: v128 = v0_vec.into(); + let v1: v128 = v1_vec.into(); + let v2: v128 = v2_vec.into(); + let v3: v128 = v3_vec.into(); + }; + + quote! { + #method_sig { + #split_code + + // InterleaveLowerLanes(v0, v2) and InterleaveLowerLanes(v1, v3) + let v02_lower = #shuffle_fn::<#lower_indices>(v0, v2); + let v13_lower = #shuffle_fn::<#lower_indices>(v1, v3); + + // InterleaveUpperLanes(v0, v2) and InterleaveUpperLanes(v1, v3) + let v02_upper = #shuffle_fn::<#upper_indices>(v0, v2); + let v13_upper = #shuffle_fn::<#upper_indices>(v1, v3); + + // Interleave lower and upper to get final result + let out0 = #shuffle_fn::<#lower_indices>(v02_lower, v13_lower); + let out1 = #shuffle_fn::<#upper_indices>(v02_lower, v13_lower); + let out2 = #shuffle_fn::<#lower_indices>(v02_upper, v13_upper); + let out3 = #shuffle_fn::<#upper_indices>(v02_upper, v13_upper); + + unsafe { + v128_store(dest[0 * #elems_per_vec..].as_mut_ptr() as *mut v128, out0); + v128_store(dest[1 * #elems_per_vec..].as_mut_ptr() as *mut v128, out1); + v128_store(dest[2 * #elems_per_vec..].as_mut_ptr() as *mut v128, out2); + v128_store(dest[3 * #elems_per_vec..].as_mut_ptr() as *mut v128, out3); + } } } } - impl From<#simd> for v128 { - #[inline(always)] - fn from(value: #simd) -> Self { - unsafe { core::mem::transmute_copy(&value.val) } - } + OpSig::FromArray { kind } => { + generic_from_array(method_sig, vec_ty, kind, self.max_block_size(), |_| { + v128_intrinsic("load") + }) } - }); - } - quote! { - #( #result )* + OpSig::AsArray { kind } => { + generic_as_array(method_sig, vec_ty, kind, self.max_block_size(), |_| { + Ident::new("v128", Span::call_site()) + }) + } + OpSig::FromBytes => generic_from_bytes(method_sig, vec_ty), + OpSig::ToBytes => generic_to_bytes(method_sig, vec_ty), + } } } diff --git a/fearless_simd_gen/src/types.rs b/fearless_simd_gen/src/types.rs index 327c1ecf..0f12c8b9 100644 --- a/fearless_simd_gen/src/types.rs +++ b/fearless_simd_gen/src/types.rs @@ -36,6 +36,16 @@ impl ScalarType { let ident = Ident::new(&self.rust_name(scalar_bits), Span::call_site()); quote! { #ident } } + + pub(crate) fn native_width_name(&self, scalar_bits: usize) -> Ident { + let prefix = match self { + Self::Float => "f", + Self::Unsigned => "u", + Self::Int => "i", + Self::Mask => "mask", + }; + format_ident!("{}{}s", prefix, scalar_bits) + } } impl VecType { @@ -79,7 +89,7 @@ impl VecType { /// array of them. pub(crate) fn wrapped_native_ty( &self, - arch_ty: impl Fn(&Self) -> Ident, + arch_ty: impl Fn(&Self) -> TokenStream, max_block_size: usize, ) -> TokenStream { let block_size = self.n_bits().min(max_block_size); @@ -97,7 +107,7 @@ impl VecType { /// Returns the full type name for this vector's `Aligned` wrapper, including the type parameter. pub(crate) fn aligned_wrapper_ty( &self, - arch_ty: impl Fn(&Self) -> Ident, + arch_ty: impl Fn(&Self) -> TokenStream, max_block_size: usize, ) -> TokenStream { let newtype = self.aligned_wrapper();