Skip to content

Commit b6453da

Browse files
committed
Docs
1 parent fde75c1 commit b6453da

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

src/te/operation/create_primfunc.cc

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,22 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
106106
f_push_block_vars(compute_op->axis);
107107
f_push_block_vars(compute_op->reduce_axis);
108108

109-
// Step 2. Declare buffer and update op2buffers
109+
// Step 2.
110+
// - Declare buffers
111+
// - Update `op2buffers`
112+
// - Add the non-argument tensors to `alloc_buffer` of the root block
110113
Array<Buffer> buffers;
111114
for (const te::Tensor& tensor : tensors) {
112115
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
113116
info->tensor2buffers[tensor] = buffer;
114117
buffers.push_back(buffer);
115-
}
116118

117-
// Step 3. Add Buffer to root_alloc
118-
for (const te::Tensor& tensor : tensors) {
119119
if (!info->IsArg(tensor)) {
120120
info->root_alloc.push_back(info->tensor2buffers[tensor]);
121121
}
122122
}
123123

124-
// Step 4. Calculate indices for BufferStore
124+
// Step 3. Calculate indices for BufferStore
125125
Array<PrimExpr> indices;
126126
indices.reserve(compute_op->axis.size());
127127
for (const IterVar& iter_var : compute_op->axis) {
@@ -130,7 +130,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
130130
indices.push_back(it->second);
131131
}
132132

133-
// Step 5. Create block body.
133+
// Step 4. Create block body.
134134
String block_name{nullptr};
135135
Optional<Stmt> init = NullOpt;
136136
Stmt body;
@@ -144,6 +144,9 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
144144
lhs.reserve(n_buffers);
145145
rhs.reserve(n_buffers);
146146

147+
// Make the LHS operands and RHS operands:
148+
// - A LHS operand is the buffer storing the reduction result, with corresponding indices.
149+
// - A RHS operand is the value to be reduced.
147150
for (int i = 0; i < n_buffers; ++i) {
148151
const PrimExpr& left = BufferLoad(buffers[i], indices);
149152
const PrimExpr& right =
@@ -160,6 +163,11 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
160163
body_stmts.reserve(n_buffers);
161164
init_stmts.reserve(n_buffers);
162165

166+
// - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
167+
// rhs)" into the target buffer position.
168+
// - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
169+
// variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
170+
// then store the value of the variables into the target buffer positions.
163171
for (int i = 0; i < n_buffers; ++i) {
164172
const Buffer& buffer = buffers[i];
165173
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
@@ -176,6 +184,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
176184
init = SeqStmt::Flatten(init_stmts);
177185
body = SeqStmt::Flatten(body_stmts);
178186
if (n_buffers > 1) {
187+
// When there are multiple buffers, we wrap the body with LetStmts.
179188
for (int i = n_buffers - 1; i >= 0; --i) {
180189
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
181190
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
@@ -189,7 +198,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
189198
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
190199
}
191200

192-
// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
201+
// Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
193202
Map<String, ObjectRef> annotations;
194203
auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
195204
if (const auto* tensor_value = value.as<te::TensorNode>()) {
@@ -213,7 +222,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
213222
// Set script_parsing_detect_access
214223
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
215224

216-
// Step 7. Create Block and BlockRealize.
225+
// Step 6. Create Block and BlockRealize.
217226
return BlockRealize(/*iter_values=*/std::move(bindings),
218227
/*predicate=*/Bool(true),
219228
/*block=*/
@@ -228,12 +237,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
228237
/*annotations=*/std::move(annotations)));
229238
}
230239

231-
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
232-
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
233-
(a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
234-
((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
235-
}
236-
237240
Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info,
238241
arith::Analyzer* analyzer) {
239242
// Step 1. Creating loop vars for block bindings.
@@ -246,15 +249,23 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
246249
// Step 2. Generate block bodies.
247250
Array<Stmt> seq_stmt;
248251
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
252+
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
253+
return a->combiner.same_as(b->combiner) && //
254+
a->source.same_as(b->source) && //
255+
a->axis.same_as(b->axis) && //
256+
a->condition.same_as(b->condition) && //
257+
((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
258+
};
259+
249260
PrimExpr expr_body = compute_op->body[0];
250261
Array<te::Tensor> tensors = {compute_op.output(0)};
251262
const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
252263
// specially handle reduction inline for multiplre reductions.
253264
for (size_t k = 1; k < compute_op->body.size(); ++k) {
254265
const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
255266
ICHECK(reduce_);
256-
ICHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
257-
<< "have the same attribute except value_index";
267+
ICHECK(f_reducer_equal(reduce_, reduce))
268+
<< "The Reduce inputs of ComputeOp should have the same attribute except value_index";
258269
tensors.push_back(compute_op.output(k));
259270
}
260271

0 commit comments

Comments
 (0)