@@ -2288,7 +2288,8 @@ class AutoTensorizeMappingProposer {
22882288 // Collect the set of potential iter var mapping between the workload and the tensor intrin.
22892289 // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and
22902290 // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask).
2291- // Variables on the LHS and the RHS with the same bit-mask are potential mappings.
2291+ // Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential
2292+ // mappings.
22922293 //
22932294 // For example, consider the conv2d case. We will try to match the workload
22942295 // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c]
@@ -2302,7 +2303,7 @@ class AutoTensorizeMappingProposer {
23022303 // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential
23032304 // mapping.
23042305
2305- // Note: the mapping is not unique when multiple variables in RHS has the same bit-mask.
2306+ // Note: the mapping is not unique when multiple variables on RHS has the same bit-mask.
23062307 // This is currently not supported.
23072308
23082309 using BufferMask = std::vector<bool >;
@@ -2358,16 +2359,25 @@ class AutoTensorizeMappingProposer {
23582359 }
23592360 }
23602361
2361- // Step 3: Find variables on LHS and RHS with the same buffer mask
2362+ // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars
2363+ // have the same iter type.
23622364 std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars;
23632365 for (const auto & kv : rhs_buffer_masks) {
23642366 const VarNode* rhs_var = kv.first ;
23652367 const BufferMask& mask = kv.second ;
23662368 mask_to_rhs_vars[mask].insert (GetRef<Var>(rhs_var));
23672369 }
2368-
2370+ std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type;
2371+ for (const auto & iter : extractor_->rhs_iters_ ) {
2372+ rhs_var_iter_type.emplace (iter->var .get (), iter->iter_type );
2373+ }
23692374 for (const auto & iter : extractor_->lhs_iters_ ) {
2370- lhs_feasible_vars_[iter->var ] = mask_to_rhs_vars[lhs_buffer_masks[iter->var .get ()]];
2375+ auto & potential_mappings = lhs_feasible_vars_[iter->var ];
2376+ VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var .get ()]];
2377+ std::copy_if (
2378+ rhs_candidates.begin (), rhs_candidates.end (),
2379+ std::inserter (potential_mappings, potential_mappings.begin ()),
2380+ [&](const Var& var) { return rhs_var_iter_type.at (var.get ()) == iter->iter_type ; });
23712381 }
23722382 }
23732383
0 commit comments