@@ -64,6 +64,13 @@ def _stable_softrelu(x):
6464 raise RuntimeError ("Do not support act_type: {}" .format (act_type ))
6565
6666
67+ def _mx_compare (new_op , wrapper ):
68+ def impl (inputs , attrs ):
69+ dtype = ir_pass .infer_type (inputs [0 ]).checked_type .dtype
70+ return wrapper (new_op )(inputs , attrs ).astype (dtype )
71+ return impl
72+
73+
6774def _mx_conv2d (inputs , attrs ):
6875 kernel_size = attrs .get_int_tuple ("kernel" )
6976 if len (kernel_size ) != 2 :
@@ -333,32 +340,52 @@ def _mx_roi_align(inputs, attrs):
333340]
334341
335342_convert_map = {
336- "_copy" : _rename (_op .copy ),
337- "relu" : _rename (_op .nn .relu ),
338- "broadcast_add" : _rename (_op .add ),
339- "broadcast_sub" : _rename (_op .subtract ),
340- "broadcast_mul" : _rename (_op .multiply ),
341- "broadcast_div" : _rename (_op .divide ),
342- "elemwise_add" : _rename (_op .add ),
343- "elemwise_sub" : _rename (_op .subtract ),
344- "elemwise_mul" : _rename (_op .multiply ),
345- "elemwise_div" : _rename (_op .divide ),
346- "flatten" : _rename (_op .nn .batch_flatten ),
347- "Flatten" : _rename (_op .nn .batch_flatten ),
348- "_plus_scalar" : _binop_scalar (_op .add ),
349- "__add_scalar__" : _binop_scalar (_op .add ),
350- "__sub_scalar__" : _binop_scalar (_op .subtract ),
351- "_minus_scalar" : _binop_scalar (_op .subtract ),
352- "__mul_scalar__" : _binop_scalar (_op .multiply ),
353- "_mul_scalar" : _binop_scalar (_op .multiply ),
354- "__div_scalar__" : _binop_scalar (_op .divide ),
355- "_div_scalar" : _binop_scalar (_op .divide ),
356- "__pow_scalar__" : _binop_scalar (_op .power ),
357- "_rminus_scalar" : _rbinop_scalar (_op .subtract ),
358- "__rsub_scalar__" : _rbinop_scalar (_op .subtract ),
359- "_rdiv_scalar" : _rbinop_scalar (_op .divide ),
360- "__rdiv_scalar__" : _rbinop_scalar (_op .divide ),
361- "__rpow_scalar__" : _rbinop_scalar (_op .power ),
343+ "_copy" : _rename (_op .copy ),
344+ "relu" : _rename (_op .nn .relu ),
345+ "broadcast_add" : _rename (_op .add ),
346+ "broadcast_sub" : _rename (_op .subtract ),
347+ "broadcast_mul" : _rename (_op .multiply ),
348+ "broadcast_div" : _rename (_op .divide ),
349+ "broadcast_mod" : _rename (_op .mod ),
350+ "broadcast_maximum" : _rename (_op .maximum ),
351+ "broadcast_minimum" : _rename (_op .minimum ),
352+ "broadcast_equal" : _mx_compare (_op .equal , _rename ),
353+ "broadcast_not_equal" : _mx_compare (_op .not_equal , _rename ),
354+ "broadcast_greater" : _mx_compare (_op .greater , _rename ),
355+ "broadcast_greater_equal" : _mx_compare (_op .greater_equal , _rename ),
356+ "broadcast_lesser" : _mx_compare (_op .less , _rename ),
357+ "broadcast_lesser_equal" : _mx_compare (_op .less_equal , _rename ),
358+ "elemwise_add" : _rename (_op .add ),
359+ "elemwise_sub" : _rename (_op .subtract ),
360+ "elemwise_mul" : _rename (_op .multiply ),
361+ "elemwise_div" : _rename (_op .divide ),
362+ "_maximum" : _rename (_op .maximum ),
363+ "_minimum" : _rename (_op .minimum ),
364+ "flatten" : _rename (_op .nn .batch_flatten ),
365+ "Flatten" : _rename (_op .nn .batch_flatten ),
366+ "__add_scalar__" : _binop_scalar (_op .add ),
367+ "_plus_scalar" : _binop_scalar (_op .add ),
368+ "__sub_scalar__" : _binop_scalar (_op .subtract ),
369+ "_minus_scalar" : _binop_scalar (_op .subtract ),
370+ "__mul_scalar__" : _binop_scalar (_op .multiply ),
371+ "_mul_scalar" : _binop_scalar (_op .multiply ),
372+ "__div_scalar__" : _binop_scalar (_op .divide ),
373+ "_div_scalar" : _binop_scalar (_op .divide ),
374+ "__pow_scalar__" : _binop_scalar (_op .power ),
375+ "_power_scalar" : _binop_scalar (_op .power ),
376+ "__rsub_scalar__" : _rbinop_scalar (_op .subtract ),
377+ "_rminus_scalar" : _rbinop_scalar (_op .subtract ),
378+ "__rdiv_scalar__" : _rbinop_scalar (_op .divide ),
379+ "_rdiv_scalar" : _rbinop_scalar (_op .divide ),
380+ "__rpow_scalar__" : _rbinop_scalar (_op .power ),
381+ "_equal_scalar" : _mx_compare (_op .equal , _binop_scalar ),
382+ "_not_equal_scalar" : _mx_compare (_op .not_equal , _binop_scalar ),
383+ "_greater_scalar" : _mx_compare (_op .greater , _binop_scalar ),
384+ "_greater_equal_scalar" : _mx_compare (_op .greater_equal , _binop_scalar ),
385+ "_lesser_scalar" : _mx_compare (_op .less , _binop_scalar ),
386+ "_lesser_equal_scalar" : _mx_compare (_op .less_equal , _binop_scalar ),
387+ "_maximum_scalar" : _binop_scalar (_op .maximum ),
388+ "_minimum_scalar" : _binop_scalar (_op .minimum ),
362389 # reduction ops
363390 "max" : _reduce (_op .max ),
364391 "min" : _reduce (_op .min ),
0 commit comments