2323 */
2424#include " ir_builder.h"
2525
26+ #include < spirv.hpp>
27+
2628namespace tvm {
2729namespace codegen {
2830namespace spirv {
2931
3032// implementations
3133
34+ IRBuilder::IRBuilder (const SPIRVSupport& support) : spirv_support_(support) {}
35+
3236void IRBuilder::InitHeader () {
3337 ICHECK_EQ (header_.size (), 0U );
3438 header_.push_back (spv::MagicNumber);
3539
36- // Use the spirv version as indicated in the SDK.
37- #if SPV_VERSION >= 0x10300
38- header_.push_back (0x10300 );
39- #else
40+ // Target SPIR-V version 1.0. Additional functionality will be
41+ // enabled through extensions.
4042 header_.push_back (0x10000 );
41- #endif
4243
4344 // generator: set to 0, unknown
4445 header_.push_back (0U );
4546 // Bound: set during Finalize
4647 header_.push_back (0U );
4748 // Schema: reserved
4849 header_.push_back (0U );
49- // shader
50- ib_.Begin (spv::OpCapability).Add (spv::CapabilityShader).Commit (&header_);
51- // Declare int64 capability by default
52- ib_.Begin (spv::OpCapability).Add (spv::CapabilityInt64).Commit (&header_);
50+
51+ // Declare CapabilityShader by default. All other capabilities are
52+ // determined by the types declared.
53+ capabilities_used_.insert (spv::CapabilityShader);
54+
5355 // memory model
5456 ib_.Begin (spv::OpMemoryModel)
5557 .AddSeq (spv::AddressingModelLogical, spv::MemoryModelGLSL450)
@@ -71,6 +73,30 @@ void IRBuilder::InitPreDefs() {
7173 ib_.Begin (spv::OpTypeFunction).AddSeq (t_void_func_, t_void_).Commit (&global_);
7274}
7375
76+ std::vector<uint32_t > IRBuilder::Finalize () {
77+ std::vector<uint32_t > data;
78+ // Index for upper bound of id numbers.
79+ const int kBoundLoc = 3 ;
80+ header_[kBoundLoc ] = id_counter_;
81+ data.insert (data.end (), header_.begin (), header_.end ());
82+ for (const auto & capability : capabilities_used_) {
83+ ib_.Begin (spv::OpCapability).Add (capability).Commit (&data);
84+ }
85+ for (const auto & ext_name : extensions_used_) {
86+ ib_.Begin (spv::OpExtension).Add (ext_name).Commit (&data);
87+ }
88+ data.insert (data.end (), extended_instruction_section_.begin (),
89+ extended_instruction_section_.end ());
90+ data.insert (data.end (), entry_.begin (), entry_.end ());
91+ data.insert (data.end (), exec_mode_.begin (), exec_mode_.end ());
92+ data.insert (data.end (), debug_.begin (), debug_.end ());
93+ data.insert (data.end (), decorate_.begin (), decorate_.end ());
94+ data.insert (data.end (), global_.begin (), global_.end ());
95+ data.insert (data.end (), func_header_.begin (), func_header_.end ());
96+ data.insert (data.end (), function_.begin (), function_.end ());
97+ return data;
98+ }
99+
74100SType IRBuilder::GetSType (const DataType& dtype) {
75101 if (dtype == DataType::Int (32 )) {
76102 return t_int32_;
@@ -145,16 +171,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems)
145171 .AddSeq (struct_type, 0 , spv::DecorationOffset, 0 )
146172 .Commit (&decorate_);
147173
148- #if SPV_VERSION < 0x10300
149- // NOTE: BufferBlock was deprecated in SPIRV 1.3
150- // use StorageClassStorageBuffer instead.
151- // runtime array are always decorated as BufferBlock(shader storage buffer)
152- if (num_elems == 0 ) {
153- this ->Decorate (spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
174+ // Runtime array are always decorated as Block or BufferBlock
175+ // (shader storage buffer)
176+ if (spirv_support_.supports_StorageBufferStorageClass ) {
177+ // If SPIRV 1.3+, or with extension
178+ // SPV_KHR_storage_buffer_storage_class, BufferBlock is
179+ // deprecated.
180+ extensions_used_.insert (" SPV_KHR_storage_buffer_storage_class" );
181+ this ->Decorate (spv::OpDecorate, struct_type, spv::DecorationBlock);
182+ } else {
183+ if (num_elems == 0 ) {
184+ this ->Decorate (spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
185+ }
154186 }
155- #else
156- this ->Decorate (spv::OpDecorate, struct_type, spv::DecorationBlock);
157- #endif
158187 struct_array_type_tbl_[key] = struct_type;
159188 return struct_type;
160189}
@@ -186,13 +215,14 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) {
186215
187216Value IRBuilder::BufferArgument (const SType& value_type, uint32_t descriptor_set,
188217 uint32_t binding) {
189- // NOTE: BufferBlock was deprecated in SPIRV 1.3
190- // use StorageClassStorageBuffer instead.
191- #if SPV_VERSION >= 0x10300
192- spv::StorageClass storage_class = spv::StorageClassStorageBuffer;
193- #else
194- spv::StorageClass storage_class = spv::StorageClassUniform;
195- #endif
218+ // If SPIRV 1.3+, or with extension SPV_KHR_storage_buffer_storage_class, BufferBlock is
219+ // deprecated.
220+ spv::StorageClass storage_class;
221+ if (spirv_support_.supports_StorageBufferStorageClass ) {
222+ storage_class = spv::StorageClassStorageBuffer;
223+ } else {
224+ storage_class = spv::StorageClassUniform;
225+ }
196226
197227 SType sarr_type = GetStructArrayType (value_type, 0 );
198228 SType ptr_type = GetPointerType (sarr_type, storage_class);
@@ -383,6 +413,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
383413}
384414
385415SType IRBuilder::DeclareType (const DataType& dtype) {
416+ AddCapabilityFor (dtype);
417+
386418 if (dtype.lanes () == 1 ) {
387419 SType t;
388420 t.id = id_counter_++;
@@ -410,6 +442,60 @@ SType IRBuilder::DeclareType(const DataType& dtype) {
410442 }
411443}
412444
445+ void IRBuilder::AddCapabilityFor (const DataType& dtype) {
446+ // Declare appropriate capabilities for int/float types
447+ if (dtype.is_int () || dtype.is_uint ()) {
448+ if (dtype.bits () == 8 ) {
449+ ICHECK (spirv_support_.supports_Int8 ) << " Vulkan target does not support Int8 capability" ;
450+ capabilities_used_.insert (spv::CapabilityInt8);
451+ } else if (dtype.bits () == 16 ) {
452+ ICHECK (spirv_support_.supports_Int16 ) << " Vulkan target does not support Int16 capability" ;
453+ capabilities_used_.insert (spv::CapabilityInt16);
454+ } else if (dtype.bits () == 64 ) {
455+ ICHECK (spirv_support_.supports_Int64 ) << " Vulkan target does not support Int64 capability" ;
456+ capabilities_used_.insert (spv::CapabilityInt64);
457+ }
458+
459+ } else if (dtype.is_float ()) {
460+ if (dtype.bits () == 16 ) {
461+ ICHECK (spirv_support_.supports_Float16 )
462+ << " Vulkan target does not support Float16 capability" ;
463+ capabilities_used_.insert (spv::CapabilityFloat16);
464+ } else if (dtype.bits () == 64 ) {
465+ ICHECK (spirv_support_.supports_Float64 )
466+ << " Vulkan target does not support Float64 capability" ;
467+ capabilities_used_.insert (spv::CapabilityFloat64);
468+ }
469+ }
470+
471+ // Declare ability to read type to/from storage buffers. Doing so
472+ // here is a little bit overzealous, should be relaxed in the
473+ // future. Requiring StorageBuffer8BitAccess in order to declare an
474+ // Int8 prevents use of an 8-bit loop iterator on a device that
475+ // supports Int8 but doesn't support 8-bit buffer access.
476+ if (dtype.bits () == 8 ) {
477+ ICHECK (spirv_support_.supports_StorageBuffer8BitAccess )
478+ << " Vulkan target does not support StorageBuffer8BitAccess" ;
479+ capabilities_used_.insert (spv::CapabilityStorageBuffer8BitAccess);
480+ extensions_used_.insert (" SPV_KHR_8bit_storage" );
481+
482+ ICHECK (spirv_support_.supports_StorageBufferStorageClass )
483+ << " Illegal Vulkan target description. "
484+ << " Vulkan spec requires extension VK_KHR_storage_buffer_storage_class "
485+ << " if VK_KHR_8bit_storage is supported" ;
486+ } else if (dtype.bits () == 16 ) {
487+ ICHECK (spirv_support_.supports_StorageBuffer8BitAccess )
488+ << " Vulkan target does not support StorageBuffer16BitAccess" ;
489+
490+ extensions_used_.insert (" SPV_KHR_16bit_storage" );
491+ if (spirv_support_.supports_StorageBufferStorageClass ) {
492+ capabilities_used_.insert (spv::CapabilityStorageBuffer16BitAccess);
493+ } else {
494+ capabilities_used_.insert (spv::CapabilityStorageUniformBufferBlock16);
495+ }
496+ }
497+ }
498+
413499PhiValue IRBuilder::MakePhi (const SType& out_type, uint32_t num_incoming) {
414500 Value val = NewValue (out_type, kNormal );
415501 ib_.Begin (spv::OpPhi).AddSeq (out_type, val);
0 commit comments