Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,6 +3019,7 @@ def _op_dispatch(cls, operator, inputs, attr, params):
op_map = {
"size": cls._size,
"arange": cls._arange,
"index_put": cls._index_put,
"reshape": cls._reshape,
"embedding_bag": cls._embedding_bag,
}
Expand All @@ -3038,6 +3039,47 @@ def _size(cls, inputs, attr, params):
def _arange(cls, inputs, attr, params):
return _op.arange(inputs[0], inputs[1], inputs[2], dtype="int64")

@classmethod
def _check_index(cls, indices, values):
def unfolding_indices(indices, values):
n = len(indices)
flatten_indices = []
slices_size = []
for index in indices:
flatten_indices.append(_op.reshape(index, _op.const([-1])))
slices_size.append(infer_shape(flatten_indices[-1])[0])
repeat_size = [1]
tile_size = [1]
for i in range(1, n):
repeat_size.append(slices_size[-i] * repeat_size[-1])
tile_size.append(slices_size[i - 1] * tile_size[-1])
repeat_size.reverse()
unflod_slices = []
for i in range(n):
unflod_slices.append(
fold_constant(
_op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0)
)
)
return unflod_slices, _op.reshape(values, _op.const([-1]))

values_shape = infer_shape(values)
if len(values_shape) != 1:
return unfolding_indices(indices, values)
return indices, values

@classmethod
def _index_put(cls, inputs, attr, params):
in_tensor = inputs[0]
indices, values = cls._check_index(inputs[1 : len(inputs) - 2], inputs[len(inputs) - 2])
accumulate = inputs[len(inputs) - 1].data.asnumpy() != 0
if not accumulate:
mode = "update"
else:
mode = "add"
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)

@classmethod
def _reshape(cls, inputs, attr, params):
return _op.reshape(inputs[0], inputs[1])
Expand Down
76 changes: 76 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5045,6 +5045,81 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
verify_embedding_bag(32, 2, [3, 3])


@tvm.testing.parametrize_targets
def test_index_put(target, dev):
class _index_put_model(torch.nn.Module):
def __init__(self, indices, values, accumulate):
super(_index_put_model, self).__init__()
self.indices = indices
self.values = values
self.accumulate = accumulate

def forward(self, x):
return x.index_put(self.indices, self.values, self.accumulate)

def _convert_to_onnx(model, dummy_data):
file_name = "{}.onnx".format("aten_model")
torch.onnx.export(
model,
dummy_data,
file_name,
export_params=True,
verbose=False,
opset_version=11,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(file_name)
return onnx_model

def verify_index_put(data_shape, indices, accumulate):
dummy_data = torch.ones(data_shape)
tvm_inputs = [dummy_data.numpy()]
values = torch.rand(indices[0].size())
model = _index_put_model(indices, values, accumulate)
onnx_model = _convert_to_onnx(model, dummy_data)
torch_out = model(dummy_data)

tvm_out = get_tvm_output_with_vm(
onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True
)
tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)

shape = (3, 5)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
verify_index_put(shape, [xidx, yidx], True)

shape = (3, 5, 3)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
verify_index_put(shape, [xidx, yidx, zidx], False)

def verify_index_put_slice(data_shape, value_shape, accumulate):
dummy_data = torch.ones(data_shape)
tvm_inputs = [dummy_data.numpy()]
indices = []
index_shape = [1] * len(value_shape)
index_shape[0] = -1
for i in range(len(value_shape)):
indices.append(torch.arange(0, value_shape[i]).reshape(tuple(index_shape)))
index_shape.pop()
values = torch.rand(value_shape)

model = _index_put_model(indices, values, accumulate)
onnx_model = _convert_to_onnx(model, dummy_data)
torch_out = model(dummy_data)

tvm_out = get_tvm_output_with_vm(
onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True
)
tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)

verify_index_put_slice((3, 3), (2, 2), False)
verify_index_put_slice((2, 3, 4), (1, 2, 3), True)
verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False)


@tvm.testing.parametrize_targets
def test_reverse_sequence(target, dev):
def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
Expand Down Expand Up @@ -5621,6 +5696,7 @@ def repeat(N, D):
test_cumsum()
test_wrong_input()
test_aten()
test_index_put()
test_reverse_sequence()
test_eyelike()
test_qlinearconv()
Expand Down