Skip to content

Commit 32341dd

Browse files
committed
[QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2.
1 parent 02c1e11 commit 32341dd

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,18 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
224224
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \
225225
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point']
226226

227+
def is_quantized(self, op):
228+
"""Check if an input tensor is quantized."""
229+
try:
230+
from tflite.Operator import Operator
231+
except ImportError:
232+
raise ImportError("The tflite package must be installed")
233+
234+
assert isinstance(op, Operator)
235+
input_tensors = self.get_input_tensors(op)
236+
first_tensor = input_tensors[0]
237+
return first_tensor.qnn_params is not None
238+
227239
def convert_conv2d(self, op):
228240
"""Convert TFLite conv2d"""
229241
return self.convert_conv(op, "conv2d")
@@ -498,7 +510,25 @@ def _convert_elemwise(self, relay_op, op):
498510
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
499511
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
500512
dtype=rhs_type_str)
501-
out = relay_op(lhs_expr, rhs_expr)
513+
514+
output_tensors = self.get_output_tensors(op)
515+
assert len(output_tensors) == 1, "output tensors length should be 1"
516+
output_tensor = output_tensors[0]
517+
518+
# If quantized, extracts qnn params and call QNN add operator.
519+
if lhs_tensor.qnn_params:
520+
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
521+
assert output_tensor.qnn_params, "Output tensor should be quantized."
522+
out = relay_op(lhs=lhs_expr,
523+
rhs=rhs_expr,
524+
lhs_scale=lhs_tensor.qnn_params['scale'],
525+
lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
526+
rhs_scale=rhs_tensor.qnn_params['scale'],
527+
rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
528+
output_scale=output_tensor.qnn_params['scale'],
529+
output_zero_point=output_tensor.qnn_params['zero_point'])
530+
else:
531+
out = relay_op(lhs_expr, rhs_expr)
502532

503533
# Options (fused_activation_function)
504534
options = None
@@ -517,36 +547,70 @@ def _convert_elemwise(self, relay_op, op):
517547
fused_activation_fn = options.FusedActivationFunction()
518548
# if we have activation fn
519549
if fused_activation_fn != ActivationFunctionType.NONE:
550+
if output_tensor.qnn_params:
551+
raise tvm.error.OpNotImplemented(
552+
'Elemwise operators with fused activation are not supported yet.')
520553
out = self.convert_fused_activation_function(out, fused_activation_fn)
521554

522555
return out
523556

524557
def convert_add(self, op):
525558
"""Convert TFLite ADD"""
559+
# Check if the input tensor is quantized, call QNN op
560+
if self.is_quantized(op):
561+
return self._convert_elemwise(_qnn.op.add, op)
526562
return self._convert_elemwise(_op.add, op)
527563

528564
def convert_sub(self, op):
529565
"""Convert TFLite SUB"""
566+
# Check if the input tensor is quantized, call QNN op
567+
if self.is_quantized(op):
568+
raise tvm.error.OpNotImplemented(
569+
'TFlite quantized sub operator is not supported yet.')
530570
return self._convert_elemwise(_op.subtract, op)
531571

532572
def convert_mul(self, op):
533573
"""Convert TFLite MUL"""
574+
# Check if the input tensor is quantized, call QNN op
575+
if self.is_quantized(op):
576+
raise tvm.error.OpNotImplemented(
577+
'TFlite quantized mul operator is not supported yet.')
534578
return self._convert_elemwise(_op.multiply, op)
535579

536580
def convert_div(self, op):
537581
"""Convert TFLite DIV"""
582+
# Check if the input tensor is quantized, call QNN op
583+
if self.is_quantized(op):
584+
raise tvm.error.OpNotImplemented(
585+
'TFlite quantized div operator is not supported yet.')
538586
return self._convert_elemwise(_op.divide, op)
539587

540588
def convert_pow(self, op):
589+
# Check if the input tensor is quantized, call QNN op
590+
if self.is_quantized(op):
591+
raise tvm.error.OpNotImplemented(
592+
'TFlite quantized pow operator is not supported yet.')
541593
return self._convert_elemwise(_op.power, op)
542594

543595
def convert_maximum(self, op):
596+
# Check if the input tensor is quantized, call QNN op
597+
if self.is_quantized(op):
598+
raise tvm.error.OpNotImplemented(
599+
'TFlite quantized maximum operator is not supported yet.')
544600
return self._convert_elemwise(_op.maximum, op)
545601

546602
def convert_minimum(self, op):
603+
# Check if the input tensor is quantized, call QNN op
604+
if self.is_quantized(op):
605+
raise tvm.error.OpNotImplemented(
606+
'TFlite quantized minimum operator is not supported yet.')
547607
return self._convert_elemwise(_op.minimum, op)
548608

549609
def convert_greater(self, op):
610+
# Check if the input tensor is quantized, call QNN op
611+
if self.is_quantized(op):
612+
raise tvm.error.OpNotImplemented(
613+
'TFlite quantized greater operator is not supported yet.')
550614
return self._convert_elemwise(_op.greater, op)
551615

552616
def convert_zeros_like(self, op):

tests/python/frontend/tflite/test_forward.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,26 @@ def test_forward_qnn_mobilenet_v1_net():
10371037
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
10381038
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
10391039

1040+
def test_forward_qnn_mobilenet_v2_net():
1041+
"""Test the Quantized TFLite Mobilenet V2 model."""
1042+
# MobilenetV2
1043+
tflite_model_file = tf_testing.get_workload_official(
1044+
"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz",
1045+
"mobilenet_v2_1.0_224_quant.tflite")
1046+
with open(tflite_model_file, "rb") as f:
1047+
tflite_model_buf = f.read()
1048+
# Checking the labels because the requantize implementation is different between TFLite and
1049+
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
1050+
np.random.seed(0)
1051+
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
1052+
tflite_output = run_tflite_graph(tflite_model_buf, data)
1053+
tflite_predictions = np.squeeze(tflite_output)
1054+
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
1055+
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
1056+
tvm_predictions = np.squeeze(tvm_output)
1057+
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
1058+
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
1059+
10401060
#######################################################################
10411061
# SSD Mobilenet
10421062
# -------------
@@ -1111,6 +1131,6 @@ def test_forward_ssd_mobilenet_v1():
11111131
test_forward_ssd_mobilenet_v1()
11121132

11131133
# End to End quantized
1114-
# TODO - MobilenetV2 fails for now. Remove when fixed.
11151134
test_forward_qnn_inception_v1_net()
11161135
test_forward_qnn_mobilenet_v1_net()
1136+
test_forward_qnn_mobilenet_v2_net()

0 commit comments

Comments
 (0)