Skip to content

Commit 2f02b1e

Browse files
authored
support Torch all and any op (#9185)
1 parent b9f2284 commit 2f02b1e

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# pylint: disable=missing-function-docstring
2121
"""PT: PyTorch frontend."""
2222
import itertools
23+
import functools
2324
import logging
2425
import math
2526
import sys
@@ -2763,6 +2764,16 @@ def lstm(self, inputs, input_types):
27632764

27642765
return (output, _op.stack(hy, 0), _op.stack(cy, 0))
27652766

2767+
def all_any_common(self, op, inputs, input_types):
2768+
dim = inputs[1]
2769+
keepdim = inputs[2]
2770+
if self.infer_type(inputs[0]).dtype != "bool":
2771+
# The input dtype can be uint8.
2772+
inp = _op.cast(inputs[0], "bool")
2773+
else:
2774+
inp = inputs[0]
2775+
return op(inp, axis=dim, keepdims=keepdim)
2776+
27662777
# Operator mappings
27672778
def create_convert_map(self):
27682779
self.convert_map = {
@@ -2986,6 +2997,8 @@ def create_convert_map(self):
29862997
"aten::flip": self.flip,
29872998
"aten::gru": self.gru,
29882999
"aten::lstm": self.lstm,
3000+
"aten::all": functools.partial(self.all_any_common, _op.all),
3001+
"aten::any": functools.partial(self.all_any_common, _op.any),
29893002
}
29903003

29913004
def update_convert_map(self, custom_map):

tests/python/frontend/pytorch/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,5 +3948,17 @@ def test_annotate_span():
39483948
relay.transform.AnnotateSpans()(mod)
39493949

39503950

3951+
@tvm.testing.uses_gpu
3952+
def test_all_any():
3953+
def test_fn(f, dim=None, keepdim=False):
3954+
return lambda x: f(x, dim=dim, keepdim=keepdim)
3955+
3956+
for f in [torch.all, torch.any]:
3957+
verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])
3958+
verify_model(test_fn(f, 0), [torch.arange(0, 3).to(torch.uint8)])
3959+
verify_model(test_fn(f, 1), [torch.rand(4, 2).bool()])
3960+
verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()])
3961+
3962+
39513963
if __name__ == "__main__":
39523964
pytest.main([__file__])

0 commit comments

Comments
 (0)