Skip to content

Commit fee4f73

Browse files
Ubuntuanijain2305
authored andcommitted
[QNN][Relay] Calling Dialect passes from inside Relay Build API.
1 parent 4ba911a commit fee4f73

File tree

10 files changed

+178
-34
lines changed

10 files changed

+178
-34
lines changed

include/tvm/relay/op.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ class Op : public relay::Expr {
153153
*/
154154
template <typename ValueType>
155155
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
156+
/*!
157+
* \brief Checks if an attr is present in the registry.
158+
* \param attr_name The name of the attribute.
159+
* \return An OpMap of specified attr_name.
160+
* \tparam bool True if the attr is present.
161+
*/
162+
inline static bool HasAttr(const std::string& attr_name);
156163
/*!
157164
* \brief Get an Op for a given operator name.
158165
* Will raise an error if the op has not been registered.
@@ -171,6 +178,12 @@ class Op : public relay::Expr {
171178
* \return reference to GenericOpMap
172179
*/
173180
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
181+
/*!
182+
* \brief Checks if the key is present in the registryg
183+
* \param key The attribute key
184+
* \return bool True if the key is present
185+
*/
186+
TVM_DLL static const bool HasGenericAttr(const std::string& key);
174187
};
175188

176189
/*! \brief Helper structure to register operators */
@@ -393,6 +406,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
393406
return OpMap<ValueType>(Op::GetGenericAttr(key));
394407
}
395408

409+
inline bool Op::HasAttr(const std::string& key) {
410+
return Op::HasGenericAttr(key);
411+
}
412+
396413
inline OpNode* OpRegistry::get() {
397414
return const_cast<OpNode*>(op_.operator->());
398415
}

include/tvm/relay/qnn/transform.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/qnn/transform.h
22+
*
23+
* This file implements a pass manager for QNN ops using Relay Pass manager.
24+
*/
25+
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
26+
#define TVM_RELAY_QNN_TRANSFORM_H_
27+
28+
#include <tvm/runtime/c_runtime_api.h>
29+
#include <tvm/relay/transform.h>
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
using relay::transform::Pass;
35+
36+
namespace qnn {
37+
namespace transform {
38+
39+
/*!
40+
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
41+
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
42+
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
43+
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
44+
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
45+
* only QNN ops. One can register a transformation/legalization function for an op by using the
46+
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
47+
* gives us separation of concerns, leading to a better software practice. The legalization can be
48+
* configured to happen per target.
49+
*
50+
* \return The pass.
51+
*/
52+
TVM_DLL Pass Legalize();
53+
54+
} // namespace transform
55+
56+
} // namespace qnn
57+
} // namespace relay
58+
} // namespace tvm
59+
60+
#endif // TVM_RELAY_QNN_TRANSFORM_H_

src/relay/backend/build_module.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/runtime/vm.h>
2828
#include <tvm/relay/expr.h>
2929
#include <tvm/relay/transform.h>
30+
#include <tvm/relay/qnn/transform.h>
3031
#include <memory>
3132

3233
#include "utils.h"
@@ -282,6 +283,15 @@ class RelayBuildModule : public runtime::ModuleNode {
282283
const TargetsMap& targets,
283284
const std::unordered_map<std::string, runtime::NDArray>& params) {
284285
Array<Pass> pass_seqs;
286+
287+
// Run all dialect legalization passes.
288+
pass_seqs.push_back(relay::qnn::transform::Legalize());
289+
290+
// Legalize pass is restricted to homogeneous execution for now.
291+
if (targets.size() == 1) {
292+
pass_seqs.push_back(transform::Legalize());
293+
}
294+
285295
pass_seqs.push_back(transform::SimplifyInference());
286296
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
287297
Expr expr = args[0];
@@ -304,11 +314,6 @@ class RelayBuildModule : public runtime::ModuleNode {
304314
pass_seqs.push_back(transform::CanonicalizeCast());
305315
pass_seqs.push_back(transform::CanonicalizeOps());
306316

307-
// Legalize pass is restricted to homogeneous execution for now.
308-
if (targets.size() == 1) {
309-
pass_seqs.push_back(transform::Legalize());
310-
}
311-
312317
// Alter layout transformation is only applied to homogeneous execution yet.
313318
if (targets.size() == 1) {
314319
pass_seqs.push_back(transform::AlterOpLayout());

src/relay/ir/op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
8484
return *it->second.get();
8585
}
8686

87+
// Check if a key is present in the registry.
88+
const bool Op::HasGenericAttr(const std::string& key) {
89+
OpManager* mgr = OpManager::Global();
90+
std::lock_guard<std::mutex> lock(mgr->mutex);
91+
auto it = mgr->attr.find(key);
92+
if (it == mgr->attr.end()) {
93+
return false;
94+
}
95+
return true;
96+
}
97+
8798
void OpRegistry::UpdateAttr(const std::string& key,
8899
TVMRetValue value,
89100
int plevel) {

src/relay/pass/legalize.cc

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
4646
Expr new_e = ExprMutator::VisitExpr_(call_node);
4747
Call new_call = Downcast<Call>(new_e);
4848

49+
// Check if the string is registered in the OpRegistry.
50+
if (!Op::HasAttr(legalize_map_attr_name_)) {
51+
return new_e;
52+
}
53+
4954
// Collect the registered legalize function.
5055
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
51-
Op op = Downcast<Op>(call_node->op);
52-
53-
if (fop_legalize.count(op)) {
54-
// Collect the new_args.
55-
tvm::Array<Expr> call_args = new_call->args;
56-
57-
// Collect input and output dtypes to pass on to Legalize API.
58-
tvm::Array<tvm::relay::Type> types;
59-
for (auto arg : call_node->args) {
60-
types.push_back(arg->checked_type());
61-
}
62-
types.push_back(call_node->checked_type());
63-
64-
// Transform the op by calling the registered legalize function.
65-
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
66-
67-
// Reassign new_e if the transformation succeeded.
68-
if (legalized_value.defined()) {
69-
// Check that the returned Expr from legalize is CallNode.
70-
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
71-
CHECK(legalized_call_node)
72-
<< "Can only replace the original operator with another call node";
73-
74-
new_e = legalized_value;
56+
auto call_op = call_node->op;
57+
if (call_op.as<OpNode>()) {
58+
Op op = Downcast<Op>(call_node->op);
59+
60+
if (fop_legalize.count(op)) {
61+
// Collect the new_args.
62+
tvm::Array<Expr> call_args = new_call->args;
63+
64+
// Collect input and output dtypes to pass on to Legalize API.
65+
tvm::Array<tvm::relay::Type> types;
66+
for (auto arg : call_node->args) {
67+
types.push_back(arg->checked_type());
68+
}
69+
types.push_back(call_node->checked_type());
70+
71+
// Transform the op by calling the registered legalize function.
72+
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
73+
74+
// Reassign new_e if the transformation succeeded.
75+
if (legalized_value.defined()) {
76+
// Check that the returned Expr from legalize is CallNode.
77+
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
78+
CHECK(legalized_call_node)
79+
<< "Can only replace the original operator with another call node";
80+
81+
new_e = legalized_value;
82+
}
7583
}
7684
}
7785

@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
95103
[=](Function f, Module m, PassContext pc) {
96104
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
97105
};
98-
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
106+
return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
99107
}
100108

101109
TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);

src/relay/qnn/pass/legalize.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file relay/qnn/pass/legalize.cc
22+
* \brief The Legalize wrapper for QNN.
23+
*/
24+
25+
#include <tvm/relay/qnn/transform.h>
26+
27+
namespace tvm {
28+
namespace relay {
29+
namespace qnn {
30+
31+
namespace transform {
32+
33+
Pass Legalize() {
34+
Array<Pass> pass_seqs;
35+
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
36+
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
37+
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
38+
return seq;
39+
}
40+
41+
TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);
42+
43+
} // namespace transform
44+
45+
} // namespace qnn
46+
} // namespace relay
47+
} // namespace tvm

tests/python/relay/test_op_qnn_conv2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def get_qnn_func(data,
7878

7979
mod = relay.Function(relay.analysis.free_vars(func), func)
8080
mod = relay.Module.from_expr(mod)
81-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
8281
return mod
8382

8483
def get_funcs(data_shape,

tests/python/relay/test_op_qnn_dequantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
3131
input_zero_point=input_zero_point)
3232
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
3333
mod = relay.Module.from_expr(mod)
34-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
3534
with relay.build_config(opt_level=3):
3635
graph, lib, params = relay.build(mod, "llvm", params=None)
3736
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))

tests/python/relay/test_op_qnn_quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output
3131
output_zero_point=output_zero_point,out_dtype=out_dtype)
3232
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
3333
mod = relay.Module.from_expr(mod)
34-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
3534
with relay.build_config(opt_level=3):
3635
graph, lib, params = relay.build(mod, "llvm", params=None)
3736
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))

tests/python/relay/test_op_qnn_requantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
4949

5050
mod = relay.Function(relay.analysis.free_vars(mod), mod)
5151
mod = relay.Module.from_expr(mod)
52-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
5352
return mod
5453

5554
def same_scale_test():

0 commit comments

Comments
 (0)