Skip to content

Commit b1ab4fe

Browse files
committed
check iter type
1 parent b32dd62 commit b1ab4fe

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/tir/schedule/analysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ class AutoTensorizeMappingInfoNode : public Object {
726726

727727
void VisitAttrs(AttrVisitor* v) {
728728
v->Visit("mappings", &mappings);
729+
v->Visit("lhs_buffer_map", &lhs_buffer_map);
729730
v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
730731
v->Visit("lhs_iters", &lhs_iters);
731732
v->Visit("rhs_iters", &rhs_iters);

src/tir/schedule/analysis/analysis.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2288,7 +2288,8 @@ class AutoTensorizeMappingProposer {
22882288
// Collect the set of potential iter var mapping between the workload and the tensor intrin.
22892289
// We analyze the appearance of each variable in the buffer indices of each buffer on LHS and
22902290
// RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask).
2291-
// Variables on the LHS and the RHS with the same bit-mask are potential mappings.
2291+
// Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential
2292+
// mappings.
22922293
//
22932294
// For example, consider the conv2d case. We will try to match the workload
22942295
// conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c]
@@ -2302,7 +2303,7 @@ class AutoTensorizeMappingProposer {
23022303
// both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential
23032304
// mapping.
23042305

2305-
// Note: the mapping is not unique when multiple variables in RHS has the same bit-mask.
2306+
// Note: the mapping is not unique when multiple variables on RHS has the same bit-mask.
23062307
// This is currently not supported.
23072308

23082309
using BufferMask = std::vector<bool>;
@@ -2358,16 +2359,25 @@ class AutoTensorizeMappingProposer {
23582359
}
23592360
}
23602361

2361-
// Step 3: Find variables on LHS and RHS with the same buffer mask
2362+
// Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars
2363+
// have the same iter type.
23622364
std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars;
23632365
for (const auto& kv : rhs_buffer_masks) {
23642366
const VarNode* rhs_var = kv.first;
23652367
const BufferMask& mask = kv.second;
23662368
mask_to_rhs_vars[mask].insert(GetRef<Var>(rhs_var));
23672369
}
2368-
2370+
std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type;
2371+
for (const auto& iter : extractor_->rhs_iters_) {
2372+
rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type);
2373+
}
23692374
for (const auto& iter : extractor_->lhs_iters_) {
2370-
lhs_feasible_vars_[iter->var] = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
2375+
auto& potential_mappings = lhs_feasible_vars_[iter->var];
2376+
VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
2377+
std::copy_if(
2378+
rhs_candidates.begin(), rhs_candidates.end(),
2379+
std::inserter(potential_mappings, potential_mappings.begin()),
2380+
[&](const Var& var) { return rhs_var_iter_type.at(var.get()) == iter->iter_type; });
23712381
}
23722382
}
23732383

0 commit comments

Comments
 (0)