Skip to content

Commit 1e9c1bf

Browse files
authored
[Relay][Pass] SimplifyCastLike/Cast and ConcretizeFullLikeRewrite rewrites for SimplifyExpr (#7827)
1 parent 90dce48 commit 1e9c1bf

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

src/relay/transforms/simplify_expr.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,33 @@ class SimplifyReshape : public DFPatternRewrite {
7575
DFPattern x_;
7676
};
7777

78+
/*!
79+
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
80+
*/
81+
class SimplifyCast : public DFPatternRewrite {
82+
public:
83+
SimplifyCast() {
84+
data_pat_ = IsWildcard();
85+
like_pat_ = IsWildcard();
86+
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
87+
}
88+
89+
Expr Callback(const Expr& pre, const Expr& post,
90+
const Map<DFPattern, Array<Expr>>& node_map) const override {
91+
const CallNode* call = pre.as<CallNode>();
92+
const TensorTypeNode* data_ty = call->args[0]->checked_type().as<TensorTypeNode>();
93+
const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
94+
if (like_ty->dtype == data_ty->dtype) {
95+
return node_map[data_pat_][0];
96+
}
97+
return post;
98+
}
99+
100+
protected:
101+
DFPattern data_pat_;
102+
DFPattern like_pat_;
103+
};
104+
78105
/*!
79106
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
80107
* and merges or cancels them.
@@ -321,6 +348,17 @@ class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
321348
}
322349
};
323350

351+
class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite {
352+
public:
353+
ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like")) {}
354+
355+
Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
356+
DataType dtype) const override {
357+
// `like_pat_` here is `fill_value`
358+
return MakeFull(node_map[like_pat_][0], shape, dtype);
359+
}
360+
};
361+
324362
class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
325363
public:
326364
ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
@@ -439,12 +477,14 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
439477
DFPatternRewriteComposer composer;
440478
composer.AddRewrite<ConcretizeZerosLikeRewrite>();
441479
composer.AddRewrite<ConcretizeOnesLikeRewrite>();
480+
composer.AddRewrite<ConcretizeFullLikeRewrite>();
442481
composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
443482
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
444483
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
445484
composer.AddRewrite<EliminateIdentityRewrite>();
446485
composer.AddRewrite<SimplifyReshape>();
447486
composer.AddRewrite<SimplifyTranspose>();
487+
composer.AddRewrite<SimplifyCast>();
448488
composer.AddRewrite<FullElementwise>();
449489
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
450490
}

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ def check(x, y=None, do_nothing=False):
236236
check(id_op(const, x), id_op(op_like(x), x))
237237

238238

239+
def test_simplify_cast():
240+
dtype = "int32"
241+
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
242+
expr1 = relay.cast(data, dtype)
243+
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
244+
expr2 = relay.cast_like(data, dtype_like)
245+
246+
expected = run_infer_type(data)
247+
actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr())
248+
assert tvm.ir.structural_equal(actual1, expected)
249+
actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
250+
assert tvm.ir.structural_equal(actual2, expected)
251+
252+
239253
def test_concretize_reshape_like():
240254
data = relay.var("data", shape=(2, 3, 4), dtype="float32")
241255
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
@@ -276,6 +290,17 @@ def test_concretize_ones_like():
276290
assert tvm.ir.structural_equal(actual, expected)
277291

278292

293+
def test_concretize_full_like():
294+
dtype = "int32"
295+
shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
296+
fill_value = relay.var("fill", relay.TensorType((), "float32"))
297+
expr = relay.full_like(shape_like, fill_value)
298+
299+
expected = run_infer_type(relay.full(fill_value, (3, 4, 5), dtype))
300+
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
301+
assert tvm.ir.structural_equal(actual, expected)
302+
303+
279304
def test_concretize_collapse_sum_like():
280305
data = relay.var("data", shape=(3, 3, 3), dtype="float32")
281306
shape_like = relay.var("shape_like", shape=(3,), dtype="float32")

0 commit comments

Comments
 (0)