@@ -25,53 +25,55 @@ def _masked_index_copy_group_quant_fp8(
2525    # mask indices 
2626    start_offsets_ptr ,
2727    row_indices_ptr ,
28-     # group size 
29-     group_size ,
30-     # output size 
28+     # dimensions 
29+     num_groups ,
3130    row_size ,
3231    col_size ,
3332    dim_size ,
34-     # avoid to divide zero 
33+     group_size ,
34+     # quantization parameters 
3535    eps ,
36+     fp8_max ,
3637    # block size 
3738    BLOCK : tl .constexpr ,
39+     NUM_STAGE : tl .constexpr ,
3840):
39-     # get program id and block offset 
40-     pid  =  tl .program_id (0 )
41-     block_start  =  pid   *   group_size 
41+     group_block   =   tl . program_id ( 0 ) 
42+     token_block  =  tl .program_id (1 )
43+     token_block_num  =  tl . num_programs ( 1 ) 
4244
43-     # compute mask and pointers 
44-     offsets  =  block_start  +  tl .arange (0 , BLOCK )
45-     mask  =  offsets  <  (block_start  +  group_size )
45+     # calculate group and element offsets 
4646    num_tokens  =  tl .load (start_offsets_ptr  +  row_size )
47-     token_idx  =  offsets  //  dim_size 
48-     valid  =  (token_idx  <  num_tokens ) &  mask 
49-     row_idx  =  tl .load (row_indices_ptr  +  token_idx , mask = valid )
50-     start_offset  =  tl .load (start_offsets_ptr  +  row_idx , mask = valid )
51-     col_idx  =  token_idx  -  start_offset 
52-     elem_idx  =  offsets  %  dim_size 
53- 
54-     # load input data 
55-     input  =  tl .load (input_ptr  +  offsets , mask = valid , other = 0.0 ).to (tl .float32 )
56- 
57-     # quant 
58-     _absmax  =  tl .maximum (tl .max (tl .abs (input )), eps )
59-     output_s  =  _absmax  /  448.0 
60-     output_s  =  tl .exp2 (tl .ceil (tl .log2 (tl .abs (output_s ))))
61-     output_s_inv  =  1.0  /  output_s 
62-     output_q  =  tl .clamp (input  *  output_s_inv , - 448.0 ,
63-                         448.0 ).to (out_q_ptr .dtype .element_ty )
64- 
65-     # write output 
66-     s_dim_size  =  dim_size  //  group_size 
67-     out_offsets  =  row_idx  *  col_size  *  dim_size  +  col_idx  *  dim_size  +  elem_idx 
68-     group_in_token  =  elem_idx  //  group_size 
69-     out_s_offset  =  row_idx  *  col_size  *  s_dim_size  +  col_idx  *  s_dim_size  +  group_in_token 
70- 
71-     # Only store scaling factor for the first element in each group to avoid race conditions 
72-     is_first_in_group  =  elem_idx  %  group_size  ==  0 
73-     tl .store (out_q_ptr  +  out_offsets , output_q , mask = valid )
74-     tl .store (out_s_ptr  +  out_s_offset , output_s , mask = valid  &  is_first_in_group )
47+     group_start  =  group_block  *  group_size 
48+     elem_offsets  =  group_start  +  tl .arange (0 , BLOCK )
49+     valid_elem  =  elem_offsets  <  (group_start  +  group_size )
50+     input_ptr_offs  =  input_ptr  +  elem_offsets 
51+     output_ptr_offs  =  out_q_ptr  +  elem_offsets 
52+     output_s_offs  =  out_s_ptr  +  group_block 
53+ 
54+     # process tokens 
55+     for  token_index  in  tl .range (token_block ,
56+                                 num_tokens ,
57+                                 token_block_num ,
58+                                 num_stages = NUM_STAGE ):
59+         # load input and indices 
60+         input_data  =  tl .load (input_ptr_offs  +  token_index  *  dim_size ,
61+                              mask = valid_elem ,
62+                              other = 0.0 )
63+         row_idx  =  tl .load (row_indices_ptr  +  token_index )
64+         start_offset  =  tl .load (start_offsets_ptr  +  row_idx )
65+         idx  =  row_idx  *  col_size  +  token_index  -  start_offset 
66+ 
67+         # quantization 
68+         _absmax  =  tl .maximum (tl .max (tl .abs (input_data )), eps )
69+         output_s  =  _absmax  /  fp8_max 
70+         output_s  =  tl .exp2 (tl .ceil (tl .log2 (tl .abs (output_s ))))
71+         output_q  =  tl .clamp (input_data  /  output_s , - fp8_max ,
72+                             fp8_max ).to (out_q_ptr .dtype .element_ty )
73+ 
74+         # store quantized values and scaling factor 
75+         tl .store (output_ptr_offs  +  idx  *  dim_size , output_q , mask = valid_elem )
76+         tl .store (output_s_offs  +  idx  *  num_groups , output_s )
7577
7678
7779def  masked_index_copy_group_quant_fp8 (
@@ -88,32 +90,50 @@ def masked_index_copy_group_quant_fp8(
8890    ), "the last dimension of `input` cannot be divisible by `group_size`" 
8991    assert  input .is_contiguous (), "`input` is not contiguous" 
9092    assert  input .ndim  ==  2 , "Input must be a 2D tensor" 
91-     assert  output .ndim  ==  3 , "Input  must be a 3D tensor, [row, col, dim]" 
93+     assert  output .ndim  ==  3 , "Output  must be a 3D tensor, [row, col, dim]" 
9294    assert  start_offsets .shape [
9395        0 ] ==  output .shape [0 ] +  1 , "Start offsets must be (num_experts + 1)" 
9496
9597    num_tokens  =  input .shape [0 ]
9698    row_size  =  output .shape [0 ]
9799    col_size  =  output .shape [1 ]
98100    dim_size  =  output .shape [2 ]
99-     total_elems  =  num_tokens  *  dim_size 
101+     num_groups  =  (dim_size  +  group_size  -  1 ) //  group_size 
102+ 
103+     # get block/grid/stage/warp 
104+     BLOCK  =  group_size 
105+     if  num_tokens  <=  4096 :
106+         TOKEN_BLOCK_NUM  =  128 
107+         NUM_STAGES  =  4 
108+         num_warps  =  2 
109+     else :
110+         TOKEN_BLOCK_NUM  =  64 
111+         NUM_STAGES  =  6 
112+         num_warps  =  1 
113+     grid  =  (
114+         num_groups ,
115+         TOKEN_BLOCK_NUM ,
116+     )
100117
101-     M   =   total_elems   //   group_size 
102-     BLOCK  =  triton . next_power_of_2 ( group_size )
103-     # heuristics for number of warps 
104-      num_warps   =   min ( max ( BLOCK   //   256 ,  1 ),  8 ) 
105-     _masked_index_copy_group_quant_fp8 [( M , ) ](
118+     # FP8 quantization parameters 
119+     finfo  =  torch . finfo ( torch . float8_e4m3fn )
120+     fp8_max   =   finfo . max 
121+ 
122+     _masked_index_copy_group_quant_fp8 [grid ](
106123        input ,
107124        output ,
108125        output_s ,
109126        start_offsets ,
110127        row_indices ,
111-         group_size ,
128+         num_groups ,
112129        row_size ,
113130        col_size ,
114131        dim_size ,
132+         group_size ,
115133        eps ,
134+         fp8_max ,
116135        BLOCK = BLOCK ,
136+         NUM_STAGE = NUM_STAGES ,
117137        num_warps = num_warps ,
118138    )
119139    return 
0 commit comments