Skip to content

Commit 6a13b44

Browse files
committed
[Vulkan] Broke out implicit device requirements into SPIRVSupport
Codifies the current requirements that are implicit in the shaders built by CodeGenSPIRV (e.g. can read from 8-bit buffers). The next steps for this development are (1) to query driver/device support information from the device, (2) to pass these query parameters through the Target, and (3) to ensure correct shader generation even when features are not supported. Step (3) will require exposing the target properties to relay optimization passes.
1 parent 21a7b49 commit 6a13b44

File tree

7 files changed

+448
-46
lines changed

7 files changed

+448
-46
lines changed

src/target/spirv/build_vulkan.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction)
7575

7676
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
7777

78-
CodeGenSPIRV cg;
78+
CodeGenSPIRV cg(target);
7979

8080
for (auto kv : mod->functions) {
8181
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV: Can only take PrimFunc";

src/target/spirv/codegen_spirv.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@
3737
namespace tvm {
3838
namespace codegen {
3939

40+
CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {}
41+
4042
runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
4143
this->InitFuncState();
4244
ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
4345
std::vector<Var> pod_args;
4446
uint32_t num_buffer = 0;
4547

4648
// Currently, all storage and uniform buffer arguments are passed as
47-
// a single descriptor set at index 0.
49+
// a single descriptor set at index 0. If ever non-zero, must
50+
// ensure it is less than maxBoundDescriptorSets.
4851
const uint32_t descriptor_set = 0;
4952

5053
for (Var arg : f->params) {
@@ -114,7 +117,7 @@ void CodeGenSPIRV::InitFuncState() {
114117
var_map_.clear();
115118
storage_info_.clear();
116119
analyzer_.reset(new arith::Analyzer());
117-
builder_.reset(new spirv::IRBuilder());
120+
builder_.reset(new spirv::IRBuilder(spirv_support_));
118121
builder_->InitHeader();
119122
}
120123

src/target/spirv/codegen_spirv.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
2626

2727
#include <tvm/arith/analyzer.h>
28+
#include <tvm/target/target.h>
2829
#include <tvm/tir/analysis.h>
2930
#include <tvm/tir/expr.h>
3031
#include <tvm/tir/function.h>
@@ -38,6 +39,7 @@
3839
#include "../../runtime/thread_storage_scope.h"
3940
#include "../../runtime/vulkan/vulkan_shader.h"
4041
#include "ir_builder.h"
42+
#include "spirv_support.h"
4143

4244
namespace tvm {
4345
namespace codegen {
@@ -50,6 +52,14 @@ using namespace tir;
5052
class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
5153
public StmtFunctor<void(const Stmt&)> {
5254
public:
55+
/*!
56+
* \brief Initialize the codegen based on a specific target.
57+
*
58+
* \param target The target for which code should be generated. The
59+
* device_type for this target must be kDLVulkan.
60+
*/
61+
CodeGenSPIRV(Target target);
62+
5363
/*!
5464
* \brief Compile and add function f to the current module.
5565
* \param f The function to be added.
@@ -131,6 +141,8 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
131141
spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
132142
spirv::Value CreateStorageSync(const CallNode* op);
133143
void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);
144+
// SPIRV-related capabilities of the target
145+
SPIRVSupport spirv_support_;
134146
// The builder
135147
std::unique_ptr<spirv::IRBuilder> builder_;
136148
// Work group size of three

src/target/spirv/ir_builder.cc

Lines changed: 111 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,35 @@
2323
*/
2424
#include "ir_builder.h"
2525

26+
#include <spirv.hpp>
27+
2628
namespace tvm {
2729
namespace codegen {
2830
namespace spirv {
2931

3032
// implementations
3133

34+
IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {}
35+
3236
void IRBuilder::InitHeader() {
3337
ICHECK_EQ(header_.size(), 0U);
3438
header_.push_back(spv::MagicNumber);
3539

36-
// Use the spirv version as indicated in the SDK.
37-
#if SPV_VERSION >= 0x10300
38-
header_.push_back(0x10300);
39-
#else
40+
// Target SPIR-V version 1.0. Additional functionality will be
41+
// enabled through extensions.
4042
header_.push_back(0x10000);
41-
#endif
4243

4344
// generator: set to 0, unknown
4445
header_.push_back(0U);
4546
// Bound: set during Finalize
4647
header_.push_back(0U);
4748
// Schema: reserved
4849
header_.push_back(0U);
49-
// shader
50-
ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_);
51-
// Declare int64 capability by default
52-
ib_.Begin(spv::OpCapability).Add(spv::CapabilityInt64).Commit(&header_);
50+
51+
// Declare CapabilityShader by default. All other capabilities are
52+
// determined by the types declared.
53+
capabilities_used_.insert(spv::CapabilityShader);
54+
5355
// memory model
5456
ib_.Begin(spv::OpMemoryModel)
5557
.AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450)
@@ -71,6 +73,30 @@ void IRBuilder::InitPreDefs() {
7173
ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_);
7274
}
7375

76+
std::vector<uint32_t> IRBuilder::Finalize() {
77+
std::vector<uint32_t> data;
78+
// Index for upper bound of id numbers.
79+
const int kBoundLoc = 3;
80+
header_[kBoundLoc] = id_counter_;
81+
data.insert(data.end(), header_.begin(), header_.end());
82+
for (const auto& capability : capabilities_used_) {
83+
ib_.Begin(spv::OpCapability).Add(capability).Commit(&data);
84+
}
85+
for (const auto& ext_name : extensions_used_) {
86+
ib_.Begin(spv::OpExtension).Add(ext_name).Commit(&data);
87+
}
88+
data.insert(data.end(), extended_instruction_section_.begin(),
89+
extended_instruction_section_.end());
90+
data.insert(data.end(), entry_.begin(), entry_.end());
91+
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
92+
data.insert(data.end(), debug_.begin(), debug_.end());
93+
data.insert(data.end(), decorate_.begin(), decorate_.end());
94+
data.insert(data.end(), global_.begin(), global_.end());
95+
data.insert(data.end(), func_header_.begin(), func_header_.end());
96+
data.insert(data.end(), function_.begin(), function_.end());
97+
return data;
98+
}
99+
74100
SType IRBuilder::GetSType(const DataType& dtype) {
75101
if (dtype == DataType::Int(32)) {
76102
return t_int32_;
@@ -145,16 +171,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems)
145171
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
146172
.Commit(&decorate_);
147173

148-
#if SPV_VERSION < 0x10300
149-
// NOTE: BufferBlock was deprecated in SPIRV 1.3
150-
// use StorageClassStorageBuffer instead.
151-
// runtime array are always decorated as BufferBlock(shader storage buffer)
152-
if (num_elems == 0) {
153-
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
174+
// Runtime array are always decorated as Block or BufferBlock
175+
// (shader storage buffer)
176+
if (spirv_support_.supports_StorageBufferStorageClass) {
177+
// If SPIRV 1.3+, or with extension
178+
// SPV_KHR_storage_buffer_storage_class, BufferBlock is
179+
// deprecated.
180+
extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
181+
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
182+
} else {
183+
if (num_elems == 0) {
184+
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
185+
}
154186
}
155-
#else
156-
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
157-
#endif
158187
struct_array_type_tbl_[key] = struct_type;
159188
return struct_type;
160189
}
@@ -186,13 +215,14 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) {
186215

187216
Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set,
188217
uint32_t binding) {
189-
// NOTE: BufferBlock was deprecated in SPIRV 1.3
190-
// use StorageClassStorageBuffer instead.
191-
#if SPV_VERSION >= 0x10300
192-
spv::StorageClass storage_class = spv::StorageClassStorageBuffer;
193-
#else
194-
spv::StorageClass storage_class = spv::StorageClassUniform;
195-
#endif
218+
// If SPIRV 1.3+, or with extension SPV_KHR_storage_buffer_storage_class, BufferBlock is
219+
// deprecated.
220+
spv::StorageClass storage_class;
221+
if (spirv_support_.supports_StorageBufferStorageClass) {
222+
storage_class = spv::StorageClassStorageBuffer;
223+
} else {
224+
storage_class = spv::StorageClassUniform;
225+
}
196226

197227
SType sarr_type = GetStructArrayType(value_type, 0);
198228
SType ptr_type = GetPointerType(sarr_type, storage_class);
@@ -383,6 +413,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
383413
}
384414

385415
SType IRBuilder::DeclareType(const DataType& dtype) {
416+
AddCapabilityFor(dtype);
417+
386418
if (dtype.lanes() == 1) {
387419
SType t;
388420
t.id = id_counter_++;
@@ -410,6 +442,60 @@ SType IRBuilder::DeclareType(const DataType& dtype) {
410442
}
411443
}
412444

445+
void IRBuilder::AddCapabilityFor(const DataType& dtype) {
446+
// Declare appropriate capabilities for int/float types
447+
if (dtype.is_int() || dtype.is_uint()) {
448+
if (dtype.bits() == 8) {
449+
ICHECK(spirv_support_.supports_Int8) << "Vulkan target does not support Int8 capability";
450+
capabilities_used_.insert(spv::CapabilityInt8);
451+
} else if (dtype.bits() == 16) {
452+
ICHECK(spirv_support_.supports_Int16) << "Vulkan target does not support Int16 capability";
453+
capabilities_used_.insert(spv::CapabilityInt16);
454+
} else if (dtype.bits() == 64) {
455+
ICHECK(spirv_support_.supports_Int64) << "Vulkan target does not support Int64 capability";
456+
capabilities_used_.insert(spv::CapabilityInt64);
457+
}
458+
459+
} else if (dtype.is_float()) {
460+
if (dtype.bits() == 16) {
461+
ICHECK(spirv_support_.supports_Float16)
462+
<< "Vulkan target does not support Float16 capability";
463+
capabilities_used_.insert(spv::CapabilityFloat16);
464+
} else if (dtype.bits() == 64) {
465+
ICHECK(spirv_support_.supports_Float64)
466+
<< "Vulkan target does not support Float64 capability";
467+
capabilities_used_.insert(spv::CapabilityFloat64);
468+
}
469+
}
470+
471+
// Declare ability to read type to/from storage buffers. Doing so
472+
// here is a little bit overzealous, should be relaxed in the
473+
// future. Requiring StorageBuffer8BitAccess in order to declare an
474+
// Int8 prevents use of an 8-bit loop iterator on a device that
475+
// supports Int8 but doesn't support 8-bit buffer access.
476+
if (dtype.bits() == 8) {
477+
ICHECK(spirv_support_.supports_StorageBuffer8BitAccess)
478+
<< "Vulkan target does not support StorageBuffer8BitAccess";
479+
capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess);
480+
extensions_used_.insert("SPV_KHR_8bit_storage");
481+
482+
ICHECK(spirv_support_.supports_StorageBufferStorageClass)
483+
<< "Illegal Vulkan target description. "
484+
<< "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class "
485+
<< "if VK_KHR_8bit_storage is supported";
486+
} else if (dtype.bits() == 16) {
487+
ICHECK(spirv_support_.supports_StorageBuffer8BitAccess)
488+
<< "Vulkan target does not support StorageBuffer16BitAccess";
489+
490+
extensions_used_.insert("SPV_KHR_16bit_storage");
491+
if (spirv_support_.supports_StorageBufferStorageClass) {
492+
capabilities_used_.insert(spv::CapabilityStorageBuffer16BitAccess);
493+
} else {
494+
capabilities_used_.insert(spv::CapabilityStorageUniformBufferBlock16);
495+
}
496+
}
497+
}
498+
413499
PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) {
414500
Value val = NewValue(out_type, kNormal);
415501
ib_.Begin(spv::OpPhi).AddSeq(out_type, val);

src/target/spirv/ir_builder.h

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
// clang-format off
3131
#include <algorithm>
3232
#include <map>
33+
#include <set>
3334
#include <string>
3435
#include <unordered_map>
3536
#include <utility>
3637
#include <vector>
3738
#include <spirv.hpp>
3839
// clang-format on
3940

41+
#include "spirv_support.h"
42+
4043
namespace tvm {
4144
namespace codegen {
4245
namespace spirv {
@@ -268,6 +271,14 @@ class InstrBuilder {
268271
*/
269272
class IRBuilder {
270273
public:
274+
/*!
275+
* \brief Initialize the codegen based on a specific feature set.
276+
*
277+
* \param support The features in SPIRV that are supported by the
278+
* target device.
279+
*/
280+
explicit IRBuilder(const SPIRVSupport& support);
281+
271282
/*! \brief Initialize header */
272283
void InitHeader();
273284
/*! \brief Initialize the predefined contents */
@@ -278,29 +289,21 @@ class IRBuilder {
278289
* \return The finalized binary instruction.
279290
*/
280291
Value ExtInstImport(const std::string& name) {
292+
auto it = ext_inst_tbl_.find(name);
293+
if (it != ext_inst_tbl_.end()) {
294+
return it->second;
295+
}
281296
Value val = NewValue(SType(), kExtInst);
282-
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_);
297+
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&extended_instruction_section_);
298+
ext_inst_tbl_[name] = val;
283299
return val;
284300
}
285301
/*!
286302
* \brief Get the final binary built from the builder
287303
* \return The finalized binary instruction.
288304
*/
289-
std::vector<uint32_t> Finalize() {
290-
std::vector<uint32_t> data;
291-
// set bound
292-
const int kBoundLoc = 3;
293-
header_[kBoundLoc] = id_counter_;
294-
data.insert(data.end(), header_.begin(), header_.end());
295-
data.insert(data.end(), entry_.begin(), entry_.end());
296-
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
297-
data.insert(data.end(), debug_.begin(), debug_.end());
298-
data.insert(data.end(), decorate_.begin(), decorate_.end());
299-
data.insert(data.end(), global_.begin(), global_.end());
300-
data.insert(data.end(), func_header_.begin(), func_header_.end());
301-
data.insert(data.end(), function_.begin(), function_.end());
302-
return data;
303-
}
305+
std::vector<uint32_t> Finalize();
306+
304307
/*!
305308
* \brief Create new label
306309
* \return The created new label
@@ -599,6 +602,19 @@ class IRBuilder {
599602
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
600603
// declare type
601604
SType DeclareType(const DataType& dtype);
605+
606+
// Declare the appropriate SPIR-V capabilities and extensions to use
607+
// this data type.
608+
void AddCapabilityFor(const DataType& dtype);
609+
610+
/*! \brief SPIRV-related capabilities of the target
611+
*
612+
* This SPIRVSupport object is owned by the same CodeGenSPIRV
613+
* object that owns the IRBuilder. Therefore, safe to use a
614+
* reference as the CodeGenSPIRV will live longer.
615+
*/
616+
const SPIRVSupport& spirv_support_;
617+
602618
/*! \brief internal instruction builder */
603619
InstrBuilder ib_;
604620
/*! \brief Current label */
@@ -623,9 +639,22 @@ class IRBuilder {
623639
std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
624640
/*! \brief map from constant int to its value */
625641
std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
626-
/*! \brief Header segment, include import */
642+
/*! \brief map from name of a ExtInstImport to its value */
643+
std::map<std::string, Value> ext_inst_tbl_;
644+
645+
/*! \brief Header segment
646+
*
647+
* 5 words long, described in "First Words of Physical Layout"
648+
* section of SPIR-V documentation.
649+
*/
627650
std::vector<uint32_t> header_;
628-
/*! \brief engtry point segment */
651+
/*! \brief SPIR-V capabilities used by this module. */
652+
std::set<spv::Capability> capabilities_used_;
653+
/*! \brief SPIR-V extensions used by this module. */
654+
std::set<std::string> extensions_used_;
655+
/*! \brief entry point segment */
656+
std::vector<uint32_t> extended_instruction_section_;
657+
/*! \brief entry point segment */
629658
std::vector<uint32_t> entry_;
630659
/*! \brief Header segment */
631660
std::vector<uint32_t> exec_mode_;

0 commit comments

Comments
 (0)