Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for leading_zeros and trailing_zeros #213

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 3 additions & 30 deletions crates/rustc_codegen_spirv/src/builder/ext_inst.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::custom_insts;
use rspirv::dr::Operand;
use rspirv::spirv::{GLOp, Word};
use rspirv::{dr::Operand, spirv::Capability};

const GLSL_STD_450: &str = "GLSL.std.450";

Expand All @@ -13,7 +13,6 @@ pub struct ExtInst {
custom: Option<Word>,

glsl: Option<Word>,
integer_functions_2_intel: bool,
}

impl ExtInst {
Expand All @@ -38,32 +37,11 @@ impl ExtInst {
id
}
}

pub fn require_integer_functions_2_intel(&mut self, bx: &Builder<'_, '_>, to_zombie: Word) {
if !self.integer_functions_2_intel {
self.integer_functions_2_intel = true;
if !bx
.builder
.has_capability(Capability::IntegerFunctions2INTEL)
{
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
}
if !bx
.builder
.has_extension(bx.sym.spv_intel_shader_integer_functions2)
{
bx.zombie(
to_zombie,
"extension SPV_INTEL_shader_integer_functions2 is required",
);
}
}
}
}

impl<'a, 'tcx> Builder<'a, 'tcx> {
pub fn custom_inst(
&mut self,
&self,
result_type: Word,
inst: custom_insts::CustomInst<Operand>,
) -> SpirvValue {
Expand All @@ -80,12 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.with_type(result_type)
}

pub fn gl_op(
&mut self,
op: GLOp,
result_type: Word,
args: impl AsRef<[SpirvValue]>,
) -> SpirvValue {
pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
let args = args.as_ref();
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
self.emit()
Expand Down
74 changes: 48 additions & 26 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::codegen_cx::CodegenCx;
use crate::custom_insts::CustomInst;
use crate::spirv_type::SpirvType;
use rspirv::dr::Operand;
use rspirv::spirv::GLOp;
use rspirv::spirv::{GLOp, Word};
use rustc_codegen_ssa::mir::operand::OperandRef;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallBuilderMethods};
Expand Down Expand Up @@ -211,34 +211,11 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
self.rotate(val, shift, is_left)
}

// TODO: Do we want to manually implement these instead of using intel instructions?
sym::ctlz | sym::ctlz_nonzero => {
let result = self
.emit()
.u_count_leading_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), false)
}
sym::cttz | sym::cttz_nonzero => {
let result = self
.emit()
.u_count_trailing_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), true)
}

sym::ctpop => self
Expand Down Expand Up @@ -398,6 +375,51 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
}

impl Builder<'_, '_> {
pub fn count_leading_trailing_zeros(
&self,
ret_ty: Word,
arg: SpirvValue,
trailing: bool,
) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, _) => {
let int_0 = self.constant_int(ty, 0);
let int_bits = self.constant_int(ret_ty, bits as u128).def(self);
let glsl = self.ext_inst.borrow_mut().import_glsl(self);

let mut emit = self.emit();
let is_0 = emit
.i_equal(ty, None, arg.def(self), int_0.def(self))
.unwrap();
let end_label = emit.id();
let xsb_label = emit.id();
emit.branch_conditional(is_0, end_label, xsb_label, [])
.unwrap();

emit.begin_block(Some(xsb_label)).unwrap();
// rust is always unsigned
let gl_op = if trailing {
GLOp::FindILsb
} else {
GLOp::FindUMsb
};
let find_xsb = emit
.ext_inst(ret_ty, None, glsl, gl_op as u32, [Operand::IdRef(
arg.def(self),
)])
.unwrap();
emit.branch(end_label).unwrap();

emit.begin_block(Some(end_label)).unwrap();
emit.phi(ret_ty, None, [(end_label, int_bits), (xsb_label, find_xsb)])
.unwrap()
.with_type(ret_ty)
}
_ => self.fatal("counting leading / trailing zeros on a non-integer type"),
}
}

pub fn abort_with_kind_and_message_debug_printf(
&mut self,
kind: &str,
Expand Down
4 changes: 0 additions & 4 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub struct Symbols {
pub spirv: Symbol,
pub libm: Symbol,
pub entry_point_name: Symbol,
pub spv_intel_shader_integer_functions2: Symbol,
pub spv_khr_vulkan_memory_model: Symbol,

descriptor_set: Symbol,
Expand Down Expand Up @@ -411,9 +410,6 @@ impl Symbols {
spirv: Symbol::intern("spirv"),
libm: Symbol::intern("libm"),
entry_point_name: Symbol::intern("entry_point_name"),
spv_intel_shader_integer_functions2: Symbol::intern(
"SPV_INTEL_shader_integer_functions2",
),
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),

descriptor_set: Symbol::intern("descriptor_set"),
Expand Down
37 changes: 37 additions & 0 deletions tests/ui/lang/bitcount/trailing_leading_bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Test all trailing and leading zeros. No need to test ones, they just call the zero variant with !value

// build-pass

use spirv_std::spirv;

#[spirv(fragment)]
pub fn leading_zeros_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
out: &mut u32,
) {
*out = u32::leading_zeros(*buffer);
}

#[spirv(fragment)]
pub fn trailing_zeros_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
out: &mut u32,
) {
*out = u32::trailing_zeros(*buffer);
}

#[spirv(fragment)]
pub fn leading_zeros_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
out: &mut u32,
) {
*out = i32::leading_zeros(*buffer);
}

#[spirv(fragment)]
pub fn trailing_zeros_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
out: &mut u32,
) {
*out = i32::trailing_zeros(*buffer);
}
Loading