@@ -379,7 +379,6 @@ def test_forward_l2_normalize():
379379 mx_sym = mx .sym .L2Normalization (data , mode = "channel" )
380380 verify_mxnet_frontend_impl (mx_sym , (2 , 3 , 4 , 5 ), (2 , 3 , 4 , 5 ))
381381
382-
383382def test_forward_shape_array ():
384383 def verify (shape ):
385384 x_np = np .random .uniform (size = shape ).astype ("float32" )
@@ -395,6 +394,75 @@ def verify(shape):
395394 verify ((3 , 4 , 5 ))
396395 verify ((3 , 4 , 5 , 6 ))
397396
397+ def test_forward_squeeze ():
398+ def verify (shape , axis ):
399+ x_np = np .random .uniform (size = shape ).astype ("float32" )
400+ if axis is None :
401+ ref_res = mx .nd .squeeze (mx .nd .array (x_np ))
402+ mx_sym = mx .sym .squeeze (mx .sym .var ("x" ))
403+ else :
404+ ref_res = mx .nd .squeeze (mx .nd .array (x_np ), axis = axis )
405+ mx_sym = mx .sym .squeeze (mx .sym .var ("x" ), axis = axis )
406+ new_sym , _ = relay .frontend .from_mxnet (mx_sym , {"x" : shape })
407+ for target , ctx in ctx_list ():
408+ for kind in ["graph" , "debug" ]:
409+ intrp = relay .create_executor (kind , ctx = ctx , target = target )
410+ op_res = intrp .evaluate (new_sym )(x_np )
411+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy ())
412+ verify ((1 , 3 , 1 ), None )
413+ verify ((1 , 3 , 1 ), 0 )
414+ verify ((1 , 3 , 1 ), 2 )
415+ verify ((1 , 3 , 1 ), (0 , 2 ))
416+
417+ def test_forward_broadcast_axis ():
418+ def verify (shape , axis , size ):
419+ x_np = np .random .uniform (size = shape ).astype ("float32" )
420+ ref_res = mx .nd .broadcast_axis (mx .nd .array (x_np ), axis = axis , size = size )
421+ mx_sym = mx .sym .broadcast_axis (mx .sym .var ("x" ), axis = axis , size = size )
422+ new_sym , _ = relay .frontend .from_mxnet (mx_sym , {"x" : shape })
423+ for target , ctx in ctx_list ():
424+ for kind in ["graph" , "debug" ]:
425+ intrp = relay .create_executor (kind , ctx = ctx , target = target )
426+ op_res = intrp .evaluate (new_sym )(x_np )
427+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy ())
428+ verify ((1 , 2 , 1 ), 2 , 3 )
429+ verify ((1 , 2 , 1 ), (0 , 2 ), (2 , 3 ))
430+
431+ def test_forward_full ():
432+ def verify (val , shape , dtype ):
433+ ctx = mx .cpu ()
434+ ref_res = mx .nd .full (shape , val , dtype = dtype )
435+ mx_sym = mx .sym .full (shape , val , dtype = dtype )
436+ new_sym , _ = relay .frontend .from_mxnet (mx_sym , {})
437+ for target , ctx in ctx_list ():
438+ # Skip testing graph runtime because this op will be optimized out
439+ # by constant folding.
440+ for kind in ["debug" ]:
441+ intrp = relay .create_executor (kind , ctx = ctx , target = target )
442+ op_res = intrp .evaluate (new_sym )()
443+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy ())
444+ verify (2 , (3 , 4 ), "float32" )
445+ verify (2 , (3 , 4 ), "int32" )
446+ verify (3.5 , (1 , 3 , 4 ), "float32" )
447+
448+ def test_forward_embedding ():
449+ def verify (data_shape , weight_shape ):
450+ in_dim , out_dim = weight_shape
451+ x_np = np .random .randint (0 , weight_shape [0 ], size = data_shape ).astype ("float32" )
452+ w_np = np .random .uniform (size = weight_shape ).astype ("float32" )
453+ ref_res = mx .nd .Embedding (mx .nd .array (x_np ), mx .nd .array (w_np ),
454+ input_dim = in_dim , output_dim = out_dim )
455+ mx_sym = mx .sym .Embedding (mx .sym .var ("x" ), mx .sym .var ("w" ),
456+ input_dim = in_dim , output_dim = out_dim )
457+ new_sym , _ = relay .frontend .from_mxnet (
458+ mx_sym , {"x" : data_shape , "w" : weight_shape })
459+ for target , ctx in ctx_list ():
460+ for kind in ["graph" , "debug" ]:
461+ intrp = relay .create_executor (kind , ctx = ctx , target = target )
462+ op_res = intrp .evaluate (new_sym )(x = x_np , w = w_np )
463+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy ())
464+ verify ((2 , 2 ), (4 , 5 ))
465+ verify ((2 , 3 , 4 ), (4 , 5 ))
398466
399467if __name__ == '__main__' :
400468 test_forward_mlp ()
@@ -426,3 +494,7 @@ def verify(shape):
426494 test_forward_slice_axis ()
427495 test_forward_l2_normalize ()
428496 test_forward_shape_array ()
497+ test_forward_squeeze ()
498+ test_forward_broadcast_axis ()
499+ test_forward_full ()
500+ test_forward_embedding ()
0 commit comments