Skip to content

Commit 2f431e4

Browse files
guomingzftian1
authored andcommitted
Disable the Scale Propagation if quantizedavgpooling op has multi outputs. (#147)
Signed-off-by: Zhang, Guoming <[email protected]>
1 parent 7789607 commit 2f431e4

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lpot/adaptor/tf_utils/graph_rewriter/int8/scale_propagation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,17 @@ def _cac_transformation(self):
6565

6666
for match in target_nodes:
6767
pre_node_name = match[0]
68+
6869
pre_node = self.graph_info[pre_node_name].node
6970

7071
output_nodes_count = len(set(self.graph_info[pre_node_name].outputs))
7172

7273
if output_nodes_count > 1:
7374
continue
75+
# Skip transformation if avgpool has multi output nodes.
76+
pooling_nodes_count = len(set(self.graph_info[match[1]].outputs))
77+
if pooling_nodes_count > 1:
78+
continue
7479

7580
if pre_node.op == 'QuantizeV2':
7681
pre_min_index, pre_max_index = quantize_v2_min_index, quantize_v2_max_index
@@ -87,7 +92,6 @@ def _cac_transformation(self):
8792

8893
requantize_min_value = (requantize_min.attr['value'].tensor.float_val)[0]
8994
requantize_max_value = (requantize_max.attr['value'].tensor.float_val)[0]
90-
9195
self._create_new_const_node(pre_node_name + '_cac_requantize_min_value',
9296
requantize_min_value, pre_node.input[pre_min_index])
9397
self._create_new_const_node(pre_node_name + '_cac_requantize_max_value',

0 commit comments

Comments
 (0)