Skip to content

Commit 59ef0ee

Browse files
authored
[Bugfix][ONNX] Improve broadcast and batch_matmul conversion (#16961)
* [Bugfix][VTA] Fix FSIM compile error on macOS. VTA FSIM could not be built on macOS, for it leverages malloc.h and memalign, yet both have been deprecated and are not provided by macOS. This issue was captured in #13173. This commit stops including malloc.h in VTA Runtime as stdlib.h has provided functions we need. This commit uses posix_memalign instead of memalign. It is a portable standard function. * Fix format. * [Bugfix][ONNX] Improve broadcast and batch_matmul conversion This commit provides batch_matmul conversions between a 3D or above matrix and a 1D matrix with proper broadcasting, which improves the robustness of the ONNX frontend. This issue was captured in #16891. * Fix format.
1 parent 944d180 commit 59ef0ee

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,21 @@ def matmul_out_dtype(inputs, out_dtype):
307307
a = flatten_to_nd(inputs[0], a_shape, 2)
308308
b = _op.transpose(inputs[1])
309309
output = _op.nn.dense(a, b, out_dtype=out_dtype)
310+
elif a_rank == 1 or b_rank == 1:
311+
a, b = inputs
312+
_a_shape = tuple(a_shape.data.numpy())
313+
_b_shape = tuple(b_shape.data.numpy())
314+
if a_rank == 1:
315+
axis = -2
316+
a = _op.expand_dims(a, axis=0)
317+
batches = _b_shape[:-2]
318+
a = _op.broadcast_to(a, (*batches, 1, _a_shape[0]))
319+
else:
320+
axis = -1
321+
b = _op.expand_dims(b, axis=-1)
322+
batches = _a_shape[:-2]
323+
b = _op.broadcast_to(b, (*batches, _b_shape[0], 1))
324+
return _op.squeeze(_op.nn.batch_matmul(a, b, transpose_b=False), axis=axis)
310325
else:
311326
a = inputs[0]
312327
b = inputs[1]

tests/python/frontend/onnx/test_forward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,8 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
14931493
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4))
14941494
verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4))
14951495
# Test implicit broadcasting.
1496+
verify_batch_matmul((5,), (5, 5, 4), (5, 4))
1497+
verify_batch_matmul((5, 4, 5), (5,), (5, 4))
14961498
verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4))
14971499
verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4))
14981500
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))

0 commit comments

Comments
 (0)