Skip to content

Commit 4344d4a

Browse files
srkreddy1238tqchen
authored andcommitted
[FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. (#2864)
* [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. * * review comments
1 parent 53e84f0 commit 4344d4a

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -543,25 +543,23 @@ def _impl(inputs, attr, params):
543543
op_name="reshape",
544544
extras={'newshape':tuple(shape_arg.asnumpy())},
545545
ignores=['Tshape'])(inputs, attr)
546-
except KeyError:
546+
except AttributeError:
547547
# Shape operator is already pruned, hence
548548
# try to infer shape by precompute prune if possible.
549-
if all(in_node in params for in_node in inputs[1].list_input_names()):
550-
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
551-
with tvm.relay.build_config(opt_level=0):
552-
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
553-
ctx = tvm.context("llvm", 0)
554-
from tvm.contrib import graph_runtime
555-
m = graph_runtime.create(graph, lib, ctx)
556-
m.set_input(**params)
557-
m.run()
558-
params_new = m.get_output(0)
559-
inputs.pop(1)
560-
return AttrCvt(
561-
op_name="reshape",
562-
extras={'newshape':tuple(params_new.asnumpy().flatten())},
563-
ignores=['Tshape'])(inputs, attr)
564-
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
549+
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
550+
with tvm.relay.build_config(opt_level=0):
551+
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
552+
ctx = tvm.context("llvm", 0)
553+
from tvm.contrib import graph_runtime
554+
m = graph_runtime.create(graph, lib, ctx)
555+
m.set_input(**params)
556+
m.run()
557+
params_new = m.get_output(0)
558+
inputs.pop(1)
559+
return AttrCvt(
560+
op_name="reshape",
561+
extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
562+
ignores=['Tshape'])(inputs, attr)
565563
return _impl
566564

567565
def _bias_add():

0 commit comments

Comments
 (0)