@@ -321,8 +321,124 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
321321 return s
322322
323323
324+ def schedule_group_conv2d_nchw_direct (cfg , s , conv ):
325+ """Schedule group conv2d NCHW direct template"""
326+ workload = conv .op .attrs ["workload" ]
327+ groups = get_const_int (workload [6 ])
328+ num_filters = get_const_int (conv .shape [1 ])
329+
330+ ##### space definition begin #####
331+ n , f , y , x = s [conv ].op .axis
332+ rc , ry , rx = s [conv ].op .reduce_axis
333+ cfg .define_split ("tile_n" , n , num_outputs = 4 )
334+ cfg .define_split ("tile_g" , cfg .axis (groups ), num_outputs = 2 )
335+ cfg .define_split ("tile_f" , cfg .axis (num_filters // groups ), num_outputs = 4 )
336+ cfg .define_split ("tile_y" , y , num_outputs = 4 )
337+ cfg .define_split ("tile_x" , x , num_outputs = 4 )
338+ cfg .define_split ("tile_rc" , rc , num_outputs = 2 )
339+ cfg .define_split ("tile_ry" , ry , num_outputs = 2 )
340+ cfg .define_split ("tile_rx" , rx , num_outputs = 2 )
341+ cfg .define_knob ("auto_unroll_max_step" , [0 , 512 , 1500 ])
342+
343+ target = tvm .target .current_target ()
344+ if target .target_name in ['nvptx' , 'rocm' ]:
345+ cfg .define_knob ("unroll_explicit" , [1 ])
346+ else :
347+ cfg .define_knob ("unroll_explicit" , [0 , 1 ])
348+
349+ pad_data , kernel = s [conv ].op .input_tensors
350+
351+ s [pad_data ].compute_inline ()
352+
353+ if conv .op in s .outputs :
354+ output = conv
355+ OL = s .cache_write (conv , 'local' )
356+ else :
357+ output = s .outputs [0 ].output (0 )
358+ s [conv ].set_scope ('local' )
359+ OL = conv
360+
361+ # create cache stage
362+ AA = s .cache_read (pad_data , 'shared' , [OL ])
363+ WW = s .cache_read (kernel , 'shared' , [OL ])
364+
365+ # tile and bind spatial axes
366+ n , f , y , x = s [output ].op .axis
367+ kernel_scope , n = s [output ].split (n , nparts = 1 )
368+
369+ g , f = s [output ].split (f , nparts = groups )
370+ bn , vn , tn , ni = cfg ["tile_n" ].apply (s , output , n )
371+ bg , vg = cfg ["tile_g" ].apply (s , output , g )
372+ bf , vf , tf , fi = cfg ["tile_f" ].apply (s , output , f )
373+ by , vy , ty , yi = cfg ["tile_y" ].apply (s , output , y )
374+ bx , vx , tx , xi = cfg ["tile_x" ].apply (s , output , x )
375+
376+ s [output ].reorder (bn , bg , bf , by , bx , vn , vg , vf , vy , vx , tn , tf , ty , tx , ni , fi , yi , xi )
377+ s [output ].bind (bn , tvm .thread_axis ("blockIdx.z" ))
378+ s [output ].bind (s [output ].fuse (bg , bf ), tvm .thread_axis ("blockIdx.y" ))
379+ s [output ].bind (s [output ].fuse (by , bx ), tvm .thread_axis ("blockIdx.x" ))
380+ s [output ].bind (vn , tvm .thread_axis ("vthread" ))
381+ s [output ].bind (vg , tvm .thread_axis ("vthread" ))
382+ s [output ].bind (vf , tvm .thread_axis ("vthread" ))
383+ s [output ].bind (vy , tvm .thread_axis ("vthread" ))
384+ s [output ].bind (vx , tvm .thread_axis ("vthread" ))
385+
386+ cfg .define_knob ("fuse_yx" , [0 , 1 ]) # fuse ty,tx or tn,tf
387+ if cfg ["fuse_yx" ].val :
388+ s [output ].bind (tn , tvm .thread_axis ("threadIdx.z" ))
389+ s [output ].bind (tf , tvm .thread_axis ("threadIdx.y" ))
390+ tyx = s [output ].fuse (ty , tx )
391+ s [output ].bind (tyx , tvm .thread_axis ("threadIdx.x" ))
392+ s [OL ].compute_at (s [output ], tyx )
393+
394+ # number of threads
395+ n_tz = cfg ["tile_n" ].size [2 ]
396+ n_ty = cfg ["tile_f" ].size [2 ]
397+ n_tx = cfg ["tile_y" ].size [2 ] * cfg ["tile_x" ].size [2 ]
398+ else :
399+ s [output ].bind (s [output ].fuse (tn , tf ), tvm .thread_axis ("threadIdx.z" ))
400+ s [output ].bind (ty , tvm .thread_axis ("threadIdx.y" ))
401+ s [output ].bind (tx , tvm .thread_axis ("threadIdx.x" ))
402+ s [OL ].compute_at (s [output ], tx )
403+
404+ # number of threads
405+ n_tz = cfg ["tile_n" ].size [2 ] * cfg ["tile_f" ].size [2 ]
406+ n_ty = cfg ["tile_y" ].size [2 ]
407+ n_tx = cfg ["tile_x" ].size [2 ]
408+
409+ # tile reduction axes
410+ n , f , y , x = s [OL ].op .axis
411+ rc , ry , rx = s [OL ].op .reduce_axis
412+ rco , rci = cfg ['tile_rc' ].apply (s , OL , rc )
413+ ryo , ryi = cfg ['tile_rx' ].apply (s , OL , ry )
414+ rxo , rxi = cfg ['tile_ry' ].apply (s , OL , rx )
415+ s [OL ].reorder (rco , ryo , rxo , rci , ryi , rxi , n , f , y , x )
416+
417+ s [AA ].compute_at (s [OL ], rxo )
418+ s [WW ].compute_at (s [OL ], rxo )
419+
420+ # cooperative fetching
421+ for load in [AA , WW ]:
422+ n , f , y , x = s [load ].op .axis
423+ fused = s [load ].fuse (n , f , y , x )
424+ fused , tx = s [load ].split (fused , factor = n_tx )
425+ fused , ty = s [load ].split (fused , factor = n_ty )
426+ fused , tz = s [load ].split (fused , factor = n_tz )
427+ s [load ].bind (tz , tvm .thread_axis ("threadIdx.z" ))
428+ s [load ].bind (ty , tvm .thread_axis ("threadIdx.y" ))
429+ s [load ].bind (tx , tvm .thread_axis ("threadIdx.x" ))
430+
431+ # unroll
432+ s [output ].pragma (kernel_scope , 'auto_unroll_max_step' , cfg ['auto_unroll_max_step' ].val )
433+ s [output ].pragma (kernel_scope , 'unroll_explicit' , cfg ['unroll_explicit' ].val )
434+
435+ N , CO , OH , OW = get_const_tuple (output .shape )
436+ _ , KH , KW , CI = get_const_tuple (kernel .shape )
437+ cfg .add_flop (2 * N * OH * OW * CO * CI * KH * KW // groups )
438+
439+
324440@autotvm .register_topi_schedule (generic .schedule_group_conv2d_nchw ,
325- ["cuda" , "gpu" ], ["int8" ])
441+ ["cuda" , "gpu" ], ["int8" , "direct" ])
326442def schedule_conv2d_nchw_cuda (cfg , outs ):
327443 """TOPI schedule callback of group conv2d for cuda gpu
328444
@@ -347,7 +463,7 @@ def _callback(op):
347463 if op .tag == "group_conv2d_NCHWc_int8" :
348464 schedule_group_conv2d_NCHWc_int8 (cfg , s , op .output (0 ))
349465 if op .tag == "group_conv2d_nchw" :
350- raise tvm . error . OpNotImplemented ( "group_conv2d_nchw not supported" )
466+ schedule_group_conv2d_nchw_direct ( cfg , s , op . output ( 0 ) )
351467
352468 traverse_inline (s , outs [0 ].op , _callback )
353469 return s
0 commit comments