@@ -83,9 +83,10 @@ struct CreateFuncInfo {
8383 }
8484};
8585
86- BlockRealize GenerateBlockFromTensor (const te::ComputeOp& compute_op, const te::Tensor& tensor,
87- Array<PrimExpr> bindings, PrimExpr expr_body,
88- CreateFuncInfo* info, arith::Analyzer* analyzer) {
86+ BlockRealize GenerateBlockFromTensors (const te::ComputeOp& compute_op,
87+ const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
88+ PrimExpr expr_body, CreateFuncInfo* info,
89+ arith::Analyzer* analyzer) {
8990 // Step 1. Push_back data_par axis and reduce_axis into block_vars.
9091 Array<IterVar> iter_vars;
9192 std::unordered_map<const VarNode*, PrimExpr> var_map;
@@ -105,16 +106,22 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
105106 f_push_block_vars (compute_op->axis );
106107 f_push_block_vars (compute_op->reduce_axis );
107108
108- // Step 2. Declare buffer and update op2buffers
109- Buffer buffer = decl_buffer (tensor->shape , tensor->dtype , tensor->GetNameHint (), " global" );
110- info->tensor2buffers [tensor] = buffer;
111-
112- // Step 3. Add Buffer to root_alloc
113- if (!info->IsArg (tensor)) {
114- info->root_alloc .push_back (buffer);
109+ // Step 2.
110+ // - Declare buffers
111+ // - Update `op2buffers`
112+ // - Add the non-argument tensors to `alloc_buffer` of the root block
113+ Array<Buffer> buffers;
114+ for (const te::Tensor& tensor : tensors) {
115+ Buffer buffer = decl_buffer (tensor->shape , tensor->dtype , tensor->GetNameHint (), " global" );
116+ info->tensor2buffers [tensor] = buffer;
117+ buffers.push_back (buffer);
118+
119+ if (!info->IsArg (tensor)) {
120+ info->root_alloc .push_back (info->tensor2buffers [tensor]);
121+ }
115122 }
116123
117- // Step 4 . Calculate indices for BufferStore
124+ // Step 3 . Calculate indices for BufferStore
118125 Array<PrimExpr> indices;
119126 indices.reserve (compute_op->axis .size ());
120127 for (const IterVar& iter_var : compute_op->axis ) {
@@ -123,26 +130,75 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
123130 indices.push_back (it->second );
124131 }
125132
126- // Step 5. Create block body.
133+ // Step 4. Create block body.
134+ String block_name{nullptr };
127135 Optional<Stmt> init = NullOpt;
128136 Stmt body;
129137 if (const auto * reduce = expr_body.as <ReduceNode>()) {
130138 // Case 1. Reduce compute
131- ICHECK_EQ (reduce->source .size (), 1 );
132- const PrimExpr& lhs = BufferLoad (buffer, indices);
133- const PrimExpr& rhs = Substitute (info->transformer (reduce->source [0 ]), var_map);
134- ICHECK (lhs->dtype == rhs->dtype );
135- const PrimExpr& reduce_body = reduce->combiner .get ()->operator ()({lhs}, {rhs})[0 ];
136- const PrimExpr& init_body = reduce->combiner ->identity_element [0 ];
137- body = BufferStore (buffer, analyzer->Simplify (reduce_body), indices);
138- init = BufferStore (buffer, analyzer->Simplify (init_body), indices);
139+ block_name = compute_op->name ;
140+ int n_buffers = buffers.size ();
141+
142+ Array<PrimExpr> lhs;
143+ Array<PrimExpr> rhs;
144+ lhs.reserve (n_buffers);
145+ rhs.reserve (n_buffers);
146+
147+ // Make the LHS operands and RHS operands:
148+ // - A LHS operand is the buffer storing the reduction result, with corresponding indices.
149+ // - A RHS operand is the value to be reduced.
150+ for (int i = 0 ; i < n_buffers; ++i) {
151+ const PrimExpr& left = BufferLoad (buffers[i], indices);
152+ const PrimExpr& right =
153+ analyzer->Simplify (Substitute (info->transformer (reduce->source [i]), var_map));
154+ lhs.push_back (left);
155+ rhs.push_back (right);
156+ ICHECK_EQ (left->dtype , right->dtype );
157+ }
158+
159+ Array<Var> temp_vars;
160+ Array<Stmt> body_stmts;
161+ Array<Stmt> init_stmts;
162+ temp_vars.reserve (n_buffers);
163+ body_stmts.reserve (n_buffers);
164+ init_stmts.reserve (n_buffers);
165+
166+ // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
167+ // rhs)" into the target buffer position.
168+ // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
169+ // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
170+ // then store the value of the variables into the target buffer positions.
171+ for (int i = 0 ; i < n_buffers; ++i) {
172+ const Buffer& buffer = buffers[i];
173+ init_stmts.push_back (BufferStore (buffer, reduce->combiner ->identity_element [i], indices));
174+ PrimExpr value{nullptr };
175+ if (n_buffers > 1 ) {
176+ temp_vars.push_back (Var (" v_" + buffer->name , PrimType (lhs[i].dtype ())));
177+ value = temp_vars.back ();
178+ } else {
179+ value = reduce->combiner .get ()->operator ()(lhs, rhs)[i];
180+ }
181+ body_stmts.push_back (BufferStore (buffer, value, indices));
182+ }
183+
184+ init = SeqStmt::Flatten (init_stmts);
185+ body = SeqStmt::Flatten (body_stmts);
186+ if (n_buffers > 1 ) {
187+ // When there are multiple buffers, we wrap the body with LetStmts.
188+ for (int i = n_buffers - 1 ; i >= 0 ; --i) {
189+ PrimExpr value = reduce->combiner .get ()->operator ()(lhs, rhs)[i];
190+ body = LetStmt (temp_vars[i], std::move (value), std::move (body));
191+ }
192+ }
139193 } else {
140194 // Case 2. Data parallel compute
195+ ICHECK_EQ (tensors.size (), 1 );
196+ block_name = info->GetUniqueName (tensors[0 ]->GetNameHint ());
141197 const PrimExpr& compute_body = Substitute (info->transformer (expr_body), var_map);
142- body = BufferStore (buffer , analyzer->Simplify (compute_body), indices);
198+ body = BufferStore (info-> tensor2buffers [tensors[ 0 ]] , analyzer->Simplify (compute_body), indices);
143199 }
144200
145- // Step 6 . Add script_parsing_detect_access attr for auto complete the whole IR.
201+ // Step 5 . Add script_parsing_detect_access attr for auto complete the whole IR.
146202 Map<String, ObjectRef> annotations;
147203 auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
148204 if (const auto * tensor_value = value.as <te::TensorNode>()) {
@@ -166,14 +222,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
166222 // Set script_parsing_detect_access
167223 annotations.Set (tir::attr::script_parsing_detect_access, IntImm (DataType::Int (32 ), 3 ));
168224
169- // Step 7 . Create Block and BlockRealize.
225+ // Step 6 . Create Block and BlockRealize.
170226 return BlockRealize (/* iter_values=*/ std::move (bindings),
171227 /* predicate=*/ Bool (true ),
172228 /* block=*/
173229 Block (/* iter_vars=*/ std::move (iter_vars),
174230 /* reads=*/ {},
175231 /* writes=*/ {},
176- /* name_hint=*/ info-> GetUniqueName (tensor-> GetNameHint ()) ,
232+ /* name_hint=*/ block_name ,
177233 /* body=*/ std::move (body),
178234 /* init=*/ std::move (init),
179235 /* alloc_buffers=*/ {},
@@ -192,12 +248,38 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
192248 }
193249 // Step 2. Generate block bodies.
194250 Array<Stmt> seq_stmt;
195- for (int i = 0 ; i < compute_op->num_outputs (); ++i) {
196- const te::Tensor& tensor = compute_op.output (i);
197- PrimExpr expr_body = compute_op->body [i];
198- seq_stmt.push_back (GenerateBlockFromTensor (compute_op, tensor, bindings, std::move (expr_body),
199- info, analyzer));
251+ if (compute_op->body [0 ]->IsInstance <ReduceNode>()) {
252+ auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
253+ return a->combiner .same_as (b->combiner ) && //
254+ a->source .same_as (b->source ) && //
255+ a->axis .same_as (b->axis ) && //
256+ a->condition .same_as (b->condition ) && //
257+ ((a->init .empty () && b->init .empty ()) || a->init .same_as (b->init ));
258+ };
259+
260+ PrimExpr expr_body = compute_op->body [0 ];
261+ Array<te::Tensor> tensors = {compute_op.output (0 )};
262+ const tir::ReduceNode* reduce = expr_body.as <tir::ReduceNode>();
263+ // specially handle reduction inline for multiplre reductions.
264+ for (size_t k = 1 ; k < compute_op->body .size (); ++k) {
265+ const tir::ReduceNode* reduce_ = compute_op->body [k].as <tir::ReduceNode>();
266+ ICHECK (reduce_);
267+ ICHECK (f_reducer_equal (reduce_, reduce))
268+ << " The Reduce inputs of ComputeOp should have the same attribute except value_index" ;
269+ tensors.push_back (compute_op.output (k));
270+ }
271+
272+ seq_stmt.push_back (GenerateBlockFromTensors (compute_op, tensors, bindings, std::move (expr_body),
273+ info, analyzer));
274+ } else {
275+ for (int i = 0 ; i < compute_op->num_outputs (); ++i) {
276+ const te::Tensor& tensor = compute_op.output (i);
277+ PrimExpr expr_body = compute_op->body [i];
278+ seq_stmt.push_back (GenerateBlockFromTensors (compute_op, {tensor}, bindings,
279+ std::move (expr_body), info, analyzer));
280+ }
200281 }
282+
201283 Stmt body = SeqStmt::Flatten (seq_stmt);
202284
203285 // Step 3. Generate loop nesting.
0 commit comments