From 45d82bbf1e3a0fdc08a7c9d158a5a6140e889554 Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Tue, 7 Feb 2023 11:27:32 +0000 Subject: [PATCH 1/4] [TFLite] Support for BATCH_MATMUL tflite operator Adds support for BATCH_MATMUL operator in the TFLite frontend. Adds a test that checks supported TFLite types. --- python/tvm/relay/frontend/tflite.py | 150 +++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 71 +++++++-- 2 files changed, 212 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index db21fa6668d1..98f914d667a5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -32,7 +32,9 @@ from .. import op as _op from .. import qnn as _qnn from .common import ExprTable +from .common import fold_constant as _fold_constant from .common import infer_shape as _infer_shape +from .common import infer_type as _infer_type from .common import lstm_cell, to_int_list, shape_of, try_infer_value from .common import set_span from .tflite_flexbuffer import FlexBufferDecoder @@ -80,6 +82,7 @@ def __init__(self, model, subgraph, exp_tab): "ARG_MIN": self.convert_arg_min, "AVERAGE_POOL_2D": self.convert_average_pool2d, "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, + "BATCH_MATMUL": self.convert_batch_matmul, "CAST": self.convert_cast, "CEIL": self.convert_ceil, "CONCATENATION": self.convert_concatenation, @@ -492,6 +495,21 @@ def get_tensor_type_str(self, tensor_type): "Tensor type {} is currently not supported".format(str(tensor_type)) ) + def flatten_to_nd(self, x, x_shape, nd=3): + """Flatten input tensor to nd rank""" + ndims = _infer_shape(x_shape)[0] + if ndims == nd: + return x + newshape = _op.concatenate( + [ + _expr.const([-1], dtype=_infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), + ], + 0, + ) + out = _op.reshape(x, _fold_constant(newshape)) + return out + def has_same_qnn_params(self, lhs_tensor, rhs_tensor): lhs_scale = lhs_tensor.qnn_params["scale"] rhs_scale = rhs_tensor.qnn_params["scale"] @@ -2959,6 +2977,138 @@ def convert_batch_to_space_nd(self, op): return out + def convert_batch_matmul(self, op): + """batch_matmul implementation.""" + try: + from tflite.BatchMatMulOptions import BatchMatMulOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + output_tensor = self.get_output_tensors(op) + + assert len(input_tensors) == 2, "two input tensor arguments expected" + + batch_matmul_options = BatchMatMulOptions() + op_options = op.BuiltinOptions() + batch_matmul_options.Init(op_options.Bytes, op_options.Pos) + + input_a = self.get_expr(input_tensors[0].tensor_idx) + input_b = self.get_expr(input_tensors[1].tensor_idx) + + shape_a = shape_of(input_a) + shape_b = shape_of(input_b) + rank_a = _infer_shape(shape_a)[0] + rank_b = _infer_shape(shape_b)[0] + + if rank_a > 2 or rank_b > 2: + # Determine the output batch dimension + new_a_shape = shape_a + new_b_shape = shape_b + if rank_a > rank_b: + rank_diff = rank_a - rank_b + new_b_shape = _op.concatenate( + [ + _expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype), + shape_b, + ], + 0, + ) + elif rank_a < rank_b: + rank_diff = rank_b - rank_a + new_a_shape = _op.concatenate( + [ + _expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype), + shape_a, + ], + 0, + ) + else: + pass + + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(new_b_shape, [i], [i + 1]), + _op.strided_slice(new_a_shape, [i], [i + 1]), + ) + for i in range(max(rank_a, rank_b) - 2) + ], + 0, + ) + + out_batch_shape = _fold_constant(out_batch) + + a_broadcasted_shape = _fold_constant( + _op.concatenate( + [ + out_batch, + _op.strided_slice(shape_a, [rank_a - 2], [rank_a]), + ], + 0, + ) + ) + b_broadcasted_shape = _fold_constant( + _op.concatenate( + [ + out_batch, + _op.strided_slice(shape_b, [rank_b - 2], [rank_b]), + ], + 0, + ) + ) + if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape): + input_a = _op.transform.broadcast_to(a, a_broadcasted_shape) + if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape): + input_b = _op.transform.broadcast_to(b, b_broadcasted_shape) + + input_a = self.flatten_to_nd(input_a, shape_a, 3) + input_b = self.flatten_to_nd(input_b, shape_b, 3) + + if batch_matmul_options.AdjX(): + input_a = _op.transpose(input_a, [0, 2, 1]) + if batch_matmul_options.AdjY() == False: + input_b = _op.transpose(input_b, [0, 2, 1]) + + if self.is_quantized(op): + output = _qnn.op.batch_matmul( + input_a, + input_b, + relay.const(0, "int32"), + relay.const(0, "int32"), + relay.const(1.0, "float32"), + relay.const(1.0, "float32"), + ) + else: + output = _op.nn.batch_matmul(input_a, input_b) + + # Reshape output to original dimensions. + output_shape = shape_of(output) + + rank_out = _infer_shape(output_shape)[0] + + final_shape = _op.concatenate( + [ + _op.strided_slice(shape_a, [0], [rank_a - 2]), + _op.strided_slice(output_shape, [rank_out - 2], [rank_out]), + ], + 0, + ) + + reshape = _op.reshape(output, _fold_constant(final_shape)) + # qnn batch matmul returns a int32 tensor so we need to requantize + if self.is_quantized(op): + return _qnn.op.requantize( + reshape, + relay.const(1.0, "float32"), + relay.const(0, "int32"), + relay.const(1.0, "float32"), + relay.const(0, "int32"), + out_dtype="int8", + ) + else: + return reshape + def convert_space_to_batch_nd(self, op): """space_to_batch_nd implementation.""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 42a27bbd2671..35d12942b522 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -61,6 +61,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import variables +from tensorflow import raw_ops try: from tensorflow import lite as interpreter_wrapper @@ -319,6 +320,12 @@ def compare_tflite_with_tvm( sess.run(variables.global_variables_initializer()) # convert to tflite model converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) + + if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4: + converter._experimental_disable_batchmatmul_unfold = True + else: + converter._experimental_disable_batchmatmul_unfold = False + converter.experimental_new_converter = experimental_new_converter if quantized: if int_quant_dtype == tf.int16: @@ -734,24 +741,70 @@ def test_forward_cast(): ####################################################################### # Batch Mat Mul # ---- -def _test_batch_matmul(a_shape, b_shape, dtype, adjoint_a=False, adjoint_b=False): +def _test_batch_matmul( + a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False, quantized=False +): with tf.Graph().as_default(): a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A") b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B") - result = math_ops.matmul(a, b, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul") + print(tf.__version__) + + result = raw_ops.BatchMatMulV3( + x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b, name="batchmatmul" + ) + input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else None a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype) b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype) - compare_tflite_with_tvm([a_np, b_np], [a.name, b.name], [a, b], [result]) + compare_tflite_with_tvm( + [a_np, b_np], + [a.name, b.name], + [a, b], + [result], + experimental_new_converter=True, + quantized=quantized, + input_range=input_range, + ) -def test_forward_batch_matmul(): +@pytest.mark.parametrize("config", [("int8", "int32", True), ("float32", "float32", False)]) +def test_forward_batch_matmul(config): """BATCH_MAT_MUL""" - _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32") - _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True) - _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False) - _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True) - _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32") + _test_batch_matmul( + (3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1], quantized=config[2] + ) + _test_batch_matmul( + (3, 5, 4), + (3, 4, 5), + dtype=config[0], + out_dtype=config[1], + adjoint_a=True, + adjoint_b=True, + quantized=config[2], + ) + _test_batch_matmul( + (3, 5, 4), + (3, 5, 4), + dtype=config[0], + out_dtype=config[1], + adjoint_a=True, + adjoint_b=False, + quantized=config[2], + ) + _test_batch_matmul( + (3, 5, 4), + (3, 5, 4), + dtype=config[0], + out_dtype=config[1], + adjoint_a=False, + adjoint_b=True, + quantized=config[2], + ) + _test_batch_matmul( + (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2] + ) + # BatchMatMul doesn't support larger than 4D tensors + # _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]) ####################################################################### From f1c73ec28c466ae356c412be02f11fbd48eee024 Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Wed, 29 Mar 2023 10:43:13 +0100 Subject: [PATCH 2/4] Fixing linting issues --- python/tvm/relay/frontend/tflite.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 98f914d667a5..9daf7f716fbf 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2985,7 +2985,6 @@ def convert_batch_matmul(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - output_tensor = self.get_output_tensors(op) assert len(input_tensors) == 2, "two input tensor arguments expected" @@ -3037,8 +3036,6 @@ def convert_batch_matmul(self, op): 0, ) - out_batch_shape = _fold_constant(out_batch) - a_broadcasted_shape = _fold_constant( _op.concatenate( [ @@ -3067,7 +3064,7 @@ def convert_batch_matmul(self, op): if batch_matmul_options.AdjX(): input_a = _op.transpose(input_a, [0, 2, 1]) - if batch_matmul_options.AdjY() == False: + if not batch_matmul_options.AdjY(): input_b = _op.transpose(input_b, [0, 2, 1]) if self.is_quantized(op): From a5be7caac40ea38967e6160806d55ed111e7c5c1 Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Wed, 29 Mar 2023 14:43:39 +0100 Subject: [PATCH 3/4] Fixing more lint issues --- tests/python/frontend/tflite/test_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 35d12942b522..1e9b7475e3ea 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -804,7 +804,9 @@ def test_forward_batch_matmul(config): (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2] ) # BatchMatMul doesn't support larger than 4D tensors - # _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]) + # _test_batch_matmul( + # (2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2] + # ) ####################################################################### From a98abe7ff79666ee5e519d8123aee6fa61cf6d49 Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Thu, 30 Mar 2023 11:49:46 +0100 Subject: [PATCH 4/4] Fixing compare_tflite function for input_tensors < 2 --- tests/python/frontend/tflite/test_forward.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 1e9b7475e3ea..41eb1f3067ad 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -321,10 +321,11 @@ def compare_tflite_with_tvm( # convert to tflite model converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) - if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4: - converter._experimental_disable_batchmatmul_unfold = True - else: - converter._experimental_disable_batchmatmul_unfold = False + if len(input_tensors) > 1: + if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4: + converter._experimental_disable_batchmatmul_unfold = True + else: + converter._experimental_disable_batchmatmul_unfold = False converter.experimental_new_converter = experimental_new_converter if quantized: