diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 1d1df91dc4a4..032bcddbb8a0 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -103,6 +103,8 @@ std::string CodeGenWebGPU::Finish() { if (enable_fp16_) { header_stream << "enable f16;\n\n"; } + // TODO(Charlie): Add enable_subgroups_ to control + header_stream << "enable subgroups;\n\n"; return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); } diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index f3e561f71477..81994e36738e 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -32,6 +32,29 @@ namespace intrin { using tir::FLowerIntrinsic; +// warp-level primitives. Follows implementation in intrin_rule_metal.cc +struct WebGPUWarpIntrinsic { + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.webgpu.subgroup_shuffle"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.webgpu.subgroup_shuffle_up"); + } else { + ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.webgpu.subgroup_shuffle_down"); + } + } +}; + +template +static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { + const CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array webgpu_args{{call->args[1], call->args[2]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), webgpu_args); +} + // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions struct ReturnAbs { @@ -113,6 +136,39 @@ TVM_REGISTER_OP("tir.trunc") // extra dispatch TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); +// warp-level primitives. Follows implementation in intrin_rule_metal.cc +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle); + +// Register low-level builtin ops. +TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TGlobalSymbol", "subgroupShuffle") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_up") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TGlobalSymbol", "subgroupShuffleUp") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_down") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TGlobalSymbol", "subgroupShuffleDown") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 62ba2787a367..c831fcb975ae 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -423,6 +423,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("max_num_threads", runtime::Int(256)) + // TODO(Charlie): Not all WebGPU supports this, need a control logic + .add_attr_option("thread_warp_size", runtime::Int(32)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index dde33fa2678d..dad4944e8e7d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -505,7 +505,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. - PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); + bool cast_offset_to_uint = target_->kind->name == "webgpu"; + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset, + cast_offset_to_uint); Buffer local_buf = local_bufs[i]; Stmt s = BufferStore(local_buf, other, zero_indices); seq->push_back(s); @@ -694,7 +696,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Emit warp shuffle calls. PrimExpr WarpShuffle(const Op& op, Optional mask_buffer, PrimExpr val, - PrimExpr delta_or_lane) { + PrimExpr delta_or_lane, bool cast_delta_to_uint = false) { + if (cast_delta_to_uint) { + delta_or_lane = cast(DataType::UInt(32, delta_or_lane.dtype().lanes()), delta_or_lane); + } Array indices = {0}; PrimExpr mask; if (mask_buffer.defined()) { @@ -714,11 +719,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { bool IsWarpReduction(const std::vector& types, int group_extent, int reduce_extent, int contiguous_reduce_extent) { if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && - (target_->kind->name != "metal")) { + (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) { return false; } - need_warp_shuffle_mask_ = target_->kind->name != "metal"; + need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu"; // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 27d68d887c32..deee50a6b645 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -110,11 +110,17 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo ); } - const requiredFeatures: GPUFeatureName[] = []; + // TODO(Charlie): cannot type annotate because @webgpu/types + // does not have "subgroups" as GPUFeatureName yet + // const requiredFeatures: GPUFeatureName[] = []; + const requiredFeatures = []; // Always require f16 if available if (adapter.features.has("shader-f16")) { requiredFeatures.push("shader-f16"); } + if (adapter.features.has("subgroups")) { + requiredFeatures.push("subgroups"); + } // requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise // issue when building. However, it is still needed for older browsers, hence `as any`.