Skip to content

Commit e58007c

Browse files
inadobwweic
authored andcommitted
Add parser support for ReLU tflite operator (apache#4022)
1 parent 399fc28 commit e58007c

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
8484
'PACK': self.convert_pack,
8585
'LOGISTIC': self.convert_logistic,
8686
'TANH':self.convert_tanh,
87+
'RELU':self.convert_relu,
8788
'SPLIT': self.convert_split,
8889
'TRANSPOSE': self.convert_transpose,
8990
'TILE': self.convert_tile,
@@ -345,6 +346,23 @@ def convert_tanh(self, op):
345346

346347
return out
347348

349+
def convert_relu(self, op):
350+
"""Convert TFLite ReLU"""
351+
try:
352+
from tflite.Operator import Operator
353+
except ImportError:
354+
raise ImportError("The tflite package must be installed")
355+
356+
assert isinstance(op, Operator)
357+
input_tensors = self.get_input_tensors(op)
358+
assert len(input_tensors) == 1, "input tensors length should be 1"
359+
360+
input_tensor = input_tensors[0]
361+
in_expr = self.get_expr(input_tensor.tensor_idx)
362+
out = _op.nn.relu(in_expr)
363+
364+
return out
365+
348366
def convert_concatenation(self, op):
349367
"""Convert TFLite concatenation"""
350368
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,21 @@ def test_forward_tanh():
836836
""" TANH """
837837
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
838838

839+
#######################################################################
840+
# ReLu
841+
# --------
842+
843+
def _test_relu(data):
844+
""" One iteration of ReLU """
845+
with tf.Graph().as_default():
846+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
847+
out = nn_ops.relu(in_data)
848+
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
849+
850+
def test_forward_relu():
851+
""" ReLU """
852+
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
853+
839854
#######################################################################
840855
# Fully Connected
841856
# -------
@@ -999,6 +1014,7 @@ def test_forward_ssd_mobilenet_v1():
9991014
test_forward_pooling()
10001015
test_forward_softmax()
10011016
test_forward_tanh()
1017+
test_forward_relu()
10021018
test_forward_fully_connected()
10031019

10041020
# Elemwise

0 commit comments

Comments
 (0)