Skip to content

Commit 5f0d376

Browse files
Matthew Brookhartalexwong
authored andcommitted
Simplify full broadcast (apache#7423)
* convert argwhere(full(const)) to reshape(arange()) * Add IsWildcard syntatic sugar * add a simplify expression to fold full into broadcast ops * Allow constant folding of full-like ops after SimplifyExpr * fix a bug with the Attr Pattern matching * remove skip_list
1 parent 65f3a11 commit 5f0d376

File tree

10 files changed

+185
-37
lines changed

10 files changed

+185
-37
lines changed

include/tvm/relay/dataflow_pattern.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ class DominatorPattern : public DFPattern {
524524
DFPattern IsVar(const String& name);
525525
/*! \brief Syntatic Sugar for creating a ConstantPattern */
526526
DFPattern IsConstant();
527+
/*! \brief Syntatic Sugar for creating a WildcardPattern */
528+
DFPattern IsWildcard();
527529
/*! \brief Syntatic Sugar for creating a ExprPattern */
528530
DFPattern IsExpr(const Expr& expr);
529531
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/

src/relay/ir/dataflow_matcher.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,12 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
162162
if (Op::HasAttrMap(attr_name)) {
163163
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
164164
if (op_map.count(op)) {
165-
matches = MatchRetValue(attr_value, op_map[op]);
165+
matches &= MatchRetValue(attr_value, op_map[op]);
166+
} else {
167+
matches = false;
166168
}
169+
} else {
170+
matches = false;
167171
}
168172
}
169173
} else if (auto* op = expr.as<CallNode>()) {
@@ -196,6 +200,8 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
196200
break;
197201
}
198202
}
203+
} else {
204+
matches = false;
199205
}
200206
return matches;
201207
}

src/relay/ir/dataflow_pattern.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) {
357357
}
358358
DFPattern IsVar(const String& name) { return VarPattern(name); }
359359
DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
360+
DFPattern IsWildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
360361
DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
361362
DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
362363
DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }

src/relay/op/make_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
100100

101101
Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);
102102

103+
Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);
104+
105+
Expr MakeShapeOf(Expr data, DataType dtype);
106+
107+
Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
108+
103109
} // namespace relay
104110
} // namespace tvm
105111
#endif // TVM_RELAY_OP_MAKE_OP_H_

src/relay/op/tensor/unary.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
430430
return {topi::shape(inputs[0], param->dtype)};
431431
}
432432

433-
TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
433+
Expr MakeShapeOf(Expr data, DataType dtype) {
434434
auto attrs = make_object<ShapeOfAttrs>();
435435
attrs->dtype = dtype;
436436
static const Op& op = Op::Get("shape_of");
437437
return Call(op, {data}, Attrs(attrs), {});
438-
});
438+
}
439+
440+
TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);
439441

440442
RELAY_REGISTER_OP("shape_of")
441443
.describe(R"code(Returns a tensor representing the shape of a tensor.

src/relay/transforms/fold_constant.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,6 @@ class ConstantFolder : public MixedModeMutator {
148148
}
149149
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
150150

151-
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
152-
153151
auto origin_args = call->args;
154152
call = post.as<CallNode>();
155153
// We don't constant fold function with zero arguments.
@@ -158,9 +156,6 @@ class ConstantFolder : public MixedModeMutator {
158156
if (call->args.size() == 0) return post;
159157
const OpNode* op = call->op.as<OpNode>();
160158
if (op == nullptr) return post;
161-
if (skip_list.count(op->name)) {
162-
return post;
163-
}
164159
// skip stateful ops.
165160
if (op_stateful.get(GetRef<Op>(op), false)) return post;
166161
// Try to evaluate shape_of op

src/relay/transforms/simplify_expr.cc

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,38 @@
2929
#include <tvm/support/logging.h>
3030

3131
#include "../op/tensor/transform.h"
32+
#include "pattern_utils.h"
3233

3334
namespace tvm {
3435
namespace relay {
3536

37+
class SimplifyPattern {
38+
public:
39+
virtual Expr callback(const Expr& pre, const Expr& post,
40+
const Map<DFPattern, Array<Expr>>& node_map) const = 0;
41+
42+
DFPattern pattern() const { return pattern_; }
43+
44+
protected:
45+
/*! \brief Pattern for rewriting */
46+
DFPattern pattern_;
47+
};
48+
3649
/*!
3750
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
3851
* and merges into one reshape op.
3952
*/
40-
class SimplifyReshape {
53+
class SimplifyReshape : public SimplifyPattern {
4154
public:
4255
SimplifyReshape() {
43-
x_ = WildcardPattern(make_object<WildcardPatternNode>());
56+
x_ = IsWildcard();
4457
auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
4558
auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
4659
pattern_ = reshape1({reshape2({x_})});
4760
}
4861

49-
Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
62+
Expr callback(const Expr& pre, const Expr& post,
63+
const Map<DFPattern, Array<Expr>>& node_map) const override {
5064
auto x = node_map[x_][0];
5165
bool const_shape = true;
5266
Array<Integer> newshape;
@@ -63,13 +77,82 @@ class SimplifyReshape {
6377
return post;
6478
}
6579

66-
DFPattern pattern() const { return pattern_; }
67-
6880
private:
6981
/*! \brief Pattern input */
7082
DFPattern x_;
71-
/*! \brief Pattern for consecutive reshape or reverse_reshape ops */
72-
DFPattern pattern_;
83+
};
84+
85+
/*!
86+
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
87+
*/
88+
class FullElementwise : public SimplifyPattern {
89+
public:
90+
FullElementwise() {
91+
x_ = IsWildcard();
92+
data_ = IsWildcard();
93+
value_ = IsConstant();
94+
95+
full_ = IsOp("full")({value_}) || IsOp("full_like")({data_, value_});
96+
ones_ = IsOp("ones")({}) || IsOp("ones_like")({data_});
97+
zeros_ = IsOp("zeros")({}) || IsOp("zeros_like")({data_});
98+
99+
Map<String, ObjectRef> attrs;
100+
attrs.Set("TOpPattern", Integer(static_cast<int>(kBroadcast)));
101+
DFPattern op = IsWildcard().HasAttr(attrs);
102+
DFPattern full = full_ || ones_ || zeros_;
103+
pattern_ = op({full, x_}) || op({x_, full});
104+
}
105+
106+
Expr callback(const Expr& pre, const Expr& post,
107+
const Map<DFPattern, Array<Expr>>& node_map) const override {
108+
const CallNode* call = pre.as<CallNode>();
109+
ICHECK(call);
110+
Type pre_type = pre->checked_type_;
111+
ICHECK(pre_type.as<TensorTypeNode>());
112+
auto dtype = pre_type.as<TensorTypeNode>()->dtype;
113+
auto x = node_map[x_][0];
114+
bool is_left = post.as<CallNode>()->args[1] == x;
115+
Type x_type;
116+
if (is_left) {
117+
x_type = call->args[1]->checked_type_;
118+
} else {
119+
x_type = call->args[0]->checked_type_;
120+
}
121+
122+
if (StructuralEqual()(x_type, pre_type)) {
123+
Expr value;
124+
if (node_map.count(full_)) {
125+
value = node_map[value_][0];
126+
ICHECK(IsConstScalar(value));
127+
} else if (node_map.count(ones_)) {
128+
value = MakeConstantScalar(dtype, 1);
129+
} else if (node_map.count(zeros_)) {
130+
value = MakeConstantScalar(dtype, 0);
131+
} else {
132+
ICHECK(false) << "Didn't find a full op while matching full + elementwise";
133+
}
134+
if (is_left) {
135+
return Call(call->op, {value, x}, call->attrs, call->type_args, call->span);
136+
} else {
137+
return Call(call->op, {x, value}, call->attrs, call->type_args, call->span);
138+
}
139+
}
140+
return post;
141+
}
142+
143+
private:
144+
/*! \brief binary argument */
145+
DFPattern x_;
146+
/*! \brief data ops get shape from */
147+
DFPattern data_;
148+
/*! \brief constant input */
149+
DFPattern value_;
150+
/*! \brief full op */
151+
DFPattern full_;
152+
/*! \brief ones op */
153+
DFPattern ones_;
154+
/*! \brief zeros op */
155+
DFPattern zeros_;
73156
};
74157

75158
/*!
@@ -78,22 +161,24 @@ class SimplifyReshape {
78161
class ExprSimplifier {
79162
public:
80163
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
81-
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
164+
CreateCallback(SimplifyReshape());
165+
CreateCallback(FullElementwise());
166+
}
167+
template <typename T>
168+
void CreateCallback(const T& pattern) {
169+
auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
82170
Expr pre = args[0];
83171
Expr post = args[1];
84172
Map<DFPattern, Array<Expr>> node_map = args[2];
85-
*rv = simplify_reshape_.callback(pre, post, node_map);
173+
*rv = pattern.callback(pre, post, node_map);
86174
};
87-
callbacks_.push_back(
88-
DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
175+
callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
89176
}
90177

91178
Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
92179

93180
private:
94181
IRModule mod_;
95-
/*! \brief Simplify reshape pattern */
96-
SimplifyReshape simplify_reshape_;
97182
/*! \brief Callbacks for expr simplification */
98183
Array<DFPatternCallback> callbacks_;
99184
};

tests/python/relay/test_dataflow_pattern.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ def test_no_match_op_attr():
437437
x = relay.var("x")
438438
y = relay.var("y")
439439
assert not op_pat.match(x - y)
440+
z = relay.var("z")
441+
assert not op_pat.match(relay.Let(z, x + y, z))
440442

441443

442444
def test_match_func_attr():

tests/python/relay/test_pass_fold_constant.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,22 +231,6 @@ def expected(dtype):
231231
assert tvm.ir.structural_equal(zz, zexpected)
232232

233233

234-
def test_fold_full():
235-
c_shape = (8, 9, 10)
236-
237-
def before():
238-
dtype = "float32"
239-
return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
240-
241-
def expected():
242-
# expect no changes
243-
return before()
244-
245-
zz = run_opt_pass(before(), transform.FoldConstant())
246-
zexpected = run_opt_pass(expected(), transform.InferType())
247-
assert tvm.ir.structural_equal(zz, zexpected)
248-
249-
250234
def test_fold_batch_norm():
251235
def expected():
252236
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,70 @@ def symbolic():
5858
assert tvm.ir.structural_equal(zz, after)
5959

6060

61+
def test_simplify_full_elementwise():
62+
def validate(shape, value, dtype):
63+
def before_left(x, elem_op, full):
64+
return elem_op(full, x)
65+
66+
def after_left(x, elem_op, value):
67+
return elem_op(relay.const(value, dtype), x)
68+
69+
def before_right(x, elem_op, full):
70+
return elem_op(x, full)
71+
72+
def after_right(x, elem_op, value):
73+
return elem_op(x, relay.const(value, dtype))
74+
75+
x = relay.var("x", shape=shape, dtype=dtype)
76+
elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide]
77+
full_ops = []
78+
if value == 0:
79+
full_ops.append(relay.zeros(shape, dtype))
80+
full_ops.append(relay.zeros_like(x))
81+
if value == 1:
82+
full_ops.append(relay.ones(shape, dtype))
83+
full_ops.append(relay.ones_like(x))
84+
else:
85+
full_ops.append(relay.full(relay.const(value, dtype), shape))
86+
full_ops.append(relay.full_like(x, relay.const(value, dtype)))
87+
for op in elem_ops:
88+
for full in full_ops:
89+
z = before_left(x, op, full)
90+
zz = run_opt_pass(z, transform.SimplifyExpr())
91+
after = run_opt_pass(after_left(x, op, value), transform.InferType())
92+
assert tvm.ir.structural_equal(zz, after)
93+
94+
z = before_right(x, op, full)
95+
zz = run_opt_pass(z, transform.SimplifyExpr())
96+
after = run_opt_pass(after_right(x, op, value), transform.InferType())
97+
assert tvm.ir.structural_equal(zz, after)
98+
99+
# Test the case in which x is broadcast to full's shape
100+
full_ops = []
101+
if value == 0:
102+
full_ops.append(relay.zeros(shape * 2, dtype))
103+
if value == 1:
104+
full_ops.append(relay.ones(shape * 2, dtype))
105+
else:
106+
full_ops.append(relay.full(relay.const(value, dtype), shape * 2))
107+
for op in elem_ops:
108+
for full in full_ops:
109+
z = before_left(x, op, full)
110+
zz = run_opt_pass(z, transform.SimplifyExpr())
111+
after = run_opt_pass(before_left(x, op, full), transform.InferType())
112+
assert tvm.ir.structural_equal(zz, after)
113+
114+
z = before_right(x, op, full)
115+
zz = run_opt_pass(z, transform.SimplifyExpr())
116+
after = run_opt_pass(before_right(x, op, full), transform.InferType())
117+
assert tvm.ir.structural_equal(zz, after)
118+
119+
for shape in [[10], [10, 10], [10, 10, 10]]:
120+
for dtype in ["float32", "int32"]:
121+
for value in [0, 1, 2]:
122+
validate(shape, value, dtype)
123+
124+
61125
if __name__ == "__main__":
62126
test_simplify_reshape()
127+
test_simplify_full_elementwise()

0 commit comments

Comments
 (0)