@@ -99,7 +99,6 @@ def __call__(self, inputs, attrs, *args):
9999        self ._ignores .append ('_node_name' )
100100        self ._ignores .append ('is_training' )
101101        self ._ignores .append ('_target_layout' )
102-         self ._ignores .append ('_input_0d_mismatch' )
103102
104103        # apply custom check 
105104        if  self ._custom_check :
@@ -458,9 +457,9 @@ def _impl(inputs, attr, params):
458457def  _expand_dims ():
459458    def  _impl (inputs , attr , params ):
460459        dim_input  =  inputs .pop (1 )
461-         axis  =  params [ dim_input . name_hint ]
462-         params . pop ( dim_input . name_hint ) 
463-         return   _expand_dims_0d_aware ( inputs [ 0 ],  attr ,  axis = axis . asnumpy ()[ 0 ] )
460+         axis  =  params . pop ( _get_name_hint ( dim_input )). asnumpy ()[ 0 ]
461+         return   AttrCvt ( op_name = "expand_dims" ,  ignores = [ 'Tdim' ,  'N' ], 
462+                         extras = { ' axis' :  axis ,  'num_newaxis' :  1 })( inputs ,  attr )
464463    return  _impl 
465464
466465def  _resize_bilinear ():
@@ -528,7 +527,7 @@ def _impl(inputs, attr, params):
528527def  _pack ():
529528    def  _impl (inputs , attr , params ):
530529        axis  =  int (attr ["axis" ])
531-         inputs_reshaped  =  [_expand_dims_0d_aware ( i ,  attr , axis = axis , num_newaxis = 1 ) for  i  in  inputs ]
530+         inputs_reshaped  =  [_op . expand_dims ( i , axis = axis , num_newaxis = 1 ) for  i  in  inputs ]
532531        return  _op .concatenate (inputs_reshaped , axis )
533532    return  _impl 
534533
@@ -820,9 +819,9 @@ def _transform_mask(stride_dim, ellipsis_mask):
820819                pass 
821820            else :
822821                final_output .append (out_shape [gather_index ])
823-          # Prevent 0-dim tensors which are not accepted by Relay 
822+ 
824823        if  not  final_output :
825-             final_output . append ( 1 ) 
824+             return   out 
826825        return  _op .reshape (out , newshape = tuple (final_output ))
827826    return  _impl 
828827
@@ -984,16 +983,6 @@ def _impl(inputs, attr, params):
984983            for  split_item  in  splitted ]), len (splitted ))
985984    return  _impl 
986985
987- def  _expand_dims_0d_aware (data , attr , axis , num_newaxis = 1 ):
988-     if  data  in  attr ['_input_0d_mismatch' ]:
989-         return  data  if  num_newaxis  ==  1  else  \
990-             AttrCvt (op_name = "expand_dims" , ignores = ['Tdim' , 'N' ],
991-                     extras = {'axis' : int (axis ), 'num_newaxis' : int (num_newaxis - 1 )})([data ], attr )
992- 
993-     return  AttrCvt (op_name = "expand_dims" , ignores = ['Tdim' , 'N' ],
994-                    extras = {'axis' : int (axis ), 'num_newaxis' : int (num_newaxis )})([data ], attr )
995- 
996- 
997986def  _softmax ():
998987    def  _impl (inputs , attr , params ):
999988        return  AttrCvt (op_name = 'softmax' ,
@@ -1647,7 +1636,6 @@ def __init__(self):
16471636        self ._output_shapes  =  {}
16481637        self ._num_param  =  0 
16491638        self ._num_rnn_layer  =  False 
1650-         self ._outputs_are_0d  =  {}
16511639        self ._input_shapes  =  {}
16521640        self ._loops  =  {}
16531641        self ._branches  =  {}
@@ -1737,7 +1725,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
17371725            # Operator name 'Const' is treated as a parameter to build params dict. 
17381726
17391727            input_shapes  =  {}
1740-             input_0d_mismatch  =  set ()
17411728            attr  =  self ._parse_attr (node .attr )
17421729
17431730            # Variable converted to Const will not have only value attr 
@@ -1753,10 +1740,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
17531740                # Will infer shapes if the graph is not frozen with add_shapes=True 
17541741                self ._output_shapes [node .name ] =  [None ]
17551742
1756-             self ._outputs_are_0d [node .name ] =  [ \
1757-                 not  shape  if  isinstance (tshape , list ) else  False  \
1758-                 for  tshape  in  self ._output_shapes [node .name ]]
1759- 
17601743            if  node .op  ==  "Const" :
17611744                # All Const nodes are Param nodes, lets parse 
17621745                self ._num_param  +=  1 
@@ -1810,14 +1793,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
18101793                            input_shape  =  self ._output_shapes [node_name ][0 ]
18111794                        inputs .append (in_sym [0 ])
18121795                        input_shapes [in_sym [0 ]] =  input_shape 
1813-                         # This means the node is 1d in Relay and 0d in TF. 
1814-                         # See `_expand_dims_0d_aware`. 
1815-                         if  node_name  in  self ._outputs_are_0d  \
1816-                                 and  self ._outputs_are_0d [node_name ][tensor_slot ] and  input_shape :
1817-                             input_0d_mismatch .add (in_sym [0 ])
18181796
18191797                attr ['_input_shapes' ] =  input_shapes 
1820-                 attr ['_input_0d_mismatch' ] =  input_0d_mismatch 
18211798
18221799                if  node .op  in  _control_flow_nodes :
18231800                    op  =  self ._convert_control_flow_operator (node , inputs ,
0 commit comments