Skip to content

Commit f6c39b2

Browse files
Ivy ZhangLucien0
authored andcommitted
[BYOC-DNNL] enable conv3d->bn folding (apache#10837)
* support conv3d bn folding * add test case for fold_scale_axis * modify lint * remove test cases * unify conv2d 3d impls, and add test cases.
1 parent a7b2db8 commit f6c39b2

File tree

4 files changed

+272
-39
lines changed

4 files changed

+272
-39
lines changed

src/relay/transforms/fold_scale_axis.cc

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/relay/transform.h>
3030
#include <tvm/tir/data_layout.h>
3131

32+
#include "../backend/utils.h"
3233
#include "../op/tensor/transform.h"
3334
#include "pass_utils.h"
3435
#include "pattern_utils.h"
@@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
492493
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
493494

494495
// Consumer operators
495-
// Conv2D send out requirement of axis folding.
496-
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
496+
// Conv send out requirement of axis folding.
497+
template <typename ATTRS>
498+
Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
497499
// TODO(tvm-team) support general data layout
498500
// by transforming weight
499-
const auto* param = call->attrs.as<Conv2DAttrs>();
500501
ICHECK(param != nullptr);
501502
Layout data_layout(param->data_layout);
502503
Layout kernel_layout(param->kernel_layout);
@@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
512513
//
513514
// only handle depthwise or full conv2d.
514515
// TODO(tvm-team) handle grouped conv by reshape + bcast
515-
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
516-
if (param->groups == 1 || is_depthwise_conv2d) {
516+
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
517+
if (param->groups == 1 || is_depthwise_conv) {
517518
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
518519
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
519520
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
@@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
529530
}
530531

531532
// Conv2D consumes the scale axis during transformation.
532-
Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
533-
const Message& message) {
533+
template <typename ATTRS>
534+
Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
535+
const Message& message) {
534536
// if data do not have scale, normal transform path.
535537
const auto* sdata = new_args[0].as<ScaledExprNode>();
536538
const auto* sweight = new_args[1].as<ScaledExprNode>();
537539
if (sdata == nullptr) return Expr();
538540
if (sweight != nullptr) return Expr();
539-
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
540541
ICHECK(param != nullptr);
541542
Layout data_layout(param->data_layout);
542543
Layout kernel_layout(param->kernel_layout);
@@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
552553
ICHECK(is_simple || is_blocking);
553554

554555
// Check it must be depthwise or full conv2d.
555-
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
556-
ICHECK(param->groups == 1 || is_depthwise_conv2d);
556+
bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
557+
ICHECK(param->groups == 1 || is_depthwise_conv);
557558

558559
Expr weight = new_args[1];
559560

560561
// match the ic_axis
561-
if (is_depthwise_conv2d) {
562+
if (is_depthwise_conv) {
562563
if (is_simple) {
563564
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
564565
weight = Multiply(weight, scale);
@@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
580581
if (!weight.defined()) return Expr();
581582
}
582583
}
583-
// return transformed conv2d
584+
// return transformed conv
584585
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
585586
}
586587

587-
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
588+
Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
589+
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
590+
const auto* param = call->attrs.as<Conv2DAttrs>();
591+
return ConvForwardPrep(call, param, out_message);
592+
}
593+
const auto* param = call->attrs.as<Conv3DAttrs>();
594+
return ConvForwardPrep(call, param, out_message);
595+
}
596+
597+
Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
598+
const Message& message) {
599+
if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
600+
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
601+
return ConvForwardRewrite(ref_call, param, new_args, message);
602+
}
603+
const auto* param = ref_call->attrs.as<Conv3DAttrs>();
604+
return ConvForwardRewrite(ref_call, param, new_args, message);
605+
}
606+
607+
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
588608

589609
RELAY_REGISTER_OP("nn.conv2d")
590-
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
610+
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
611+
612+
RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
613+
614+
RELAY_REGISTER_OP("nn.conv3d")
615+
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
591616

592617
// Dense send out requirement of axis folding.
593618
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
@@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
937962
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
938963

939964
// Consumer operators
940-
// Conv2D send out requirement of axis folding.
941-
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
942-
const auto* param = call->attrs.as<Conv2DAttrs>();
965+
// Conv send out requirement of axis folding.
966+
template <typename ATTRS>
967+
Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
943968
ICHECK(param != nullptr);
944969
Layout kernel_layout(param->kernel_layout);
945970
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
952977
// By using a unified layout transformation.
953978
// We only need to change the Prep and Mutate function.
954979
//
955-
// only handle depthwise or full conv2d.
980+
// only handle depthwise or full conv.
956981
// TODO(tvm-team) handle grouped conv by reshape + bcast
957-
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
958-
if (param->groups == 1 || is_depthwise_conv2d) {
982+
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
983+
if (param->groups == 1 || is_depthwise_conv) {
959984
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
960985
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
961986
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
@@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
970995
return NullValue<Message>();
971996
}
972997

973-
// Conv2D consumes the scale axis during transformation.
974-
Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
975-
const BackwardTransformer& transformer) {
998+
// Conv consumes the scale axis during transformation.
999+
template <typename ATTRS>
1000+
Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
1001+
const Expr& scale, const BackwardTransformer& transformer) {
9761002
if (!message.defined()) {
9771003
return transformer->NormalCallTransform(call.operator->());
9781004
}
979-
const auto* param = call->attrs.as<Conv2DAttrs>();
9801005
ICHECK(param != nullptr);
9811006
Layout kernel_layout(param->kernel_layout);
9821007
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
9881013
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
9891014
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
9901015
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
991-
// Check it must be depthwise or full conv2d.
992-
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
993-
ICHECK(param->groups == 1 || is_depthwise_conv2d);
1016+
// Check it must be depthwise or full conv.
1017+
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
1018+
ICHECK(param->groups == 1 || is_depthwise_conv);
9941019
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
9951020
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
9961021
ICHECK(is_simple || is_blocking);
@@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
10121037
return Call(call->op, {data, weight}, call->attrs, call->type_args);
10131038
}
10141039

1040+
Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
1041+
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
1042+
const auto* param = call->attrs.as<Conv2DAttrs>();
1043+
return ConvBackwardPrep(call, param, in_messages);
1044+
}
1045+
const auto* param = call->attrs.as<Conv3DAttrs>();
1046+
return ConvBackwardPrep(call, param, in_messages);
1047+
}
1048+
1049+
Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
1050+
const BackwardTransformer& transformer) {
1051+
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
1052+
const auto* param = call->attrs.as<Conv2DAttrs>();
1053+
return ConvBackwardTransform(call, param, message, scale, transformer);
1054+
}
1055+
const auto* param = call->attrs.as<Conv3DAttrs>();
1056+
return ConvBackwardTransform(call, param, message, scale, transformer);
1057+
}
1058+
10151059
RELAY_REGISTER_OP("nn.conv2d")
1016-
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
1060+
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
10171061

10181062
RELAY_REGISTER_OP("nn.conv2d")
1019-
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
1063+
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
1064+
1065+
RELAY_REGISTER_OP("nn.conv3d")
1066+
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
1067+
1068+
RELAY_REGISTER_OP("nn.conv3d")
1069+
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
10201070

10211071
Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
10221072
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();

src/relay/transforms/pattern_utils.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <utility>
4545
#include <vector>
4646

47+
#include "../backend/utils.h"
4748
#include "../op/make_op.h"
4849

4950
namespace tvm {
@@ -183,16 +184,17 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Intege
183184
}
184185

185186
/*!
186-
* \brief Check if the call is depthwise conv2d.
187+
* \brief Check if the call is depthwise conv3d.
187188
*
188-
* \param call The conv2d call.
189-
* \param param The conv2d attributes.
190-
* \return Whether it is depthwise_conv2d.
189+
* \param call The conv call.
190+
* \param param The conv attributes.
191+
* \return Whether it is depthwise_conv3d.
191192
*/
192-
inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
193-
const Layout& kernel_layout) {
194-
static const Layout kOIHW("OIHW");
195-
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
193+
template <typename ATTRS>
194+
inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
195+
static const Layout kOIXX =
196+
backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
197+
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
196198
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
197199
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
198200
}

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
157157
{"IODHW8i8o", tag::any},
158158
{"ODHWI8o", tag::Odhwi8o},
159159
{"ODHWI16o", tag::Odhwi16o},
160+
{"ODHWI32o", tag::Odhwi32o},
161+
{"ODHWI48o", tag::Odhwi48o},
162+
{"ODHWI64o", tag::Odhwi64o},
160163
};
161164

162165
bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
@@ -342,7 +345,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
342345

343346
if (layout_dict.find(kernel_layout) == layout_dict.end()) {
344347
layout_dict.insert({kernel_layout, tag::any});
345-
LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout
348+
LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
346349
<< ", transfer to tag::any";
347350
}
348351

@@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
382385
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
383386
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
384387

385-
// Covn2d description.
388+
// Conv description.
386389
auto conv_desc =
387390
has_bias ? dnnl::convolution_forward::desc(
388391
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,

0 commit comments

Comments
 (0)