@@ -858,7 +858,7 @@ def test_matmul():
858858 tvm .testing .assert_allclose (out_np , tvm_out , rtol = 1e-5 , atol = 1e-5 )
859859
860860
861- def verify_batch_matmul (a_shape , b_shape ):
861+ def verify_batch_matmul (a_shape , b_shape , target , ctx ):
862862 a_array = np .random .uniform (size = a_shape ).astype ("float32" )
863863 b_array = np .random .uniform (size = b_shape ).astype ("float32" )
864864 out_np = np .matmul (a_array , b_array )
@@ -877,17 +877,16 @@ def verify_batch_matmul(a_shape, b_shape):
877877
878878 model = helper .make_model (graph , producer_name = "matmul_test" )
879879
880- for target , ctx in tvm .testing .enabled_targets ():
881- tvm_out = get_tvm_output_with_vm (model , [a_array , b_array ], target , ctx )
882- tvm .testing .assert_allclose (out_np , tvm_out , rtol = 1e-5 , atol = 1e-5 )
880+ tvm_out = get_tvm_output_with_vm (model , [a_array , b_array ], target , ctx )
881+ tvm .testing .assert_allclose (out_np , tvm_out , rtol = 1e-5 , atol = 1e-5 )
883882
884883
885884# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
886885@tvm .testing .parametrize_targets ("llvm" )
887- def test_batch_matmul ():
888- verify_batch_matmul ((2 , 3 , 4 , 3 ), (2 , 3 , 3 , 4 ))
889- verify_batch_matmul ((2 , 4 , 3 ), (3 , 4 ))
890- verify_batch_matmul ((2 , 3 , 4 , 3 ), (3 , 4 ))
886+ def test_batch_matmul (target , ctx ):
887+ verify_batch_matmul ((2 , 3 , 4 , 3 ), (2 , 3 , 3 , 4 ), target , ctx )
888+ verify_batch_matmul ((2 , 4 , 3 ), (3 , 4 ), target , ctx )
889+ verify_batch_matmul ((2 , 3 , 4 , 3 ), (3 , 4 ), target , ctx )
891890
892891
893892def verify_lrn (shape , nsize , dtype , alpha = None , beta = None , bias = None ):
0 commit comments