@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
4747 * binary operator with identity element
4848 */
4949struct CommReducerNode : public Node {
50- /* ! \brief The arguments of reducer */
51- Array<Var> args;
50+ /* ! \brief The left argument of reducer */
51+ Array<Var> lhs;
52+ /* ! \brief The right argument of reducer */
53+ Array<Var> rhs;
5254 /* ! \brief The result of reducer */
53- Expr result;
55+ Array< Expr> result;
5456 /* !
5557 * \brief The identity element of reducer, which leaves other
5658 * elements unchanged when combined with it, with respect to
5759 * the binary operation of this reducer uses.
5860 */
59- Expr identity_element;
61+ Array< Expr> identity_element;
6062 /* ! \brief Function call operator to combine a and b */
61- Expr operator ()(Expr a, Expr b) const ;
63+ Array< Expr> operator ()(Array< Expr> a, Array< Expr> b) const ;
6264 /* ! \brief construct CommReducer from args, result and identity_element */
63- static CommReducer make (Array<Var> args, Expr result, Expr identity_element);
65+ static CommReducer make (Array<Var> lhs, Array<Var> rhs,
66+ Array<Expr> result, Array<Expr> identity_element);
6467
6568 void VisitAttrs (AttrVisitor* v) final {
66- v->Visit (" args" , &args);
69+ v->Visit (" lhs" , &lhs);
70+ v->Visit (" rhs" , &rhs);
6771 v->Visit (" result" , &result);
6872 v->Visit (" identity_element" , &identity_element);
6973 }
@@ -84,26 +88,30 @@ struct Reduce : public ExprNode<Reduce> {
8488 /* ! \brief The commutative combiner */
8589 CommReducer combiner;
8690 /* ! \brief The source operand */
87- Expr source;
91+ Array< Expr> source;
8892 /* ! \brief The reduction axis */
8993 Array<IterVar> axis;
9094 /* !
9195 * \brief Predicate on the reduction
9296 * Only add the body to reduction if condition is true.
9397 */
9498 Expr condition;
99+ /* ! \brief the index of this reduce node */
100+ int value_index;
95101
96102 /* ! \brief construct expr from op and rdom */
97103 static Expr make (CommReducer combiner,
98- Expr src,
104+ Array< Expr> src,
99105 Array<IterVar> rdom,
100- Expr condition = const_true());
106+ Expr condition,
107+ int value_index);
101108
102109 void VisitAttrs (AttrVisitor* v) final {
103110 v->Visit (" dtype" , &type);
104111 v->Visit (" source" , &source);
105112 v->Visit (" axis" , &axis);
106113 v->Visit (" condition" , &condition);
114+ v->Visit (" value_index" , &value_index);
107115 }
108116 static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
109117 static constexpr const char * _type_key = " Reduce" ;
@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
292300/* !
293301 * \brief See pesudo code
294302 *
295- * Expr tvm_thread_allreduce(CommReducer combiner , Expr value , Expr cond,
296- * Var thread_idx1, thread_idx2 ...) {
303+ * void tvm_thread_allreduce(UIntImm size , Expr source0, ... , Expr cond,
304+ * Var reduce_temp0, .., Var thread_idx1, ...) {
297305 * // constraint by the other thread_idx remain the same.
298- * return reduce(combiner, value, cond,
299- * over [thread_idx1, thread_idx2] passed by any caller)
306+ * // reduce_temp is used to save intermediate result.
307+ * reduce_temp0, ... = reduce(combiner, source0, ..., cond
308+ * over [thread_idx1, thread_idx2] passed by any caller)
300309 * }
301310 */
302311constexpr const char * tvm_thread_allreduce = " tvm_thread_allreduce" ;
0 commit comments