Skip to content

Commit ae61603

Browse files
[TIR] Tuple Reduction Support in CreatePrimFunc (#10671)
* [CreatePrimFunc] Support multi-source ReduceNode (#64) * initial * assert structural equal test * Enhancement and tests * Fix dtype * Docs Co-authored-by: Andrew Liu <[email protected]>
1 parent 5b5bf75 commit ae61603

File tree

2 files changed

+215
-29
lines changed

2 files changed

+215
-29
lines changed

src/te/operation/create_primfunc.cc

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ struct CreateFuncInfo {
8383
}
8484
};
8585

86-
BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
87-
Array<PrimExpr> bindings, PrimExpr expr_body,
88-
CreateFuncInfo* info, arith::Analyzer* analyzer) {
86+
BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
87+
const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
88+
PrimExpr expr_body, CreateFuncInfo* info,
89+
arith::Analyzer* analyzer) {
8990
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
9091
Array<IterVar> iter_vars;
9192
std::unordered_map<const VarNode*, PrimExpr> var_map;
@@ -105,16 +106,22 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
105106
f_push_block_vars(compute_op->axis);
106107
f_push_block_vars(compute_op->reduce_axis);
107108

108-
// Step 2. Declare buffer and update op2buffers
109-
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
110-
info->tensor2buffers[tensor] = buffer;
111-
112-
// Step 3. Add Buffer to root_alloc
113-
if (!info->IsArg(tensor)) {
114-
info->root_alloc.push_back(buffer);
109+
// Step 2.
110+
// - Declare buffers
111+
// - Update `op2buffers`
112+
// - Add the non-argument tensors to `alloc_buffer` of the root block
113+
Array<Buffer> buffers;
114+
for (const te::Tensor& tensor : tensors) {
115+
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
116+
info->tensor2buffers[tensor] = buffer;
117+
buffers.push_back(buffer);
118+
119+
if (!info->IsArg(tensor)) {
120+
info->root_alloc.push_back(info->tensor2buffers[tensor]);
121+
}
115122
}
116123

117-
// Step 4. Calculate indices for BufferStore
124+
// Step 3. Calculate indices for BufferStore
118125
Array<PrimExpr> indices;
119126
indices.reserve(compute_op->axis.size());
120127
for (const IterVar& iter_var : compute_op->axis) {
@@ -123,26 +130,75 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
123130
indices.push_back(it->second);
124131
}
125132

126-
// Step 5. Create block body.
133+
// Step 4. Create block body.
134+
String block_name{nullptr};
127135
Optional<Stmt> init = NullOpt;
128136
Stmt body;
129137
if (const auto* reduce = expr_body.as<ReduceNode>()) {
130138
// Case 1. Reduce compute
131-
ICHECK_EQ(reduce->source.size(), 1);
132-
const PrimExpr& lhs = BufferLoad(buffer, indices);
133-
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
134-
ICHECK(lhs->dtype == rhs->dtype);
135-
const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0];
136-
const PrimExpr& init_body = reduce->combiner->identity_element[0];
137-
body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices);
138-
init = BufferStore(buffer, analyzer->Simplify(init_body), indices);
139+
block_name = compute_op->name;
140+
int n_buffers = buffers.size();
141+
142+
Array<PrimExpr> lhs;
143+
Array<PrimExpr> rhs;
144+
lhs.reserve(n_buffers);
145+
rhs.reserve(n_buffers);
146+
147+
// Make the LHS operands and RHS operands:
148+
// - A LHS operand is the buffer storing the reduction result, with corresponding indices.
149+
// - A RHS operand is the value to be reduced.
150+
for (int i = 0; i < n_buffers; ++i) {
151+
const PrimExpr& left = BufferLoad(buffers[i], indices);
152+
const PrimExpr& right =
153+
analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map));
154+
lhs.push_back(left);
155+
rhs.push_back(right);
156+
ICHECK_EQ(left->dtype, right->dtype);
157+
}
158+
159+
Array<Var> temp_vars;
160+
Array<Stmt> body_stmts;
161+
Array<Stmt> init_stmts;
162+
temp_vars.reserve(n_buffers);
163+
body_stmts.reserve(n_buffers);
164+
init_stmts.reserve(n_buffers);
165+
166+
// - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
167+
// rhs)" into the target buffer position.
168+
// - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
169+
// variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
170+
// then store the value of the variables into the target buffer positions.
171+
for (int i = 0; i < n_buffers; ++i) {
172+
const Buffer& buffer = buffers[i];
173+
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
174+
PrimExpr value{nullptr};
175+
if (n_buffers > 1) {
176+
temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
177+
value = temp_vars.back();
178+
} else {
179+
value = reduce->combiner.get()->operator()(lhs, rhs)[i];
180+
}
181+
body_stmts.push_back(BufferStore(buffer, value, indices));
182+
}
183+
184+
init = SeqStmt::Flatten(init_stmts);
185+
body = SeqStmt::Flatten(body_stmts);
186+
if (n_buffers > 1) {
187+
// When there are multiple buffers, we wrap the body with LetStmts.
188+
for (int i = n_buffers - 1; i >= 0; --i) {
189+
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
190+
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
191+
}
192+
}
139193
} else {
140194
// Case 2. Data parallel compute
195+
ICHECK_EQ(tensors.size(), 1);
196+
block_name = info->GetUniqueName(tensors[0]->GetNameHint());
141197
const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
142-
body = BufferStore(buffer, analyzer->Simplify(compute_body), indices);
198+
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
143199
}
144200

145-
// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
201+
// Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
146202
Map<String, ObjectRef> annotations;
147203
auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
148204
if (const auto* tensor_value = value.as<te::TensorNode>()) {
@@ -166,14 +222,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
166222
// Set script_parsing_detect_access
167223
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
168224

169-
// Step 7. Create Block and BlockRealize.
225+
// Step 6. Create Block and BlockRealize.
170226
return BlockRealize(/*iter_values=*/std::move(bindings),
171227
/*predicate=*/Bool(true),
172228
/*block=*/
173229
Block(/*iter_vars=*/std::move(iter_vars),
174230
/*reads=*/{},
175231
/*writes=*/{},
176-
/*name_hint=*/info->GetUniqueName(tensor->GetNameHint()),
232+
/*name_hint=*/block_name,
177233
/*body=*/std::move(body),
178234
/*init=*/std::move(init),
179235
/*alloc_buffers=*/{},
@@ -192,12 +248,38 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
192248
}
193249
// Step 2. Generate block bodies.
194250
Array<Stmt> seq_stmt;
195-
for (int i = 0; i < compute_op->num_outputs(); ++i) {
196-
const te::Tensor& tensor = compute_op.output(i);
197-
PrimExpr expr_body = compute_op->body[i];
198-
seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body),
199-
info, analyzer));
251+
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
252+
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
253+
return a->combiner.same_as(b->combiner) && //
254+
a->source.same_as(b->source) && //
255+
a->axis.same_as(b->axis) && //
256+
a->condition.same_as(b->condition) && //
257+
((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
258+
};
259+
260+
PrimExpr expr_body = compute_op->body[0];
261+
Array<te::Tensor> tensors = {compute_op.output(0)};
262+
const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
263+
// specially handle reduction inline for multiplre reductions.
264+
for (size_t k = 1; k < compute_op->body.size(); ++k) {
265+
const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
266+
ICHECK(reduce_);
267+
ICHECK(f_reducer_equal(reduce_, reduce))
268+
<< "The Reduce inputs of ComputeOp should have the same attribute except value_index";
269+
tensors.push_back(compute_op.output(k));
270+
}
271+
272+
seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body),
273+
info, analyzer));
274+
} else {
275+
for (int i = 0; i < compute_op->num_outputs(); ++i) {
276+
const te::Tensor& tensor = compute_op.output(i);
277+
PrimExpr expr_body = compute_op->body[i];
278+
seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings,
279+
std::move(expr_body), info, analyzer));
280+
}
200281
}
282+
201283
Stmt body = SeqStmt::Flatten(seq_stmt);
202284

203285
// Step 3. Generate loop nesting.

tests/python/unittest/test_te_create_primfunc.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,108 @@ def test_tensor_attr():
359359
tvm.ir.assert_structural_equal(func, rt_func)
360360

361361

362+
def te_argmax_idx_val():
363+
def f_combine(x, y):
364+
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
365+
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
366+
return lhs, rhs
367+
368+
def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
369+
return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)
370+
371+
argmax = te.comm_reducer(f_combine, f_identity, name="argmax")
372+
373+
m = te.var("m")
374+
n = te.var("n")
375+
idx = te.placeholder((m, n), name="idx", dtype="int32")
376+
val = te.placeholder((m, n), name="val", dtype="float32")
377+
k = te.reduce_axis((0, n), "k")
378+
max_idx, max_val = te.compute(
379+
(m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="argmax"
380+
)
381+
return [idx, val, max_idx, max_val]
382+
383+
384+
@T.prim_func
385+
def tir_argmax_idx_val(
386+
var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
387+
) -> None:
388+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
389+
m = T.var("int32")
390+
n = T.var("int32")
391+
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
392+
val = T.match_buffer(var_val, [m, n], dtype="float32")
393+
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32")
394+
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32")
395+
for i0, i1 in T.grid(m, n):
396+
with T.block("argmax"):
397+
i, k = T.axis.remap("SR", [i0, i1])
398+
T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k])
399+
T.writes(argmax_v0[i], argmax_v1[i])
400+
with T.init():
401+
argmax_v0[i] = T.int32(-1)
402+
argmax_v1[i] = T.min_value("float32")
403+
v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k])
404+
v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k])
405+
argmax_v0[i] = v_argmax_v0
406+
argmax_v1[i] = v_argmax_v1
407+
408+
409+
def te_argmax_val_idx():
410+
def f_combine(x, y):
411+
lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0])
412+
rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1])
413+
return lhs, rhs
414+
415+
def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
416+
return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1)
417+
418+
argmax = te.comm_reducer(f_combine, f_identity, name="argmax")
419+
420+
m = te.var("m")
421+
n = te.var("n")
422+
val = te.placeholder((m, n), name="val", dtype="float32")
423+
idx = te.placeholder((m, n), name="idx", dtype="int32")
424+
k = te.reduce_axis((0, n), "k")
425+
max_val, max_idx = te.compute(
426+
(m,), lambda i: argmax((val[i, k], idx[i, k]), axis=k), name="argmax"
427+
)
428+
return [val, idx, max_val, max_idx]
429+
430+
431+
@T.prim_func
432+
def tir_argmax_val_idx(
433+
var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
434+
) -> None:
435+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
436+
m = T.var("int32")
437+
n = T.var("int32")
438+
val = T.match_buffer(var_val, [m, n], dtype="float32")
439+
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
440+
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
441+
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
442+
for i0, i1 in T.grid(m, n):
443+
with T.block("argmax"):
444+
i, k = T.axis.remap("SR", [i0, i1])
445+
T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k])
446+
T.writes(argmax_v0[i], argmax_v1[i])
447+
with T.init():
448+
argmax_v0[i] = T.min_value("float32")
449+
argmax_v1[i] = T.int32(-1)
450+
v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
451+
v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
452+
argmax_v0[i] = v_argmax_v0
453+
argmax_v1[i] = v_argmax_v1
454+
455+
456+
def test_argmax_idx_val():
457+
_check_workload(te_argmax_idx_val, tir_argmax_idx_val)
458+
459+
460+
def test_argmax_val_idx():
461+
_check_workload(te_argmax_val_idx, tir_argmax_val_idx)
462+
463+
362464
if __name__ == "__main__":
363465
test_unique_name()
364466
test_matmul()
@@ -371,3 +473,5 @@ def test_tensor_attr():
371473
test_constant()
372474
test_select_simplify()
373475
test_tensor_attr()
476+
test_argmax_idx_val()
477+
test_argmax_val_idx()

0 commit comments

Comments
 (0)