Skip to content

Commit c3ca260

Browse files
Matthew Brookharttrevor-m
authored andcommitted
Fix QNN type inference (apache#7074)
* check for incomplete types in QNN Relation functions * add regression test from apache#7067 * respond to review comments
1 parent 800b523 commit c3ca260

File tree

7 files changed

+121
-15
lines changed

7 files changed

+121
-15
lines changed

src/relay/qnn/op/concatenate.cc

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,53 @@ namespace qnn {
3838

3939
bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4040
const TypeReporter& reporter) {
41+
// Expected Types: data, input_scales, input_zero_points, output_scale, output_zero_point,
42+
// out_type
4143
ICHECK_EQ(types.size(), 6);
4244

45+
if (types[0].as<IncompleteTypeNode>()) {
46+
return false;
47+
}
4348
// Check the scale and zero point types
4449
const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
4550
if (input_scales_tuple == nullptr) {
46-
throw Error(ErrorBuilder()
47-
<< "qnn concatenate requires a tuple of scales as the second argument, found "
48-
<< PrettyPrint(types[1]));
51+
if (types[1].as<IncompleteTypeNode>()) {
52+
return false;
53+
} else {
54+
throw Error(ErrorBuilder()
55+
<< "qnn concatenate requires a tuple of scales as the second argument, found "
56+
<< PrettyPrint(types[1]));
57+
}
4958
}
5059
for (const auto& input_scale : input_scales_tuple->fields) {
60+
if (input_scale.as<IncompleteTypeNode>()) {
61+
return false;
62+
}
5163
ICHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
5264
}
5365

5466
const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
5567
if (input_zero_points_tuple == nullptr) {
56-
throw Error(ErrorBuilder()
57-
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
58-
<< PrettyPrint(types[2]));
68+
if (types[2].as<IncompleteTypeNode>()) {
69+
return false;
70+
} else {
71+
throw Error(ErrorBuilder()
72+
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
73+
<< PrettyPrint(types[2]));
74+
}
5975
}
6076
for (const auto& input_zero_point : input_zero_points_tuple->fields) {
77+
if (input_zero_point.as<IncompleteTypeNode>()) {
78+
return false;
79+
}
6180
ICHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
6281
}
6382

83+
for (size_t i = 3; i < 5; ++i) {
84+
if (types[i].as<IncompleteTypeNode>()) {
85+
return false;
86+
}
87+
}
6488
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
6589
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
6690

src/relay/qnn/op/convolution.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ namespace qnn {
4242

4343
bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4444
const TypeReporter& reporter) {
45+
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
46+
// out_type
4547
ICHECK_EQ(types.size(), 7);
4648
const auto* data = types[0].as<TensorTypeNode>();
4749
const auto* weight = types[1].as<TensorTypeNode>();
@@ -57,22 +59,27 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5759
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
5860

5961
// Check the types of scale and zero points.
62+
for (size_t i = 2; i < 5; ++i) {
63+
if (types[i].as<IncompleteTypeNode>()) {
64+
return false;
65+
}
66+
}
6067
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
61-
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
68+
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
6269
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
6370
// Kernel scale can be a vector of length output_channels or a scalar.
6471
if (param->groups == 1) {
6572
size_t axis = param->kernel_layout.operator std::string().find('O');
6673
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
67-
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
74+
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
6875
} else {
6976
// Here, total number of output channels depend on depth multiplier.
7077
size_t o_axis = param->kernel_layout.operator std::string().find('O');
7178
size_t i_axis = param->kernel_layout.operator std::string().find('I');
7279
ICHECK(o_axis != std::string::npos || i_axis != std::string::npos)
7380
<< "Kernel layout attribute is not defined";
7481
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
75-
reporter); // kernel scale
82+
reporter); // weight_scale
7683
}
7784

7885
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay

src/relay/qnn/op/convolution_transpose.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ Array<Array<Layout>> QnnConvTransposeInferCorrectLayout(
8181

8282
bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
8383
const TypeReporter& reporter) {
84+
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
85+
// out_type
8486
ICHECK_EQ(types.size(), 7);
8587
const auto* data = types[0].as<TensorTypeNode>();
8688
const auto* weight = types[1].as<TensorTypeNode>();
@@ -96,14 +98,19 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
9698
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
9799

98100
// Check the types of scale and zero points.
101+
for (size_t i = 2; i < 5; ++i) {
102+
if (types[i].as<IncompleteTypeNode>()) {
103+
return false;
104+
}
105+
}
99106
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
100-
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
107+
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
101108
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
102109
// Kernel scale can be a vector of length output_channels or a scalar.
103110
if (param->groups == 1) {
104111
size_t axis = param->kernel_layout.find('O');
105112
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
106-
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
113+
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
107114
} else {
108115
// Here, total number of output channels depend on depth multiplier.
109116
size_t o_axis = param->kernel_layout.find('O');

src/relay/qnn/op/dense.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ namespace qnn {
3939

4040
bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4141
const TypeReporter& reporter) {
42+
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
43+
// out_type
4244
ICHECK_EQ(types.size(), 7);
4345
const auto* data = types[0].as<TensorTypeNode>();
4446
const auto* weight = types[1].as<TensorTypeNode>();
@@ -53,10 +55,15 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5355
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
5456

5557
// Check the types of scale and zero points.
56-
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
57-
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
58-
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
59-
AssignType(types[5], DataType::Float(32), param->units, reporter);
58+
for (size_t i = 2; i < 5; ++i) {
59+
if (types[i].as<IncompleteTypeNode>()) {
60+
return false;
61+
}
62+
}
63+
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
64+
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
65+
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
66+
AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale
6067

6168
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
6269

src/relay/qnn/op/op_common.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,22 @@ inline Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,
168168

169169
static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
170170
const TypeReporter& reporter) {
171+
// Expected Types: lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
172+
// output_zero_point, out_type
171173
ICHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);
172174

175+
// Check the lhs and rhs types
176+
for (size_t i = 0; i < 2; ++i) {
177+
if (types[i].as<IncompleteTypeNode>()) {
178+
return false;
179+
}
180+
}
173181
// Check the scale and zero point types
182+
for (size_t i = 2; i < 8; ++i) {
183+
if (types[i].as<IncompleteTypeNode>()) {
184+
return false;
185+
}
186+
}
174187
ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
175188
ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
176189
ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale

src/relay/qnn/op/requantize.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,20 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
256256
*/
257257
bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
258258
const TypeReporter& reporter) {
259+
// Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, output
259260
ICHECK_EQ(types.size(), 6);
260261
const auto* data = types[0].as<TensorTypeNode>();
261262

262263
if (data == nullptr) {
263264
return false;
264265
}
265266

267+
// Check the scale and zero point types
268+
for (size_t i = 3; i < 5; ++i) {
269+
if (types[i].as<IncompleteTypeNode>()) {
270+
return false;
271+
}
272+
}
266273
const auto in_dtype = data->dtype;
267274
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
268275
in_dtype == DataType::Int(32))

tests/python/frontend/pytorch/qnn_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,58 @@
3232
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
3333
from tvm.contrib.download import download_testdata
3434

35+
from tvm.relay.dataflow_pattern import wildcard, is_op
36+
from tvm.relay.op.contrib.register import register_pattern_table
37+
from tvm.relay.op.contrib.register import get_pattern_table
38+
3539

3640
def torch_version_check():
3741
from packaging import version
3842

3943
return version.parse(torch.__version__) > version.parse("1.4.0")
4044

4145

46+
def make_qnn_add_pattern():
47+
lhs = wildcard()
48+
rhs = wildcard()
49+
lhs_scale = wildcard()
50+
lhs_zero_point = wildcard()
51+
rhs_scale = wildcard()
52+
rhs_zero_point = wildcard()
53+
output_scale = wildcard()
54+
output_zero_point = wildcard()
55+
qadd = is_op("qnn.add")(
56+
lhs,
57+
rhs,
58+
lhs_scale,
59+
lhs_zero_point,
60+
rhs_scale,
61+
rhs_zero_point,
62+
output_scale,
63+
output_zero_point,
64+
)
65+
return qadd.optional(is_op("clip"))
66+
67+
68+
@register_pattern_table("test_table")
69+
def pattern_table():
70+
return [
71+
("qnn_add", make_qnn_add_pattern()),
72+
]
73+
74+
4275
def get_tvm_runtime(script_module, input_name, ishape):
4376

4477
input_shapes = [(input_name, ishape)]
4578
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
79+
pattern_table = get_pattern_table("test_table")
80+
with tvm.transform.PassContext(opt_level=3):
81+
pass_list = [
82+
tvm.relay.transform.SimplifyInference(),
83+
tvm.relay.transform.MergeComposite(pattern_table),
84+
]
85+
composite_partition = tvm.transform.Sequential(pass_list)
86+
partitioned = composite_partition(mod)
4687

4788
with tvm.transform.PassContext(opt_level=3):
4889
# test on only cpu for now, torch cannot run quant models on cuda

0 commit comments

Comments
 (0)