@@ -275,5 +275,112 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int
275275 tvm .ir .assert_structural_equal (mod ["main" ], before )
276276
277277
278+ def test_small_spatial_axis ():
279+
280+ @T .prim_func (private = True )
281+ def func (var_A : T .handle , B : T .Buffer ((T .int64 (8 ), T .int64 (4096 )), "float16" ), var_C : T .handle ):
282+ T .func_attr ({"tir.noalias" : T .bool (True )})
283+ batch_size = T .int64 ()
284+ A = T .match_buffer (var_A , (batch_size , T .int64 (4096 )), "float16" )
285+ C = T .match_buffer (var_C , (batch_size , T .int64 (8 )), "float16" )
286+ for i0 , i1 , k in T .grid (batch_size , T .int64 (8 ), T .int64 (4096 )):
287+ with T .block ("NT_matmul" ):
288+ v_i0 , v_i1 , v_k = T .axis .remap ("SSR" , [i0 , i1 , k ])
289+ T .reads (A [v_i0 , v_k ], B [v_i1 , v_k ])
290+ T .writes (C [v_i0 , v_i1 ])
291+ with T .init ():
292+ C [v_i0 , v_i1 ] = T .float16 (0 )
293+ C [v_i0 , v_i1 ] = C [v_i0 , v_i1 ] + A [v_i0 , v_k ] * B [v_i1 , v_k ]
294+
295+ # fmt: off
296+ @T .prim_func (private = True )
297+ def expected (var_A : T .handle , B : T .Buffer ((T .int64 (8 ), T .int64 (4096 )), "float16" ), var_C : T .handle ):
298+ T .func_attr ({"tir.is_scheduled" : 1 , "tir.noalias" : T .bool (True )})
299+ batch_size = T .int64 ()
300+ A = T .match_buffer (var_A , (batch_size , T .int64 (4096 )), "float16" )
301+ C = T .match_buffer (var_C , (batch_size , T .int64 (8 )), "float16" )
302+ # with T.block("root"):
303+ C_pad_local = T .alloc_buffer (((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), T .int64 (8 )), "float16" , scope = "local" )
304+ C_pad_rf_local = T .alloc_buffer ((T .int64 (128 ), (batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), T .int64 (8 )), "float16" , scope = "local" )
305+ C_pad_rf_local_1 = T .alloc_buffer ((T .int64 (32 ), (batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), T .int64 (8 )), "float16" , scope = "local" )
306+ for ax0_0 in T .thread_binding ((batch_size + T .int64 (3 )) // T .int64 (4 ), thread = "blockIdx.y" ):
307+ for u_fused_ax1_fused_fused_0 in T .thread_binding (T .int64 (1 ), thread = "blockIdx.x" ):
308+ for u_fused_ax1_fused_fused_1 in T .thread_binding (T .int64 (16 ), thread = "threadIdx.y" ):
309+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T .thread_binding (T .int64 (32 ), thread = "threadIdx.x" ):
310+ for ax0_1_init , u_fused_ax1_fused_fused_2_init in T .grid (T .int64 (4 ), T .int64 (2 )):
311+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T .vectorized (T .int64 (4 )):
312+ with T .block ("NT_matmul_rf_init" ):
313+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T .axis .spatial (T .int64 (128 ), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T .int64 (4 ) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init )
314+ v0 = T .axis .spatial ((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), ax0_0 * T .int64 (4 ) + ax0_1_init )
315+ v1 = T .axis .spatial (T .int64 (8 ), u_fused_ax1_fused_fused_0 * T .int64 (32 ) + u_fused_ax1_fused_fused_1 * T .int64 (2 ) + u_fused_ax1_fused_fused_2_init )
316+ T .where ((u_fused_ax1_fused_fused_0 * T .int64 (16 ) + u_fused_ax1_fused_fused_1 ) * T .int64 (2 ) + u_fused_ax1_fused_fused_2_init < T .int64 (8 ))
317+ T .reads ()
318+ T .writes (C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ])
319+ C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ] = T .float16 (0 )
320+ for ax2_fused_u_fused_0 in T .serial (T .int64 (16 ), annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
321+ for ax0_1 , u_fused_ax1_fused_fused_2 , ax2_fused_u_fused_2 in T .grid (T .int64 (4 ), T .int64 (2 ), T .int64 (2 )):
322+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T .vectorized (T .int64 (4 )):
323+ with T .block ("NT_matmul_rf_update" ):
324+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T .axis .spatial (T .int64 (128 ), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T .int64 (4 ) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 )
325+ v0 = T .axis .spatial ((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), ax0_0 * T .int64 (4 ) + ax0_1 )
326+ v1 = T .axis .spatial (T .int64 (8 ), u_fused_ax1_fused_fused_0 * T .int64 (32 ) + u_fused_ax1_fused_fused_1 * T .int64 (2 ) + u_fused_ax1_fused_fused_2 )
327+ vax2_fused_u_fused_0 , vax2_fused_u_fused_2 = T .axis .remap ("RR" , [ax2_fused_u_fused_0 , ax2_fused_u_fused_2 ])
328+ T .where ((u_fused_ax1_fused_fused_0 * T .int64 (16 ) + u_fused_ax1_fused_fused_1 ) * T .int64 (2 ) + u_fused_ax1_fused_fused_2 < T .int64 (8 ))
329+ T .reads (C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ], A [v0 , vax2_fused_u_fused_0 * T .int64 (256 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T .int64 (4 ) * T .int64 (8 ) + vax2_fused_u_fused_2 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T .int64 (4 )], B [v1 , vax2_fused_u_fused_0 * T .int64 (256 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T .int64 (4 ) * T .int64 (8 ) + vax2_fused_u_fused_2 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T .int64 (4 )])
330+ T .writes (C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ])
331+ C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ] = C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , v0 , v1 ] + T .if_then_else (v0 < batch_size , A [v0 , vax2_fused_u_fused_0 * T .int64 (256 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T .int64 (4 ) * T .int64 (8 ) + vax2_fused_u_fused_2 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T .int64 (4 )], T .float16 (0 )) * B [v1 , vax2_fused_u_fused_0 * T .int64 (256 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T .int64 (4 ) * T .int64 (8 ) + vax2_fused_u_fused_2 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T .int64 (4 )]
332+ for ax3_fused_0_ax3_fused_1_fused in T .thread_binding (T .int64 (16 ), thread = "threadIdx.y" ):
333+ for ax0 in T .thread_binding (T .int64 (32 ), thread = "threadIdx.x" ):
334+ for ax3_fused_2_0 in T .serial (T .int64 (1 ), annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
335+ for ax2 in range (T .int64 (4 )):
336+ for ax3_fused_2_1 in T .vectorized (T .int64 (2 )):
337+ with T .block ("NT_matmul_rf_init" ):
338+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T .axis .spatial (T .int64 (32 ), ax0 )
339+ v0 = T .axis .spatial ((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), ax0_0 * T .int64 (4 ) + ax2 )
340+ v1 = T .axis .spatial (T .int64 (8 ), ax3_fused_0_ax3_fused_1_fused * T .int64 (2 ) + ax3_fused_2_0 * T .int64 (2 ) + ax3_fused_2_1 )
341+ T .where ((T .Mul (T .int64 (0 ), T .int64 (16 )) + ax3_fused_0_ax3_fused_1_fused % T .int64 (16 )) * T .int64 (2 ) + (ax3_fused_2_0 * T .int64 (2 ) + ax3_fused_2_1 ) < T .int64 (8 ))
342+ T .reads ()
343+ T .writes (C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ])
344+ C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ] = T .float16 (0 )
345+ for ax1 in range (T .int64 (4 )):
346+ with T .block ("NT_matmul_rf_update" ):
347+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T .axis .remap ("SR" , [ax0 , ax1 ])
348+ v0 = T .axis .spatial ((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), ax0_0 * T .int64 (4 ) + ax2 )
349+ v1 = T .axis .spatial (T .int64 (8 ), ax3_fused_0_ax3_fused_1_fused * T .int64 (2 ) + ax3_fused_2_0 * T .int64 (2 ) + ax3_fused_2_1 )
350+ T .where ((T .Mul (T .int64 (0 ), T .int64 (16 )) + ax3_fused_0_ax3_fused_1_fused % T .int64 (16 )) * T .int64 (2 ) + (ax3_fused_2_0 * T .int64 (2 ) + ax3_fused_2_1 ) < T .int64 (8 ))
351+ T .reads (C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ], C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 , v0 , v1 ])
352+ T .writes (C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ])
353+ C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ] = C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ] + C_pad_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T .int64 (4 ) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 , v0 , v1 ]
354+ for ax2_fused_2 , ax1 in T .grid (T .int64 (2 ), T .int64 (4 )):
355+ for ax2_fused_0_ax2_fused_1_fused in T .thread_binding (T .int64 (16 ), thread = "threadIdx.y" ):
356+ for ax0 in T .thread_binding (T .int64 (32 ), thread = "threadIdx.x" ):
357+ with T .block ("NT_matmul" ):
358+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T .axis .reduce (T .int64 (32 ), ax0 )
359+ v0 = T .axis .spatial ((batch_size + T .int64 (3 )) // T .int64 (4 ) * T .int64 (4 ), ax0_0 * T .int64 (4 ) + ax1 )
360+ v1 = T .axis .spatial (T .int64 (8 ), ax2_fused_0_ax2_fused_1_fused * T .int64 (2 ) + ax2_fused_2 )
361+ T .where ((T .Mul (T .int64 (0 ), T .int64 (16 )) + ax2_fused_0_ax2_fused_1_fused % T .int64 (16 )) * T .int64 (2 ) + ax2_fused_2 < T .int64 (8 ))
362+ T .reads (C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ])
363+ T .writes (C_pad_local [v0 , v1 ])
364+ with T .init ():
365+ C_pad_local [v0 , v1 ] = T .float16 (0 )
366+ C_pad_local [v0 , v1 ] = C_pad_local [v0 , v1 ] + C_pad_rf_local_1 [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 , v0 , v1 ]
367+ for ax0 in range (T .int64 (4 )):
368+ for ax1_fused_0_ax1_fused_1_fused in T .thread_binding (T .int64 (16 ), thread = "threadIdx.y" ):
369+ for ax1_fused_2 in range (T .int64 (2 )):
370+ with T .block ("C_pad" ):
371+ v0 = T .axis .spatial (batch_size , ax0_0 * T .int64 (4 ) + ax0 )
372+ v1 = T .axis .spatial (T .int64 (8 ), ax1_fused_0_ax1_fused_1_fused * T .int64 (2 ) + ax1_fused_2 )
373+ T .where ((ax0_0 - (batch_size + T .int64 (3 )) // T .int64 (4 ) < T .int64 (0 ) or ax0_0 == T .int64 (0 )) and ax0_0 * T .int64 (4 ) + ax0 < batch_size and (T .Mul (T .int64 (0 ), T .int64 (16 )) + ax1_fused_0_ax1_fused_1_fused % T .int64 (16 )) * T .int64 (2 ) + ax1_fused_2 < T .int64 (8 ))
374+ T .reads (C_pad_local [v0 , v1 ])
375+ T .writes (C [v0 , v1 ])
376+ C [v0 , v1 ] = C_pad_local [v0 , v1 ]
377+ # fmt: on
378+
379+ mod = tvm .IRModule ({"main" : func })
380+ with Target ("cuda" ):
381+ mod = dl .ApplyDefaultSchedule (dl .gpu .LowBatchGEMV (4 ))(mod )
382+ tvm .ir .assert_structural_equal (mod ["main" ], expected )
383+
384+
278385if __name__ == "__main__" :
279386 tvm .testing .main ()
0 commit comments