@@ -554,14 +554,12 @@ def transform_Assign(self, node):
554554 4.1 var = T.allocate()
555555 """
556556
557- print ("parsing " , node .rhs .func_name )
558557 if isinstance (node .rhs , ast .Call ):
559558 # Pattern 1 & Pattern 4
560559 if isinstance (node .rhs .func_name , ast .Op ):
561560 func = None
562561 else :
563562 func = self .transform (node .rhs .func_name )
564- print (func )
565563
566564 if isinstance (func , WithScopeHandler ):
567565 if not func .concise_scope or not func .def_symbol :
@@ -582,27 +580,12 @@ def transform_Assign(self, node):
582580 elif callable (func ):
583581 args = [self .transform (arg ) for arg in node .rhs .params ]
584582 out = func (* args )
585- print (out )
586- print (node .lhs )
587583 assert len (out ) == len (node .lhs )
588584
589- lhs_vars = []
590585 for ast_var , value in zip (node .lhs , out ):
591- var = tvm .te .var (
592- ast_var .id .name ,
593- "int32" ,
594- span = tvm_span_from_synr (ast_var .span ),
595- )
596- self .context .update_symbol (var .name , var , node )
597- lhs_vars .append (var )
598-
599- body = self .parse_body (node )
586+ self .context .update_symbol (ast_var .id .name , value , node )
600587
601- for var , value in reversed (list (zip (lhs_vars , out ))):
602- self .context .remove_symbol (var .name )
603- body = tvm .tir .LetStmt (var , value , body , span = tvm_span_from_synr (node .span ))
604-
605- return body
588+ return self .parse_body (node )
606589
607590 if isinstance (node .rhs , (ast .Call , ast .Constant )):
608591 # Pattern 4 of let binding
0 commit comments