@@ -78,8 +78,7 @@ def _visit(op):
7878 if not fail [0 ]:
7979 begin = tvm .call_extern (
8080 "int32" , "VTAUopLoopBegin" , stmt .extent , * gemm_offsets )
81- end = tvm .call_extern (
82- "int32" , "VTAUopLoopEnd" , stmt .extent , * gemm_offsets )
81+ end = tvm .call_extern ("int32" , "VTAUopLoopEnd" )
8382 return [begin , ret , end ]
8483 raise ValueError ("Failed to fold the GEMM instructions.." )
8584
@@ -683,8 +682,14 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
683682 else :
684683 raise RuntimeError (
685684 "Function call not recognized %s" % (loop_body .value .name ))
685+ elif isinstance (loop_body .value , tvm .expr .Load ):
686+ alu_opcode = env .dev .ALU_OPCODE_SHR
687+ lhs = loop_body .value
688+ rhs = tvm .const (0 )
686689 else :
687- raise RuntimeError ("Expression not recognized %s" % (type (loop_body .value )))
690+ raise RuntimeError (
691+ "Expression not recognized %s, %s, %s" % (
692+ type (loop_body .value ), str (loop_body .value ), str (stmt )))
688693
689694 # Derive array index coefficients
690695 dst_coeff = tvm .arith .DetectLinearEquation (dst_idx , indices )
@@ -772,7 +777,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
772777 irb = tvm .ir_builder .create ()
773778 for idx , extent in enumerate (extents ):
774779 irb .emit (tvm .call_extern (
775- "int32" , "VTAUopLoopBegin" , extent , dst_coeff [idx ], src_coeff [idx ]))
780+ "int32" , "VTAUopLoopBegin" ,
781+ extent , dst_coeff [idx ], src_coeff [idx ], 0 ))
782+ use_imm = int (use_imm )
776783 irb .emit (tvm .call_extern (
777784 "int32" , "VTAUopPush" ,
778785 1 , 0 ,
@@ -804,5 +811,6 @@ def debug_print(stmt):
804811 stmt : Stmt
805812 The
806813 """
814+ # pylint: disable=superfluous-parens
807815 print (stmt )
808816 return stmt
0 commit comments