Skip to content

Commit aa31762

Browse files
committed
add support for functional conv1d
1 parent 8ca8df6 commit aa31762

File tree

1 file changed

+53
-13
lines changed

1 file changed

+53
-13
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -740,34 +740,54 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var:
740740
bias = args[2] if len(args) > 2 else None
741741
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))
742742

743-
def _conv1d(self, node: fx.node.Node) -> relax.Var:
744-
x = self.env[node.args[0]]
745-
module = self.named_modules[node.target]
746-
weight = self.params[module.weight]
747-
743+
def _conv1d_impl(
744+
self,
745+
x: relax.Expr,
746+
weight: relax.Expr,
747+
bias: Optional[relax.Expr],
748+
strides: Optional[Tuple],
749+
padding: Optional[Tuple],
750+
dilation: Optional[Tuple],
751+
groups: Optional[Tuple],
752+
) -> relax.Var:
748753
conv1d = self.block_builder.emit(
749754
relax.op.nn.conv1d(
750755
x,
751756
weight,
752-
strides=module.stride,
753-
padding=module.padding,
754-
dilation=module.dilation,
755-
groups=module.groups,
757+
strides=strides,
758+
padding=padding,
759+
dilation=dilation,
760+
groups=groups,
756761
data_layout="NCW",
757762
kernel_layout="OIW",
758763
out_dtype="float32",
759764
)
760765
)
761766

762-
if module.bias is None:
767+
if bias is None:
763768
return conv1d
764-
765-
bias = self.params[module.bias]
766769
assert len(self.shape_of(bias)) == 1
767770
bias = relax.op.reshape(bias, (1, -1, 1))
768-
769771
return self.block_builder.emit(relax.op.add(conv1d, bias))
770772

773+
def _conv1d(self, node: fx.node.Node) -> relax.Var:
774+
x = self.env[node.args[0]]
775+
module = self.named_modules[node.target]
776+
weight = self.params[module.weight]
777+
bias = None
778+
if module.bias is not None:
779+
bias = self.params[module.bias]
780+
781+
return self._conv1d_impl(
782+
x,
783+
weight,
784+
bias=bias,
785+
strides=module.stride,
786+
padding=module.padding,
787+
dilation=module.dilation,
788+
groups=module.groups,
789+
)
790+
771791
def _conv3d(self, node: fx.node.Node) -> relax.Var:
772792
x = self.env[node.args[0]]
773793
module = self.named_modules[node.target]
@@ -826,6 +846,25 @@ def _conv2d_impl(
826846
bias = relax.op.reshape(bias, (1, -1, 1, 1))
827847
return self.block_builder.emit(relax.op.add(conv2d, bias))
828848

849+
def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
850+
args = self.retrieve_args(node)
851+
x = args[0]
852+
weight = args[1]
853+
bias = args[2] if len(args) > 2 else None
854+
stride = args[3] if len(args) > 3 else 1
855+
padding = args[4] if len(args) > 4 else 0
856+
dilation = args[5] if len(args) > 5 else 1
857+
groups = args[6] if len(args) > 6 else 1
858+
return self._conv1d_impl(
859+
x,
860+
weight,
861+
bias=bias,
862+
strides=stride,
863+
padding=padding,
864+
dilation=dilation,
865+
groups=groups,
866+
)
867+
829868
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
830869
x = self.env[node.args[0]]
831870
module = self.named_modules[node.target]
@@ -1482,6 +1521,7 @@ def create_convert_map(self):
14821521
"type": self._type,
14831522
"astype": self._type,
14841523
"matmul": self._matmul,
1524+
"conv1d": self._conv1d_functional,
14851525
"conv2d": self._conv2d_functional,
14861526
"linear": self._linear_functional,
14871527
"addmm": self._addmm,

0 commit comments

Comments
 (0)