Skip to content

Commit 7d911f4

Browse files
bindogvinx13
authored andcommitted
[Relay][Op] Add instance norm op (#4004)
* [Relay][Op] Add instance norm op * mend [Relay][Op] Add instance norm op
1 parent 36201fe commit 7d911f4

File tree

8 files changed

+286
-2
lines changed

8 files changed

+286
-2
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
492492
}; // struct BatchNormAttrs
493493

494494

495+
/*! \brief Attributes used in instance_norm operator */
496+
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
497+
int axis;
498+
double epsilon;
499+
bool center;
500+
bool scale;
501+
502+
TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
503+
TVM_ATTR_FIELD(axis)
504+
.describe("Specify which shape axis denotes the channel.")
505+
.set_default(1);
506+
TVM_ATTR_FIELD(epsilon)
507+
.describe("Small float added to variance to avoid dividing by zero")
508+
.set_default(1e-5);
509+
TVM_ATTR_FIELD(center).set_default(true)
510+
.describe("If true, add offset of beta to normalized tensor; "
511+
"otherwise, beta is ignored.");
512+
TVM_ATTR_FIELD(scale).set_default(true)
513+
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
514+
}
515+
}; // struct InstanceNormAttrs
516+
517+
495518
/*! \brief Attributes used in layer_norm operator */
496519
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
497520
int axis;

python/tvm/relay/frontend/mxnet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
324324
return _op.nn.batch_norm(*inputs, **new_attrs)
325325

326326

327+
def _mx_instance_norm(inputs, attrs):
328+
assert len(inputs) == 3
329+
new_attrs = {}
330+
new_attrs["axis"] = attrs.get_int("axis", 1)
331+
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
332+
return _op.nn.instance_norm(*inputs, **new_attrs)
333+
334+
327335
def _mx_layer_norm(inputs, attrs):
328336
assert len(inputs) == 3
329337
if attrs.get_bool("output_mean_var", False):
@@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs):
11331141
"Dropout" : _mx_dropout,
11341142
"BatchNorm" : _mx_batch_norm,
11351143
"BatchNorm_v1" : _mx_batch_norm,
1144+
"InstanceNorm" : _mx_instance_norm,
11361145
"LayerNorm" : _mx_layer_norm,
11371146
"LRN" : _mx_lrn,
11381147
"L2Normalization" : _mx_l2_normalize,

python/tvm/relay/frontend/onnx.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ def _impl_v1(cls, inputs, attr, params):
176176
return out[0]
177177

178178

179+
class InstanceNorm(OnnxOpConverter):
180+
""" Operator converter for BatchNorm.
181+
"""
182+
183+
@classmethod
184+
def _impl_v1(cls, inputs, attr, params):
185+
return AttrCvt(op_name='instance_norm')(inputs, attr, params)
186+
187+
179188
class Conv(OnnxOpConverter):
180189
""" Operator converter for Conv.
181190
"""
@@ -999,7 +1008,7 @@ def _get_convert_map(opset):
9991008
'GlobalAveragePool': Renamer('global_avg_pool2d'),
10001009
'GlobalMaxPool': Renamer('global_max_pool2d'),
10011010
'BatchNormalization': BatchNorm.get_converter(opset),
1002-
# 'InstanceNormalization'
1011+
'InstanceNormalization': InstanceNorm.get_converter(opset),
10031012
# 'LpNormalization'
10041013
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
10051014
'Flatten': Flatten.get_converter(opset),

python/tvm/relay/op/nn/nn.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,73 @@ def batch_norm(data,
935935
return TupleWrapper(result, 3)
936936

937937

938+
def instance_norm(data,
939+
gamma,
940+
beta,
941+
axis=1,
942+
epsilon=1e-5,
943+
center=True,
944+
scale=True):
945+
r"""
946+
Instance Normalization (Ulyanov and et al., 2016)
947+
Applies instance normalization to the n-dimensional input array.
948+
949+
.. math::
950+
951+
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
952+
* gamma + beta
953+
954+
The instance normalization is similar to batch normalization, but unlike
955+
batch normalization, the mean and var are calculated per-dimension
956+
separately for each object(instance) in a mini-batch, not over a batch.
957+
And the same normalization is applied both at test and train time.
958+
959+
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
960+
have shape *(k,)*.
961+
962+
The parameter ``axis`` specifies which axis of the input shape denotes
963+
the 'channel'. The default is 1. Specifying -1 sets the channel axis
964+
to be the last item in the input shape.
965+
966+
.. note::
967+
968+
This operator can be optimized away for inference.
969+
970+
Parameters
971+
----------
972+
data : tvm.relay.Expr
973+
Input to which instance_norm will be applied.
974+
975+
gamma : tvm.relay.Expr
976+
The gamma scale factor.
977+
978+
beta : tvm.relay.Expr
979+
The beta offset factor.
980+
981+
axis : int, optional, default=1
982+
Specify along which shape axis the channel is specified.
983+
984+
epsilon : double, optional, default=1e-5
985+
Small float added to variance to avoid dividing by zero.
986+
987+
center : boolean, optional, default=True
988+
If True, add offset of beta to normalized tensor, If False,
989+
beta is ignored.
990+
991+
scale : boolean, optional, default=True
992+
If True, multiply by gamma. If False, gamma is not used.
993+
994+
Returns
995+
-------
996+
result : tvm.relay.Expr
997+
The normalized data.
998+
999+
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
1000+
https://arxiv.org/abs/1607.08022
1001+
"""
1002+
return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)
1003+
1004+
9381005
def layer_norm(data,
9391006
gamma,
9401007
beta,
@@ -964,7 +1031,7 @@ def layer_norm(data,
9641031
Parameters
9651032
----------
9661033
data : tvm.relay.Expr
967-
Input to which batch_norm will be applied.
1034+
Input to which layer_norm will be applied.
9681035
9691036
gamma : tvm.relay.Expr
9701037
The gamma scale factor.

src/relay/op/nn/nn.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,76 @@ axis to be the last item in the input shape.
640640
.add_type_rel("BatchNorm", BatchNormRel);
641641

642642

643+
// instance_norm
644+
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);
645+
646+
bool InstanceNormRel(const Array<Type>& types,
647+
int num_inputs,
648+
const Attrs& attrs,
649+
const TypeReporter& reporter) {
650+
CHECK_EQ(types.size(), 4);
651+
const auto* data = types[0].as<TensorTypeNode>();
652+
if (data == nullptr) return false;
653+
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
654+
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
655+
CHECK(axis >= 0 && axis < (int)data->shape.size());
656+
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
657+
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
658+
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
659+
660+
return true;
661+
}
662+
663+
Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
664+
bool center, bool scale) {
665+
auto attrs = make_node<InstanceNormAttrs>();
666+
attrs->axis = axis;
667+
attrs->epsilon = epsilon;
668+
attrs->center = center;
669+
attrs->scale = scale;
670+
static const Op& op = Op::Get("nn.instance_norm");
671+
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
672+
}
673+
674+
TVM_REGISTER_API("relay.op.nn._make.instance_norm")
675+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
676+
runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
677+
});
678+
679+
RELAY_REGISTER_OP("nn.instance_norm")
680+
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
681+
Applies instance normalization to the n-dimensional input array.
682+
683+
.. math::
684+
685+
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
686+
* gamma + beta
687+
688+
The instance normalization is similar to batch normalization, but unlike
689+
batch normalization, the mean and var are calculated per-dimension
690+
separately for each object(instance) in a mini-batch, not over a batch.
691+
And the same normalization is applied both at test and train time.
692+
693+
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
694+
have shape *(k,)*.
695+
696+
The parameter ``axis`` specifies which axis of the input shape denotes
697+
the 'channel'. The default is 1. Specifying -1 sets the channel axis
698+
to be the last item in the input shape.
699+
700+
.. note::
701+
702+
This operator can be optimized away for inference.
703+
)code" TVM_ADD_FILELINE)
704+
.set_attrs_type_key("relay.attrs.InstanceNormAttrs")
705+
.set_num_inputs(3)
706+
.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
707+
.add_argument("gamma", "Tensor", "The gamma scale factor.")
708+
.add_argument("beta", "Tensor", "The beta offset factor.")
709+
.set_support_level(1)
710+
.add_type_rel("InstanceNorm", InstanceNormRel);
711+
712+
643713
// layer_norm
644714
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
645715

src/relay/pass/simplify_inference.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
9292
return out;
9393
}
9494

95+
96+
Expr InstanceNormToInferUnpack(const Attrs attrs,
97+
Expr data,
98+
Expr gamma,
99+
Expr beta,
100+
Type tdata) {
101+
auto ttype = tdata.as<TensorTypeNode>();
102+
CHECK(ttype);
103+
const auto param = attrs.as<InstanceNormAttrs>();
104+
CHECK(param);
105+
106+
int ndim = ttype->shape.size();
107+
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
108+
Array<Integer> reduced_axes;
109+
for (int i = 1; i < ndim; ++i) {
110+
if (i != axis)
111+
reduced_axes.push_back(i);
112+
}
113+
114+
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
115+
Expr mean = Mean(data, reduced_axes, true, false);
116+
Expr var = Variance(data, mean, reduced_axes, true, false);
117+
Expr denom = Sqrt(Add(var, epsilon));
118+
Expr out = Divide(Subtract(data, mean), denom);
119+
120+
if (param->scale) {
121+
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
122+
}
123+
if (param->center) {
124+
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
125+
}
126+
return out;
127+
}
128+
129+
95130
class InferenceSimplifier : public ExprMutator {
96131
public:
97132
Expr VisitExpr_(const TupleGetItemNode* n) final {
@@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator {
116151

117152
Expr VisitExpr_(const CallNode* n) {
118153
static const Op& batch_norm = Op::Get("nn.batch_norm");
154+
static const Op& instance_norm = Op::Get("nn.instance_norm");
119155
static const Op& layer_norm = Op::Get("nn.layer_norm");
120156
auto new_n = ExprMutator::VisitExpr_(n);
121157
if (n->op.same_as(batch_norm)) {
@@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator {
124160
const auto* call = new_n.as<CallNode>();
125161
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
126162
call->args[2], n->args[0]->checked_type());
163+
} else if (n->op.same_as(instance_norm)) {
164+
const auto* call = new_n.as<CallNode>();
165+
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
166+
call->args[2], n->args[0]->checked_type());
127167
}
128168
return new_n;
129169
}

tests/python/frontend/mxnet/test_forward.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,26 @@ def verify(shape, axis=1, fix_gamma=False):
758758
verify((2, 3, 4, 5), fix_gamma=True)
759759

760760

761+
def test_forward_instance_norm():
762+
def verify(shape, axis=1, epsilon=1e-5):
763+
x = np.random.uniform(size=shape).astype("float32")
764+
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
765+
beta = np.random.uniform(size=(shape[axis])).astype("float32")
766+
ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
767+
mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
768+
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
769+
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
770+
for target, ctx in ctx_list():
771+
for kind in ["graph", "debug"]:
772+
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
773+
op_res = intrp.evaluate()(x, gamma, beta)
774+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
775+
verify((2, 3, 4, 5))
776+
verify((32, 64, 80, 64))
777+
verify((8, 6, 5))
778+
verify((8, 7, 6, 5, 4))
779+
780+
761781
def test_forward_layer_norm():
762782
def verify(shape, axis=-1):
763783
x = np.random.uniform(size=shape).astype("float32")
@@ -938,6 +958,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
938958
test_forward_sequence_mask()
939959
test_forward_contrib_div_sqrt_dim()
940960
test_forward_batch_norm()
961+
test_forward_instance_norm()
941962
test_forward_layer_norm()
942963
test_forward_one_hot()
943964
test_forward_convolution()

tests/python/frontend/onnx/test_forward.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,50 @@ def test_lrn():
416416
verify_lrn((5, 5, 5, 5), 3, 'float32')
417417
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
418418

419+
420+
def verify_instance_norm(shape, axis=1):
421+
422+
def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
423+
dims_x = len(x.shape)
424+
axis = tuple(range(2, dims_x))
425+
mean = np.mean(x, axis=axis, keepdims=True)
426+
var = np.var(x, axis=axis, keepdims=True)
427+
dim_ones = (1,) * (dims_x - 2)
428+
gamma = gamma.reshape(-1, *dim_ones)
429+
beta = beta.reshape(-1, *dim_ones)
430+
return gamma * (x - mean) / np.sqrt(var + epsilon) + beta
431+
432+
x = np.random.randn(*shape).astype(np.float32)
433+
gamma = np.random.randn(shape[1]).astype(np.float32)
434+
beta = np.random.randn(shape[1]).astype(np.float32)
435+
epsilon = 1e-5
436+
y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
437+
438+
node = onnx.helper.make_node(
439+
'InstanceNormalization',
440+
inputs=['x', 'gamma', 'beta'],
441+
outputs=['y'],
442+
epsilon=epsilon,
443+
)
444+
graph = helper.make_graph([node],
445+
"instance_norm_test",
446+
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
447+
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
448+
helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
449+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
450+
model = helper.make_model(graph, producer_name='instance_norm_test')
451+
for target, ctx in ctx_list():
452+
tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
453+
tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
454+
455+
456+
def test_instance_norm():
457+
verify_instance_norm((2, 3, 4, 5))
458+
verify_instance_norm((32, 64, 80, 64))
459+
verify_instance_norm((8, 6, 5))
460+
verify_instance_norm((8, 7, 6, 5, 4))
461+
462+
419463
def _test_upsample_nearest():
420464
scale = 2
421465
in_shape = (1, 1, 3, 3)
@@ -1270,6 +1314,7 @@ def test_erf():
12701314
test_matmul()
12711315
test_gather()
12721316
test_lrn()
1317+
test_instance_norm()
12731318
test_upsample()
12741319
test_forward_min()
12751320
test_forward_max()

0 commit comments

Comments
 (0)