From 06ab153421ffd68fb06aac0eb610f0ea3dac5a4b Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 18 Aug 2021 14:58:07 +0900 Subject: [PATCH 1/5] [Frontend][TFLite] Implement fake quant --- python/tvm/relay/frontend/tflite.py | 52 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 20 +++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index db6e053628bf..b55fb85d4955 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -91,6 +91,7 @@ def __init__(self, model, subgraph, exp_tab): "EQUAL": self.convert_equal, "EXP": self.convert_exp, "EXPAND_DIMS": self.convert_expand_dims, + "FAKE_QUANT": self.convert_fake_quant, "FILL": self.convert_fill, "FLOOR_DIV": self.convert_floor_div, "FLOOR_MOD": self.convert_floor_mod, @@ -3336,6 +3337,57 @@ def convert_densify(self, op): self.set_prefetched_node(output_tensor.tensor_idx, dense_weight) + def convert_fake_quant(self, op): + """Convert TFLite FAKE_QUANT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + from tflite.BuiltinOptions import BuiltinOptions + from tflite.FakeQuantOptions import FakeQuantOptions + + assert op.BuiltinOptionsType() == BuiltinOptions.FakeQuantOptions + + op_options = op.BuiltinOptions() + fake_quant_options = FakeQuantOptions() + fake_quant_options.Init(op_options.Bytes, op_options.Pos) + + min = fake_quant_options.Min() + max = fake_quant_options.Max() + narrow_range = fake_quant_options.NarrowRange() + num_bits = fake_quant_options.NumBits() + + assert 2 <= num_bits <= 16 + + quant_min = 1 if narrow_range else 0 + quant_max = (1 << num_bits) - 1 + scale = (max - min) / (quant_max - quant_min) + + zero_point_from_min = quant_min - min / scale + if zero_point_from_min <= quant_min: + nudged_zero_point = quant_min + elif zero_point_from_min >= quant_max: + nudged_zero_point = quant_max + else: + nudged_zero_point = round(zero_point_from_min) + + nudged_min = (quant_min - nudged_zero_point) * scale + nudged_max = (quant_max - nudged_zero_point) * scale + + nudged_min_expr = _op.const(nudged_min) + nudged_max_expr = _op.const(nudged_max) + clamped = _op.clip(in_expr, nudged_min, nudged_max) + clamped_shifted = _op.subtract(clamped, nudged_min_expr) + + half = _op.const(0.5) + one = _op.const(1.0) + scale_expr = _op.const(scale) + inv_scale = _op.divide(one, scale_expr) + rounded = _op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) + return _op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b7f1b1c43b8..9ddc750c716b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -322,7 +322,6 @@ def compare_tflite_with_tvm( out_names=out_names, mode=mode, ) - # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range() @@ -2618,6 +2617,25 @@ def test_forward_select(): ) +def _test_fake_quant(value, min, max, num_bits): + with tf.Graph().as_default(): + with tf.Session() as sess: + input = tf.placeholder(tf.float32, shape=[1], name="input") + out = tf.quantization.fake_quant_with_min_max_args( + input, min=min, max=max, num_bits=num_bits, name=None + ) + + in_data = np.float32(value) + compare_tflite_with_tvm([in_data], ["input:0"], [input], [out]) + + +def test_forward_fake_quant(): + for quant_bits in [2, 4, 8, 16]: + _test_fake_quant(-10.11, -6, 6, quant_bits) + _test_fake_quant(3.55, -6, 6, quant_bits) + _test_fake_quant(10.11, -6, 6, quant_bits) + + # Squeeze # ------- From 476b717b981cfa902afd860589a5fed0ee4a51c8 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 18 Aug 2021 15:10:38 +0900 Subject: [PATCH 2/5] remove unused variable --- python/tvm/relay/frontend/tflite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b55fb85d4955..2cffc67e57a6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3377,7 +3377,6 @@ def convert_fake_quant(self, op): nudged_max = (quant_max - nudged_zero_point) * scale nudged_min_expr = _op.const(nudged_min) - nudged_max_expr = _op.const(nudged_max) clamped = _op.clip(in_expr, nudged_min, nudged_max) clamped_shifted = _op.subtract(clamped, nudged_min_expr) From 23697e95c3c64c3034bad786325499ad865305f7 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 18 Aug 2021 15:26:15 +0900 Subject: [PATCH 3/5] fix linting errors --- python/tvm/relay/frontend/tflite.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2cffc67e57a6..4d607e46c97f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3354,8 +3354,8 @@ def convert_fake_quant(self, op): fake_quant_options = FakeQuantOptions() fake_quant_options.Init(op_options.Bytes, op_options.Pos) - min = fake_quant_options.Min() - max = fake_quant_options.Max() + opt_min = fake_quant_options.Min() + opt_max = fake_quant_options.Max() narrow_range = fake_quant_options.NarrowRange() num_bits = fake_quant_options.NumBits() @@ -3363,9 +3363,9 @@ def convert_fake_quant(self, op): quant_min = 1 if narrow_range else 0 quant_max = (1 << num_bits) - 1 - scale = (max - min) / (quant_max - quant_min) + scale = (opt_max - opt_min) / (quant_max - quant_min) - zero_point_from_min = quant_min - min / scale + zero_point_from_min = quant_min - opt_min / scale if zero_point_from_min <= quant_min: nudged_zero_point = quant_min elif zero_point_from_min >= quant_max: From c95da2a07ed5968f5daddc3072fce43f95c8f58c Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 18 Aug 2021 21:20:22 +0900 Subject: [PATCH 4/5] add more tests --- tests/python/frontend/tflite/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9ddc750c716b..92d1346aef67 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2632,6 +2632,8 @@ def _test_fake_quant(value, min, max, num_bits): def test_forward_fake_quant(): for quant_bits in [2, 4, 8, 16]: _test_fake_quant(-10.11, -6, 6, quant_bits) + _test_fake_quant(-3.55, -6, 6, quant_bits) + _test_fake_quant(0, -6, 6, quant_bits) _test_fake_quant(3.55, -6, 6, quant_bits) _test_fake_quant(10.11, -6, 6, quant_bits) From 6944b711b28e3a7a7d450bfaadb46a3792689bdd Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Thu, 19 Aug 2021 09:41:51 +0900 Subject: [PATCH 5/5] use pytest parametrize instead of a separate function --- tests/python/frontend/tflite/test_forward.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 92d1346aef67..f2941030f0ab 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2617,27 +2617,22 @@ def test_forward_select(): ) -def _test_fake_quant(value, min, max, num_bits): +@pytest.mark.parametrize("quant_bits", [2, 4, 8, 16]) +@pytest.mark.parametrize( + "value, min, max", [[-10.11, -6, 6], [-3.55, -6, 6], [0, -6, 6], [3.55, -6, 6], [10.11, -6, 6]] +) +def test_forward_fake_quant(value, min, max, quant_bits): with tf.Graph().as_default(): with tf.Session() as sess: input = tf.placeholder(tf.float32, shape=[1], name="input") out = tf.quantization.fake_quant_with_min_max_args( - input, min=min, max=max, num_bits=num_bits, name=None + input, min=min, max=max, num_bits=quant_bits, name=None ) in_data = np.float32(value) compare_tflite_with_tvm([in_data], ["input:0"], [input], [out]) -def test_forward_fake_quant(): - for quant_bits in [2, 4, 8, 16]: - _test_fake_quant(-10.11, -6, 6, quant_bits) - _test_fake_quant(-3.55, -6, 6, quant_bits) - _test_fake_quant(0, -6, 6, quant_bits) - _test_fake_quant(3.55, -6, 6, quant_bits) - _test_fake_quant(10.11, -6, 6, quant_bits) - - # Squeeze # -------