Skip to content

Commit 613cb7e

Browse files
committed
hack to tensorize loop mapping to make conv2d work
1 parent 9e4f9df commit 613cb7e

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/tir/schedule/analysis/analysis.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2069,16 +2069,24 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20692069
// For each block var binding, we find
20702070
const PrimExpr& block_bind = block->iter_values[i_block];
20712071
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
2072+
LOG(INFO) << "block bind: " << block_bind;
2073+
LOG(INFO) << "desc bind: " << desc_bind;
20722074
// Step 4.1. Find the corresponding loop of the i-th block var of block
20732075
const tir::ForNode* block_loop = nullptr;
20742076
for (int i = 0, n = block_loops.size(); i < n; ++i) {
20752077
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
20762078
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
2077-
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
2079+
const auto* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
2080+
const auto* int_desc_extent = desc_loops[i_desc]->extent.as<IntImmNode>();
2081+
2082+
if ((i_desc == 0 && int_block_extent->value == int_desc_extent->value) || !tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
20782083
return block_loop_vars.count(var);
20792084
})) {
20802085
block_loop = block_loops[i];
2086+
LOG(INFO) << "Selected " << i << " th block loop " << block_loops[i]->loop_var << ", " << block_loop->extent;
20812087
break;
2088+
} else {
2089+
LOG(INFO) << i << " th block loop not ok " << ", " << block_loops[i]->loop_var << ", " << block_loops[i]->extent;
20822090
}
20832091
}
20842092
if (block_loop == nullptr) {
@@ -2093,6 +2101,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20932101
return desc_loop_vars.count(var);
20942102
})) {
20952103
desc_loop = desc_loops[i];
2104+
LOG(INFO) << "Selected " << i << " th desc loop " << desc_loop->extent;;
20962105
break;
20972106
}
20982107
}

0 commit comments

Comments
 (0)