@@ -106,22 +106,22 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
106106 f_push_block_vars (compute_op->axis );
107107 f_push_block_vars (compute_op->reduce_axis );
108108
109- // Step 2. Declare buffer and update op2buffers
109+ // Step 2.
110+ // - Declare buffers
111+ // - Update `op2buffers`
112+ // - Add the non-argument tensors to `alloc_buffer` of the root block
110113 Array<Buffer> buffers;
111114 for (const te::Tensor& tensor : tensors) {
112115 Buffer buffer = decl_buffer (tensor->shape , tensor->dtype , tensor->GetNameHint (), " global" );
113116 info->tensor2buffers [tensor] = buffer;
114117 buffers.push_back (buffer);
115- }
116118
117- // Step 3. Add Buffer to root_alloc
118- for (const te::Tensor& tensor : tensors) {
119119 if (!info->IsArg (tensor)) {
120120 info->root_alloc .push_back (info->tensor2buffers [tensor]);
121121 }
122122 }
123123
124- // Step 4 . Calculate indices for BufferStore
124+ // Step 3 . Calculate indices for BufferStore
125125 Array<PrimExpr> indices;
126126 indices.reserve (compute_op->axis .size ());
127127 for (const IterVar& iter_var : compute_op->axis ) {
@@ -130,7 +130,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
130130 indices.push_back (it->second );
131131 }
132132
133- // Step 5 . Create block body.
133+ // Step 4 . Create block body.
134134 String block_name{nullptr };
135135 Optional<Stmt> init = NullOpt;
136136 Stmt body;
@@ -144,6 +144,9 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
144144 lhs.reserve (n_buffers);
145145 rhs.reserve (n_buffers);
146146
147+ // Make the LHS operands and RHS operands:
148+ // - A LHS operand is the buffer storing the reduction result, with corresponding indices.
149+ // - A RHS operand is the value to be reduced.
147150 for (int i = 0 ; i < n_buffers; ++i) {
148151 const PrimExpr& left = BufferLoad (buffers[i], indices);
149152 const PrimExpr& right =
@@ -160,6 +163,11 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
160163 body_stmts.reserve (n_buffers);
161164 init_stmts.reserve (n_buffers);
162165
166+ // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
167+ // rhs)" into the target buffer position.
168+ // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
169+ // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
170+ // then store the value of the variables into the target buffer positions.
163171 for (int i = 0 ; i < n_buffers; ++i) {
164172 const Buffer& buffer = buffers[i];
165173 init_stmts.push_back (BufferStore (buffer, reduce->combiner ->identity_element [i], indices));
@@ -176,6 +184,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
176184 init = SeqStmt::Flatten (init_stmts);
177185 body = SeqStmt::Flatten (body_stmts);
178186 if (n_buffers > 1 ) {
187+ // When there are multiple buffers, we wrap the body with LetStmts.
179188 for (int i = n_buffers - 1 ; i >= 0 ; --i) {
180189 PrimExpr value = reduce->combiner .get ()->operator ()(lhs, rhs)[i];
181190 body = LetStmt (temp_vars[i], std::move (value), std::move (body));
@@ -189,7 +198,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
189198 body = BufferStore (info->tensor2buffers [tensors[0 ]], analyzer->Simplify (compute_body), indices);
190199 }
191200
192- // Step 6 . Add script_parsing_detect_access attr for auto complete the whole IR.
201+ // Step 5 . Add script_parsing_detect_access attr for auto complete the whole IR.
193202 Map<String, ObjectRef> annotations;
194203 auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
195204 if (const auto * tensor_value = value.as <te::TensorNode>()) {
@@ -213,7 +222,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
213222 // Set script_parsing_detect_access
214223 annotations.Set (tir::attr::script_parsing_detect_access, IntImm (DataType::Int (32 ), 3 ));
215224
216- // Step 7 . Create Block and BlockRealize.
225+ // Step 6 . Create Block and BlockRealize.
217226 return BlockRealize (/* iter_values=*/ std::move (bindings),
218227 /* predicate=*/ Bool (true ),
219228 /* block=*/
@@ -228,12 +237,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
228237 /* annotations=*/ std::move (annotations)));
229238}
230239
231- inline bool ReduceEqual (const tir::ReduceNode* a, const tir::ReduceNode* b) {
232- return (a->combiner .same_as (b->combiner )) && (a->source .same_as (b->source )) &&
233- (a->axis .same_as (b->axis )) && (a->condition .same_as (b->condition )) &&
234- ((a->init .empty () && b->init .empty ()) || (a->init .same_as (b->init )));
235- }
236-
237240Stmt GenerateStmtFromCompute (const te::ComputeOp& compute_op, CreateFuncInfo* info,
238241 arith::Analyzer* analyzer) {
239242 // Step 1. Creating loop vars for block bindings.
@@ -246,15 +249,23 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
246249 // Step 2. Generate block bodies.
247250 Array<Stmt> seq_stmt;
248251 if (compute_op->body [0 ]->IsInstance <ReduceNode>()) {
252+ auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
253+ return a->combiner .same_as (b->combiner ) && //
254+ a->source .same_as (b->source ) && //
255+ a->axis .same_as (b->axis ) && //
256+ a->condition .same_as (b->condition ) && //
257+ ((a->init .empty () && b->init .empty ()) || a->init .same_as (b->init ));
258+ };
259+
249260 PrimExpr expr_body = compute_op->body [0 ];
250261 Array<te::Tensor> tensors = {compute_op.output (0 )};
251262 const tir::ReduceNode* reduce = expr_body.as <tir::ReduceNode>();
252263 // specially handle reduction inline for multiplre reductions.
253264 for (size_t k = 1 ; k < compute_op->body .size (); ++k) {
254265 const tir::ReduceNode* reduce_ = compute_op->body [k].as <tir::ReduceNode>();
255266 ICHECK (reduce_);
256- ICHECK (ReduceEqual (reduce_, reduce)) << " The Reduce inputs of ComputeOp should "
257- << " have the same attribute except value_index" ;
267+ ICHECK (f_reducer_equal (reduce_, reduce))
268+ << " The Reduce inputs of ComputeOp should have the same attribute except value_index" ;
258269 tensors.push_back (compute_op.output (k));
259270 }
260271
0 commit comments