@@ -2069,16 +2069,24 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20692069 // For each block var binding, we find
20702070 const PrimExpr& block_bind = block->iter_values [i_block];
20712071 const PrimExpr& desc_bind = desc_block->iter_values [i_desc];
2072+ LOG (INFO) << " block bind: " << block_bind;
2073+ LOG (INFO) << " desc bind: " << desc_bind;
20722074 // Step 4.1. Find the corresponding loop of the i-th block var of block
20732075 const tir::ForNode* block_loop = nullptr ;
20742076 for (int i = 0 , n = block_loops.size (); i < n; ++i) {
20752077 // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
20762078 PrimExpr r = analyzer.Simplify (block_bind - block_loops[i]->loop_var );
2077- if (!tir::UsesVar (r, [&block_loop_vars](const tir::VarNode* var) {
2079+ const auto * int_block_extent = block_loops[i]->extent .as <IntImmNode>();
2080+ const auto * int_desc_extent = desc_loops[i_desc]->extent .as <IntImmNode>();
2081+
2082+ if ((i_desc == 0 && int_block_extent->value == int_desc_extent->value ) || !tir::UsesVar (r, [&block_loop_vars](const tir::VarNode* var) {
20782083 return block_loop_vars.count (var);
20792084 })) {
20802085 block_loop = block_loops[i];
2086+ LOG (INFO) << " Selected " << i << " th block loop " << block_loops[i]->loop_var << " , " << block_loop->extent ;
20812087 break ;
2088+ } else {
2089+ LOG (INFO) << i << " th block loop not ok " << " , " << block_loops[i]->loop_var << " , " << block_loops[i]->extent ;
20822090 }
20832091 }
20842092 if (block_loop == nullptr ) {
@@ -2093,6 +2101,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20932101 return desc_loop_vars.count (var);
20942102 })) {
20952103 desc_loop = desc_loops[i];
2104+ LOG (INFO) << " Selected " << i << " th desc loop " << desc_loop->extent ;;
20962105 break ;
20972106 }
20982107 }
0 commit comments