diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1b1d60119967..4e6540fb08de 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -795,19 +795,52 @@ def _mx_multibox_detection(inputs, attrs): def _mx_dot(inputs, attrs): assert len(inputs) == 2 - a, b = inputs + + a = inputs[0] + b = inputs[1] + rank_a = len(_infer_type(a).checked_type.shape) rank_b = len(_infer_type(b).checked_type.shape) - if rank_a != 2 or rank_b != 2: - raise tvm.error.OpAttributeUnimplemented("Only 2-D arrays are supported.") + + if rank_a < 1 or rank_b < 1: + raise tvm.error.OpAttributeInvalid("Unsupported shape of input tensors.") + transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) + if transpose_a is True: msg = 'Value {} in attribute "transpose_a" of operator dot ' "is not valid." raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) - if transpose_b is False: - b = _op.transpose(b, axes=[1, 0]) - return _op.nn.dense(a, b) + + # When performing dot product we need to properly handle shape of result -> out_shape + if rank_a == 1: + out_shape = list() + a = _op.expand_dims(a, axis=0) + else: + shape_a = list(_infer_type(a).checked_type.shape) + out_shape = shape_a[:-1] + a = _op.reshape(a, newshape=(-1, shape_a[-1])) + + if rank_b == 1: + if not out_shape: + out_shape = [ + 1, + ] + b = _op.expand_dims(b, axis=1) + else: + # Transpose matrix b if needed + if transpose_b: + trans_axes = list(range(rank_b)) + trans_axes = trans_axes[-1:] + trans_axes[:-1] + b = _op.transpose(b, axes=trans_axes) + + shape_b = list(_infer_type(b).checked_type.shape) + out_shape += shape_b[1:] + b = _op.reshape(b, newshape=(shape_b[0], -1)) + + out = _op.reshape(_op.nn.matmul(a, b), newshape=out_shape) + + return out def _mx_batch_dot(inputs, attrs): diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 44aa93061a62..0e34719ea27d 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -690,6 +690,17 @@ def verify(a_shape, b_shape, transpose_b=False): verify((1, 256), (256, 1)) verify((1, 256), (1, 256), transpose_b=True) + verify((5,), (5,)) + verify((3,), (3, 5)) + verify((3,), (5, 3), transpose_b=True) + verify((3,), (3, 5, 3, 5)) + verify((3,), (5, 5, 3, 3), transpose_b=True) + verify((10, 1), (1,)) + verify((1, 1), (4, 3, 2, 1), transpose_b=True) + verify((4, 3, 2, 1), (1,)) + verify((1, 2, 3, 4), (1, 4), transpose_b=True) + verify((4, 1, 1), (1, 2, 3)) + verify((1, 1, 4), (2, 3, 4), transpose_b=True) @tvm.testing.uses_gpu