Skip to content

Commit 04e8162

Browse files
vinx13tqchen
authored andcommitted
[Relay][Pass] CanonicalizeCast (#3280)
1 parent fa35104 commit 04e8162

File tree

5 files changed

+232
-0
lines changed

5 files changed

+232
-0
lines changed

include/tvm/relay/transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,13 @@ TVM_DLL Pass CanonicalizeOps();
534534
*/
535535
TVM_DLL Pass AlterOpLayout();
536536

537+
/*!
538+
* \brief Canonicalize cast expressions to make operator fusion more efficient.
539+
*
540+
* \return The pass.
541+
*/
542+
TVM_DLL Pass CanonicalizeCast();
543+
537544
} // namespace transform
538545
} // namespace relay
539546
} // namespace tvm

python/tvm/relay/transform.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,16 @@ def PartialEvaluate():
445445
"""
446446
return _transform.PartialEvaluate()
447447

448+
def CanonicalizeCast():
449+
"""
450+
Canonicalize cast expressions to make operator fusion more efficient.
451+
452+
Returns
453+
-------
454+
ret : tvm.relay.Pass
455+
The registered pass that canonicalizes cast expression.
456+
"""
457+
return _transform.CanonicalizeCast()
448458

449459
def _wrap_class_module_pass(pass_cls, pass_info):
450460
"""Wrap a python class as function pass"""

src/relay/backend/build_module.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
299299
pass_seqs.push_back(transform::CombineParallelConv2D(3));
300300
pass_seqs.push_back(transform::FoldConstant());
301301
pass_seqs.push_back(transform::FoldScaleAxis());
302+
pass_seqs.push_back(transform::CanonicalizeCast());
302303
pass_seqs.push_back(transform::CanonicalizeOps());
303304

304305
// Alter layout transformation is only applied to homogeneous execution yet.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file canonicalize_cast.cc
23+
* \brief Canonicalize cast expressions to make operator fusion more efficient.
24+
*/
25+
#include <tvm/relay/pass.h>
26+
#include <tvm/relay/expr_functor.h>
27+
#include <tvm/relay/attrs/nn.h>
28+
#include <tvm/relay/transform.h>
29+
#include "pattern_util.h"
30+
#include "pass_util.h"
31+
32+
namespace tvm {
33+
namespace relay {
34+
35+
// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
36+
// copy of it in each branch such that after fusion the previous function have output with fewer
37+
// bits.
38+
//
39+
// Consider the following example:
40+
// \code
41+
// def @main(x: int8) {
42+
// %1 = cast(%x, f32)
43+
// %2 = exp(%1)
44+
// %3 = log(%1)
45+
// (%3, 4)
46+
// }
47+
// \endcode
48+
//
49+
// We would like to prevent sharing of the cast expression such that operator fusion can produce
50+
// more efficient result as below.
51+
// \code
52+
// def @main(x: int8) {
53+
// %1 = fn (%p1: i8) {
54+
// exp(cast(%p1, f32)
55+
// }
56+
// %3 = %1(%x)
57+
// %2 = fn (%p1: i8) {
58+
// log(cast(%p1, f32)
59+
// }
60+
// %4 = %2(%x)
61+
// (%3, 4)
62+
// }
63+
// \endcode
64+
class CastCanonicalizer : public ExprMutator {
65+
public:
66+
Expr VisitExpr_(const CallNode* call) {
67+
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
68+
69+
if (const OpNode* opnode = call->op.as<OpNode>()) {
70+
auto pattern = fpattern[GetRef<Op>(opnode)];
71+
if (pattern <= kBroadcast) {
72+
Array<Expr> call_args = call->args;
73+
bool unchanged = true;
74+
for (size_t i = 0; i < call_args.size(); ++i) {
75+
Expr arg = call_args[i];
76+
Expr new_arg = GetNewCallArg(arg);
77+
if (!arg.same_as(new_arg)) {
78+
call_args.Set(i, new_arg);
79+
unchanged = false;
80+
}
81+
}
82+
if (unchanged) {
83+
return GetRef<Expr>(call);
84+
}
85+
return CallNode::make(call->op, call_args, call->attrs, call->type_args);
86+
}
87+
}
88+
89+
Expr new_expr = ExprMutator::VisitExpr_(call);
90+
return new_expr;
91+
}
92+
93+
private:
94+
std::unordered_map<const Node*, size_t> ref_counter_;
95+
96+
Expr GetNewCallArg(const Expr& e) {
97+
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
98+
99+
static auto& cast = Op::Get("cast");
100+
Expr new_expr = this->VisitExpr(e);
101+
102+
if (const CallNode* call = e.as<CallNode>()) {
103+
if (call->op.same_as(cast)) {
104+
auto attrs = call->attrs.as<CastAttrs>();
105+
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
106+
CHECK(from_type);
107+
108+
if (from_type->dtype.bits() < attrs->dtype.bits()) {
109+
if (++ref_counter_[call] > 1) {
110+
const CallNode* new_call = new_expr.as<CallNode>();
111+
CHECK(new_call);
112+
CHECK(new_call->op.same_as(cast));
113+
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
114+
new_call->type_args);
115+
}
116+
}
117+
}
118+
}
119+
return new_expr;
120+
}
121+
};
122+
123+
Expr CanonicalizeCast(const Expr& e) {
124+
return CastCanonicalizer().Mutate(e);
125+
}
126+
127+
namespace transform {
128+
129+
Pass CanonicalizeCast() {
130+
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
131+
[=](Function f, Module m, PassContext pc) {
132+
return Downcast<Function>(CanonicalizeCast(f));
133+
};
134+
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
135+
{ir::StringImm::make("InferType")});
136+
}
137+
138+
TVM_REGISTER_API("relay._transform.CanonicalizeCast")
139+
.set_body_typed(CanonicalizeCast);
140+
141+
} // namespace transform
142+
143+
} // namespace relay
144+
} // namespace tvm
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import tvm.relay as relay
20+
import tvm.relay.module as _module
21+
import tvm.relay.transform as _transform
22+
23+
24+
def test_canonicalize_cast():
25+
def before(data, conv_weight, bias1, bias2):
26+
x = relay.nn.conv2d(data, conv_weight,
27+
channels=16,
28+
kernel_size=(3, 3),
29+
padding=(1, 1),
30+
out_dtype="int8")
31+
x1 = relay.cast(x, dtype="int32")
32+
y1 = relay.add(x1, bias1)
33+
y2 = relay.add(x1, bias2)
34+
y = relay.add(y1, y2)
35+
return relay.Function([data, conv_weight, bias1, bias2], y)
36+
37+
def expected(data, conv_weight, bias1, bias2):
38+
x = relay.nn.conv2d(data, conv_weight,
39+
channels=16,
40+
kernel_size=(3, 3),
41+
padding=(1, 1),
42+
out_dtype="int8")
43+
x1 = relay.cast(x, dtype="int32")
44+
x2 = relay.cast(x, dtype="int32")
45+
y1 = relay.add(x1, bias1)
46+
y2 = relay.add(x2, bias2)
47+
y = relay.add(y1, y2)
48+
return relay.Function([data, conv_weight, bias1, bias2], y)
49+
50+
def check(shape):
51+
data = relay.var("data", shape=shape, dtype="int8")
52+
conv_weight = relay.var("weight")
53+
bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
54+
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
55+
y = before(data, conv_weight, bias1, bias2)
56+
mod = _module.Module.from_expr(y)
57+
seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
58+
_transform.InferType()])
59+
with _transform.PassContext(opt_level=3):
60+
mod = seq(mod)
61+
y = mod[mod.entry_func.name_hint]
62+
y_expected = expected(data, conv_weight, bias1, bias2)
63+
y_expected = relay.ir_pass.infer_type(y_expected)
64+
assert relay.ir_pass.alpha_equal(y, y_expected)
65+
66+
check((1, 16, 7, 7))
67+
68+
69+
if __name__ == '__main__':
70+
test_canonicalize_cast()

0 commit comments

Comments
 (0)