@@ -740,34 +740,54 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var:
740740 bias = args [2 ] if len (args ) > 2 else None
741741 return self .block_builder .emit (relax .op .linear (x , weight , bias , "float32" ))
742742
743- def _conv1d (self , node : fx .node .Node ) -> relax .Var :
744- x = self .env [node .args [0 ]]
745- module = self .named_modules [node .target ]
746- weight = self .params [module .weight ]
747-
743+ def _conv1d_impl (
744+ self ,
745+ x : relax .Expr ,
746+ weight : relax .Expr ,
747+ bias : Optional [relax .Expr ],
748+ strides : Optional [Tuple ],
749+ padding : Optional [Tuple ],
750+ dilation : Optional [Tuple ],
751+ groups : Optional [Tuple ],
752+ ) -> relax .Var :
748753 conv1d = self .block_builder .emit (
749754 relax .op .nn .conv1d (
750755 x ,
751756 weight ,
752- strides = module . stride ,
753- padding = module . padding ,
754- dilation = module . dilation ,
755- groups = module . groups ,
757+ strides = strides ,
758+ padding = padding ,
759+ dilation = dilation ,
760+ groups = groups ,
756761 data_layout = "NCW" ,
757762 kernel_layout = "OIW" ,
758763 out_dtype = "float32" ,
759764 )
760765 )
761766
762- if module . bias is None :
767+ if bias is None :
763768 return conv1d
764-
765- bias = self .params [module .bias ]
766769 assert len (self .shape_of (bias )) == 1
767770 bias = relax .op .reshape (bias , (1 , - 1 , 1 ))
768-
769771 return self .block_builder .emit (relax .op .add (conv1d , bias ))
770772
773+ def _conv1d (self , node : fx .node .Node ) -> relax .Var :
774+ x = self .env [node .args [0 ]]
775+ module = self .named_modules [node .target ]
776+ weight = self .params [module .weight ]
777+ bias = None
778+ if module .bias is not None :
779+ bias = self .params [module .bias ]
780+
781+ return self ._conv1d_impl (
782+ x ,
783+ weight ,
784+ bias = bias ,
785+ strides = module .stride ,
786+ padding = module .padding ,
787+ dilation = module .dilation ,
788+ groups = module .groups ,
789+ )
790+
771791 def _conv3d (self , node : fx .node .Node ) -> relax .Var :
772792 x = self .env [node .args [0 ]]
773793 module = self .named_modules [node .target ]
@@ -826,6 +846,25 @@ def _conv2d_impl(
826846 bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 ))
827847 return self .block_builder .emit (relax .op .add (conv2d , bias ))
828848
849+ def _conv1d_functional (self , node : fx .node .Node ) -> relax .Var :
850+ args = self .retrieve_args (node )
851+ x = args [0 ]
852+ weight = args [1 ]
853+ bias = args [2 ] if len (args ) > 2 else None
854+ stride = args [3 ] if len (args ) > 3 else 1
855+ padding = args [4 ] if len (args ) > 4 else 0
856+ dilation = args [5 ] if len (args ) > 5 else 1
857+ groups = args [6 ] if len (args ) > 6 else 1
858+ return self ._conv1d_impl (
859+ x ,
860+ weight ,
861+ bias = bias ,
862+ strides = stride ,
863+ padding = padding ,
864+ dilation = dilation ,
865+ groups = groups ,
866+ )
867+
829868 def _conv1d_transpose (self , node : fx .node .Node ) -> relax .Var :
830869 x = self .env [node .args [0 ]]
831870 module = self .named_modules [node .target ]
@@ -1482,6 +1521,7 @@ def create_convert_map(self):
14821521 "type" : self ._type ,
14831522 "astype" : self ._type ,
14841523 "matmul" : self ._matmul ,
1524+ "conv1d" : self ._conv1d_functional ,
14851525 "conv2d" : self ._conv2d_functional ,
14861526 "linear" : self ._linear_functional ,
14871527 "addmm" : self ._addmm ,
0 commit comments