@@ -67,20 +67,20 @@ std::vector<int64_t> calc_strides(
6767 const api::GPUMemoryLayout memory_layout,
6868 const api::StorageType storage_type) {
6969 switch (storage_type) {
70- case api::StorageType::BUFFER :
70+ case api::kBuffer :
7171 switch (memory_layout) {
72- case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED :
72+ case api::kWidthPacked :
7373 return calc_contiguous_strides (sizes);
7474 break ;
75- case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED :
75+ case api::kChannelsPacked :
7676 return calc_channels_last_strides (sizes);
7777 break ;
7878 default :
7979 VK_THROW (" Invalid memory format used to create vTensor!" );
8080 }
8181 break ;
82- case api::StorageType::TEXTURE_3D :
83- case api::StorageType::TEXTURE_2D :
82+ case api::kTexture3D :
83+ case api::kTexture2D :
8484 return std::vector<int64_t >(sizes.size ());
8585 default :
8686 VK_THROW (" Invalid storage type used to create vTensor!" );
@@ -99,10 +99,8 @@ std::vector<int64_t> calc_gpu_sizes(
9999 const std::vector<int64_t >& sizes,
100100 const api::GPUMemoryLayout memory_layout,
101101 const api::StorageType storage_type) {
102- VK_CHECK_COND (storage_type != api::StorageType::UNKNOWN);
103-
104102 std::vector<int64_t > gpu_sizes;
105- if (storage_type == api::StorageType::BUFFER ) {
103+ if (storage_type == api::kBuffer ) {
106104 gpu_sizes.resize (sizes.size ());
107105 for (size_t i = 0 ; i < sizes.size (); i++) {
108106 gpu_sizes.at (i) = sizes.at (i);
@@ -127,21 +125,21 @@ std::vector<int64_t> calc_gpu_sizes(
127125
128126 size_t ndim = gpu_sizes.size ();
129127 switch (memory_layout) {
130- case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED :
128+ case api::kWidthPacked :
131129 if (ndim >= 1 ) {
132130 gpu_sizes.at (ndim - 1 ) =
133131 api::utils::align_up (api::utils::val_at (-1 , sizes), INT64_C (4 ));
134132 }
135133 break ;
136134
137- case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED :
135+ case api::kHeightPacked :
138136 if (ndim >= 2 ) {
139137 gpu_sizes.at (ndim - 2 ) =
140138 api::utils::align_up (api::utils::val_at (-2 , sizes), INT64_C (4 ));
141139 }
142140 break ;
143141
144- case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED :
142+ case api::kChannelsPacked :
145143 if (ndim >= 3 ) {
146144 gpu_sizes.at (ndim - 3 ) =
147145 api::utils::align_up (api::utils::val_at (-3 , sizes), INT64_C (4 ));
@@ -162,7 +160,7 @@ api::utils::uvec3 create_image_extents(
162160 const api::GPUMemoryLayout memory_layout) {
163161 size_t ndim = gpu_sizes.size ();
164162
165- if (storage_type == api::StorageType::BUFFER ) {
163+ if (storage_type == api::kBuffer ) {
166164 // image extents do not apply to buffer storage
167165 return {0u , 0u , 0u };
168166 } else {
@@ -177,15 +175,15 @@ api::utils::uvec3 create_image_extents(
177175 uint32_t batch = safe_downcast<uint32_t >(val_at (-4 , gpu_sizes));
178176
179177 switch (memory_layout) {
180- case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED :
178+ case api::kWidthPacked :
181179 VK_CHECK_COND (width % 4 == 0 , " Channels must be divisible by 4!" );
182180 width /= 4 ;
183181 break ;
184- case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED :
182+ case api::kHeightPacked :
185183 VK_CHECK_COND (height % 4 == 0 , " Channels must be divisible by 4!" );
186184 height /= 4 ;
187185 break ;
188- case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED :
186+ case api::kChannelsPacked :
189187 VK_CHECK_COND (channels % 4 == 0 , " Channels must be divisible by 4!" );
190188 channels /= 4 ;
191189 break ;
@@ -326,41 +324,35 @@ std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() {
326324
327325VmaAllocationCreateInfo vTensor::get_allocation_create_info () const {
328326 switch (storage_type ()) {
329- case api::StorageType::BUFFER :
327+ case api::kBuffer :
330328 return view_->buffer_ .allocation_create_info ();
331- case api::StorageType::TEXTURE_2D :
332- case api::StorageType::TEXTURE_3D :
329+ case api::kTexture2D :
330+ case api::kTexture3D :
333331 return view_->image_ .allocation_create_info ();
334- case api::StorageType::UNKNOWN:
335- break ;
336332 }
337333 return {};
338334}
339335
340336VkMemoryRequirements vTensor::get_memory_requirements () const {
341337 switch (storage_type ()) {
342- case api::StorageType::BUFFER :
338+ case api::kBuffer :
343339 return view_->buffer_ .get_memory_requirements ();
344- case api::StorageType::TEXTURE_2D :
345- case api::StorageType::TEXTURE_3D :
340+ case api::kTexture2D :
341+ case api::kTexture3D :
346342 return view_->image_ .get_memory_requirements ();
347- case api::StorageType::UNKNOWN:
348- break ;
349343 }
350344 return {};
351345}
352346
353347void vTensor::bind_allocation (const api::MemoryAllocation& allocation) {
354348 switch (storage_type ()) {
355- case api::StorageType::BUFFER :
349+ case api::kBuffer :
356350 view_->buffer_ .bind_allocation (allocation);
357351 break ;
358- case api::StorageType::TEXTURE_2D :
359- case api::StorageType::TEXTURE_3D :
352+ case api::kTexture2D :
353+ case api::kTexture3D :
360354 view_->image_ .bind_allocation (allocation);
361355 break ;
362- case api::StorageType::UNKNOWN:
363- break ;
364356 }
365357}
366358
@@ -397,7 +389,7 @@ void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
397389
398390void vTensor::virtual_resize (const std::vector<int64_t >& new_sizes) {
399391 update_size_metadata (new_sizes);
400- if (storage_type () == api::StorageType::BUFFER ) {
392+ if (storage_type () == api::kBuffer ) {
401393 if (gpu_nbytes () > view_->buffer_ .mem_size ()) {
402394 VK_THROW (
403395 " Cannot virtual_resize a vTensor with sizes that require a larger "
@@ -446,11 +438,11 @@ api::VulkanImage allocate_image(
446438 VkImageViewType image_view_type = VK_IMAGE_VIEW_TYPE_3D;
447439
448440 switch (storage_type) {
449- case api::StorageType::TEXTURE_3D :
441+ case api::kTexture3D :
450442 image_type = VK_IMAGE_TYPE_3D;
451443 image_view_type = VK_IMAGE_VIEW_TYPE_3D;
452444 break ;
453- case api::StorageType::TEXTURE_2D :
445+ case api::kTexture2D :
454446 image_type = VK_IMAGE_TYPE_2D;
455447 image_view_type = VK_IMAGE_VIEW_TYPE_2D;
456448 break ;
@@ -481,7 +473,7 @@ api::VulkanBuffer allocate_buffer(
481473 api::Adapter* adapter_ptr = context_ptr->adapter_ptr ();
482474
483475 switch (storage_type) {
484- case api::StorageType::BUFFER :
476+ case api::kBuffer :
485477 break ;
486478 default :
487479 // Return an empty VulkanBuffer if Buffer storage is not used
0 commit comments