1717"""Common utilities"""
1818from __future__ import absolute_import as _abs
1919import logging
20-
21- import tvm
2220from topi .util import get_const_tuple
2321from .. import expr as _expr
2422from .. import module as _module
@@ -226,7 +224,6 @@ def get_bool(self, key, default=RequiredAttr()):
226224 raise AttributeError ("Required attribute {} not found." .format (key ))
227225 return default
228226
229-
230227def get_relay_op (op_name ):
231228 """Get the callable function from Relay based on operator name.
232229 Parameters
@@ -249,10 +246,9 @@ def get_relay_op(op_name):
249246 if op is not None :
250247 break
251248 if not op :
252- raise tvm . error . OpNotImplemented ("Unable to map op_name {} to relay" .format (op_name ))
249+ raise RuntimeError ("Unable to map op_name {} to relay" .format (op_name ))
253250 return op
254251
255-
256252class ExprTable (object ):
257253 """Table storing Relay expressions by names."""
258254 def __init__ (self ):
@@ -302,27 +298,21 @@ class AttrCvt(object):
302298 If set as str, returned operator name is the str.
303299 If set as callable, returned operator is the str returned by calling:
304300 `op_name = func(attr)`
305-
306301 transforms : dict of `new_name, or (new_name, default_value, transform function)`
307302 If only a new_name is provided, it's like renaming the attribute name.
308303 If default_value if provided, then the attribute is considered as optional.
309304 If transform function is provided, the original attribute value is handled
310305 by transform function.
311-
312306 excludes : list
313307 A list of excluded attributes that should `NOT` appear.
314308 Raise NotImplementedError if occurred.
315-
316309 disables : list
317310 A list of attributes that is disabled in relay. Log warnings.
318-
319311 ignores : list
320312 A list of attributes that is ignored in relay. Debug level logging.
321-
322313 extras : dict
323314 A series of additional attributes should be added anyway to the returned
324315 attribute dict.
325-
326316 custom_check : callable
327317 A custom function takes attribute, and return True/False.
328318 Raise RuntimeError if not bool(True) returned.
@@ -339,14 +329,6 @@ def __init__(self, op_name, transforms=None,
339329 self ._custom_check = custom_check
340330
341331 def __call__ (self , inputs , attrs , * args ):
342- self ._ignores .append ('_output_shapes' )
343- self ._ignores .append ('_input_shapes' )
344- self ._ignores .append ('T' )
345- self ._ignores .append ('use_cudnn_on_gpu' )
346- self ._ignores .append ('_node_name' )
347- self ._ignores .append ('is_training' )
348- self ._ignores .append ('_target_layout' )
349-
350332 # apply custom check
351333 if self ._custom_check :
352334 func , msg = self ._custom_check
@@ -366,8 +348,7 @@ def __call__(self, inputs, attrs, *args):
366348 new_attrs = {}
367349 for k in attrs .keys ():
368350 if k in self ._excludes :
369- raise NotImplementedError ('Attribute %s in operator %s is not' +
370- ' supported.' , k , op_name )
351+ raise NotImplementedError ("Attribute {} not supported yet." .format (k ))
371352 elif k in self ._disables :
372353 logging .warning ("Attribute %s is disabled in relay.sym.%s" , k , op_name )
373354 elif k in self ._ignores :
@@ -420,7 +401,6 @@ def _required_attr(self, attr, key):
420401 raise AttributeError ("Required attribute {} not found." .format (key ))
421402 return attr [key ]
422403
423-
424404def get_name (node ):
425405 name = ''
426406 if hasattr (node , "name_hint" ):
@@ -430,19 +410,17 @@ def get_name(node):
430410
431411def infer_type (node ):
432412 """A method to infer the type of an intermediate node in the relay graph."""
433- mod = node if isinstance ( node , _module . Module ) else _module .Module .from_expr (node )
413+ mod = _module .Module .from_expr (node )
434414 mod = _transform .InferType ()(mod )
435415 entry = mod ["main" ]
436416 return entry if isinstance (node , _expr .Function ) else entry .body
437417
438-
439418def infer_shape (inputs ):
440419 """A method to get the output shape of an intermediate node in the graph."""
441420 out_type = infer_type (inputs )
442421 out_shapes = get_const_tuple (out_type .checked_type .shape )
443422 return out_shapes
444423
445-
446424def infer_channels (inputs , transpose = False ):
447425 """A hack for getting 'channels' or 'units' since caffe2 does not provide
448426 these attributes. We check the shape of weights provided to get the number.
@@ -452,14 +430,12 @@ def infer_channels(inputs, transpose=False):
452430 channels = out_shapes [0 ][0 ] if not transpose else out_shapes [0 ][1 ]
453431 return channels
454432
455-
456433def new_var (name_hint ,
457434 type_annotation = None ,
458435 shape = None ,
459436 dtype = "float32" ):
460437 return _expr .var (name_hint , type_annotation , shape , dtype )
461438
462-
463439class Renamer (object ):
464440 """A simply renamer for operators.
465441
0 commit comments