Skip to content

Commit caf55b8

Browse files
Merge pull request #20 from lfr-0531/user/fanrongl/fix_scaling_factor
Opt padding+quant kernel
2 parents 4e23925 + 814a4cb commit caf55b8

File tree

1 file changed

+66
-46
lines changed

1 file changed

+66
-46
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -25,53 +25,55 @@ def _masked_index_copy_group_quant_fp8(
2525
# mask indices
2626
start_offsets_ptr,
2727
row_indices_ptr,
28-
# group size
29-
group_size,
30-
# output size
28+
# dimensions
29+
num_groups,
3130
row_size,
3231
col_size,
3332
dim_size,
34-
# avoid to divide zero
33+
group_size,
34+
# quantization parameters
3535
eps,
36+
fp8_max,
3637
# block size
3738
BLOCK: tl.constexpr,
39+
NUM_STAGE: tl.constexpr,
3840
):
39-
# get program id and block offset
40-
pid = tl.program_id(0)
41-
block_start = pid * group_size
41+
group_block = tl.program_id(0)
42+
token_block = tl.program_id(1)
43+
token_block_num = tl.num_programs(1)
4244

43-
# compute mask and pointers
44-
offsets = block_start + tl.arange(0, BLOCK)
45-
mask = offsets < (block_start + group_size)
45+
# calculate group and element offsets
4646
num_tokens = tl.load(start_offsets_ptr + row_size)
47-
token_idx = offsets // dim_size
48-
valid = (token_idx < num_tokens) & mask
49-
row_idx = tl.load(row_indices_ptr + token_idx, mask=valid)
50-
start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid)
51-
col_idx = token_idx - start_offset
52-
elem_idx = offsets % dim_size
53-
54-
# load input data
55-
input = tl.load(input_ptr + offsets, mask=valid, other=0.0).to(tl.float32)
56-
57-
# quant
58-
_absmax = tl.maximum(tl.max(tl.abs(input)), eps)
59-
output_s = _absmax / 448.0
60-
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
61-
output_s_inv = 1.0 / output_s
62-
output_q = tl.clamp(input * output_s_inv, -448.0,
63-
448.0).to(out_q_ptr.dtype.element_ty)
64-
65-
# write output
66-
s_dim_size = dim_size // group_size
67-
out_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx
68-
group_in_token = elem_idx // group_size
69-
out_s_offset = row_idx * col_size * s_dim_size + col_idx * s_dim_size + group_in_token
70-
71-
# Only store scaling factor for the first element in each group to avoid race conditions
72-
is_first_in_group = elem_idx % group_size == 0
73-
tl.store(out_q_ptr + out_offsets, output_q, mask=valid)
74-
tl.store(out_s_ptr + out_s_offset, output_s, mask=valid & is_first_in_group)
47+
group_start = group_block * group_size
48+
elem_offsets = group_start + tl.arange(0, BLOCK)
49+
valid_elem = elem_offsets < (group_start + group_size)
50+
input_ptr_offs = input_ptr + elem_offsets
51+
output_ptr_offs = out_q_ptr + elem_offsets
52+
output_s_offs = out_s_ptr + group_block
53+
54+
# process tokens
55+
for token_index in tl.range(token_block,
56+
num_tokens,
57+
token_block_num,
58+
num_stages=NUM_STAGE):
59+
# load input and indices
60+
input_data = tl.load(input_ptr_offs + token_index * dim_size,
61+
mask=valid_elem,
62+
other=0.0)
63+
row_idx = tl.load(row_indices_ptr + token_index)
64+
start_offset = tl.load(start_offsets_ptr + row_idx)
65+
idx = row_idx * col_size + token_index - start_offset
66+
67+
# quantization
68+
_absmax = tl.maximum(tl.max(tl.abs(input_data)), eps)
69+
output_s = _absmax / fp8_max
70+
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
71+
output_q = tl.clamp(input_data / output_s, -fp8_max,
72+
fp8_max).to(out_q_ptr.dtype.element_ty)
73+
74+
# store quantized values and scaling factor
75+
tl.store(output_ptr_offs + idx * dim_size, output_q, mask=valid_elem)
76+
tl.store(output_s_offs + idx * num_groups, output_s)
7577

7678

7779
def masked_index_copy_group_quant_fp8(
@@ -88,32 +90,50 @@ def masked_index_copy_group_quant_fp8(
8890
), "the last dimension of `input` cannot be divisible by `group_size`"
8991
assert input.is_contiguous(), "`input` is not contiguous"
9092
assert input.ndim == 2, "Input must be a 2D tensor"
91-
assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]"
93+
assert output.ndim == 3, "Output must be a 3D tensor, [row, col, dim]"
9294
assert start_offsets.shape[
9395
0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)"
9496

9597
num_tokens = input.shape[0]
9698
row_size = output.shape[0]
9799
col_size = output.shape[1]
98100
dim_size = output.shape[2]
99-
total_elems = num_tokens * dim_size
101+
num_groups = (dim_size + group_size - 1) // group_size
102+
103+
# get block/grid/stage/warp
104+
BLOCK = group_size
105+
if num_tokens <= 4096:
106+
TOKEN_BLOCK_NUM = 128
107+
NUM_STAGES = 4
108+
num_warps = 2
109+
else:
110+
TOKEN_BLOCK_NUM = 64
111+
NUM_STAGES = 6
112+
num_warps = 1
113+
grid = (
114+
num_groups,
115+
TOKEN_BLOCK_NUM,
116+
)
100117

101-
M = total_elems // group_size
102-
BLOCK = triton.next_power_of_2(group_size)
103-
# heuristics for number of warps
104-
num_warps = min(max(BLOCK // 256, 1), 8)
105-
_masked_index_copy_group_quant_fp8[(M, )](
118+
# FP8 quantization parameters
119+
finfo = torch.finfo(torch.float8_e4m3fn)
120+
fp8_max = finfo.max
121+
122+
_masked_index_copy_group_quant_fp8[grid](
106123
input,
107124
output,
108125
output_s,
109126
start_offsets,
110127
row_indices,
111-
group_size,
128+
num_groups,
112129
row_size,
113130
col_size,
114131
dim_size,
132+
group_size,
115133
eps,
134+
fp8_max,
116135
BLOCK=BLOCK,
136+
NUM_STAGE=NUM_STAGES,
117137
num_warps=num_warps,
118138
)
119139
return

0 commit comments

Comments
 (0)