Skip to content

Commit adf0bbd

Browse files
haojin2wweic
authored andcommitted
add MXNet converter for where operator for both NNVM and Relay (apache#2647)
1 parent 02db9ed commit adf0bbd

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _argmin(inputs, attrs):
317317
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
318318
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
319319
'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd',
320-
'reshape_like']
320+
'reshape_like', 'where']
321321

322322
_convert_map = {
323323
'_copy' : _rename('copy'),

nnvm/tests/python/frontend/mxnet/test_forward.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_forward_ones():
158158
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
159159
mx_sym = mx.sym.elemwise_add(data, ones)
160160
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
161-
161+
162162
def test_forward_zeros():
163163
data = mx.sym.var('data')
164164
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
@@ -184,7 +184,42 @@ def test_forward_argmin():
184184
data = mx.sym.var('data')
185185
mx_sym = mx.sym.argmin(data, axis=0)
186186
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
187-
187+
188+
def test_forward_where():
189+
cond = mx.sym.var('cond')
190+
x = mx.sym.var('x')
191+
y = mx.sym.var('y')
192+
dshape = (2, 2)
193+
dtype = 'float32'
194+
mx_sym = mx.sym.where(cond, x, y)
195+
np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
196+
np_x = np.random.uniform(size=dshape).astype(dtype)
197+
np_y = np.random.uniform(size=dshape).astype(dtype)
198+
mx_cond = mx.nd.array(np_cond)
199+
mx_x = mx.nd.array(np_x)
200+
mx_y = mx.nd.array(np_y)
201+
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
202+
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
203+
mod.init_params()
204+
args, auxs = mod.get_params()
205+
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
206+
out_shape = dshape
207+
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
208+
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
209+
for target, ctx in ctx_list():
210+
with nnvm.compiler.build_config(opt_level=3):
211+
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
212+
m = graph_runtime.create(graph, lib, ctx)
213+
# set inputs
214+
m.set_input("cond", tvm.nd.array(np_cond))
215+
m.set_input("x", tvm.nd.array(np_x))
216+
m.set_input("y", tvm.nd.array(np_y))
217+
m.set_input(**params)
218+
m.run()
219+
# get outputs
220+
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
221+
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
222+
188223
if __name__ == '__main__':
189224
test_forward_mlp()
190225
test_forward_vgg()
@@ -206,4 +241,5 @@ def test_forward_argmin():
206241
test_forward_zeros_like()
207242
test_forward_argmax()
208243
test_forward_argmin()
209-
244+
test_forward_where()
245+

python/tvm/relay/frontend/mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def _mx_roi_align(inputs, attrs):
290290
"slice_like",
291291
"zeros_like",
292292
"ones_like",
293+
"where",
293294
]
294295

295296
_convert_map = {

tests/python/frontend/mxnet/test_forward.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,44 @@ def test_forward_argmin():
190190
mx_sym = mx.sym.argmin(data, axis=0)
191191
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
192192

193+
def test_forward_where():
194+
cond = mx.sym.var('cond')
195+
x = mx.sym.var('x')
196+
y = mx.sym.var('y')
197+
dshape = (2, 2)
198+
dtype = 'float32'
199+
mx_sym = mx.sym.where(cond, x, y)
200+
np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
201+
np_x = np.random.uniform(size=dshape).astype(dtype)
202+
np_y = np.random.uniform(size=dshape).astype(dtype)
203+
mx_cond = mx.nd.array(np_cond)
204+
mx_x = mx.nd.array(np_x)
205+
mx_y = mx.nd.array(np_y)
206+
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
207+
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
208+
mod.init_params()
209+
args, auxs = mod.get_params()
210+
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
211+
out_shape = dshape
212+
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
213+
new_sym, params = relay.frontend.from_mxnet(mx_sym,
214+
shape_dict,
215+
arg_params=args,
216+
aux_params=auxs)
217+
for target, ctx in ctx_list():
218+
with relay.build_config(opt_level=3):
219+
graph, lib, params = relay.build(new_sym, target, params=params)
220+
m = graph_runtime.create(graph, lib, ctx)
221+
# set inputs
222+
m.set_input("cond", tvm.nd.array(np_cond))
223+
m.set_input("x", tvm.nd.array(np_x))
224+
m.set_input("y", tvm.nd.array(np_y))
225+
m.set_input(**params)
226+
m.run()
227+
# get outputs
228+
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
229+
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
230+
193231

194232
if __name__ == '__main__':
195233
test_forward_mlp()
@@ -212,3 +250,4 @@ def test_forward_argmin():
212250
test_forward_zeros_like()
213251
test_forward_argmax()
214252
test_forward_argmin()
253+
test_forward_where()

0 commit comments

Comments
 (0)