Skip to content

Commit b67afcd

Browse files
yongwwwtqchen
authored andcommitted
[Relay] add ClipByValue and Neg in tf frontend converter (#3211)
1 parent 29ee8a2 commit b67afcd

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,13 @@ def _impl(inputs, attr, params):
941941
return AttrCvt(op_name="where")(inputs, attr)
942942
return _impl
943943

944+
def _clip_by_value():
945+
def _impl(inputs, attr, params):
946+
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
947+
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
948+
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
949+
return _impl
950+
944951
def _reverse_v2():
945952
def _impl(inputs, attr, params):
946953
axis = _get_num_param(params, inputs[1])
@@ -1212,6 +1219,7 @@ def _impl(inputs, attr, params):
12121219
'Cast' : _cast(),
12131220
'Ceil' : AttrCvt('ceil'),
12141221
'CheckNumerics' : _check_numerics(),
1222+
'ClipByValue' : _clip_by_value(),
12151223
'Concat' : _concat(),
12161224
'ConcatV2' : _concatV2(),
12171225
'Conv2D' : _conv('conv'),
@@ -1245,6 +1253,7 @@ def _impl(inputs, attr, params):
12451253
'Mean' : _mean(),
12461254
'Minimum' : _elemwise('minimum'),
12471255
'Mul' : _elemwise('multiply'),
1256+
'Neg' : AttrCvt('negative'),
12481257
'NotEqual' : _broadcast('not_equal'),
12491258
'Pack' : _pack(),
12501259
'Pad' : _pad('Pad'),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,23 @@ def test_forward_tile():
833833
_test_tile((2, 4, 6), (6, 7, 8), "float64")
834834

835835

836+
#######################################################################
837+
# ClipByValue
838+
# -----------
839+
840+
def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
841+
tf.reset_default_graph()
842+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
843+
tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue")
844+
np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
845+
compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
846+
847+
def test_forward_clip_by_value():
848+
'''test ClipByValue op'''
849+
if tf.__version__ < LooseVersion('1.9'):
850+
_test_forward_clip_by_value((4,), .1, 5., 'float32')
851+
_test_forward_clip_by_value((4, 4), 1, 5, 'int32')
852+
836853
#######################################################################
837854
# Multi Input to graph
838855
# --------------------
@@ -1591,6 +1608,14 @@ def test_forward_log():
15911608
tf.log(in_data, name="log")
15921609
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
15931610

1611+
def test_forward_negative():
1612+
"""test tf operator Neg """
1613+
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
1614+
tf.reset_default_graph()
1615+
in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
1616+
tf.negative(in_data, name="negative")
1617+
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
1618+
15941619
def test_forward_softplus():
15951620
"""test operator Softplus"""
15961621
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
@@ -1738,6 +1763,7 @@ def test_placeholder():
17381763
test_forward_unstack()
17391764
test_forward_tile()
17401765
test_forward_top_k_v2()
1766+
test_forward_clip_by_value()
17411767

17421768
# Activations
17431769
test_forward_sigmoid()
@@ -1753,6 +1779,7 @@ def test_placeholder():
17531779
test_forward_pow_exp()
17541780
test_forward_sign()
17551781
test_forward_log()
1782+
test_forward_negative()
17561783
test_forward_softplus()
17571784
test_forward_sqrt()
17581785
test_forward_rsqrt()
@@ -1802,4 +1829,4 @@ def test_placeholder():
18021829
test_where()
18031830

18041831
test_forward_matmul()
1805-
# TODO missing tests: rank, range
1832+
# TODO missing tests: rank, range

0 commit comments

Comments
 (0)