Skip to content

Non uniform for everything! #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rspirv::spirv::{Decoration, Word};
use rustc_codegen_spirv_types::Capability;
use rustc_codegen_ssa::traits::BuilderMethods;
use rustc_errors::ErrorGuaranteed;
use rustc_span::DUMMY_SP;
Expand Down Expand Up @@ -41,11 +42,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let array = array.def(self);
let actual_index = actual_index.def(self);
let ptr = self
.emit()
.in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
.in_bounds_access_chain(u32_ptr, None, array, [actual_index])
.unwrap()
.with_type(u32_ptr);
if self.builder.has_capability(Capability::ShaderNonUniform) {
// apply NonUniform to the operation and the index
self.emit()
.decorate(ptr.def(self), Decoration::NonUniform, []);
self.emit()
.decorate(actual_index, Decoration::NonUniform, []);
}
self.load(u32_ty, ptr, Align::ONE)
}

Expand Down Expand Up @@ -233,11 +243,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let array = array.def(self);
let actual_index = actual_index.def(self);
let ptr = self
.emit()
.in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
.in_bounds_access_chain(u32_ptr, None, array, [actual_index])
.unwrap()
.with_type(u32_ptr);
if self.builder.has_capability(Capability::ShaderNonUniform) {
// apply NonUniform to the operation and the index
self.emit()
.decorate(ptr.def(self), Decoration::NonUniform, []);
self.emit()
.decorate(actual_index, Decoration::NonUniform, []);
}
self.store(value, ptr, Align::ONE);
Ok(())
}
Expand Down
64 changes: 35 additions & 29 deletions crates/spirv-std/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,16 @@ impl<
) where
I: Integer,
{
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%texels = OpLoad _ {texels}",
"OpImageWrite %image %coordinate %texels",
this = in(reg) self,
coordinate = in(reg) &coordinate,
texels = in(reg) &texels,
unsafe {
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%texels = OpLoad _ {texels}",
"OpImageWrite %image %coordinate %texels",
this = in(reg) self,
coordinate = in(reg) &coordinate,
texels = in(reg) &texels,
}
}
}
}
Expand Down Expand Up @@ -802,14 +804,16 @@ impl<
) where
I: Integer,
{
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%texels = OpLoad _ {texels}",
"OpImageWrite %image %coordinate %texels",
this = in(reg) self,
coordinate = in(reg) &coordinate,
texels = in(reg) &texels,
unsafe {
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%texels = OpLoad _ {texels}",
"OpImageWrite %image %coordinate %texels",
this = in(reg) self,
coordinate = in(reg) &coordinate,
texels = in(reg) &texels,
}
}
}
}
Expand Down Expand Up @@ -848,13 +852,13 @@ impl<

unsafe {
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%result = OpImageRead typeof*{result} %image %coordinate",
"OpStore {result} %result",
this = in(reg) self,
coordinate = in(reg) &coordinate,
result = in(reg) &mut result,
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%result = OpImageRead typeof*{result} %image %coordinate",
"OpStore {result} %result",
this = in(reg) self,
coordinate = in(reg) &coordinate,
result = in(reg) &mut result,
}
}

Expand All @@ -880,13 +884,14 @@ impl<
where
Self: HasQueryLevels,
{
let result: u32;
let mut result = Default::default();
unsafe {
asm! {
"%image = OpLoad _ {this}",
"{result} = OpImageQueryLevels typeof{result} %image",
"%result = OpImageQueryLevels typeof*{result} %image",
"OpStore {result} %result",
this = in(reg) self,
result = out(reg) result,
result = in(reg) &mut result,
}
}
result
Expand Down Expand Up @@ -1019,13 +1024,14 @@ impl<
#[crate::macros::gpu_only]
#[doc(alias = "OpImageQuerySamples")]
pub fn query_samples(&self) -> u32 {
let result: u32;
let mut result = Default::default();
unsafe {
asm! {
"%image = OpLoad _ {this}",
"{result} = OpImageQuerySamples typeof{result} %image",
"%result = OpImageQuerySamples typeof*{result} %image",
"OpStore {result} %result",
this = in(reg) self,
result = out(reg) result,
result = in(reg) &mut result,
}
}
result
Expand Down
29 changes: 29 additions & 0 deletions tests/compiletests/ui/dis/non_uniform_load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// build-pass
// compile-flags: -Ctarget-feature=+ShaderNonUniform,+ext:SPV_EXT_descriptor_indexing
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""

use spirv_std::{ByteAddressableBuffer, RuntimeArray, TypedBuffer, spirv};

pub struct BigStruct {
a: u32,
b: u32,
c: u32,
d: u32,
e: u32,
f: u32,
}

#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &RuntimeArray<
TypedBuffer<[u32]>,
>,
#[spirv(flat)] index_in: u32,
out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf.index(index_in as usize));
*out = buf.load(5);
}
}
149 changes: 149 additions & 0 deletions tests/compiletests/ui/dis/non_uniform_load.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 69
OpCapability Shader
OpCapability Float64
OpCapability Int64
OpCapability Int16
OpCapability Int8
OpCapability ShaderClockKHR
OpCapability ShaderNonUniform
OpExtension "SPV_EXT_descriptor_indexing"
OpExtension "SPV_KHR_shader_clock"
OpMemoryModel Logical Simple
OpEntryPoint Fragment %1 "main" %2 %3
OpExecutionMode %1 OriginUpperLeft
%4 = OpString "/home/firestar99/workspace/frameworks/rust-gpu/crates/spirv-std/src/byte_addressable_buffer.rs"
%5 = OpString "/home/firestar99/workspace/frameworks/rust-gpu/crates/spirv-std/src/runtime_array.rs"
%6 = OpString "/home/firestar99/workspace/frameworks/rust-gpu/crates/spirv-std/src/typed_buffer.rs"
%7 = OpString "$DIR/non_uniform_load.rs"
OpName %8 "BigStruct"
OpMemberName %8 0 "a"
OpMemberName %8 1 "b"
OpMemberName %8 2 "c"
OpMemberName %8 3 "d"
OpMemberName %8 4 "e"
OpMemberName %8 5 "f"
OpName %2 "index_in"
OpName %9 "buf"
OpName %3 "out"
OpMemberDecorate %8 0 Offset 0
OpMemberDecorate %8 1 Offset 4
OpMemberDecorate %8 2 Offset 8
OpMemberDecorate %8 3 Offset 12
OpMemberDecorate %8 4 Offset 16
OpMemberDecorate %8 5 Offset 20
OpDecorate %2 Flat
OpDecorate %2 Location 0
OpDecorate %10 ArrayStride 4
OpDecorate %11 Block
OpMemberDecorate %11 0 Offset 0
OpDecorate %9 NonWritable
OpDecorate %9 Binding 0
OpDecorate %9 DescriptorSet 0
OpDecorate %3 Location 0
OpDecorate %12 NonUniform
OpDecorate %13 NonUniform
OpDecorate %14 NonUniform
OpDecorate %15 NonUniform
OpDecorate %16 NonUniform
OpDecorate %17 NonUniform
OpDecorate %18 NonUniform
OpDecorate %19 NonUniform
OpDecorate %20 NonUniform
OpDecorate %21 NonUniform
OpDecorate %22 NonUniform
OpDecorate %23 NonUniform
%24 = OpTypeInt 32 0
%25 = OpTypePointer Input %24
%8 = OpTypeStruct %24 %24 %24 %24 %24 %24
%26 = OpTypePointer Output %8
%27 = OpTypeVoid
%28 = OpTypeFunction %27
%2 = OpVariable %25 Input
%10 = OpTypeRuntimeArray %24
%11 = OpTypeStruct %10
%29 = OpTypePointer StorageBuffer %11
%30 = OpTypeRuntimeArray %11
%31 = OpTypePointer StorageBuffer %30
%9 = OpVariable %31 StorageBuffer
%32 = OpTypePointer StorageBuffer %10
%33 = OpConstant %24 0
%34 = OpTypeBool
%35 = OpConstant %24 4
%36 = OpConstant %24 5
%37 = OpConstant %24 24
%38 = OpConstant %24 2
%39 = OpTypePointer StorageBuffer %24
%40 = OpConstant %24 1
%41 = OpConstant %24 3
%3 = OpVariable %26 Output
%1 = OpFunction %27 None %28
%42 = OpLabel
OpLine %7 22 4
%43 = OpLoad %24 %2
OpLine %5 35 8
%44 = OpAccessChain %29 %9 %43
OpLine %6 69 12
%45 = OpAccessChain %32 %44 %33
%46 = OpArrayLength %24 %44 0
OpLine %4 70 7
%47 = OpIEqual %34 %35 %33
OpNoLine
OpSelectionMerge %48 None
OpBranchConditional %47 %49 %50
%49 = OpLabel
OpReturn
%50 = OpLabel
OpBranch %48
%48 = OpLabel
OpLine %4 70 7
%51 = OpUMod %24 %36 %35
%52 = OpIEqual %34 %51 %33
OpNoLine
OpSelectionMerge %53 None
OpBranchConditional %52 %54 %55
%54 = OpLabel
OpBranch %53
%55 = OpLabel
OpReturn
%53 = OpLabel
OpLine %4 74 14
%56 = OpIMul %24 %46 %35
OpLine %4 75 7
%57 = OpIAdd %24 %36 %37
%58 = OpUGreaterThan %34 %57 %56
OpNoLine
OpSelectionMerge %59 None
OpBranchConditional %58 %60 %61
%60 = OpLabel
OpReturn
%61 = OpLabel
OpBranch %59
%59 = OpLabel
OpLine %4 97 8
%12 = OpShiftRightLogical %24 %36 %38
%13 = OpInBoundsAccessChain %39 %45 %12
%62 = OpLoad %24 %13
%14 = OpIAdd %24 %12 %40
%15 = OpInBoundsAccessChain %39 %45 %14
%63 = OpLoad %24 %15
%16 = OpIAdd %24 %12 %38
%17 = OpInBoundsAccessChain %39 %45 %16
%64 = OpLoad %24 %17
%18 = OpIAdd %24 %12 %41
%19 = OpInBoundsAccessChain %39 %45 %18
%65 = OpLoad %24 %19
%20 = OpIAdd %24 %12 %35
%21 = OpInBoundsAccessChain %39 %45 %20
%66 = OpLoad %24 %21
%22 = OpIAdd %24 %12 %36
%23 = OpInBoundsAccessChain %39 %45 %22
%67 = OpLoad %24 %23
%68 = OpCompositeConstruct %8 %62 %63 %64 %65 %66 %67
OpLine %7 27 8
OpStore %3 %68
OpNoLine
OpReturn
OpFunctionEnd
29 changes: 29 additions & 0 deletions tests/compiletests/ui/dis/non_uniform_load_mut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// build-pass
// compile-flags: -Ctarget-feature=+ShaderNonUniform,+ext:SPV_EXT_descriptor_indexing
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""

use spirv_std::{ByteAddressableBuffer, RuntimeArray, TypedBuffer, spirv};

pub struct BigStruct {
a: u32,
b: u32,
c: u32,
d: u32,
e: u32,
f: u32,
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut RuntimeArray<
TypedBuffer<[u32]>,
>,
#[spirv(flat)] index_in: u32,
out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::from_mut_slice(buf.index_mut(index_in as usize));
*out = buf.load(5);
}
}
Loading
Loading