4242#include < vector>
4343
4444#include " ../file_utils.h"
45+ #include " ../texture.h"
4546
4647namespace tvm {
4748namespace runtime {
@@ -51,6 +52,7 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
5152 if (align < kAllocAlignment ) return kAllocAlignment ;
5253 return align;
5354}
55+ constexpr auto Is2DStorage = IsTextureStorage;
5456} // namespace details
5557
5658/* !
@@ -361,24 +363,16 @@ void GraphExecutor::SetupStorage() {
361363 // Find the maximum space size.
362364 for (size_t i = 0 ; i < attrs_.shape .size (); ++i) {
363365 int storage_id = attrs_.storage_id [i];
366+ std::string storage_scope = attrs_.storage_scope .empty () ? " " : attrs_.storage_scope [i];
364367 // Use the fallback device if no device index is available.
365368 int device_type = static_cast <int >(devices_[0 ].device_type );
366369 if (!attrs_.device_index .empty ()) {
367370 device_type = attrs_.device_index [i];
368371 }
369- size_t size = 1 ;
370- for (int64_t sz : attrs_.shape [i]) {
371- size *= static_cast <size_t >(sz);
372- }
373- ICHECK_GE (storage_id, 0 ) << " Do not support runtime shape op" ;
374- DLDataType t = vtype[i];
375- size_t bits = t.bits * t.lanes ;
376- ICHECK (bits % 8U == 0U || bits == 1U || bits == 4U );
377- size_t bytes = ((bits + 7U ) / 8U ) * size;
378372
379373 uint32_t sid = static_cast <uint32_t >(storage_id);
380374 if (sid >= pool_entry.size ()) {
381- pool_entry.resize (sid + 1 , {0 , - 1 });
375+ pool_entry.resize (sid + 1 , {- 1 , { 0 }, {} });
382376 } else {
383377 ICHECK (pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type)
384378 << " The same pool entry cannot be assigned to multiple devices" ;
@@ -395,8 +389,38 @@ void GraphExecutor::SetupStorage() {
395389 pool_entry[sid].linked_param = lookup_rv;
396390 }
397391 pool_entry[sid].param_data_entry = i;
398- pool_entry[sid].size = std::max (pool_entry[sid].size , bytes);
399392 pool_entry[sid].device_type = device_type;
393+ pool_entry[sid].scope = storage_scope;
394+
395+ DLDataType t = vtype[i];
396+ if (!details::Is2DStorage (storage_scope)) {
397+ size_t size = 1 ;
398+ for (int64_t sz : attrs_.shape [i]) {
399+ size *= static_cast <size_t >(sz);
400+ }
401+ size_t bits = t.bits * t.lanes ;
402+ ICHECK (bits % 8U == 0U || bits == 1U || bits == 4U );
403+ int64_t bytes = ((bits + 7U ) / 8U ) * size;
404+ pool_entry[sid].shape [0 ] = std::max (pool_entry[sid].shape [0 ], bytes);
405+ pool_entry[sid].dtype = DLDataType{kDLFloat , 32 , 1 };
406+ } else {
407+ if (pool_entry[sid].shape .size () == 1 ) {
408+ pool_entry[sid].shape .resize (3 , 0 );
409+ }
410+ size_t axis = runtime::DefaultTextureLayoutSeparator (attrs_.shape [i].size (), storage_scope);
411+ auto shape = ApplyTexture2DFlattening<int64_t >(attrs_.shape [i], attrs_.shape [i].size (), axis);
412+ pool_entry[sid].shape [0 ] = std::max (pool_entry[sid].shape [0 ], shape.height );
413+ pool_entry[sid].shape [1 ] = std::max (pool_entry[sid].shape [1 ], shape.width );
414+ CHECK (pool_entry[sid].shape [2 ] == 0 || pool_entry[sid].shape [2 ] == shape.channel )
415+ << pool_entry[sid].shape [2 ] << " != " << shape.channel
416+ << " , texture channel length must be consistent within a storage pool" ;
417+ pool_entry[sid].shape [2 ] = shape.channel ;
418+ CHECK (pool_entry[sid].dtype .bits == 0 || TypeEqual (pool_entry[sid].dtype , t))
419+ << DLDataType2String (pool_entry[sid].dtype ) << " != " << DLDataType2String (t)
420+ << " , pool entry for 2d texure allocations must be of the same type;"
421+ << " downstream error from memory planner likely" ;
422+ pool_entry[sid].dtype = t;
423+ }
400424 }
401425
402426 // Allocate the space.
@@ -410,9 +434,15 @@ void GraphExecutor::SetupStorage() {
410434 if (pit.linked_param .defined ()) {
411435 storage_pool_.push_back (pit.linked_param );
412436 } else {
413- std::vector<int64_t > shape;
414- shape.push_back (static_cast <int64_t >(pit.size + 3 ) / 4 );
415- storage_pool_.push_back (NDArray::Empty (shape, DLDataType{kDLFloat , 32 , 1 }, dev));
437+ std::vector<int64_t > shape = pit.shape ;
438+ if (shape.size () == 1 ) {
439+ shape[0 ] = (shape[0 ] + 3 ) / 4 ;
440+ }
441+ Optional<String> mem_scope;
442+ if (!pit.scope .empty ()) {
443+ mem_scope = String (pit.scope );
444+ }
445+ storage_pool_.push_back (NDArray::Empty (shape, pit.dtype , dev, mem_scope));
416446 }
417447 }
418448
0 commit comments