Skip to content

Commit 9e4b804

Browse files
committed
Revert "[relay][frontend] clean up tf frontend (apache#3710)"
This reverts commit be8fa6a.
1 parent 16bdd94 commit 9e4b804

File tree

3 files changed

+197
-47
lines changed

3 files changed

+197
-47
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
"""Common utilities"""
1818
from __future__ import absolute_import as _abs
1919
import logging
20-
21-
import tvm
2220
from topi.util import get_const_tuple
2321
from .. import expr as _expr
2422
from .. import module as _module
@@ -226,7 +224,6 @@ def get_bool(self, key, default=RequiredAttr()):
226224
raise AttributeError("Required attribute {} not found.".format(key))
227225
return default
228226

229-
230227
def get_relay_op(op_name):
231228
"""Get the callable function from Relay based on operator name.
232229
Parameters
@@ -249,10 +246,9 @@ def get_relay_op(op_name):
249246
if op is not None:
250247
break
251248
if not op:
252-
raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
249+
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
253250
return op
254251

255-
256252
class ExprTable(object):
257253
"""Table storing Relay expressions by names."""
258254
def __init__(self):
@@ -302,27 +298,21 @@ class AttrCvt(object):
302298
If set as str, returned operator name is the str.
303299
If set as callable, returned operator is the str returned by calling:
304300
`op_name = func(attr)`
305-
306301
transforms : dict of `new_name, or (new_name, default_value, transform function)`
307302
If only a new_name is provided, it's like renaming the attribute name.
308303
If default_value if provided, then the attribute is considered as optional.
309304
If transform function is provided, the original attribute value is handled
310305
by transform function.
311-
312306
excludes : list
313307
A list of excluded attributes that should `NOT` appear.
314308
Raise NotImplementedError if occurred.
315-
316309
disables : list
317310
A list of attributes that is disabled in relay. Log warnings.
318-
319311
ignores : list
320312
A list of attributes that is ignored in relay. Debug level logging.
321-
322313
extras : dict
323314
A series of additional attributes should be added anyway to the returned
324315
attribute dict.
325-
326316
custom_check : callable
327317
A custom function takes attribute, and return True/False.
328318
Raise RuntimeError if not bool(True) returned.
@@ -339,14 +329,6 @@ def __init__(self, op_name, transforms=None,
339329
self._custom_check = custom_check
340330

341331
def __call__(self, inputs, attrs, *args):
342-
self._ignores.append('_output_shapes')
343-
self._ignores.append('_input_shapes')
344-
self._ignores.append('T')
345-
self._ignores.append('use_cudnn_on_gpu')
346-
self._ignores.append('_node_name')
347-
self._ignores.append('is_training')
348-
self._ignores.append('_target_layout')
349-
350332
# apply custom check
351333
if self._custom_check:
352334
func, msg = self._custom_check
@@ -366,8 +348,7 @@ def __call__(self, inputs, attrs, *args):
366348
new_attrs = {}
367349
for k in attrs.keys():
368350
if k in self._excludes:
369-
raise NotImplementedError('Attribute %s in operator %s is not' +
370-
' supported.', k, op_name)
351+
raise NotImplementedError("Attribute {} not supported yet.".format(k))
371352
elif k in self._disables:
372353
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
373354
elif k in self._ignores:
@@ -420,7 +401,6 @@ def _required_attr(self, attr, key):
420401
raise AttributeError("Required attribute {} not found.".format(key))
421402
return attr[key]
422403

423-
424404
def get_name(node):
425405
name = ''
426406
if hasattr(node, "name_hint"):
@@ -430,19 +410,17 @@ def get_name(node):
430410

431411
def infer_type(node):
432412
"""A method to infer the type of an intermediate node in the relay graph."""
433-
mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
413+
mod = _module.Module.from_expr(node)
434414
mod = _transform.InferType()(mod)
435415
entry = mod["main"]
436416
return entry if isinstance(node, _expr.Function) else entry.body
437417

438-
439418
def infer_shape(inputs):
440419
"""A method to get the output shape of an intermediate node in the graph."""
441420
out_type = infer_type(inputs)
442421
out_shapes = get_const_tuple(out_type.checked_type.shape)
443422
return out_shapes
444423

445-
446424
def infer_channels(inputs, transpose=False):
447425
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
448426
these attributes. We check the shape of weights provided to get the number.
@@ -452,14 +430,12 @@ def infer_channels(inputs, transpose=False):
452430
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
453431
return channels
454432

455-
456433
def new_var(name_hint,
457434
type_annotation=None,
458435
shape=None,
459436
dtype="float32"):
460437
return _expr.var(name_hint, type_annotation, shape, dtype)
461438

462-
463439
class Renamer(object):
464440
"""A simply renamer for operators.
465441

python/tvm/relay/frontend/mxnet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
import json
2222
import tvm
23-
from .. import analysis
23+
from .. import analysis, transform
2424
from .. import expr as _expr
2525
from .. import op as _op
2626
from .. import module as _module
2727
from ... import nd as _nd
2828

2929
from .common import StrAttrsDict
30-
from .common import infer_type as _infer_type
3130
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
3231
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
3332
from .nnvm_common import _clip, _transpose, _upsampling
@@ -42,6 +41,13 @@
4241
"relu" : _op.nn.relu
4342
}
4443

44+
def _infer_type(node):
45+
"""A method to infer the type of an intermediate node in the relay graph."""
46+
mod = _module.Module.from_expr(node)
47+
mod = transform.InferType()(mod)
48+
entry = mod["main"]
49+
return entry if isinstance(node, _expr.Function) else entry.body
50+
4551
def _mx_fully_connected(inputs, attrs):
4652
import mxnet as mx
4753
units = attrs.get_int("num_hidden")

0 commit comments

Comments
 (0)