diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1b86b120dfcc..30f14b490b1b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2323,8 +2323,19 @@ def one_hot(self, inputs, input_types): def index(self, inputs, input_types): data = inputs[0] - indices = inputs[1] - return _op.adv_index([data] + indices) + indices_list = [] + + for indices in inputs[1]: + if self.infer_type(indices).dtype == "bool": + # adv_index does not support a mask as the index tensor (it will treat 0/1 as + # an index rather than a flag). + # So we use argwhere to turn the mask into indices, which will also take care + # of the dynamism in the indexing by mask. + indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1])) + else: + indices_list.append(indices) + + return _op.adv_index([data] + indices_list) def meshgrid(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8045635127bb..36bb5bede475 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4034,6 +4034,14 @@ def forward(self, x): input_data = torch.rand(input_shape).float() verify_model(Index1().eval(), input_data=input_data) + def test_fn_bool_mask(): + return lambda data, mask: data[0, mask] + + data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + mask = torch.tensor([True, True, False]) + + verify_trace_model(test_fn_bool_mask(), [data, mask], ["llvm", "cuda"]) + def test_logsumexp(): """test_logsumexp"""