Skip to content

Commit ba14bd1

Browse files
committed
cleanup conv1d
1 parent aa31762 commit ba14bd1

File tree

1 file changed

+63
-63
lines changed

1 file changed

+63
-63
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -788,33 +788,52 @@ def _conv1d(self, node: fx.node.Node) -> relax.Var:
788788
groups=module.groups,
789789
)
790790

791-
def _conv3d(self, node: fx.node.Node) -> relax.Var:
791+
def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
792+
args = self.retrieve_args(node)
793+
x = args[0]
794+
weight = args[1]
795+
bias = args[2] if len(args) > 2 else None
796+
stride = args[3] if len(args) > 3 else 1
797+
padding = args[4] if len(args) > 4 else 0
798+
dilation = args[5] if len(args) > 5 else 1
799+
groups = args[6] if len(args) > 6 else 1
800+
return self._conv1d_impl(
801+
x,
802+
weight,
803+
bias=bias,
804+
strides=stride,
805+
padding=padding,
806+
dilation=dilation,
807+
groups=groups,
808+
)
809+
810+
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
792811
x = self.env[node.args[0]]
793812
module = self.named_modules[node.target]
794813
weight = self.params[module.weight]
795814

796-
conv3d = self.block_builder.emit(
797-
relax.op.nn.conv3d(
815+
conv1d_transpose = self.block_builder.emit(
816+
relax.op.nn.conv1d_transpose(
798817
x,
799818
weight,
800819
strides=module.stride,
801820
padding=module.padding,
802821
dilation=module.dilation,
803822
groups=module.groups,
804-
data_layout="NCDHW",
805-
kernel_layout="OIDHW",
823+
data_layout="NCW",
824+
kernel_layout="OIW",
806825
out_dtype="float32",
807826
)
808827
)
809828

810829
if module.bias is None:
811-
return conv3d
830+
return conv1d_transpose
812831

813832
bias = self.params[module.bias]
814833
assert len(self.shape_of(bias)) == 1
815-
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
834+
bias = relax.op.reshape(bias, (1, -1, 1))
816835

817-
return self.block_builder.emit(relax.op.add(conv3d, bias))
836+
return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
818837

819838
def _conv2d_impl(
820839
self,
@@ -846,7 +865,25 @@ def _conv2d_impl(
846865
bias = relax.op.reshape(bias, (1, -1, 1, 1))
847866
return self.block_builder.emit(relax.op.add(conv2d, bias))
848867

849-
def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
868+
def _conv2d(self, node: fx.node.Node) -> relax.Var:
869+
x = self.env[node.args[0]]
870+
module = self.named_modules[node.target]
871+
weight = self.params[module.weight]
872+
bias = None
873+
if module.bias is not None:
874+
bias = self.params[module.bias]
875+
876+
return self._conv2d_impl(
877+
x,
878+
weight,
879+
bias=bias,
880+
strides=module.stride,
881+
padding=module.padding,
882+
dilation=module.dilation,
883+
groups=module.groups,
884+
)
885+
886+
def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
850887
args = self.retrieve_args(node)
851888
x = args[0]
852889
weight = args[1]
@@ -855,7 +892,7 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
855892
padding = args[4] if len(args) > 4 else 0
856893
dilation = args[5] if len(args) > 5 else 1
857894
groups = args[6] if len(args) > 6 else 1
858-
return self._conv1d_impl(
895+
return self._conv2d_impl(
859896
x,
860897
weight,
861898
bias=bias,
@@ -865,98 +902,61 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
865902
groups=groups,
866903
)
867904

868-
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
905+
def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
869906
x = self.env[node.args[0]]
870907
module = self.named_modules[node.target]
871908
weight = self.params[module.weight]
872909

873-
conv1d_transpose = self.block_builder.emit(
874-
relax.op.nn.conv1d_transpose(
910+
conv2d_transpose = self.block_builder.emit(
911+
relax.op.nn.conv2d_transpose(
875912
x,
876913
weight,
877914
strides=module.stride,
878915
padding=module.padding,
879916
dilation=module.dilation,
880917
groups=module.groups,
881-
data_layout="NCW",
882-
kernel_layout="OIW",
918+
data_layout="NCHW",
919+
kernel_layout="OIHW",
883920
out_dtype="float32",
884921
)
885922
)
886923

887924
if module.bias is None:
888-
return conv1d_transpose
925+
return conv2d_transpose
889926

890927
bias = self.params[module.bias]
891928
assert len(self.shape_of(bias)) == 1
892-
bias = relax.op.reshape(bias, (1, -1, 1))
929+
bias = relax.op.reshape(bias, (1, -1, 1, 1))
893930

894-
return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
931+
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
895932

896-
def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
933+
def _conv3d(self, node: fx.node.Node) -> relax.Var:
897934
x = self.env[node.args[0]]
898935
module = self.named_modules[node.target]
899936
weight = self.params[module.weight]
900937

901-
conv2d_transpose = self.block_builder.emit(
902-
relax.op.nn.conv2d_transpose(
938+
conv3d = self.block_builder.emit(
939+
relax.op.nn.conv3d(
903940
x,
904941
weight,
905942
strides=module.stride,
906943
padding=module.padding,
907944
dilation=module.dilation,
908945
groups=module.groups,
909-
data_layout="NCHW",
910-
kernel_layout="OIHW",
946+
data_layout="NCDHW",
947+
kernel_layout="OIDHW",
911948
out_dtype="float32",
912949
)
913950
)
914951

915952
if module.bias is None:
916-
return conv2d_transpose
953+
return conv3d
917954

918955
bias = self.params[module.bias]
919956
assert len(self.shape_of(bias)) == 1
920-
bias = relax.op.reshape(bias, (1, -1, 1, 1))
921-
922-
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
923-
924-
def _conv2d(self, node: fx.node.Node) -> relax.Var:
925-
x = self.env[node.args[0]]
926-
module = self.named_modules[node.target]
927-
weight = self.params[module.weight]
928-
bias = None
929-
if module.bias is not None:
930-
bias = self.params[module.bias]
931-
932-
return self._conv2d_impl(
933-
x,
934-
weight,
935-
bias=bias,
936-
strides=module.stride,
937-
padding=module.padding,
938-
dilation=module.dilation,
939-
groups=module.groups,
940-
)
957+
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
941958

942-
def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
943-
args = self.retrieve_args(node)
944-
x = args[0]
945-
weight = args[1]
946-
bias = args[2] if len(args) > 2 else None
947-
stride = args[3] if len(args) > 3 else 1
948-
padding = args[4] if len(args) > 4 else 0
949-
dilation = args[5] if len(args) > 5 else 1
950-
groups = args[6] if len(args) > 6 else 1
951-
return self._conv2d_impl(
952-
x,
953-
weight,
954-
bias=bias,
955-
strides=stride,
956-
padding=padding,
957-
dilation=dilation,
958-
groups=groups,
959-
)
959+
return self.block_builder.emit(relax.op.add(conv3d, bias))
960960

961961
def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
962962
x = self.env[node.args[0]]

0 commit comments

Comments
 (0)