Skip to content

Commit 07672d0

Browse files
elvin-ncsullivan
andauthored
[Texture] Add memory scope entity into graph JSON/runtime (#11875)
This PR is a split part of origin PR #11357 Co-authored-by: Chris Sullivan <[email protected]>
1 parent a81e69a commit 07672d0

File tree

4 files changed

+85
-18
lines changed

4 files changed

+85
-18
lines changed

src/relay/backend/graph_executor_codegen.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,12 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
326326
if (num_unknown_devices == 0) {
327327
node->attrs_["device_index"] = device_types;
328328
}
329+
// storage scope
330+
std::vector<std::string> storage_scope;
331+
for (const auto& virtual_device : storage_info->virtual_devices) {
332+
storage_scope.push_back(std::string(virtual_device->memory_scope));
333+
}
334+
node->attrs_["storage_scope"] = std::move(storage_scope);
329335
auto node_id = nodes_.size();
330336
nodes_.push_back(node);
331337
// Tuple return value, flatten as tuple
@@ -442,7 +448,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
442448
return AddNode(node, call);
443449
}
444450
} else if (!call_node->attrs.defined()) { // Call is an extern function
445-
std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
446451
const auto* func = call_node->op.as<GlobalVarNode>();
447452
ICHECK(func) << "Expected the operator to be a global var, but got "
448453
<< call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
@@ -539,12 +544,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
539544
size_t num_entry = 0;
540545
ShapeVector shapes;
541546
std::vector<size_t> storage_ids;
547+
std::vector<std::string> storage_scopes;
542548
std::vector<size_t> device_types;
543549
std::vector<std::string> dltypes;
544550
std::vector<size_t> node_row_ptr{0};
545551
for (auto node : nodes_) {
546552
const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
547553
const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
554+
const auto& storage_scope =
555+
dmlc::get<std::vector<std::string>>(node->attrs_["storage_scope"]);
548556
const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);
549557

550558
ICHECK_EQ(node->num_outputs_, shape_vec.size());
@@ -553,12 +561,25 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
553561
shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
554562
dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
555563
storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
564+
storage_scopes.insert(storage_scopes.end(), storage_scope.begin(), storage_scope.end());
556565
if (node->attrs_.count("device_index")) {
557566
const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
558567
device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
559568
}
560569
node_row_ptr.push_back(num_entry);
561570
}
571+
572+
// verification if storage_scope contains any non global memory scope
573+
// in other case it's better not to write scopes to the JSON at all
574+
bool global_only_scope = true;
575+
for (const auto& ss : storage_scopes) {
576+
if (!(ss.empty() || ss == "global")) {
577+
global_only_scope = false;
578+
}
579+
}
580+
if (global_only_scope) {
581+
storage_scopes.clear();
582+
}
562583
writer->BeginObject();
563584
writer->WriteObjectKeyValue("nodes", nodes_);
564585
writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
@@ -572,6 +593,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
572593
attrs["device_index"].emplace_back(std::string("list_int"));
573594
attrs["device_index"].emplace_back(device_types);
574595
}
596+
if (storage_scopes.size()) {
597+
attrs["storage_scope"].emplace_back(std::string("list_str"));
598+
attrs["storage_scope"].emplace_back(storage_scopes);
599+
}
575600
attrs["dltype"].emplace_back(std::string("list_str"));
576601
attrs["dltype"].emplace_back(dltypes);
577602
writer->WriteObjectKeyValue("attrs", attrs);

src/runtime/graph_executor/graph_executor.cc

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <vector>
4343

4444
#include "../file_utils.h"
45+
#include "../texture.h"
4546

4647
namespace tvm {
4748
namespace runtime {
@@ -51,6 +52,7 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
5152
if (align < kAllocAlignment) return kAllocAlignment;
5253
return align;
5354
}
55+
constexpr auto Is2DStorage = IsTextureStorage;
5456
} // namespace details
5557

5658
/*!
@@ -361,24 +363,16 @@ void GraphExecutor::SetupStorage() {
361363
// Find the maximum space size.
362364
for (size_t i = 0; i < attrs_.shape.size(); ++i) {
363365
int storage_id = attrs_.storage_id[i];
366+
std::string storage_scope = attrs_.storage_scope.empty() ? "" : attrs_.storage_scope[i];
364367
// Use the fallback device if no device index is available.
365368
int device_type = static_cast<int>(devices_[0].device_type);
366369
if (!attrs_.device_index.empty()) {
367370
device_type = attrs_.device_index[i];
368371
}
369-
size_t size = 1;
370-
for (int64_t sz : attrs_.shape[i]) {
371-
size *= static_cast<size_t>(sz);
372-
}
373-
ICHECK_GE(storage_id, 0) << "Do not support runtime shape op";
374-
DLDataType t = vtype[i];
375-
size_t bits = t.bits * t.lanes;
376-
ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U);
377-
size_t bytes = ((bits + 7U) / 8U) * size;
378372

379373
uint32_t sid = static_cast<uint32_t>(storage_id);
380374
if (sid >= pool_entry.size()) {
381-
pool_entry.resize(sid + 1, {0, -1});
375+
pool_entry.resize(sid + 1, {-1, {0}, {}});
382376
} else {
383377
ICHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type)
384378
<< "The same pool entry cannot be assigned to multiple devices";
@@ -395,8 +389,38 @@ void GraphExecutor::SetupStorage() {
395389
pool_entry[sid].linked_param = lookup_rv;
396390
}
397391
pool_entry[sid].param_data_entry = i;
398-
pool_entry[sid].size = std::max(pool_entry[sid].size, bytes);
399392
pool_entry[sid].device_type = device_type;
393+
pool_entry[sid].scope = storage_scope;
394+
395+
DLDataType t = vtype[i];
396+
if (!details::Is2DStorage(storage_scope)) {
397+
size_t size = 1;
398+
for (int64_t sz : attrs_.shape[i]) {
399+
size *= static_cast<size_t>(sz);
400+
}
401+
size_t bits = t.bits * t.lanes;
402+
ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U);
403+
int64_t bytes = ((bits + 7U) / 8U) * size;
404+
pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], bytes);
405+
pool_entry[sid].dtype = DLDataType{kDLFloat, 32, 1};
406+
} else {
407+
if (pool_entry[sid].shape.size() == 1) {
408+
pool_entry[sid].shape.resize(3, 0);
409+
}
410+
size_t axis = runtime::DefaultTextureLayoutSeparator(attrs_.shape[i].size(), storage_scope);
411+
auto shape = ApplyTexture2DFlattening<int64_t>(attrs_.shape[i], attrs_.shape[i].size(), axis);
412+
pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], shape.height);
413+
pool_entry[sid].shape[1] = std::max(pool_entry[sid].shape[1], shape.width);
414+
CHECK(pool_entry[sid].shape[2] == 0 || pool_entry[sid].shape[2] == shape.channel)
415+
<< pool_entry[sid].shape[2] << " != " << shape.channel
416+
<< ", texture channel length must be consistent within a storage pool";
417+
pool_entry[sid].shape[2] = shape.channel;
418+
CHECK(pool_entry[sid].dtype.bits == 0 || TypeEqual(pool_entry[sid].dtype, t))
419+
<< DLDataType2String(pool_entry[sid].dtype) << " != " << DLDataType2String(t)
420+
<< ", pool entry for 2d texure allocations must be of the same type;"
421+
<< " downstream error from memory planner likely";
422+
pool_entry[sid].dtype = t;
423+
}
400424
}
401425

402426
// Allocate the space.
@@ -410,9 +434,15 @@ void GraphExecutor::SetupStorage() {
410434
if (pit.linked_param.defined()) {
411435
storage_pool_.push_back(pit.linked_param);
412436
} else {
413-
std::vector<int64_t> shape;
414-
shape.push_back(static_cast<int64_t>(pit.size + 3) / 4);
415-
storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, dev));
437+
std::vector<int64_t> shape = pit.shape;
438+
if (shape.size() == 1) {
439+
shape[0] = (shape[0] + 3) / 4;
440+
}
441+
Optional<String> mem_scope;
442+
if (!pit.scope.empty()) {
443+
mem_scope = String(pit.scope);
444+
}
445+
storage_pool_.push_back(NDArray::Empty(shape, pit.dtype, dev, mem_scope));
416446
}
417447
}
418448

src/runtime/graph_executor/graph_executor.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,12 @@ class TVM_DLL GraphExecutor : public ModuleNode {
204204
protected:
205205
// Memory pool entry.
206206
struct PoolEntry {
207-
size_t size;
208207
int device_type;
208+
std::vector<int64_t> shape;
209+
DLDataType dtype;
209210
int param_data_entry;
210211
NDArray linked_param;
212+
std::string scope;
211213
// PoolEntry(int s, int dev_type, void* pre_linked_param) :
212214
// size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {}
213215
};
@@ -303,6 +305,7 @@ class TVM_DLL GraphExecutor : public ModuleNode {
303305
std::vector<int> storage_id;
304306
std::vector<int> device_index;
305307
std::vector<std::string> dltype;
308+
std::vector<std::string> storage_scope;
306309
std::vector<std::vector<int64_t>> shape;
307310
// The graph attribute fields.
308311
void Load(dmlc::JSONReader* reader) {
@@ -328,6 +331,15 @@ class TVM_DLL GraphExecutor : public ModuleNode {
328331
reader->Read(&storage_id);
329332
ICHECK(!reader->NextArrayItem());
330333
bitmask |= 2;
334+
} else if (key == "storage_scope") {
335+
reader->BeginArray();
336+
ICHECK(reader->NextArrayItem());
337+
reader->Read(&type);
338+
ICHECK_EQ(type, "list_str");
339+
ICHECK(reader->NextArrayItem());
340+
reader->Read(&storage_scope);
341+
ICHECK(!reader->NextArrayItem());
342+
bitmask |= 1;
331343
} else if (key == "shape") {
332344
reader->BeginArray();
333345
ICHECK(reader->NextArrayItem());

src/target/source/codegen_opencl.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ std::string CodeGenOpenCL::Finish() {
9898
"#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
9999
"#else\n"
100100
"#error \"Half precision floating point not supported"
101-
"by OpenCL implementation on your device.\" \n"
101+
" by OpenCL implementation on your device.\" \n"
102102
"#endif\n\n";
103103
}
104104

@@ -109,7 +109,7 @@ std::string CodeGenOpenCL::Finish() {
109109
"#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
110110
"#else\n"
111111
"#error \"Double precision floating point not supported"
112-
"by OpenCL implementation on your device.\" \n"
112+
" by OpenCL implementation on your device.\" \n"
113113
"#endif\n\n";
114114
}
115115

0 commit comments

Comments
 (0)