File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -1364,12 +1364,16 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
13641364    threads_per_block *= 2 ;
13651365    cluster_size /= 2 ;
13661366  }
1367+   int  sm_count = get_sm_count ();
1368+   while  (cluster_num * cluster_size > sm_count && cluster_size > 1  && threads_per_block <= 512 ) {
1369+       threads_per_block *= 2 ;
1370+       cluster_size /= 2 ;
1371+   }
13671372  FLASHINFER_CHECK (oneshot || threads_per_block >= params.nranks ,
13681373                   " not oneshot, or threads_per_block < nranks" 
13691374  int  block_size = threads_per_block;
13701375  FLASHINFER_CHECK (block_size <= 1024  && cluster_size > 0 ,
13711376                   " block_size > 1024 or cluster_size <= 0" 
1372-   int  sm_count = get_sm_count ();
13731377  int  grid_size = (std::min (sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
13741378  cudaLaunchConfig_t cfg;
13751379  cudaLaunchAttribute attribute[2 ];
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments