3232}
3333
3434
35- def verify_batch_matmul (batch , M , N , K ):
36- x = te .placeholder ((batch , M , K ), name = "x" )
37- y = te .placeholder ((batch , N , K ), name = "y" )
35+ def verify_batch_matmul (x_batch , y_batch , M , N , K ):
36+ x = te .placeholder ((x_batch , M , K ), name = "x" )
37+ y = te .placeholder ((y_batch , N , K ), name = "y" )
3838 dtype = x .dtype
3939
4040 # use memoize to pickle the test data for next time use
4141 @memoize ("topi.tests.test_topi_batch_matmul" )
4242 def get_ref_data ():
43- a_np = np .random .uniform (size = (batch , M , K )).astype (dtype )
44- b_np = np .random .uniform (size = (batch , N , K )).astype (dtype )
43+ a_np = np .random .uniform (size = (x_batch , M , K )).astype (dtype )
44+ b_np = np .random .uniform (size = (y_batch , N , K )).astype (dtype )
4545 c_np = tvm .topi .testing .batch_matmul (a_np , b_np )
4646 return (a_np , b_np , c_np )
4747
@@ -67,10 +67,13 @@ def check_device(device, ctx):
6767
6868@tvm .testing .uses_gpu
6969def test_batch_matmul ():
70- verify_batch_matmul (1 , 16 , 16 , 32 )
71- verify_batch_matmul (5 , 16 , 16 , 32 )
72- verify_batch_matmul (5 , 16 , 20 , 32 )
73- verify_batch_matmul (30 , 16 , 20 , 32 )
70+ verify_batch_matmul (1 , 1 , 16 , 16 , 32 )
71+ verify_batch_matmul (5 , 5 , 16 , 16 , 32 )
72+ verify_batch_matmul (5 , 5 , 16 , 20 , 32 )
73+ verify_batch_matmul (30 , 30 , 16 , 20 , 32 )
74+ # Test batch broadcasting.
75+ verify_batch_matmul (1 , 5 , 16 , 16 , 32 )
76+ verify_batch_matmul (5 , 1 , 16 , 16 , 32 )
7477
7578
7679if __name__ == "__main__" :
0 commit comments