Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,12 @@ def _argsort(self, node: fx.Node) -> relax.Var:
descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)
return self.block_builder.emit(relax.op.argsort(x, dim, descending))

def _broadcast_to(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
shape = args[1] if len(args) > 1 else args[0]
return self.block_builder.emit(relax.op.broadcast_to(x, shape))

def _cat(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,13 @@ def _flatten_module(self, node: fx.Node) -> relax.Var:
end_dim = module.end_dim
return self._flatten_impl(x, start_dim, end_dim)

def _narrow(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
start = node.args[2]
length = node.args[3]
return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length]))

def _numel(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
shape = self.shape_of(x)
Expand Down Expand Up @@ -755,6 +762,7 @@ def create_convert_map(
"where": self._where,
# tensor manipulation
"argsort": self._argsort,
"broadcast_to": self._broadcast_to,
"cat": self._cat,
"chunk": self._chunk,
"concat": self._cat,
Expand All @@ -766,6 +774,7 @@ def create_convert_map(
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
"narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
"repeat": self._repeat,
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4430,5 +4430,48 @@ def main(
verify_model(Topk(), [([5, 3], "float32")], {}, Expected)


def test_broadcast_to():
class BroadcastTo(Module):
def forward(self, x):
return torch.broadcast_to(x, (5, 3))

@tvm.script.ir_module
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 1), dtype="float32"),
) -> R.Tensor((5, 3), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, (5, 3))
gv: R.Tensor((5, 3), dtype="float32") = lv
R.output(gv)
return gv

verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected)


def test_narrow():
class Narrow(Module):
def forward(self, x):
return torch.narrow(x, 1, 0, 2)

@tvm.script.ir_module
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
) -> R.Tensor((5, 2), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
inp_0, axes=[1], begin=[0], end=[2]
)
gv: R.Tensor((5, 2), dtype="float32") = lv
R.output(gv)

return gv

verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading