Skip to content

Commit f467f66

Browse files
authored
Support for Tuple Inputs of Reducer and ComputeOp (#175)
* Support for batch ComputeOp * Support for batch ComputeOp * Fix CrossThreadReduction * Fix lint * Add UpdateArray, remove support for batch reduce * Tuple input support for reduce * rfactor works with multiple reducer; support multiple reducers with different types * Small fix * Small fix * Change return type of rfactor to Array<Expr> * Fix lint * Improve * Add tutorial * Improve tutorial * Improve tutorial
1 parent ef50162 commit f467f66

27 files changed

+739
-228
lines changed

include/tvm/ir.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
4747
* binary operator with identity element
4848
*/
4949
struct CommReducerNode : public Node {
50-
/*! \brief The arguments of reducer */
51-
Array<Var> args;
50+
/*! \brief The left argument of reducer */
51+
Array<Var> lhs;
52+
/*! \brief The right argument of reducer */
53+
Array<Var> rhs;
5254
/*! \brief The result of reducer */
53-
Expr result;
55+
Array<Expr> result;
5456
/*!
5557
* \brief The identity element of reducer, which leaves other
5658
* elements unchanged when combined with it, with respect to
5759
* the binary operation of this reducer uses.
5860
*/
59-
Expr identity_element;
61+
Array<Expr> identity_element;
6062
/*! \brief Function call operator to combine a and b */
61-
Expr operator()(Expr a, Expr b) const;
63+
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
6264
/*! \brief construct CommReducer from args, result and identity_element */
63-
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
65+
static CommReducer make(Array<Var> lhs, Array<Var> rhs,
66+
Array<Expr> result, Array<Expr> identity_element);
6467

6568
void VisitAttrs(AttrVisitor* v) final {
66-
v->Visit("args", &args);
69+
v->Visit("lhs", &lhs);
70+
v->Visit("rhs", &rhs);
6771
v->Visit("result", &result);
6872
v->Visit("identity_element", &identity_element);
6973
}
@@ -84,26 +88,30 @@ struct Reduce : public ExprNode<Reduce> {
8488
/*! \brief The commutative combiner */
8589
CommReducer combiner;
8690
/*! \brief The source operand */
87-
Expr source;
91+
Array<Expr> source;
8892
/*! \brief The reduction axis */
8993
Array<IterVar> axis;
9094
/*!
9195
* \brief Predicate on the reduction
9296
* Only add the body to reduction if condition is true.
9397
*/
9498
Expr condition;
99+
/*! \brief the index of this reduce node */
100+
int value_index;
95101

96102
/*! \brief construct expr from op and rdom */
97103
static Expr make(CommReducer combiner,
98-
Expr src,
104+
Array<Expr> src,
99105
Array<IterVar> rdom,
100-
Expr condition = const_true());
106+
Expr condition,
107+
int value_index);
101108

102109
void VisitAttrs(AttrVisitor* v) final {
103110
v->Visit("dtype", &type);
104111
v->Visit("source", &source);
105112
v->Visit("axis", &axis);
106113
v->Visit("condition", &condition);
114+
v->Visit("value_index", &value_index);
107115
}
108116
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
109117
static constexpr const char* _type_key = "Reduce";
@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
292300
/*!
293301
* \brief See pesudo code
294302
*
295-
* Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond,
296-
* Var thread_idx1, thread_idx2...) {
303+
* void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
304+
* Var reduce_temp0, .., Var thread_idx1, ...) {
297305
* // constraint by the other thread_idx remain the same.
298-
* return reduce(combiner, value, cond,
299-
* over [thread_idx1, thread_idx2] passed by any caller)
306+
* // reduce_temp is used to save intermediate result.
307+
* reduce_temp0, ... = reduce(combiner, source0, ..., cond
308+
* over [thread_idx1, thread_idx2] passed by any caller)
300309
* }
301310
*/
302311
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";

include/tvm/ir_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
9696
/*!
9797
* \brief inline all calls of f in stmt.
9898
*
99+
* \param stmt The statement to apply inline optimization.
99100
* \param f The function reference to be inlined
100101
* \param args The arguments variable of the function.
101-
* \param body The defintion body of the function.
102-
* \param stmt The statement to apply inline optimization.
102+
* \param body The definition body of the function.
103103
* \return The result stmt
104104
*
105105
* \note All the passes in this file uses SSA form and outputs SSA form.

include/tvm/operation.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
182182
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
183183
Array<IterVar> reduce_axis;
184184
/*! \brief the compute expression */
185-
Expr body;
185+
Array<Expr> body;
186186
/*! \brief constructor */
187187
ComputeOpNode() {}
188188
// override functions
@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
218218
}
219219
static Operation make(std::string name,
220220
Array<IterVar> axis,
221-
Expr body);
221+
Array<Expr> body);
222222

223223
static constexpr const char* _type_key = "ComputeOp";
224224
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
358358
/*! \brief The compute function to specify the input source of a Tensor */
359359
using FCompute = std::function<Expr (const Array<Var>& i)>;
360360

361+
/*! \brief The compute function to specify the inputs source of Tensors */
362+
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
363+
361364
/*!
362365
* \brief create a place holder tensor.
363366
* \param shape The shape of the tensor.
@@ -377,6 +380,15 @@ Tensor placeholder(Array<Expr> shape,
377380
*/
378381
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
379382

383+
/*!
384+
* \brief Construct a new tensor by computing over shape,
385+
* using the computation rule: result_tensor[axis] = fcompute(axis)
386+
* \param shape Shape of the tensor.
387+
* \param fcompute The compute function to create the tensors.
388+
* \param name The optional name of the tensor.
389+
*/
390+
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
391+
380392
/*!
381393
* \brief Construct new tensors by scan.
382394
*

include/tvm/schedule.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ class Schedule : public NodeRef {
252252
/*!
253253
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
254254
* This will create a new stage that generated the new tensor with axis
255-
* as the first dimension. The tensor's body wil be rewriten as a reduction
255+
* as the first dimension. The tensor's body will be rewritten as a reduction
256256
* over the factored tensor.
257257
*
258258
* \param tensor The tensor to be factored.
259259
* \param axis The reduction axis in tensor's schedule to be factored.
260-
* \return The created factored tensor.
260+
* \return The created factored tensors.
261261
*/
262-
Tensor rfactor(const Tensor& tensor,
263-
const IterVar& axis);
262+
Array<Tensor> rfactor(const Tensor& tensor,
263+
const IterVar& axis);
264264
/*!
265265
* \brief Normalize the schedule.
266266
* This is needed before bound inference.

python/tvm/api.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):
174174

175175
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
176176
body = fcompute(*[v.var for v in dim_var])
177+
if not isinstance(body, (list, tuple)):
178+
body = [body]
177179
body = convert(body)
178180
op_node = _api_internal._ComputeOp(
179181
name, dim_var, body)
180-
return op_node.output(0)
182+
num = op_node.num_outputs
183+
outputs = tuple(op_node.output(i) for i in range(num))
184+
return outputs[0] if num == 1 else outputs
181185

182186

183187
def scan(init, update, state_placeholder, inputs=None, name="scan"):
@@ -525,18 +529,45 @@ def _reduce_directly(*args):
525529
return res
526530

527531
def _make_reduce(expr, axis, where=None):
528-
expr = convert(expr)
529-
dtype = expr.dtype
530532
code = fcombine.__code__
531533
assert fcombine.__code__.co_argcount == 2
532-
arg_vars = [var(name, dtype) for name in code.co_varnames]
533-
result = fcombine(*[v for v in arg_vars])
534+
expr = convert(expr)
535+
if isinstance(expr, _collections.Array):
536+
size = len(expr)
537+
larr = []
538+
rarr = []
539+
dtypes = []
540+
for i in range(size):
541+
dtype = expr[i].dtype
542+
dtypes.append(dtype)
543+
lname = code.co_varnames[0] + '_' + str(i)
544+
larr.append(var(lname, dtype))
545+
rname = code.co_varnames[1] + '_' + str(i)
546+
rarr.append(var(rname, dtype))
547+
lhs = convert(larr)
548+
rhs = convert(rarr)
549+
result = fcombine(lhs, rhs)
550+
id_elem = fidentity(*dtypes)
551+
else:
552+
assert isinstance(expr, _expr.Expr)
553+
size = 1
554+
dtype = expr.dtype
555+
lvar = var(code.co_varnames[0], dtype)
556+
rvar = var(code.co_varnames[1], dtype)
557+
result = [fcombine(lvar, rvar)]
558+
id_elem = [fidentity(dtype)]
559+
lhs = convert([lvar])
560+
rhs = convert([rvar])
561+
expr = convert([expr])
534562
result = convert(result)
535-
id_elem = fidentity(dtype)
536-
assert isinstance(id_elem, _expr.Expr)
537-
combiner = _make.CommReducer(arg_vars, result, id_elem)
538-
axis = axis if isinstance(axis, list) else [axis]
539-
return _make.Reduce(combiner, expr, axis, where)
563+
id_elem = convert(id_elem)
564+
combiner = _make.CommReducer(lhs, rhs, result, id_elem)
565+
axis = convert(axis if isinstance(axis, list) else [axis])
566+
if where is None:
567+
where = convert(True)
568+
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
569+
for i in range(size))
570+
return outputs[0] if size == 1 else outputs
540571

541572
def reducer(expr, axis, where=None, *args):
542573
if isinstance(axis, (_schedule.IterVar, list)):

python/tvm/schedule.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def rfactor(self, tensor, axis):
181181
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
182182
183183
This will create a new stage that generated the new tensor with axis
184-
as the first dimension. The tensor's body wil be rewriten as a reduction
184+
as the first dimension. The tensor's body will be rewritten as a reduction
185185
over the factored tensor.
186186
187187
Parameters
@@ -193,10 +193,11 @@ def rfactor(self, tensor, axis):
193193
194194
Returns
195195
-------
196-
tfactor : Tensor
196+
tfactor : Tensor or Array of Tensor
197197
The created factored tensor.
198198
"""
199-
return _api_internal._ScheduleRFactor(self, tensor, axis)
199+
factored = _api_internal._ScheduleRFactor(self, tensor, axis)
200+
return factored[0] if len(factored) == 1 else factored
200201

201202

202203
@register_node

src/api/api_ir.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
6868
});
6969

7070
TVM_REGISTER_API("make.CommReducer")
71-
.set_body([](TVMArgs args, TVMRetValue *ret) {
72-
*ret = CommReducerNode::make(args[0], args[1], args[2]);
71+
.set_body([](TVMArgs args, TVMRetValue *ret) {
72+
*ret = CommReducerNode::make(args[0],
73+
args[1],
74+
args[2],
75+
args[3]);
7376
});
7477

75-
7678
// make from two arguments
7779
#define REGISTER_MAKE1(Node) \
7880
TVM_REGISTER_API("make."#Node) \
@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
112114
*ret = Node::make(a, b); \
113115
})
114116

115-
REGISTER_MAKE4(Reduce);
117+
REGISTER_MAKE5(Reduce);
116118
REGISTER_MAKE4(AttrStmt);
117119

118120
REGISTER_MAKE2(IntImm);

src/lang/expr.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
5050
Var x("x"), y("y");
5151
Expr result = ir::Add::make(x, y);
5252
Expr identity_element = make_zero(source.type());
53-
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
54-
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
53+
ir::CommReducer combiner =
54+
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
55+
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
5556
}
5657

5758
Expr max(Expr source, Array<IterVar> rdom) {
5859
Var x("x"), y("y");
5960
Expr result = ir::Max::make(x, y);
6061
Expr identity_element = source.type().min();
61-
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
62-
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
62+
ir::CommReducer combiner =
63+
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
64+
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
6365
}
6466

6567
Expr min(Expr source, Array<IterVar> rdom) {
6668
Var x("x"), y("y");
6769
Expr result = ir::Min::make(x, y);
6870
Expr identity_element = source.type().max();
69-
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
70-
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
71+
ir::CommReducer combiner =
72+
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
73+
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
7174
}
7275

7376
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)

0 commit comments

Comments
 (0)