2929#include < tvm/relay/transform.h>
3030#include < tvm/tir/data_layout.h>
3131
32+ #include " ../backend/utils.h"
3233#include " ../op/tensor/transform.h"
3334#include " pass_utils.h"
3435#include " pattern_utils.h"
@@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
492493 .set_attr<FForwardRewrite>(" FScaleAxisForwardRewrite" , MultiplyForwardRewrite);
493494
494495// Consumer operators
495- // Conv2D send out requirement of axis folding.
496- Array<Message> Conv2DForwardPrep (const Call& call, const Message& out_message) {
496+ // Conv send out requirement of axis folding.
497+ template <typename ATTRS>
498+ Array<Message> ConvForwardPrep (const Call& call, const ATTRS* param, const Message& out_message) {
497499 // TODO(tvm-team) support general data layout
498500 // by transforming weight
499- const auto * param = call->attrs .as <Conv2DAttrs>();
500501 ICHECK (param != nullptr );
501502 Layout data_layout (param->data_layout );
502503 Layout kernel_layout (param->kernel_layout );
@@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
512513 //
513514 // only handle depthwise or full conv2d.
514515 // TODO(tvm-team) handle grouped conv by reshape + bcast
515- bool is_depthwise_conv2d = IsDepthwiseConv2D (call, param, kernel_layout);
516- if (param->groups == 1 || is_depthwise_conv2d ) {
516+ bool is_depthwise_conv = IsDepthwiseConv (call, param, kernel_layout);
517+ if (param->groups == 1 || is_depthwise_conv ) {
517518 auto ko_small_axis = kernel_layout.IndexOf (LayoutAxis::Get (' o' ));
518519 auto ki_small_axis = kernel_layout.IndexOf (LayoutAxis::Get (' i' ));
519520 if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0 ) || // simple layout
@@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
529530}
530531
531532// Conv2D consumes the scale axis during transformation.
532- Expr Conv2DForwardRewrite (const Call& ref_call, const Array<Expr>& new_args,
533- const Message& message) {
533+ template <typename ATTRS>
534+ Expr ConvForwardRewrite (const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
535+ const Message& message) {
534536 // if data do not have scale, normal transform path.
535537 const auto * sdata = new_args[0 ].as <ScaledExprNode>();
536538 const auto * sweight = new_args[1 ].as <ScaledExprNode>();
537539 if (sdata == nullptr ) return Expr ();
538540 if (sweight != nullptr ) return Expr ();
539- const auto * param = ref_call->attrs .as <Conv2DAttrs>();
540541 ICHECK (param != nullptr );
541542 Layout data_layout (param->data_layout );
542543 Layout kernel_layout (param->kernel_layout );
@@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
552553 ICHECK (is_simple || is_blocking);
553554
554555 // Check it must be depthwise or full conv2d.
555- bool is_depthwise_conv2d = IsDepthwiseConv2D (ref_call, param, kernel_layout);
556- ICHECK (param->groups == 1 || is_depthwise_conv2d );
556+ bool is_depthwise_conv = IsDepthwiseConv (ref_call, param, kernel_layout);
557+ ICHECK (param->groups == 1 || is_depthwise_conv );
557558
558559 Expr weight = new_args[1 ];
559560
560561 // match the ic_axis
561- if (is_depthwise_conv2d ) {
562+ if (is_depthwise_conv ) {
562563 if (is_simple) {
563564 Expr scale = ExpandBiasToMatchAxis (sdata->scale , kernel_layout.ndim (), {big_ko_axis});
564565 weight = Multiply (weight, scale);
@@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
580581 if (!weight.defined ()) return Expr ();
581582 }
582583 }
583- // return transformed conv2d
584+ // return transformed conv
584585 return Call (ref_call->op , {sdata->value , weight}, ref_call->attrs , ref_call->type_args );
585586}
586587
587- RELAY_REGISTER_OP (" nn.conv2d" ).set_attr<FForwardPrep>(" FScaleAxisForwardPrep" , Conv2DForwardPrep);
588+ Array<Message> PreConvForwardPrep (const Call& call, const Message& out_message) {
589+ if (backend::IsOp (call.as <CallNode>(), " nn.conv2d" )) {
590+ const auto * param = call->attrs .as <Conv2DAttrs>();
591+ return ConvForwardPrep (call, param, out_message);
592+ }
593+ const auto * param = call->attrs .as <Conv3DAttrs>();
594+ return ConvForwardPrep (call, param, out_message);
595+ }
596+
597+ Expr PreConvForwardRewrite (const Call& ref_call, const Array<Expr>& new_args,
598+ const Message& message) {
599+ if (backend::IsOp (ref_call.as <CallNode>(), " nn.conv2d" )) {
600+ const auto * param = ref_call->attrs .as <Conv2DAttrs>();
601+ return ConvForwardRewrite (ref_call, param, new_args, message);
602+ }
603+ const auto * param = ref_call->attrs .as <Conv3DAttrs>();
604+ return ConvForwardRewrite (ref_call, param, new_args, message);
605+ }
606+
607+ RELAY_REGISTER_OP (" nn.conv2d" ).set_attr<FForwardPrep>(" FScaleAxisForwardPrep" , PreConvForwardPrep);
588608
589609RELAY_REGISTER_OP (" nn.conv2d" )
590- .set_attr<FForwardRewrite>(" FScaleAxisForwardRewrite" , Conv2DForwardRewrite);
610+ .set_attr<FForwardRewrite>(" FScaleAxisForwardRewrite" , PreConvForwardRewrite);
611+
612+ RELAY_REGISTER_OP (" nn.conv3d" ).set_attr<FForwardPrep>(" FScaleAxisForwardPrep" , PreConvForwardPrep);
613+
614+ RELAY_REGISTER_OP (" nn.conv3d" )
615+ .set_attr<FForwardRewrite>(" FScaleAxisForwardRewrite" , PreConvForwardRewrite);
591616
592617// Dense send out requirement of axis folding.
593618Array<Message> DenseForwardPrep (const Call& call, const Message& out_message) {
@@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
937962 .set_attr<FBackwardTransform>(" FScaleAxisBackwardTransform" , MultiplyBackwardTransform);
938963
939964// Consumer operators
940- // Conv2D send out requirement of axis folding.
941- Message Conv2DBackwardPrep ( const Call& call, const Array<Message>& in_messages) {
942- const auto * param = call-> attrs . as <Conv2DAttrs>();
965+ // Conv send out requirement of axis folding.
966+ template < typename ATTRS>
967+ Message ConvBackwardPrep ( const Call& call, const ATTRS * param, const Array<Message>& in_messages) {
943968 ICHECK (param != nullptr );
944969 Layout kernel_layout (param->kernel_layout );
945970 Layout out_layout (param->out_layout == " " ? param->data_layout : param->out_layout );
@@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
952977 // By using a unified layout transformation.
953978 // We only need to change the Prep and Mutate function.
954979 //
955- // only handle depthwise or full conv2d .
980+ // only handle depthwise or full conv .
956981 // TODO(tvm-team) handle grouped conv by reshape + bcast
957- bool is_depthwise_conv2d = IsDepthwiseConv2D (call, param, kernel_layout);
958- if (param->groups == 1 || is_depthwise_conv2d ) {
982+ bool is_depthwise_conv = IsDepthwiseConv (call, param, kernel_layout);
983+ if (param->groups == 1 || is_depthwise_conv ) {
959984 auto ko_small_axis = kernel_layout.IndexOf (LayoutAxis::Get (' o' ));
960985 auto ki_small_axis = kernel_layout.IndexOf (LayoutAxis::Get (' i' ));
961986 if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0 ) || // simple layout
@@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
970995 return NullValue<Message>();
971996}
972997
973- // Conv2D consumes the scale axis during transformation.
974- Expr Conv2DBackwardTransform (const Call& call, const Message& message, const Expr& scale,
975- const BackwardTransformer& transformer) {
998+ // Conv consumes the scale axis during transformation.
999+ template <typename ATTRS>
1000+ Expr ConvBackwardTransform (const Call& call, const ATTRS* param, const Message& message,
1001+ const Expr& scale, const BackwardTransformer& transformer) {
9761002 if (!message.defined ()) {
9771003 return transformer->NormalCallTransform (call.operator ->());
9781004 }
979- const auto * param = call->attrs .as <Conv2DAttrs>();
9801005 ICHECK (param != nullptr );
9811006 Layout kernel_layout (param->kernel_layout );
9821007 Layout out_layout (param->out_layout == " " ? param->data_layout : param->out_layout );
@@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
9881013 int small_ki_axis = kernel_layout.IndexOf (LayoutAxis::Get (' i' ));
9891014 int big_ki_axis = kernel_layout.IndexOf (LayoutAxis::Get (' I' ));
9901015 int big_ko_axis = kernel_layout.IndexOf (LayoutAxis::Get (' O' ));
991- // Check it must be depthwise or full conv2d .
992- bool is_depthwise_conv2d = IsDepthwiseConv2D (call, param, kernel_layout);
993- ICHECK (param->groups == 1 || is_depthwise_conv2d );
1016+ // Check it must be depthwise or full conv .
1017+ bool is_depthwise_conv = IsDepthwiseConv (call, param, kernel_layout);
1018+ ICHECK (param->groups == 1 || is_depthwise_conv );
9941019 bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0 );
9951020 bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0 );
9961021 ICHECK (is_simple || is_blocking);
@@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
10121037 return Call (call->op , {data, weight}, call->attrs , call->type_args );
10131038}
10141039
1040+ Message PreConvBackwardPrep (const Call& call, const Array<Message>& in_messages) {
1041+ if (backend::IsOp (call.as <CallNode>(), " nn.conv2d" )) {
1042+ const auto * param = call->attrs .as <Conv2DAttrs>();
1043+ return ConvBackwardPrep (call, param, in_messages);
1044+ }
1045+ const auto * param = call->attrs .as <Conv3DAttrs>();
1046+ return ConvBackwardPrep (call, param, in_messages);
1047+ }
1048+
1049+ Expr PreConvBackwardTransform (const Call& call, const Message& message, const Expr& scale,
1050+ const BackwardTransformer& transformer) {
1051+ if (backend::IsOp (call.as <CallNode>(), " nn.conv2d" )) {
1052+ const auto * param = call->attrs .as <Conv2DAttrs>();
1053+ return ConvBackwardTransform (call, param, message, scale, transformer);
1054+ }
1055+ const auto * param = call->attrs .as <Conv3DAttrs>();
1056+ return ConvBackwardTransform (call, param, message, scale, transformer);
1057+ }
1058+
10151059RELAY_REGISTER_OP (" nn.conv2d" )
1016- .set_attr<FBackwardPrep>(" FScaleAxisBackwardPrep" , Conv2DBackwardPrep );
1060+ .set_attr<FBackwardPrep>(" FScaleAxisBackwardPrep" , PreConvBackwardPrep );
10171061
10181062RELAY_REGISTER_OP (" nn.conv2d" )
1019- .set_attr<FBackwardTransform>(" FScaleAxisBackwardTransform" , Conv2DBackwardTransform);
1063+ .set_attr<FBackwardTransform>(" FScaleAxisBackwardTransform" , PreConvBackwardTransform);
1064+
1065+ RELAY_REGISTER_OP (" nn.conv3d" )
1066+ .set_attr<FBackwardPrep>(" FScaleAxisBackwardPrep" , PreConvBackwardPrep);
1067+
1068+ RELAY_REGISTER_OP (" nn.conv3d" )
1069+ .set_attr<FBackwardTransform>(" FScaleAxisBackwardTransform" , PreConvBackwardTransform);
10201070
10211071Message BiasAddBackwardPrep (const Call& call, const Array<Message>& in_messages) {
10221072 const BiasAddAttrs* attrs = call->attrs .as <BiasAddAttrs>();
0 commit comments