Skip to content

Commit 35fdf8b

Browse files
authored
[relay][qnn]: Fix qnn.avg_pool2d layout inference (#17339)
1 parent e468426 commit 35fdf8b

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

src/relay/qnn/op/avg_pool2d.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,11 @@ InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
132132
auto avgpool_new_layouts =
133133
PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts, old_in_layouts, old_in_types);
134134

135-
// Scales and zero points are scalars, use the "undef" layout for them.
136-
Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0], Layout::Undef(),
137-
Layout::Undef(), Layout::Undef(), Layout::Undef()};
135+
// Scales and zero points are scalars, the layouts of these tensors can be treated as channel
136+
// layout.
137+
Layout channel_layout = Layout("C");
138+
Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0], channel_layout,
139+
channel_layout, channel_layout, channel_layout};
138140
Array<Layout> output_layouts = avgpool_new_layouts->output_layouts;
139141
return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs);
140142
}

tests/python/relay/test_pass_convert_op_layout.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,85 @@ def expected():
15421542
tvm.ir.assert_structural_equal(a, b)
15431543

15441544

1545+
def test_qnn_conv_avgpool_2d_convert_layout():
1546+
def before():
1547+
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
1548+
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
1549+
y = relay.qnn.op.conv2d(
1550+
x,
1551+
weight,
1552+
relay.const(1, "int32"),
1553+
relay.const(1, "int32"),
1554+
relay.const(1, "float32"),
1555+
relay.const(1, "float32"),
1556+
channels=64,
1557+
kernel_size=(3, 3),
1558+
padding=(1, 1),
1559+
data_layout="NHWC",
1560+
kernel_layout="HWIO",
1561+
)
1562+
y = relay.cast(y, "int8")
1563+
y = relay.qnn.op.avg_pool2d(
1564+
y,
1565+
relay.const(1, "float32"),
1566+
relay.const(1, "int32"),
1567+
relay.const(1, "float32"),
1568+
relay.const(1, "int32"),
1569+
layout="NHWC",
1570+
out_layout="NHWC",
1571+
pool_size=(3, 3),
1572+
padding=(0, 0),
1573+
strides=(1, 1),
1574+
dilation=(1, 1),
1575+
)
1576+
y = relay.Function([x, weight], y)
1577+
return y
1578+
1579+
def expected():
1580+
x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
1581+
weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
1582+
x = relay.layout_transform(x, "NHWC", "NCHW")
1583+
weight = relay.layout_transform(weight, "HWIO", "OIHW")
1584+
y = relay.qnn.op.conv2d(
1585+
x,
1586+
weight,
1587+
relay.const(1, "int32"),
1588+
relay.const(1, "int32"),
1589+
relay.const(1, "float32"),
1590+
relay.const(1, "float32"),
1591+
channels=64,
1592+
kernel_size=(3, 3),
1593+
padding=(1, 1),
1594+
data_layout="NCHW",
1595+
kernel_layout="OIHW",
1596+
)
1597+
y = relay.cast(y, "int8")
1598+
y = relay.qnn.op.avg_pool2d(
1599+
y,
1600+
relay.const(1, "float32"),
1601+
relay.const(1, "int32"),
1602+
relay.const(1, "float32"),
1603+
relay.const(1, "int32"),
1604+
layout="NCHW",
1605+
out_layout="NCHW",
1606+
pool_size=(3, 3),
1607+
padding=(0, 0),
1608+
strides=(1, 1),
1609+
dilation=(1, 1),
1610+
)
1611+
y = relay.layout_transform(y, "NCHW", "NHWC")
1612+
y = relay.Function(relay.analysis.free_vars(y), y)
1613+
return y
1614+
1615+
a = before()
1616+
a = run_opt_pass(
1617+
a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"], "qnn.avg_pool2d": ["NCHW"]})
1618+
)
1619+
b = run_opt_pass(expected(), transform.InferType())
1620+
1621+
tvm.ir.assert_structural_equal(a, b)
1622+
1623+
15451624
def test_conv_roi_align_convert_layout():
15461625
def before():
15471626
x = relay.var("x", shape=(1, 64, 56, 56))

0 commit comments

Comments
 (0)