@@ -543,25 +543,23 @@ def _impl(inputs, attr, params):
543543 op_name = "reshape" ,
544544 extras = {'newshape' :tuple (shape_arg .asnumpy ())},
545545 ignores = ['Tshape' ])(inputs , attr )
546- except KeyError :
546+ except AttributeError :
547547 # Shape operator is already pruned, hence
548548 # try to infer shape by precompute prune if possible.
549- if all (in_node in params for in_node in inputs [1 ].list_input_names ()):
550- func = _expr .Function (ir_pass .free_vars (inputs [1 ]), inputs [1 ])
551- with tvm .relay .build_config (opt_level = 0 ):
552- graph , lib , params = tvm .relay .build (func , target = "llvm" , params = params )
553- ctx = tvm .context ("llvm" , 0 )
554- from tvm .contrib import graph_runtime
555- m = graph_runtime .create (graph , lib , ctx )
556- m .set_input (** params )
557- m .run ()
558- params_new = m .get_output (0 )
559- inputs .pop (1 )
560- return AttrCvt (
561- op_name = "reshape" ,
562- extras = {'newshape' :tuple (params_new .asnumpy ().flatten ())},
563- ignores = ['Tshape' ])(inputs , attr )
564- raise RuntimeError ("Reshape with dynamic shape input not supported yet." )
549+ func = _expr .Function (ir_pass .free_vars (inputs [1 ]), inputs [1 ])
550+ with tvm .relay .build_config (opt_level = 0 ):
551+ graph , lib , params = tvm .relay .build (func , target = "llvm" , params = params )
552+ ctx = tvm .context ("llvm" , 0 )
553+ from tvm .contrib import graph_runtime
554+ m = graph_runtime .create (graph , lib , ctx )
555+ m .set_input (** params )
556+ m .run ()
557+ params_new = m .get_output (0 )
558+ inputs .pop (1 )
559+ return AttrCvt (
560+ op_name = "reshape" ,
561+ extras = {'newshape' :tuple (params_new .asnumpy ().astype ('int64' ).flatten ())},
562+ ignores = ['Tshape' ])(inputs , attr )
565563 return _impl
566564
567565def _bias_add ():
0 commit comments