Skip to content

Commit e23dae5

Browse files
lhutton1pfk-beta
authored andcommitted
[microNPU] Remove identity operations between non-compute operations (apache#10411)
Builds upon the work in apache#10254 to remove identity operations sandwiched between two non-compute operations (reshape/strided slice - concatenate is handled differently), under certain conditions. Specifically, an identity operation is not removed when the dimensionality between the two non-compute operations is reduced, due to non-congruent values being accessed incorrectly. For example, ``` strided_slice(dims=4) -> identity -> reshape(dims=4) ``` becomes... ``` strided_slice -> reshape ``` but, ``` strided_slice(dims=4) -> identity -> reshape(dims=2) ``` remains as... ``` strided_slice -> identity -> reshape ``` Change-Id: Ie28ba384fcb3230d6f4651c0c19e2b9526ebcc42
1 parent f5178d4 commit e23dae5

File tree

3 files changed

+97
-12
lines changed

3 files changed

+97
-12
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
347347
mod = OutlineCompilerFunctions("ethos-u")(mod)
348348
mod = LegalizeEthosU()(mod)
349349
mod = LUTsOptimizer()(mod)
350+
mod = relay.transform.InferType()(mod)
350351
mod = IdentityOptimizer()(mod)
351352
mod = LayoutOptimizer()(mod)
352353
mod = relay.transform.InferType()(mod)

src/relay/backend/contrib/ethosu/codegen.cc

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,33 @@ class RemoveRedundantIdentities : public MixedModeMutator {
115115
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
116116
Call call = Downcast<Call>(post);
117117

118-
// only consider rewrite if current op is an NPU compute op.
118+
// don't consider rewrite if current op is an identity or concatenate.
119119
if (!call->op->IsInstance<OpNode>()) {
120120
return post;
121121
}
122122
const auto* op = call->op.as<OpNode>();
123123
std::string op_name = op->name;
124-
if (op_name.substr(0, 15) != "contrib.ethosu." || op_name == "contrib.ethosu.identity") {
124+
if (op_name == "contrib.ethosu.identity" || op_name == "concatenate") {
125125
return post;
126126
}
127127

128128
// check if we can rewrite parent identity operations to current call.
129129
bool needs_rewrite = false;
130130
Array<Expr> new_args;
131131
for (const auto& arg : call->args) {
132-
if (const auto* parent_callnode = arg.as<CallNode>()) {
132+
Expr current_arg = arg;
133+
134+
// expand tuple to get parent op if we run into one - nested tuples are not supported.
135+
if (const auto* tuple_get_item = arg.as<TupleGetItemNode>()) {
136+
const auto* tuple = tuple_get_item->tuple.as<TupleNode>();
137+
current_arg = tuple->fields[tuple_get_item->index];
138+
}
139+
140+
if (const auto* parent_callnode = current_arg.as<CallNode>()) {
133141
if (const auto* parent_op = parent_callnode->op.as<OpNode>()) {
134142
Call parent_call = GetRef<Call>(parent_callnode);
135-
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call)) {
143+
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call) &&
144+
CheckIdentityBetweenTransformOperations(call, parent_call)) {
136145
needs_rewrite = true;
137146
new_args.push_back(parent_call->args[0]);
138147
continue;
@@ -143,7 +152,10 @@ class RemoveRedundantIdentities : public MixedModeMutator {
143152
}
144153

145154
if (needs_rewrite) {
146-
return Call(call->op, new_args, call->attrs, call->type_args);
155+
Call new_call = Call(call->op, new_args, call->attrs, call->type_args);
156+
// since we are only removing an identity, we know the type information has not changed
157+
new_call->checked_type_ = call->checked_type_;
158+
return new_call;
147159
}
148160
return post;
149161
}
@@ -156,6 +168,41 @@ class RemoveRedundantIdentities : public MixedModeMutator {
156168
bool has_no_activation = attrs->activation == "NONE";
157169
return does_not_requantize && has_no_activation;
158170
}
171+
172+
bool CheckIdentityBetweenTransformOperations(const Call& call, const Call& identity_call) {
173+
const auto* op = call->op.as<OpNode>();
174+
std::vector<std::string> nc_ops = {"reshape", "strided_slice"};
175+
176+
if (op && (std::find(nc_ops.begin(), nc_ops.end(), op->name) != nc_ops.end())) {
177+
// check if the parent to identity operation is also a non-compute operation,
178+
// if it isn't we can safely remove the identity in question by returning true.
179+
const auto* identity_arg = identity_call->args[0].as<CallNode>();
180+
if (!identity_arg) {
181+
return true;
182+
}
183+
const auto* identity_arg_op = identity_arg->op.as<OpNode>();
184+
if (!identity_arg_op ||
185+
!(std::find(nc_ops.begin(), nc_ops.end(), identity_arg_op->name) != nc_ops.end())) {
186+
return true;
187+
}
188+
189+
const auto* call_tt = call->checked_type_.as<TensorTypeNode>();
190+
const auto* identity_arg_tt = identity_arg->checked_type_.as<TensorTypeNode>();
191+
CHECK(call_tt && identity_arg_tt)
192+
<< "InferType should be run before RemoveRedundantIdentities";
193+
194+
// we can only remove the identity operation if the second non-compute operation
195+
// in the sequence does not reduce the dimensionality of the output to the first
196+
// non-compute operation. Doing so could lead to data being accessed incorrectly
197+
// by the subsequent compute operation due to the reduction in dimensionality.
198+
size_t first_transform_op_dims = identity_arg_tt->shape.size();
199+
size_t second_transform_op_dims = call_tt->shape.size();
200+
if (second_transform_op_dims < first_transform_op_dims) {
201+
return false;
202+
}
203+
}
204+
return true;
205+
}
159206
};
160207

161208
/*!
@@ -177,8 +224,8 @@ tvm::transform::Pass IdentityOptimizer() {
177224
}
178225
return mod;
179226
};
180-
return tvm::transform::CreateModulePass(pass_func, 0,
181-
"relay.backend.contrib.ethos-u.IdentityOptimizer", {});
227+
return tvm::transform::CreateModulePass(
228+
pass_func, 0, "relay.backend.contrib.ethos-u.IdentityOptimizer", {"InferType"});
182229
}
183230

184231
TVM_REGISTER_GLOBAL("relay.ext.ethos-u.IdentityOptimizer").set_body_typed(IdentityOptimizer);

tests/python/contrib/test_ethosu/test_identity_optimizer.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,14 @@ def test_many_output_identity():
179179
def get_graph(get_expected=False):
180180
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
181181
x = relay.reshape(x, newshape=(1, 1, 4, 4))
182-
identity = infra.make_ethosu_identity(x)
182+
if not get_expected:
183+
x = infra.make_ethosu_identity(x)
183184
outputs = []
184185
for _ in range(4):
185-
ifm = x if get_expected else identity
186-
outputs.append(infra.make_ethosu_unary_elementwise(ifm, 4, "ABS"))
187-
outputs.append(relay.strided_slice(identity, begin=(0, 0, 0, 0), end=(1, 1, 4, 4)))
186+
outputs.append(infra.make_ethosu_unary_elementwise(x, 4, "ABS"))
187+
ss = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 1, 4, 4))
188+
identity_2 = infra.make_ethosu_identity(ss)
189+
outputs.append(identity_2)
188190
out = relay.concatenate(outputs, axis=0)
189191
return relay.Function(relay.analysis.free_vars(out), out)
190192

@@ -220,7 +222,8 @@ def test_identity_removal_with_multiple_transform_ops():
220222
def get_graph(get_expected=False):
221223
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
222224
x = relay.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 2, 2, 2])
223-
x = infra.make_ethosu_identity(x)
225+
if not get_expected:
226+
x = infra.make_ethosu_identity(x)
224227
x = relay.reshape(x, newshape=(1, 1, 1, 8))
225228
if not get_expected:
226229
x = infra.make_ethosu_identity(x)
@@ -267,6 +270,25 @@ def get_graph(get_expected=False):
267270
_assert_structural_equal(actual, expected)
268271

269272

273+
def test_multiple_transform_ops_with_reduction_in_dimensionality():
274+
"""Removal of an identity operation between two transform operations is usually okay.
275+
However, if the dimensionality of the input is reduced by the second transformation
276+
operation, it can lead to an output mismatch. Checking that the pass doesn't remove
277+
an identity given this case."""
278+
279+
def get_graph():
280+
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
281+
x = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 2, 2, 2))
282+
x = infra.make_ethosu_identity(x)
283+
x = relay.reshape(x, newshape=(1, 2, 4))
284+
x = infra.make_ethosu_identity(x)
285+
return relay.Function(relay.analysis.free_vars(x), x)
286+
287+
actual = _optimize(get_graph())
288+
expected = _optimize(get_graph(), optimize=False)
289+
_assert_structural_equal(actual, expected)
290+
291+
270292
def test_identity_optimizer_runs_in_compilation_pipeline():
271293
"""Checks that the identity optimization pass is run as part of the NPU compilation pipeline."""
272294

@@ -320,3 +342,18 @@ def model(x):
320342
return y
321343

322344
_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")
345+
346+
347+
def test_multiple_transform_ops_same_output():
348+
"""Check case of identity removal between transform ops and
349+
then without, making sure they have the same output."""
350+
ifm_shape = (1, 2, 2, 4)
351+
352+
@tf.function
353+
def model(x):
354+
x = tf.reshape(x, (1, 1, 4, 4))
355+
x = tf.slice(x, (0, 0, 0, 0), (1, 1, 4, 3))
356+
x = tf.reshape(x, (12,))
357+
return x
358+
359+
_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")

0 commit comments

Comments
 (0)