Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,45 @@ def __init__(self, model, subgraph, exp_tab):
def check_unsupported_ops(self):
"""Check unsupported TFLite ops in our converter."""
unsupported_ops_set = set()

dynamic_range_ops_set = set()
for op_idx in range(self.subgraph.OperatorsLength()):
op = self.subgraph.Operators(op_idx)
op_code_str = self.get_op_code_str(op)
if op_code_str not in self.convert_map:
unsupported_ops_set.add(op_code_str)
continue

# Trying to exclude "dynamic range quantization" optimized ops as not supported in TVM
qnn_in_cnt = len(
[_.qnn_params for _ in self.get_input_tensors(op)[0:1] if _.qnn_params is not None]
)
qnn_weight_cnt = len(
[_.qnn_params for _ in self.get_input_tensors(op)[1:] if _.qnn_params is not None]
)
qnn_out_cnt = len(
[_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None]
)

if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0:
dynamic_range_ops_set.add(op_code_str)

raise_msg = ""

if unsupported_ops_set:
msg = "The following operators are not supported in frontend " "TFLite: {}"
msg = "The following operators are not supported in frontend " "TFLite: {}\n"
ops = str(list(unsupported_ops_set)).strip("[,]")
raise tvm.error.OpNotImplemented(msg.format(ops))
raise_msg += msg.format(ops)

if dynamic_range_ops_set:
msg = (
"The following operators are likely to have dynamic range quantization: {}. "
"If you are running an optimized graph, please turn off dynamic range quantization "
"or use full integer quantization"
)
raise_msg += msg.format(str(list(dynamic_range_ops_set)).strip("[,]"))

if len(raise_msg) > 0:
raise tvm.error.OpNotImplemented(raise_msg)

def convert_op_to_relay(self):
"""Convert TFLite ops to relay ops"""
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3941,6 +3941,27 @@ def test_forward_mediapipe_hand_landmark():
)


#######################################################################
# Test check for Tensorflow "dynamic range quantization" optimization
# --------------
def test_prevent_tensorflow_dynamic_range():
"""
Should prevent runnung "dynamic range quantization" optimized TFLite graph
"""
data_array = np.random.randint(0, 2, (1, 1024, 1024)).astype(dtype=np.float32)
filter_array = np.random.randint(0, 2, (1024, 1024)).astype(dtype=np.float32)
data_in = tf.keras.layers.Input(shape=data_array.shape[1:])
dense = tf.keras.layers.Dense(units=filter_array.shape[-1], use_bias=False)(data_in)
keras_model = tf.keras.models.Model(data_in, dense)
keras_model.layers[1].set_weights([filter_array])

converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with pytest.raises(tvm.error.OpNotImplemented):
tvm_output = run_tvm_graph(tflite_model, data_array, data_in.name.replace(":0", ""))


#######################################################################
# Main
# ----
Expand Down