Skip to content

Commit 2892e6a

Browse files
author
Matthew Brookhart
committed
fix batch matmul test
1 parent 6ae7c02 commit 2892e6a

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tests/python/frontend/onnx/test_forward.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def test_matmul():
858858
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
859859

860860

861-
def verify_batch_matmul(a_shape, b_shape):
861+
def verify_batch_matmul(a_shape, b_shape, target, ctx):
862862
a_array = np.random.uniform(size=a_shape).astype("float32")
863863
b_array = np.random.uniform(size=b_shape).astype("float32")
864864
out_np = np.matmul(a_array, b_array)
@@ -877,17 +877,16 @@ def verify_batch_matmul(a_shape, b_shape):
877877

878878
model = helper.make_model(graph, producer_name="matmul_test")
879879

880-
for target, ctx in tvm.testing.enabled_targets():
881-
tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
882-
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
880+
tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
881+
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
883882

884883

885884
# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
886885
@tvm.testing.parametrize_targets("llvm")
887-
def test_batch_matmul():
888-
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
889-
verify_batch_matmul((2, 4, 3), (3, 4))
890-
verify_batch_matmul((2, 3, 4, 3), (3, 4))
886+
def test_batch_matmul(target, ctx):
887+
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
888+
verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
889+
verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)
891890

892891

893892
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):

0 commit comments

Comments
 (0)