Skip to content

Commit 85c9931

Browse files
Lunderbergylc
authored andcommitted
[Vulkan] Implement sync for SyncThread("warp") (apache#8320)
- Add sync if a SyncThread("warp") node is present. The sync is done at spv::ScopeSubgroup if supported (Vulkan 1.1+), and at spv::ScopeWorkgroup otherwise. Co-authored-by: Eric Lunderberg <[email protected]>
1 parent a669ac2 commit 85c9931

File tree

5 files changed

+56
-14
lines changed

5 files changed

+56
-14
lines changed

src/target/spirv/build_vulkan.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,24 @@ namespace codegen {
3636

3737
class SPIRVTools {
3838
public:
39-
SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); }
39+
explicit SPIRVTools(Target target) {
40+
uint32_t vulkan_version =
41+
target->GetAttr<Integer>("vulkan_api_version").value_or(VK_API_VERSION_1_0);
42+
uint32_t spirv_version = target->GetAttr<Integer>("max_spirv_version").value_or(0x10000);
43+
44+
spv_target_env validation_version;
45+
if (vulkan_version >= VK_API_VERSION_1_2) {
46+
validation_version = SPV_ENV_VULKAN_1_2;
47+
} else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) {
48+
validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4;
49+
} else if (vulkan_version >= VK_API_VERSION_1_1) {
50+
validation_version = SPV_ENV_VULKAN_1_1;
51+
} else {
52+
validation_version = SPV_ENV_VULKAN_1_0;
53+
}
54+
55+
ctx_ = spvContextCreate(validation_version);
56+
}
4057
~SPIRVTools() { spvContextDestroy(ctx_); }
4158
std::string BinaryToText(const std::vector<uint32_t>& bin) {
4259
spv_text text = nullptr;
@@ -80,7 +97,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction)
8097
using tvm::runtime::VulkanShader;
8198

8299
std::ostringstream code_data;
83-
static SPIRVTools spirv_tools;
100+
SPIRVTools spirv_tools(target);
84101
std::unordered_map<std::string, VulkanShader> smap;
85102

86103
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");

src/target/spirv/codegen_spirv.cc

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,27 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext
140140
spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
141141
const std::string& sync = op->args[0].as<StringImmNode>()->value;
142142
spirv::Value value;
143-
if (sync == "warp") {
144-
return value;
145-
} else if (sync == "shared") {
146-
auto type_int = builder_->GetSType(DataType::Int(32));
147-
builder_->MakeInst(
148-
spv::OpControlBarrier,
149-
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
150-
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
151-
builder_->IntImm(type_int,
152-
static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
153-
spv::MemorySemanticsWorkgroupMemoryMask)));
143+
144+
uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;
145+
146+
int64_t sync_scope;
147+
int64_t memory_semantics;
148+
if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
149+
sync_scope = spv::ScopeSubgroup;
150+
memory_semantics =
151+
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask;
152+
} else if ((sync == "shared") || (sync == "warp")) {
153+
sync_scope = spv::ScopeWorkgroup;
154+
memory_semantics =
155+
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask;
154156
} else {
155157
LOG(FATAL) << "Do not support sync " << sync;
156158
}
159+
160+
auto type_int = builder_->GetSType(DataType::Int(32));
161+
builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
162+
builder_->IntImm(type_int, sync_scope),
163+
builder_->IntImm(type_int, memory_semantics));
157164
return value;
158165
}
159166

src/target/spirv/spirv_support.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) {
3535
ICHECK_EQ(target->kind->device_type, kDLVulkan)
3636
<< "SPIRVSupport can only be checked for vulkan device type";
3737

38+
if (target->GetAttr<Integer>("vulkan_api_version")) {
39+
vulkan_api_version = target->GetAttr<Integer>("vulkan_api_version").value();
40+
}
41+
3842
if (target->GetAttr<Integer>("supported_subgroup_operations")) {
3943
supported_subgroup_operations =
4044
target->GetAttr<Integer>("supported_subgroup_operations").value();

src/target/spirv/spirv_support.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_
2828

2929
#include <tvm/target/target.h>
30+
#include <vulkan/vulkan_core.h>
3031

3132
namespace tvm {
3233
namespace codegen {
@@ -37,6 +38,19 @@ struct SPIRVSupport {
3738
*/
3839
explicit SPIRVSupport(Target target);
3940

41+
/*! \brief The Vulkan API version supported by the device.
42+
*
43+
* Vulkan struct: VkPhysicalDeviceProperties
44+
* Device property: apiVersion
45+
*
46+
* If VK_KHR_driver_properties is present, will also check the
47+
* driver conformance version. If the version advertised does not
48+
* pass the Vulkan conformance test, vulkan_api_version will be the
49+
* latest Vulkan version that does pass the conformance test
50+
* instead.
51+
*/
52+
uint32_t vulkan_api_version{VK_MAKE_VERSION(1, 0, 0)};
53+
4054
/*!
4155
* \brief The supported subgroup operations
4256
*

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
425425
while (reduce_align > 1) {
426426
reduce_align = reduce_align >> 1;
427427
in_warp_seq.emplace_back(freduce(reduce_align));
428-
seq.emplace_back(SyncThread("warp"));
428+
in_warp_seq.emplace_back(SyncThread("warp"));
429429
}
430430
if (in_warp_seq.size() != 0) {
431431
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);

0 commit comments

Comments
 (0)