diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 54eeb9d82447..a010555938f9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -174,17 +174,45 @@ def __init__(self, model, subgraph, exp_tab): def check_unsupported_ops(self): """Check unsupported TFLite ops in our converter.""" unsupported_ops_set = set() - + dynamic_range_ops_set = set() for op_idx in range(self.subgraph.OperatorsLength()): op = self.subgraph.Operators(op_idx) op_code_str = self.get_op_code_str(op) if op_code_str not in self.convert_map: unsupported_ops_set.add(op_code_str) + continue + + # Trying to exclude "dynamic range quantization" optimized ops as not supported in TVM + qnn_in_cnt = len( + [_.qnn_params for _ in self.get_input_tensors(op)[0:1] if _.qnn_params is not None] + ) + qnn_weight_cnt = len( + [_.qnn_params for _ in self.get_input_tensors(op)[1:] if _.qnn_params is not None] + ) + qnn_out_cnt = len( + [_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None] + ) + + if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0: + dynamic_range_ops_set.add(op_code_str) + + raise_msg = "" if unsupported_ops_set: - msg = "The following operators are not supported in frontend " "TFLite: {}" + msg = "The following operators are not supported in frontend " "TFLite: {}\n" ops = str(list(unsupported_ops_set)).strip("[,]") - raise tvm.error.OpNotImplemented(msg.format(ops)) + raise_msg += msg.format(ops) + + if dynamic_range_ops_set: + msg = ( + "The following operators are likely to have dynamic range quantization: {}. " + "If you are running an optimized graph, please turn off dynamic range quantization " + "or use full integer quantization" + ) + raise_msg += msg.format(str(list(dynamic_range_ops_set)).strip("[,]")) + + if len(raise_msg) > 0: + raise tvm.error.OpNotImplemented(raise_msg) def convert_op_to_relay(self): """Convert TFLite ops to relay ops""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 56c50c315a70..1f61a055e210 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3941,6 +3941,27 @@ def test_forward_mediapipe_hand_landmark(): ) +####################################################################### +# Test check for Tensorflow "dynamic range quantization" optimization +# -------------- +def test_prevent_tensorflow_dynamic_range(): + """ + Should prevent runnung "dynamic range quantization" optimized TFLite graph + """ + data_array = np.random.randint(0, 2, (1, 1024, 1024)).astype(dtype=np.float32) + filter_array = np.random.randint(0, 2, (1024, 1024)).astype(dtype=np.float32) + data_in = tf.keras.layers.Input(shape=data_array.shape[1:]) + dense = tf.keras.layers.Dense(units=filter_array.shape[-1], use_bias=False)(data_in) + keras_model = tf.keras.models.Model(data_in, dense) + keras_model.layers[1].set_weights([filter_array]) + + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + tflite_model = converter.convert() + with pytest.raises(tvm.error.OpNotImplemented): + tvm_output = run_tvm_graph(tflite_model, data_array, data_in.name.replace(":0", "")) + + ####################################################################### # Main # ----