1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17+
1718# pylint: disable=invalid-name,too-many-locals
19+
1820"""Utility functions for Relax"""
21+
1922import functools
2023import inspect
24+ import itertools
25+
2126from typing import Tuple as typing_Tuple
2227from typing import Any , Callable , List , Dict , Optional , TypeVar
2328
29+ import tvm
2430from .. import tir
2531from ..tir import PrimExpr
2632from ..runtime import String , convert_to_object
@@ -302,9 +308,22 @@ def gen_call_tir_inputs(
302308 out_sinfo, and tir_vars.
303309 """
304310
305- def _convert_te_arg (
306- te_args : Any , tir_var_map : Dict [tir .Var , tir .PrimExpr ]
307- ) -> typing_Tuple [Any , List [te_Tensor ]]:
311+ tir_var_map : Dict [tir .Var , tir .PrimExpr ] = {}
312+
313+ call_tir_args = []
314+ # extra list of tir expression arguments
315+ # that are not covered by Tensor
316+ extra_tir_args_list = []
317+
318+ def _copy_undefined_var (expr : tir .PrimExpr ):
319+ def _visit_expr (e : tir .PrimExpr ):
320+ if isinstance (e , tir .Var ) and e not in tir_var_map :
321+ new_var = tir .Var (e .name , e .dtype )
322+ tir_var_map [e ] = new_var
323+
324+ tir .stmt_functor .post_order_visit (expr , _visit_expr )
325+
326+ def _convert_te_arg (te_args : Any ) -> Any :
308327 """Helper function used to convert Relax expressions to TE tensor.
309328
310329 In the common case, the type of te_args is a Relax expression and is converted
@@ -335,18 +354,6 @@ def _convert_te_arg(
335354 A tuple of the converted te_args, and a list of te tensors for each converted
336355 Relax expression
337356 """
338- te_args_list = []
339- # extra list of tir expression arguments
340- # that are not covered by Tensor
341- extra_tir_args_list = []
342-
343- def _copy_undefined_var (expr : tir .PrimExpr ):
344- def _visit_expr (e : tir .PrimExpr ):
345- if isinstance (e , tir .Var ) and e not in tir_var_map :
346- new_var = tir .Var (e .name , e .dtype )
347- tir_var_map [e ] = new_var
348-
349- tir .stmt_functor .post_order_visit (expr , _visit_expr )
350357
351358 n_tensor = 0
352359
@@ -363,18 +370,23 @@ def _convert_te_arg_helper(arg):
363370 name = chr (ord ("A" ) + n_tensor ) if n_tensor < 26 else f"input{ n_tensor } "
364371 arg = te_tensor (arg , tir_var_map , name )
365372 n_tensor += 1
366- te_args_list .append (arg )
373+ call_tir_args .append (arg )
367374 return arg
368375 if isinstance (arg .struct_info , ShapeStructInfo ):
369376 assert isinstance (
370377 arg , ShapeExpr
371378 ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
372379 return [_convert_te_arg_helper (val ) for val in arg .values ]
373- if (
374- isinstance (arg .struct_info , PrimStructInfo )
375- and arg .struct_info .value is not None
376- ):
377- return _convert_te_arg_helper (arg .struct_info .value )
380+ if isinstance (arg .struct_info , PrimStructInfo ):
381+ if arg .struct_info .value is None :
382+ name = arg .name_hint if isinstance (arg , tvm .relax .Var ) else "prim_arg"
383+ call_tir_args .append (arg )
384+ return tir .Var (name , arg .struct_info .dtype )
385+ # call_tir_args.append(tir.Var(name, arg.struct_info.dtype))
386+ # return arg
387+ else :
388+ return _convert_te_arg_helper (arg .struct_info .value )
389+
378390 elif isinstance (arg , (list , Array )):
379391 return [_convert_te_arg_helper (x ) for x in arg ]
380392 elif isinstance (arg , tuple ):
@@ -388,35 +400,43 @@ def _convert_te_arg_helper(arg):
388400 elif isinstance (arg , tir .PrimExpr ):
389401 _copy_undefined_var (arg )
390402 new_arg = tir .stmt_functor .substitute (arg , tir_var_map )
391- extra_tir_args_list .append (new_arg )
403+ extra_tir_args_list .append (arg )
392404 return new_arg
393405 elif isinstance (arg , (int , float , str , Type , Attrs )) or arg is None :
394406 return arg
395407 raise TypeError ("not supported type in emit_te: {}" .format (type (arg )))
396408
397409 new_arg = _convert_te_arg_helper (te_args )
398- return new_arg , te_args_list , extra_tir_args_list
410+ return new_arg
399411
400412 def _get_unbound_tir_vars (
401413 args : List [te_Tensor ], extra_tir_args : List [PrimExpr ]
402414 ) -> List [tir .Var ]:
403415 """get unbound TIR vars (i.e TIR vars used in the shape but is not
404416 itself a dimension of a shape)"""
417+
405418 bound_vars = set ()
406419 used_vars = set ()
407420
421+ def _populate_bound_vars (expr ):
422+ if isinstance (expr , te_Tensor ):
423+ for dim in expr .shape :
424+ _populate_bound_vars (dim )
425+ elif isinstance (expr , tir .Var ):
426+ bound_vars .add (expr )
427+
408428 def _populate_used_vars (expr ):
409- if isinstance (expr , tir .Var ):
410- used_vars .add (expr )
429+ if isinstance (expr , te_Tensor ):
430+ for dim in expr .shape :
431+ _populate_used_vars (dim )
432+ elif isinstance (expr , tir .PrimExpr ):
433+ used_vars .update (tir .analysis .undefined_vars (expr ))
411434
412- for val in extra_tir_args :
413- tir . stmt_functor . post_order_visit ( val , _populate_used_vars )
435+ for arg in itertools . chain ( args , extra_tir_args ) :
436+ _populate_used_vars ( arg )
414437
415- for x in args :
416- for s in x .shape :
417- tir .stmt_functor .post_order_visit (s , _populate_used_vars )
418- if isinstance (s , tir .Var ):
419- bound_vars .add (s )
438+ for arg in args :
439+ _populate_bound_vars (arg )
420440
421441 diff = used_vars - bound_vars
422442 return list (diff )
@@ -448,19 +468,16 @@ def _shape_with_old_tir_var(
448468
449469 primfunc_attrs = kwargs .pop ("primfunc_attrs" , None )
450470
451- tir_var_map : Dict [tir .Var , tir .PrimExpr ] = {}
452- new_args , te_arg_list , tir_arg_list = _convert_te_arg (args , tir_var_map )
453- new_kwargs , te_kwarg_list , tir_kwarg_list = _convert_te_arg (kwargs , tir_var_map )
454-
455- te_args = te_arg_list + te_kwarg_list
471+ te_args = _convert_te_arg (args )
472+ te_kwargs = _convert_te_arg (kwargs )
456473
457- te_out = func (* new_args , ** new_kwargs )
474+ te_out = func (* te_args , ** te_kwargs )
458475 assert isinstance (te_out , te_Tensor ) or (
459476 isinstance (te_out , (tuple , list , Array )) and all (isinstance (t , te_Tensor ) for t in te_out )
460477 ), "only support te.tensor or tuple/list/Array of te.tensor as function output"
461478
462479 outs = [te_out ] if isinstance (te_out , te_Tensor ) else list (te_out )
463- unbound_tir_vars = _get_unbound_tir_vars (te_args + outs , tir_arg_list + tir_kwarg_list )
480+ unbound_tir_vars = _get_unbound_tir_vars ([ * te_args , * outs ], extra_tir_args_list )
464481
465482 inputs = [* te_args ] + outs + unbound_tir_vars
466483 tir_func = create_prim_func (inputs , "int64" )
@@ -470,7 +487,7 @@ def _shape_with_old_tir_var(
470487
471488 tir_func = tir_func .without_attr ("global_symbol" )
472489
473- call_tir_args = [x .op .value for x in te_args ]
490+ call_tir_args = [arg .op .value if isinstance ( arg , te_Tensor ) else arg for arg in call_tir_args ]
474491
475492 # Invert the TIR variable mapping, to convert the output shape back
476493 # with old set of variables.
0 commit comments