Skip to content

Commit 22b65bc

Browse files
MasterJH5574spectrometerHBH
authored andcommitted
[Unity][Op] Group normalization (#14194)
* [TOPI] Group normalization As more and more ML models nowadays contain the group normalization computation, we find it beneficial to introduce this op to TOPI level. It will enable us to optimize the group normalization operation as a whole in a more convenient way. This PR introduces the group normalization op to TOPI. The group norm operation was introduced in https://arxiv.org/abs/1803.08494. The implementation uses tuple reduction, same as the implementation of layer norm. Implemented with tuple reduction, the corresponding generated TIR function can be optimized by cross-thread reduction or rfactor through MetaSchedule. Prior to this PR, the group normalization operations in frontend models are translated to a series of operations, which brings inconvenience when we want to optimize the group norm op as a whole. With the TOPI implementation of group norm being introduced by #14193, we can now use it to legalize the high-level group norm op and optimize it using cross-thread reduction or rfactor via MetaSchedule. Co-authored-by: Bohan Hou <[email protected]>
1 parent 1f04221 commit 22b65bc

File tree

11 files changed

+638
-57
lines changed

11 files changed

+638
-57
lines changed

include/tvm/relax/attrs/nn.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
174174
}
175175
}; // struct LayerNormAttrs
176176

177+
/*! \brief Attributes used in group_norm operator */
178+
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
179+
int num_groups;
180+
int channel_axis;
181+
Array<Integer> axes;
182+
double epsilon;
183+
bool center;
184+
bool scale;
185+
186+
TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") {
187+
TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into.");
188+
TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel.");
189+
TVM_ATTR_FIELD(axes).describe(
190+
"The axes that along which the normalization is applied (excluding the channel axis).");
191+
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
192+
TVM_ATTR_FIELD(center).describe(
193+
"Indicating if the beta offset will be added to the normalized tensor.");
194+
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
195+
}
196+
}; // struct GroupNormAttrs
197+
177198
/*! \brief Attributes used in dropout operator */
178199
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
179200
double rate;

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
465465
)
466466

467467
def _group_norm(self, node: fx.node.Node) -> relax.Var:
468-
# torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
469-
# affine=True, device=None, dtype=None)
468+
import torch # type: ignore
469+
470470
x = self.env[node.args[0]]
471471
module = self.named_modules[node.target]
472-
num_groups = module.num_groups
473-
num_channels = module.num_channels
474-
eps = module.eps
475-
affine = module.affine
476472

477-
shape = self.shape_of(x)
478-
assert len(shape) == 4
479-
N, C, H, W = shape[0], shape[1], shape[2], shape[3]
480-
assert C == num_channels
481-
assert C % num_groups == 0
482-
grouped_x = self.block_builder.emit(
483-
relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
484-
)
485-
mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True))
486-
sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
487-
square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
488-
sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True))
489-
var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value)
490-
var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
491-
std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
492-
norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))
493-
494-
if affine:
495-
weight = self.params[module.weight]
496-
bias = self.params[module.bias]
497-
weight_reshape = self.block_builder.emit(
498-
relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1))
499-
)
500-
bias_reshape = self.block_builder.emit(
501-
relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
473+
if module.affine:
474+
gamma = self.params[module.weight]
475+
beta = self.params[module.bias]
476+
else:
477+
gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type)
478+
beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type)
479+
480+
dim = len(self.shape_of(x))
481+
return self.block_builder.emit(
482+
relax.op.nn.group_norm(
483+
x,
484+
gamma,
485+
beta,
486+
num_groups=module.num_groups,
487+
channel_axis=1,
488+
axes=list(range(2, dim)),
489+
epsilon=module.eps,
502490
)
503-
norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape))
504-
norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape))
505-
return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
491+
)
506492

507493
def _embedding(self, node: fx.node.Node) -> relax.Var:
508494
x = self.env[node.args[0]]

python/tvm/relax/op/nn/nn.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,64 @@ def layer_norm(
527527
return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore
528528

529529

530+
def group_norm(
531+
data: Expr,
532+
gamma: Expr,
533+
beta: Expr,
534+
num_groups: int,
535+
channel_axis: int,
536+
axes: Union[int, List[int]],
537+
epsilon: float = 1e-5,
538+
center: bool = True,
539+
scale: bool = True,
540+
) -> Expr:
541+
r"""
542+
Group normalization (Yuxin Wu and et al., 2016).
543+
Applies group normalization to the n-dimensional input array.
544+
This operator takes an n-dimensional input array. First separate the input array
545+
into groups along the channel axis. Then apply layer normalization to each group.
546+
547+
Parameters
548+
----------
549+
data : relax.Expr
550+
Input to which group_norm will be applied.
551+
552+
gamma : relax.Expr
553+
The gamma scale factor.
554+
555+
beta : relax.Expr
556+
The beta offset factor.
557+
558+
num_groups : int
559+
Number of groups to separate the channels into.
560+
561+
channel_axis : int
562+
The index of the channel axis in the input data.
563+
564+
axes : Union[int, List[int]]
565+
The axes that along which the normalization is applied (excluding the group axis)
566+
567+
epsilon : float
568+
Small float added to variance to avoid dividing by zero.
569+
570+
center : bool
571+
Indicating if the beta offset will be added to the normalized tensor.
572+
573+
scale : bool
574+
Indicating if the gamma scale will be multiplied.
575+
576+
Returns
577+
-------
578+
result : relax.Expr
579+
The computed result.
580+
"""
581+
if isinstance(axes, int):
582+
axes = [axes]
583+
return _ffi_api.group_norm( # type: ignore
584+
data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale
585+
)
586+
587+
530588
def dropout(data: Expr, rate: float = 0.5) -> Expr:
531589
"""Applies the dropout operation to the input tensor.
532590

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr:
196196
)
197197

198198

199+
@register_legalize("relax.nn.group_norm")
200+
def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
201+
return bb.call_te(
202+
topi.nn.group_norm,
203+
call.args[0],
204+
call.args[1],
205+
call.args[2],
206+
call.attrs.num_groups,
207+
call.attrs.channel_axis,
208+
call.attrs.axes,
209+
call.attrs.epsilon,
210+
)
211+
212+
199213
@register_legalize("relax.nn.dropout")
200214
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
201215
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")

src/relax/op/nn/nn.cc

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm")
233233
.add_argument("beta", "Tensor", "The beta offset factor.")
234234
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);
235235

236+
/* relax.nn.group_norm */
237+
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
238+
239+
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
240+
Array<Integer> axes, double epsilon, bool center, bool scale) {
241+
ObjectPtr<GroupNormAttrs> attrs = make_object<GroupNormAttrs>();
242+
attrs->num_groups = num_groups;
243+
attrs->channel_axis = channel_axis;
244+
attrs->axes = std::move(axes);
245+
attrs->epsilon = epsilon;
246+
attrs->center = center;
247+
attrs->scale = scale;
248+
249+
static const Op& op = Op::Get("relax.nn.group_norm");
250+
return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {});
251+
}
252+
253+
TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm);
254+
255+
StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) {
256+
Op op = Downcast<Op>(call->op);
257+
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
258+
const auto* attrs = call->attrs.as<GroupNormAttrs>();
259+
260+
TensorStructInfo data_sinfo = input_sinfo[0];
261+
int channel_axis = -1;
262+
if (!data_sinfo->IsUnknownNdim()) {
263+
channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis);
264+
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes);
265+
// channel_axis must be in axes.
266+
if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) {
267+
ctx->ReportFatal(Diagnostic::Error(call)
268+
<< op
269+
<< " expects that channel_axis must not be in axes, but got channel_axis: "
270+
<< channel_axis << ", axes: " << attrs->axes);
271+
}
272+
}
273+
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
274+
ctx->ReportFatal(Diagnostic::Error(call)
275+
<< op << " expects that data must be float, but got " << data_sinfo->dtype);
276+
}
277+
arith::Analyzer* analyzer = ctx->GetAnalyzer();
278+
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
279+
if (data_shape != nullptr && channel_axis != -1 &&
280+
analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) {
281+
ctx->ReportFatal(Diagnostic::Error(call)
282+
<< op << " expects that the size of channel_axis must be divisible by "
283+
<< attrs->num_groups << ", but got " << data_shape->values[channel_axis]);
284+
}
285+
for (int i = 1; i < static_cast<int>(op->arguments.size()); ++i) {
286+
if (input_sinfo[i]->dtype != data_sinfo->dtype) {
287+
ctx->ReportFatal(Diagnostic::Error(call)
288+
<< op << " expects that all inputs must have the same dtype, but got "
289+
<< input_sinfo[i]->dtype << " and " << data_sinfo->dtype);
290+
} else if (input_sinfo[i]->ndim != 1) {
291+
ctx->ReportFatal(Diagnostic::Error(call)
292+
<< op << " expects that all inputs must have ndim=1, but got "
293+
<< input_sinfo[i]->ndim);
294+
} else if (channel_axis != -1) {
295+
const auto* shape = input_sinfo[i]->shape.as<ShapeExprNode>();
296+
if (shape != nullptr && data_shape != nullptr) {
297+
PrimExpr channel_size = data_shape->values[channel_axis];
298+
PrimExpr input_size = shape->values[0];
299+
if (analyzer->CanProve(channel_size != input_size)) {
300+
ctx->ReportFatal(Diagnostic::Error(call)
301+
<< op << " expects that the size of input " << i
302+
<< " must be equal to the size of channel_axis, but got " << input_size
303+
<< " and " << channel_size);
304+
}
305+
}
306+
}
307+
}
308+
return data_sinfo;
309+
}
310+
311+
TVM_REGISTER_OP("relax.nn.group_norm")
312+
.set_attrs_type<GroupNormAttrs>()
313+
.set_num_inputs(3)
314+
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
315+
.add_argument("gamma", "Tensor", "The gamma scale factor.")
316+
.add_argument("beta", "Tensor", "The beta offset factor.")
317+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm);
318+
236319
/* relax.nn.dropout */
237320
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
238321

src/relax/op/nn/nn.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_
6868
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center,
6969
bool scale);
7070

71+
/*! \brief Compute group normalization. */
72+
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
73+
Array<Integer> axes, double epsilon, bool center, bool scale);
74+
7175
/*!
7276
* \brief Applies the dropout operation to the input tensor.
7377
* \param data The input data to the operator.

tests/python/relax/test_ast_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def f(
362362
y: R.Tensor(("m",), "float32"),
363363
r: R.Tensor(dtype="int64"),
364364
) -> R.Object:
365-
m = T.var("int64")
365+
m = T.int64()
366366
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
367367
w: R.Tensor = R.multiply(z, z)
368368
q: R.Tensor(ndim=2) = R.add(w, w)
@@ -431,7 +431,7 @@ def test_call_tir():
431431
# also from test_parser
432432
@R.function
433433
def foo(x: R.Tensor(("m", "n"), "float32")):
434-
m, n = T.var("int64"), T.var("int64")
434+
m, n = T.int64(), T.int64()
435435
gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
436436
return gv0
437437

tests/python/relax/test_frontend_from_fx.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -708,29 +708,19 @@ def main(
708708
w1: R.Tensor((3,), dtype="float32"),
709709
w2: R.Tensor((3,), dtype="float32"),
710710
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
711-
# block 0
712711
with R.dataflow():
713-
lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape(
714-
input_1, (1, 3, 1, 10, 10)
715-
)
716-
lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean(
717-
lv, axis=[2, 3, 4], keepdims=True
718-
)
719-
lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1)
720-
lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2)
721-
lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum(
722-
lv3, axis=[2, 3, 4], keepdims=True
712+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm(
713+
input_1,
714+
w1,
715+
w2,
716+
num_groups=3,
717+
channel_axis=1,
718+
axes=[2, 3],
719+
epsilon=1.0000000000000001e-05,
720+
center=True,
721+
scale=True,
723722
)
724-
lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0))
725-
lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05))
726-
lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6)
727-
lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7)
728-
lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1))
729-
lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1))
730-
lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9)
731-
lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10)
732-
lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10))
733-
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13
723+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
734724
R.output(gv)
735725
return gv
736726

0 commit comments

Comments
 (0)