diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index f1e5a3cfafe0..f7253a7689ae 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { bool IsWarpReduction(const std::vector& types, int group_extent, int reduce_extent, int contiguous_reduce_extent) { if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && - (target_->kind->name != "metal")) { + (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) { return false; } - need_warp_shuffle_mask_ = target_->kind->name != "metal"; + need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu"; // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 25a691659e6e..73820d2bf380 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() { if (enable_fp16_) { header_stream << "enable f16;\n\n"; } + if (enable_subgroups_) { + header_stream << "enable subgroups;\n\n"; + } return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); } @@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { } } -CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} +CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) { + enable_subgroups_ = target_->GetAttr("supports_subgroups").value_or(Bool(false)); +} runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) { // clear previous generated state. diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index f53d090e586a..d0a541677a6f 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC { // whether enable fp16 bool enable_fp16_{false}; + // whether enable subgroups + bool enable_subgroups_{false}; /*! \brief the header stream for function label and enable directive if any, goes before any other * declaration */ diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index 968df9a579f4..86658d8e28cd 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -32,6 +32,30 @@ namespace intrin { using tirx::FLowerIntrinsic; +// warp-level primitives. Follows implementation in intrin_rule_metal.cc +struct WebGPUWarpIntrinsic { + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tirx.webgpu.subgroup_shuffle"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tirx.webgpu.subgroup_shuffle_up"); + } else { + TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tirx.webgpu.subgroup_shuffle_down"); + } + } +}; + +template +static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { + const CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + PrimExpr lane_or_delta = Cast(DataType::UInt(32, call->args[2].dtype().lanes()), call->args[2]); + ffi::Array webgpu_args{{call->args[1], lane_or_delta}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), webgpu_args); +} + // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions struct ReturnAbs { @@ -113,6 +137,41 @@ TVM_REGISTER_OP("tirx.trunc") // extra dispatch TVM_REGISTER_OP("tirx.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); +// warp-level primitives. Follows implementation in intrin_rule_metal.cc +TVM_REGISTER_OP("tirx.tvm_warp_shuffle") + .set_attr("webgpu.FLowerIntrinsic", + DispatchWebGPUShuffle); + +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up") + .set_attr("webgpu.FLowerIntrinsic", + DispatchWebGPUShuffle); + +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") + .set_attr("webgpu.FLowerIntrinsic", + DispatchWebGPUShuffle); + +// Register low-level builtin ops. +TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TGlobalSymbol", "subgroupShuffle") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TGlobalSymbol", "subgroupShuffleUp") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TGlobalSymbol", "subgroupShuffleDown") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 8c328dd9cff6..f61aad2f09b9 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -427,8 +427,41 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) // Tags .set_default_keys({"vulkan", "gpu"}); +/*! + * \brief Update WebGPU target attributes for subgroup-enabled lowering. + * Runtime routing on the WebLLM side guarantees subgroup size == 32. + * Runtime routing on the WebLLM side guarantees + * maxComputeInvocationsPerWorkgroup >= 1024. + * This is intentionally constrained for the subgroup-enabled WASM variant. + * When supports_subgroups is true, canonicalize thread_warp_size to 32 so + * TIR lowering can emit subgroup shuffle reductions. + */ +ffi::Map UpdateWebGPUAttrs(ffi::Map target) { + bool subgroups = false; + if (target.count("supports_subgroups")) { + subgroups = Downcast(target.at("supports_subgroups")); + } + + if (target.count("thread_warp_size")) { + int64_t thread_warp_size = Downcast(target.at("thread_warp_size"))->value; + TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1) + << "WebGPU target with thread_warp_size=" << thread_warp_size + << " requires supports_subgroups=true"; + } + + if (subgroups) { + target.Set("thread_warp_size", int64_t(32)); + } + return target; +} + TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("supports_subgroups", refl::DefaultValue(false)) + // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no + // subgroup ops are emitted. + .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + .set_target_canonicalizer(UpdateWebGPUAttrs) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py index aff3376052bf..558d67cadee6 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py @@ -406,5 +406,106 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32 assert "tvm_storage_sync" in After_script +def test_webgpu_warp_reduce(): + transform = tvm.s_tir.transform.LowerThreadAllreduce() + + @I.ir_module + class Before: + @T.prim_func(private=True) + def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "webgpu", + "supports_subgroups": True, + "host": "llvm", + } + ), + } + ) + A_flat = T.decl_buffer(4096, data=A.data) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce_data = T.alloc_buffer((1,), "float32", scope="local") + reduce = T.decl_buffer(1, data=reduce_data.data, scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B[i] = reduce[0] + + After = transform(Before) + assert After is not None + After_script = After.script() + assert "tvm_warp_shuffle_down" in After_script + assert "tvm_warp_shuffle(" in After_script + assert "tvm_storage_sync" not in After_script + assert "T.uint32(" not in After_script + + +def test_webgpu_multi_warp_reduce(): + transform = tvm.s_tir.transform.LowerThreadAllreduce() + + @I.ir_module + class Before: + @T.prim_func(private=True) + def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "webgpu", + "max_num_threads": 1024, + "supports_subgroups": True, + "host": "llvm", + } + ), + } + ) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + cross_thread_B = T.alloc_buffer((1,), "float32", scope="local") + threadIdx_z = T.launch_thread("threadIdx.z", 1) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + cross_thread_B_1 = T.decl_buffer((1,), data=cross_thread_B.data, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.decl_buffer((256,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 128 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.decl_buffer((2,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + After = transform(Before) + assert After is not None + After_script = After.script() + assert "tvm_warp_shuffle_down" in After_script + assert "tvm_storage_sync" in After_script + assert "\"tirx.volatile\": T.bool(True)" in After_script + assert "T.uint32(" not in After_script + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index 05c79abea7a2..43d55a27fcdc 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -426,5 +426,25 @@ def test_cli_string_rejected(): Target("llvm -mcpu=cortex-a53") +def test_webgpu_target_subgroup_attrs(): + """Test WebGPU target defaults and supports_subgroups canonicalization.""" + # Default: thread_warp_size=1, supports_subgroups=False + tgt_default = Target({"kind": "webgpu"}) + assert tgt_default.attrs["thread_warp_size"] == 1 + assert tgt_default.attrs["supports_subgroups"] == 0 + + # With supports_subgroups=True: thread_warp_size is set to 32 + tgt_subgroups = Target({"kind": "webgpu", "supports_subgroups": True}) + assert tgt_subgroups.attrs["thread_warp_size"] == 32 + assert tgt_subgroups.attrs["supports_subgroups"] == 1 + + for config in [ + {"kind": "webgpu", "thread_warp_size": 32}, + {"kind": "webgpu", "thread_warp_size": 32, "supports_subgroups": False}, + ]: + with pytest.raises(tvm.TVMError, match="requires supports_subgroups=true"): + Target(config) + + if __name__ == "__main__": tvm.testing.main() diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 7471ad592e20..14fcd425aa75 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -40,6 +40,8 @@ #include "../../src/runtime/metadata.h" #include "../../src/runtime/workspace_pool.h" #include "../../src/support/bytes_io.h" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" namespace tvm { namespace runtime { diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 55d188516d40..199fa14235ee 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo if (adapter.features.has("shader-f16")) { requiredFeatures.push("shader-f16"); } + if (adapter.features.has("subgroups")) { + requiredFeatures.push("subgroups"); + } // requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise // issue when building. However, it is still needed for older browsers, hence `as any`. const adapterInfo = adapter.info || await (adapter as any).requestAdapterInfo();