Skip to content

Commit

Permalink
add support for leading_zeros and trailing_zeros, panics during linking
Browse files Browse the repository at this point in the history
  • Loading branch information
Firestar99 committed Jan 29, 2025
1 parent 6e2c84d commit 702cb97
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 60 deletions.
33 changes: 3 additions & 30 deletions crates/rustc_codegen_spirv/src/builder/ext_inst.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::custom_insts;
use rspirv::dr::Operand;
use rspirv::spirv::{GLOp, Word};
use rspirv::{dr::Operand, spirv::Capability};

const GLSL_STD_450: &str = "GLSL.std.450";

Expand All @@ -13,7 +13,6 @@ pub struct ExtInst {
custom: Option<Word>,

glsl: Option<Word>,
integer_functions_2_intel: bool,
}

impl ExtInst {
Expand All @@ -38,32 +37,11 @@ impl ExtInst {
id
}
}

pub fn require_integer_functions_2_intel(&mut self, bx: &Builder<'_, '_>, to_zombie: Word) {
if !self.integer_functions_2_intel {
self.integer_functions_2_intel = true;
if !bx
.builder
.has_capability(Capability::IntegerFunctions2INTEL)
{
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
}
if !bx
.builder
.has_extension(bx.sym.spv_intel_shader_integer_functions2)
{
bx.zombie(
to_zombie,
"extension SPV_INTEL_shader_integer_functions2 is required",
);
}
}
}
}

impl<'a, 'tcx> Builder<'a, 'tcx> {
pub fn custom_inst(
&mut self,
&self,
result_type: Word,
inst: custom_insts::CustomInst<Operand>,
) -> SpirvValue {
Expand All @@ -80,12 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.with_type(result_type)
}

pub fn gl_op(
&mut self,
op: GLOp,
result_type: Word,
args: impl AsRef<[SpirvValue]>,
) -> SpirvValue {
pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
let args = args.as_ref();
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
self.emit()
Expand Down
74 changes: 48 additions & 26 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::codegen_cx::CodegenCx;
use crate::custom_insts::CustomInst;
use crate::spirv_type::SpirvType;
use rspirv::dr::Operand;
use rspirv::spirv::GLOp;
use rspirv::spirv::{GLOp, Word};
use rustc_codegen_ssa::mir::operand::OperandRef;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallBuilderMethods};
Expand Down Expand Up @@ -211,34 +211,11 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
self.rotate(val, shift, is_left)
}

// TODO: Do we want to manually implement these instead of using intel instructions?
sym::ctlz | sym::ctlz_nonzero => {
let result = self
.emit()
.u_count_leading_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), false)
}
sym::cttz | sym::cttz_nonzero => {
let result = self
.emit()
.u_count_trailing_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), true)
}

sym::ctpop => self
Expand Down Expand Up @@ -398,6 +375,51 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
}

impl Builder<'_, '_> {
pub fn count_leading_trailing_zeros(
&self,
ret_ty: Word,
arg: SpirvValue,
trailing: bool,
) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, _) => {
let int_0 = self.constant_int(ty, 0);
let int_bits = self.constant_int(ret_ty, bits as u128).def(self);
let glsl = self.ext_inst.borrow_mut().import_glsl(self);

let mut emit = self.emit();
let is_0 = emit
.i_equal(ty, None, arg.def(self), int_0.def(self))
.unwrap();
let end_label = emit.id();
let xsb_label = emit.id();
emit.branch_conditional(is_0, end_label, xsb_label, [])
.unwrap();

emit.begin_block(Some(xsb_label)).unwrap();
// rust is always unsigned
let gl_op = if trailing {
GLOp::FindILsb
} else {
GLOp::FindUMsb
};
let find_xsb = emit
.ext_inst(ret_ty, None, glsl, gl_op as u32, [Operand::IdRef(
arg.def(self),
)])
.unwrap();
emit.branch(end_label).unwrap();

emit.begin_block(Some(end_label)).unwrap();
emit.phi(ret_ty, None, [(end_label, int_bits), (xsb_label, find_xsb)])
.unwrap()
.with_type(ret_ty)
}
_ => self.fatal("counting leading / trailing zeros on a non-integer type"),
}
}

pub fn abort_with_kind_and_message_debug_printf(
&mut self,
kind: &str,
Expand Down
4 changes: 0 additions & 4 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub struct Symbols {
pub spirv: Symbol,
pub libm: Symbol,
pub entry_point_name: Symbol,
pub spv_intel_shader_integer_functions2: Symbol,
pub spv_khr_vulkan_memory_model: Symbol,

descriptor_set: Symbol,
Expand Down Expand Up @@ -411,9 +410,6 @@ impl Symbols {
spirv: Symbol::intern("spirv"),
libm: Symbol::intern("libm"),
entry_point_name: Symbol::intern("entry_point_name"),
spv_intel_shader_integer_functions2: Symbol::intern(
"SPV_INTEL_shader_integer_functions2",
),
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),

descriptor_set: Symbol::intern("descriptor_set"),
Expand Down
37 changes: 37 additions & 0 deletions tests/ui/lang/bitcount/trailing_leading_bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Test all trailing and leading zeros. No need to test ones, they just call the zero variant with !value

// build-pass

use spirv_std::spirv;

#[spirv(fragment)]
pub fn leading_zeros_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
out: &mut u32,
) {
*out = u32::leading_zeros(*buffer);
}

#[spirv(fragment)]
pub fn trailing_zeros_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
out: &mut u32,
) {
*out = u32::trailing_zeros(*buffer);
}

#[spirv(fragment)]
pub fn leading_zeros_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
out: &mut u32,
) {
*out = i32::leading_zeros(*buffer);
}

#[spirv(fragment)]
pub fn trailing_zeros_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
out: &mut u32,
) {
*out = i32::trailing_zeros(*buffer);
}

0 comments on commit 702cb97

Please sign in to comment.