55
66from tvm .autotvm .task .task import compute_flop
77
8+ def random_dtypes ():
9+ """Return pair of (input, accumulator) dtypes"""
10+ candidates = [("float32" , "float32" ), ("float16" , "float32" ), ("int8" , "int32" )]
11+ return candidates [np .random .choice (len (candidates ))]
12+
813def test_conv ():
914 for i in range (5 ):
1015 N , H , W , CO , CI , KH , KW = [np .random .randint (10 , 32 ) for _ in range (7 )]
11- D = tvm .placeholder ((N , CI , H , W ))
12- K = tvm .placeholder ((CO , CI , KH , KW ))
16+ (input_dtype , acc_dtype ) = random_dtypes ()
17+ D = tvm .placeholder ((N , CI , H , W ), dtype = input_dtype )
18+ K = tvm .placeholder ((CO , CI , KH , KW ), dtype = input_dtype )
1319
1420 KH = min (H , KH )
1521 KW = min (W , KW )
@@ -22,7 +28,8 @@ def test_conv():
2228 OW = (W - KW ) + 1
2329
2430 C = tvm .compute ((N , CO , OH , OW ), lambda n , co , h , w :
25- tvm .sum (D [n ][ci ][h ][w ] * K [co ][ci ][h ][w ], axis = [ci , kh , kw ]))
31+ tvm .sum (D [n ][ci ][h ][w ].astype (acc_dtype ) * K [co ][ci ][h ][w ].astype (acc_dtype ),
32+ axis = [ci , kh , kw ]))
2633
2734 s = tvm .create_schedule ([C .op ])
2835
@@ -31,15 +38,16 @@ def test_conv():
3138def test_pack_gemm ():
3239 for i in range (5 ):
3340 N , L , M = [np .random .randint (10 , 128 ) * 4 for _ in range (3 )]
34- A = tvm .placeholder ((N , L ))
35- B = tvm .placeholder ((M , L ))
41+ (input_dtype , acc_dtype ) = random_dtypes ()
42+ A = tvm .placeholder ((N , L ), dtype = input_dtype )
43+ B = tvm .placeholder ((M , L ), dtype = input_dtype )
3644 k = tvm .reduce_axis ((0 , L ))
3745
3846 bn = 4
3947 A_pack = tvm .compute ((N // bn , L , bn ), lambda i , j , k : A [i * bn + k ][j ])
4048 B_pack = tvm .compute ((M // bn , L , bn ), lambda i , j , k : B [i * bn + k ][j ])
4149 C_pack = tvm .compute ((N // bn , M // bn , bn , bn ), lambda i , j , ii , jj :
42- tvm .sum (A_pack [i , k , ii ] * B_pack [j , k , jj ], axis = [k ]))
50+ tvm .sum (A_pack [i , k , ii ]. astype ( acc_dtype ) * B_pack [j , k , jj ]. astype ( acc_dtype ) , axis = [k ]))
4351 C = tvm .compute ((N , M ), lambda i , j : C_pack [i // bn ][j // bn ][i % bn ][j % bn ])
4452
4553 s = tvm .create_schedule ([C .op ])
@@ -48,14 +56,61 @@ def test_pack_gemm():
4856def test_outer_dot ():
4957 for i in range (5 ):
5058 N , M = [np .random .randint (10 , 128 ) * 4 for _ in range (2 )]
51- A = tvm .placeholder ((N ,))
52- B = tvm .placeholder ((M ,))
59+ (input_dtype , acc_dtype ) = random_dtypes ()
60+ A = tvm .placeholder ((N ,), dtype = input_dtype )
61+ B = tvm .placeholder ((M ,), dtype = input_dtype )
5362
54- C = tvm .compute ((N , M ), lambda i , j : A [i ] * B [j ])
63+ C = tvm .compute ((N , M ), lambda i , j : A [i ]. astype ( acc_dtype ) * B [j ]. astype ( acc_dtype ) )
5564
5665 s = tvm .create_schedule ([C .op ])
5766 assert compute_flop (s ) == N * M
5867
68+ def test_max_pool ():
69+ for i in range (5 ):
70+ N , H , W , CO , CI , KH , KW = [np .random .randint (10 , 32 ) for _ in range (7 )]
71+ (input_dtype , _ ) = random_dtypes ()
72+ D = tvm .placeholder ((N , CI , H , W ), dtype = input_dtype )
73+
74+ KH = min (H , KH )
75+ KW = min (W , KW )
76+
77+ kh = tvm .reduce_axis ((0 , KH ))
78+ kw = tvm .reduce_axis ((0 , KW ))
79+
80+ OH = (H - KH ) + 1
81+ OW = (W - KW ) + 1
82+
83+ C = tvm .compute (
84+ (N , CO , OH , OW ),
85+ lambda n , co , h , w : tvm .max (D [n ][co ][h + kh ][w + kw ], axis = [kh , kw ]))
86+
87+ s = tvm .create_schedule ([C .op ])
88+
89+ assert compute_flop (s ) == N * CO * OH * OW * KH * KW
90+
91+ def test_average_pool ():
92+ for i in range (5 ):
93+ N , H , W , CO , CI , KH , KW = [np .random .randint (10 , 32 ) for _ in range (7 )]
94+ (input_dtype , acc_dtype ) = random_dtypes ()
95+ D = tvm .placeholder ((N , CI , H , W ), dtype = input_dtype )
96+
97+ KH = min (H , KH )
98+ KW = min (W , KW )
99+
100+ kh = tvm .reduce_axis ((0 , KH ))
101+ kw = tvm .reduce_axis ((0 , KW ))
102+
103+ OH = (H - KH ) + 1
104+ OW = (W - KW ) + 1
105+
106+ C = tvm .compute (
107+ (N , CO , OH , OW ),
108+ lambda n , co , h , w : tvm .sum (D [n ][co ][h + kh ][w + kw ].astype (acc_dtype ) / (KW * KH ), axis = [kh , kw ]))
109+
110+ s = tvm .create_schedule ([C .op ])
111+
112+ assert compute_flop (s ) == 2 * N * CO * OH * OW * KH * KW
113+
59114def test_move ():
60115 """No float number operation in simple move. So the estimator should raise an error """
61116 N = 1024
0 commit comments