-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[WebGPU] Add gating logic for subgroup shuffle primitives #18823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
097d05f
f119bbd
b1e3688
07d011c
3298e94
b1139a9
e9697fe
d95827a
397ac1b
89d6142
9a3edc9
4fb4cce
3c2ab40
53409cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { | |||||||||||||||||||||||||||||||
| bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent, | ||||||||||||||||||||||||||||||||
| int contiguous_reduce_extent) { | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal")) { | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) { | ||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal"; | ||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu"; | ||||||||||||||||||||||||||||||||
|
Comment on lines
744
to
+749
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve maintainability, consider using
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // rocm only supports 32 bit operands for shuffling at the moment | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name == "rocm") && | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() { | |||||||||||||||||||||||
| if (enable_fp16_) { | ||||||||||||||||||||||||
| header_stream << "enable f16;\n\n"; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if (enable_subgroups_) { | ||||||||||||||||||||||||
| header_stream << "enable subgroups;\n\n"; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { | |||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} | ||||||||||||||||||||||||
| CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) { | ||||||||||||||||||||||||
| enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | |
| Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | |
| Optional<Integer> thread_warp_size = target_->GetAttr<Integer>("thread_warp_size"); | |
| bool warp_uses_subgroups = | |
| thread_warp_size.defined() && thread_warp_size.value()->value > 1; | |
| if (warp_uses_subgroups && !supports_subgroups) { | |
| LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size.value()->value | |
| << " but does not support subgroups. Either enable the 'supports_subgroups' " | |
| << "target attribute or set thread_warp_size <= 1."; | |
| } | |
| enable_subgroups_ = supports_subgroups || warp_uses_subgroups; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the following check here:
Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
int64_t thread_warp_size = target_->GetAttr<Integer>("thread_warp_size", 1).value()->value;
if (thread_warp_size > 1 && !supports_subgroups) {
LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size
<< " but supports_subgroups is false.";
}
enable_subgroups_ = supports_subgroups;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a check here of the following form:
This is to avoid scenarios where a target such as
{"kind":"webgpu","thread_warp_size":32,"supports_subgroups":false}would still emit subgroup ops, but the WGSL would not containenable subgroups;.