@@ -788,33 +788,52 @@ def _conv1d(self, node: fx.node.Node) -> relax.Var:
788788 groups = module .groups ,
789789 )
790790
791- def _conv3d (self , node : fx .node .Node ) -> relax .Var :
791+ def _conv1d_functional (self , node : fx .node .Node ) -> relax .Var :
792+ args = self .retrieve_args (node )
793+ x = args [0 ]
794+ weight = args [1 ]
795+ bias = args [2 ] if len (args ) > 2 else None
796+ stride = args [3 ] if len (args ) > 3 else 1
797+ padding = args [4 ] if len (args ) > 4 else 0
798+ dilation = args [5 ] if len (args ) > 5 else 1
799+ groups = args [6 ] if len (args ) > 6 else 1
800+ return self ._conv1d_impl (
801+ x ,
802+ weight ,
803+ bias = bias ,
804+ strides = stride ,
805+ padding = padding ,
806+ dilation = dilation ,
807+ groups = groups ,
808+ )
809+
810+ def _conv1d_transpose (self , node : fx .node .Node ) -> relax .Var :
792811 x = self .env [node .args [0 ]]
793812 module = self .named_modules [node .target ]
794813 weight = self .params [module .weight ]
795814
796- conv3d = self .block_builder .emit (
797- relax .op .nn .conv3d (
815+ conv1d_transpose = self .block_builder .emit (
816+ relax .op .nn .conv1d_transpose (
798817 x ,
799818 weight ,
800819 strides = module .stride ,
801820 padding = module .padding ,
802821 dilation = module .dilation ,
803822 groups = module .groups ,
804- data_layout = "NCDHW " ,
805- kernel_layout = "OIDHW " ,
823+ data_layout = "NCW " ,
824+ kernel_layout = "OIW " ,
806825 out_dtype = "float32" ,
807826 )
808827 )
809828
810829 if module .bias is None :
811- return conv3d
830+ return conv1d_transpose
812831
813832 bias = self .params [module .bias ]
814833 assert len (self .shape_of (bias )) == 1
815- bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 , 1 ))
834+ bias = relax .op .reshape (bias , (1 , - 1 , 1 ))
816835
817- return self .block_builder .emit (relax .op .add (conv3d , bias ))
836+ return self .block_builder .emit (relax .op .add (conv1d_transpose , bias ))
818837
819838 def _conv2d_impl (
820839 self ,
@@ -846,7 +865,25 @@ def _conv2d_impl(
846865 bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 ))
847866 return self .block_builder .emit (relax .op .add (conv2d , bias ))
848867
849- def _conv1d_functional (self , node : fx .node .Node ) -> relax .Var :
868+ def _conv2d (self , node : fx .node .Node ) -> relax .Var :
869+ x = self .env [node .args [0 ]]
870+ module = self .named_modules [node .target ]
871+ weight = self .params [module .weight ]
872+ bias = None
873+ if module .bias is not None :
874+ bias = self .params [module .bias ]
875+
876+ return self ._conv2d_impl (
877+ x ,
878+ weight ,
879+ bias = bias ,
880+ strides = module .stride ,
881+ padding = module .padding ,
882+ dilation = module .dilation ,
883+ groups = module .groups ,
884+ )
885+
886+ def _conv2d_functional (self , node : fx .node .Node ) -> relax .Var :
850887 args = self .retrieve_args (node )
851888 x = args [0 ]
852889 weight = args [1 ]
@@ -855,7 +892,7 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
855892 padding = args [4 ] if len (args ) > 4 else 0
856893 dilation = args [5 ] if len (args ) > 5 else 1
857894 groups = args [6 ] if len (args ) > 6 else 1
858- return self ._conv1d_impl (
895+ return self ._conv2d_impl (
859896 x ,
860897 weight ,
861898 bias = bias ,
@@ -865,98 +902,61 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
865902 groups = groups ,
866903 )
867904
868- def _conv1d_transpose (self , node : fx .node .Node ) -> relax .Var :
905+ def _conv2d_transpose (self , node : fx .node .Node ) -> relax .Var :
869906 x = self .env [node .args [0 ]]
870907 module = self .named_modules [node .target ]
871908 weight = self .params [module .weight ]
872909
873- conv1d_transpose = self .block_builder .emit (
874- relax .op .nn .conv1d_transpose (
910+ conv2d_transpose = self .block_builder .emit (
911+ relax .op .nn .conv2d_transpose (
875912 x ,
876913 weight ,
877914 strides = module .stride ,
878915 padding = module .padding ,
879916 dilation = module .dilation ,
880917 groups = module .groups ,
881- data_layout = "NCW " ,
882- kernel_layout = "OIW " ,
918+ data_layout = "NCHW " ,
919+ kernel_layout = "OIHW " ,
883920 out_dtype = "float32" ,
884921 )
885922 )
886923
887924 if module .bias is None :
888- return conv1d_transpose
925+ return conv2d_transpose
889926
890927 bias = self .params [module .bias ]
891928 assert len (self .shape_of (bias )) == 1
892- bias = relax .op .reshape (bias , (1 , - 1 , 1 ))
929+ bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 ))
893930
894- return self .block_builder .emit (relax .op .add (conv1d_transpose , bias ))
931+ return self .block_builder .emit (relax .op .add (conv2d_transpose , bias ))
895932
896- def _conv2d_transpose (self , node : fx .node .Node ) -> relax .Var :
933+ def _conv3d (self , node : fx .node .Node ) -> relax .Var :
897934 x = self .env [node .args [0 ]]
898935 module = self .named_modules [node .target ]
899936 weight = self .params [module .weight ]
900937
901- conv2d_transpose = self .block_builder .emit (
902- relax .op .nn .conv2d_transpose (
938+ conv3d = self .block_builder .emit (
939+ relax .op .nn .conv3d (
903940 x ,
904941 weight ,
905942 strides = module .stride ,
906943 padding = module .padding ,
907944 dilation = module .dilation ,
908945 groups = module .groups ,
909- data_layout = "NCHW " ,
910- kernel_layout = "OIHW " ,
946+ data_layout = "NCDHW " ,
947+ kernel_layout = "OIDHW " ,
911948 out_dtype = "float32" ,
912949 )
913950 )
914951
915952 if module .bias is None :
916- return conv2d_transpose
953+ return conv3d
917954
918955 bias = self .params [module .bias ]
919956 assert len (self .shape_of (bias )) == 1
920- bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 ))
921-
922- return self .block_builder .emit (relax .op .add (conv2d_transpose , bias ))
923-
924- def _conv2d (self , node : fx .node .Node ) -> relax .Var :
925- x = self .env [node .args [0 ]]
926- module = self .named_modules [node .target ]
927- weight = self .params [module .weight ]
928- bias = None
929- if module .bias is not None :
930- bias = self .params [module .bias ]
931-
932- return self ._conv2d_impl (
933- x ,
934- weight ,
935- bias = bias ,
936- strides = module .stride ,
937- padding = module .padding ,
938- dilation = module .dilation ,
939- groups = module .groups ,
940- )
957+ bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 , 1 ))
941958
942- def _conv2d_functional (self , node : fx .node .Node ) -> relax .Var :
943- args = self .retrieve_args (node )
944- x = args [0 ]
945- weight = args [1 ]
946- bias = args [2 ] if len (args ) > 2 else None
947- stride = args [3 ] if len (args ) > 3 else 1
948- padding = args [4 ] if len (args ) > 4 else 0
949- dilation = args [5 ] if len (args ) > 5 else 1
950- groups = args [6 ] if len (args ) > 6 else 1
951- return self ._conv2d_impl (
952- x ,
953- weight ,
954- bias = bias ,
955- strides = stride ,
956- padding = padding ,
957- dilation = dilation ,
958- groups = groups ,
959- )
959+ return self .block_builder .emit (relax .op .add (conv3d , bias ))
960960
961961 def _max_pool2d (self , node : fx .node .Node ) -> relax .Var :
962962 x = self .env [node .args [0 ]]
0 commit comments