Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
*/
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);

/*!
* \brief Creates TIR Buffer for provided parameters
* \param shape shape of the buffer
* \param dtype data type
* \param name buffer name
* \param data_alignment alignment requirement of data pointer in bytes
* \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be
* multiple of offset_factor
User can specify data_alignment and offset_factor to be 0
* A default value will be picked.
* \param compact If the statement has already bound to a compact buffer.
* \param memory_scope memory scope of the buffer
*/
TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype,
std::string name, int data_alignment,
int offset_factor, bool compact,
std::string memory_scope = "");
} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
5 changes: 5 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,11 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
*/
TVM_DLL Pass FlattenAtrousConv();

/*!
* \brief Calls device dependent memory scope analysis pass, collects mapping of desirable
* expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
*/
TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config);
} // namespace transform

/*!
Expand Down
19 changes: 14 additions & 5 deletions python/tvm/topi/adreno/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
expand_spatial_dimensions,
add_pad,
bind_data_copy,
get_texture_storage,
)


Expand Down Expand Up @@ -213,8 +214,11 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
5d tensors
4. pad should be scheduled separately to create independent opencl kernel. If pad is
inlined into convolution, this gives 1.5x performance drop
5. We are using cache_read to produce texture and guarantee the best performance
on the next stage.
5. We are using cache_read for intermediate tensors to produce texture and guarantee
the best performance on the next stage.
The weights are managed through static texture planning mechanism and guarantied come
in texture memory scope.
Thus way we are calling cache_read only for data tensor
6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
for textures
For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
Expand Down Expand Up @@ -285,10 +289,15 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, "global.texture", [conv])
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, "global.texture-weight", [conv])
bind_data_copy(s[WT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/adreno/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ def schedule_conv2d_NHWC(cfg, s, output):
5d tensors
4. pad should be scheduled separately to create independent opencl kernel. If pad is
inlined into convolution, this gives 1.5x performance drop
5. We are using cache_read to produce texture and guarantee the best performance
on the next stage.
5. We are using cache_read for intermediate tensors to produce texture and guarantee
the best performance on the next stage.
The weights are managed through static texture planning mechanism and guarantied come
in texture memory scope.
Thus way we are calling cache_read only for data tensor
6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
for textures
For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
Expand Down Expand Up @@ -282,8 +285,13 @@ def schedule_conv2d_NHWC(cfg, s, output):
# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/topi/adreno/depthwise_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
expand_spatial_dimensions,
add_pad,
bind_data_copy,
get_texture_storage,
)


Expand Down Expand Up @@ -260,11 +261,12 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, "global.texture", [conv])
WT = s.cache_read(kernel, "global.texture-weight", [conv])
bind_data_copy(s[AT])
bind_data_copy(s[WT])
if autotvm.GLOBAL_SCOPE.in_tuning or len(latest.op.axis) == 4:
# create cache stage for tuning only or in case of 4d case
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,12 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[AT])
bind_data_copy(s[WT])
if autotvm.GLOBAL_SCOPE.in_tuning or len(latest.op.axis) == 4:
# create cache stage for tuning only or in case of 4d case
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
Expand Down
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ Target DefaultTargetHost(Target target) {
}

tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact) {
int data_alignment, int offset_factor, bool compact,
std::string memory_scope) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
auto data = tir::Var(name, PointerType(PrimType(storage_dtype)));
auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = transform::InferType()(relay_module);
relay_module = transform::LabelOps()(relay_module);

relay_module = transform::AnnotateMemoryScope(config_)(relay_module);
ICHECK(relay_module.defined());

return relay_module;
Expand Down
27 changes: 26 additions & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
if (num_unknown_devices == 0) {
node->attrs_["device_index"] = device_types;
}
// storage scope
std::vector<std::string> storage_scope;
for (const auto& virtual_device : storage_info->virtual_devices) {
storage_scope.push_back(std::string(virtual_device->memory_scope));
}
node->attrs_["storage_scope"] = std::move(storage_scope);
auto node_id = nodes_.size();
nodes_.push_back(node);
// Tuple return value, flatten as tuple
Expand Down Expand Up @@ -432,7 +438,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return AddNode(node, call);
}
} else if (!call_node->attrs.defined()) { // Call is an extern function
std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
const auto* func = call_node->op.as<GlobalVarNode>();
ICHECK(func) << "Expected the operator to be a global var, but got "
<< call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
Expand Down Expand Up @@ -529,12 +534,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
size_t num_entry = 0;
ShapeVector shapes;
std::vector<size_t> storage_ids;
std::vector<std::string> storage_scopes;
std::vector<size_t> device_types;
std::vector<std::string> dltypes;
std::vector<size_t> node_row_ptr{0};
for (auto node : nodes_) {
const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
const auto& storage_scope =
dmlc::get<std::vector<std::string>>(node->attrs_["storage_scope"]);
const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);

ICHECK_EQ(node->num_outputs_, shape_vec.size());
Expand All @@ -543,12 +551,25 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
storage_scopes.insert(storage_scopes.end(), storage_scope.begin(), storage_scope.end());
if (node->attrs_.count("device_index")) {
const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
}
node_row_ptr.push_back(num_entry);
}

// verification if storage_scope contains any non global memory scope
// in other case it's better not to write scopes to the JSON at all
bool global_only_scope = true;
for (const auto& ss : storage_scopes) {
if (!(ss.empty() || ss == "global")) {
global_only_scope = false;
}
}
if (global_only_scope) {
storage_scopes.clear();
}
writer->BeginObject();
writer->WriteObjectKeyValue("nodes", nodes_);
writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
Expand All @@ -562,6 +583,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
attrs["device_index"].emplace_back(std::string("list_int"));
attrs["device_index"].emplace_back(device_types);
}
if (storage_scopes.size()) {
attrs["storage_scope"].emplace_back(std::string("list_str"));
attrs["storage_scope"].emplace_back(storage_scopes);
}
attrs["dltype"].emplace_back(std::string("list_str"));
attrs["dltype"].emplace_back(dltypes);
writer->WriteObjectKeyValue("attrs", attrs);
Expand Down
Loading