diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e6b39c3eee0e..093f3ae4cf7a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1476,6 +1476,7 @@ def create_convert_map(self): "getitem": self._getitem, "contiguous": lambda node: self.env[node.args[0]], "to": self._to, + "max_pool2d": self._max_pool2d, "avg_pool2d": self._avg_pool2d, "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), "layer_norm": self._layer_norm, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b4ac3fa60ce9..1a2cc5da6242 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -796,6 +796,13 @@ def __init__(self): def forward(self, input): return self.pool(input) + class MaxPool2d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1]) + @tvm.script.ir_module class expected1: @R.function @@ -876,6 +883,7 @@ def main( return gv verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d_functional(), input_info, {}, expected1) verify_model(MaxPool2d2(), input_info, {}, expected2) verify_model(MaxPool2d3(), input_info, {}, expected3)