Skip to content

Commit c166a90

Browse files
authored
[Sampler] Use pivot-based renormalization for top-p sampling (#2272)
This PR integrates the pivot-based prob renormalization for top-p sampling, whose performance is a few times faster than the current sort-based top-p sampling on CUDA.
1 parent 8d58e52 commit c166a90

File tree

6 files changed

+109
-52
lines changed

6 files changed

+109
-52
lines changed

cpp/serve/engine_actions/batch_decode.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,10 @@ class BatchDecodeActionObj : public EngineActionObj {
114114
// Fill range [0, num_rsentries) into `sample_indices`.
115115
std::vector<int> sample_indices(num_rsentries);
116116
std::iota(sample_indices.begin(), sample_indices.end(), 0);
117-
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP(
118-
probs_on_device, sample_indices, request_ids, generation_cfg, rngs);
117+
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
118+
probs_on_device, sample_indices, request_ids, generation_cfg);
119+
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
120+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
119121
ICHECK_EQ(sample_results.size(), num_rsentries);
120122

121123
// - Update the committed tokens of states.

cpp/serve/engine_actions/new_request_prefill.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,10 @@ class NewRequestPrefillActionObj : public EngineActionObj {
229229
rsentry_activated.push_back(true);
230230
}
231231
}
232-
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP(
233-
probs_on_device, sample_indices, request_ids, generation_cfg, rngs);
232+
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
233+
probs_on_device, sample_indices, request_ids, generation_cfg);
234+
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
235+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
234236
ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());
235237

236238
// - Update the committed tokens of states.

cpp/serve/sampler/gpu_sampler.cc

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class GPUSampler : public SamplerObj {
6060
uniform_samples_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu);
6161
sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
6262
top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu);
63+
top_p_init_pivots_host_ =
64+
NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device_cpu);
6365
top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu);
6466
draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
6567
token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
@@ -73,6 +75,8 @@ class GPUSampler : public SamplerObj {
7375
uniform_samples_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device);
7476
sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
7577
top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device);
78+
top_p_init_pivots_device_ =
79+
NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device);
7680
top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device);
7781
draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
7882
token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
@@ -118,21 +122,35 @@ class GPUSampler : public SamplerObj {
118122
return probs_on_device;
119123
}
120124

121-
// - Argsort the probability.
122-
Array<NDArray> argsort_results = gpu_argsort_probs_func_(probs_on_device);
123-
ICHECK_EQ(argsort_results.size(), 2);
124-
NDArray sorted_probs_on_device = argsort_results[0];
125-
NDArray sorted_indices_on_device = argsort_results[1];
126-
127-
// - Copy auxiliary array for top-p.
125+
// - Copy auxiliary array for top-p and initial pivots.
128126
NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_);
129127
NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_);
130128
CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_);
129+
130+
NDArray top_p_init_pivots_host =
131+
top_p_init_pivots_host_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
132+
NDArray top_p_init_pivots_device =
133+
top_p_init_pivots_device_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
134+
const float* p_top_p = static_cast<const float*>(top_p_host->data);
135+
float* p_top_p_init_pivots = static_cast<float*>(top_p_init_pivots_host->data);
136+
for (int i = 0; i < num_probs; ++i) {
137+
if (1 - p_top_p[i] >= 0.02) {
138+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] =
139+
std::min(1 - p_top_p[i], static_cast<float>(0.5));
140+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = 0.02;
141+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = 0.01;
142+
} else {
143+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i];
144+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = (1 - p_top_p[i]) / 2;
145+
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = (1 - p_top_p[i]) / 4;
146+
}
147+
}
148+
CopyArray(/*src=*/top_p_init_pivots_host, /*dst=*/top_p_init_pivots_device, copy_stream_);
131149
SyncCopyStream(device_, compute_stream_, copy_stream_);
132150

133151
// - Renormalize the prob with top p.
134152
NDArray renormed_probs_on_device =
135-
gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device);
153+
gpu_renormalize_by_top_p_func_(probs_on_device, top_p_device, top_p_init_pivots_device);
136154

137155
RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p");
138156
return renormed_probs_on_device;
@@ -500,6 +518,9 @@ class GPUSampler : public SamplerObj {
500518
<< "GPU sampler requires the top_p values for each prob distribution are the same.";
501519
}
502520
}
521+
for (int i = 0; i < num_probs; ++i) {
522+
p_top_p[i] = std::max(p_top_p[i], eps_);
523+
}
503524
return need_top_p;
504525
}
505526

@@ -665,6 +686,7 @@ class GPUSampler : public SamplerObj {
665686
NDArray uniform_samples_host_;
666687
NDArray sample_indices_host_;
667688
NDArray top_p_host_;
689+
NDArray top_p_init_pivots_host_;
668690
NDArray top_prob_offsets_host_;
669691
NDArray draft_tokens_host_;
670692
NDArray token_tree_first_child_host_;
@@ -678,6 +700,7 @@ class GPUSampler : public SamplerObj {
678700
NDArray uniform_samples_device_;
679701
NDArray sample_indices_device_;
680702
NDArray top_p_device_;
703+
NDArray top_p_init_pivots_device_;
681704
NDArray top_prob_offsets_device_;
682705
NDArray draft_tokens_device_;
683706
NDArray token_tree_first_child_device_;
@@ -691,6 +714,7 @@ class GPUSampler : public SamplerObj {
691714
// The device stream for copying auxiliary data structure to GPU.
692715
TVMStreamHandle copy_stream_ = nullptr;
693716
const float eps_ = 1e-5;
717+
const int num_top_p_cutoff_pivots_ = 3;
694718
};
695719

696720
Sampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,

python/mlc_llm/compiler_pass/attach_sampler.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from tvm.relax.frontend import nn
88
from tvm.script import tir as T
99

10-
from ..op.batch_spec_verify import batch_spec_verify
10+
from mlc_llm.op.batch_spec_verify import batch_spec_verify
11+
from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm
1112

1213

1314
@tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc")
@@ -49,7 +50,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
4950
_attach_sample_with_top_p(bb, vocab_size),
5051
_attach_take_probs_func(bb, vocab_size),
5152
_attach_batch_verifier(bb, vocab_size),
52-
_attach_renormalize_by_top_p(bb, vocab_size),
53+
_attach_renormalize_by_top_p(bb, vocab_size, self.target),
5354
]
5455
]
5556

@@ -227,41 +228,36 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals
227228
return gv
228229

229230

230-
def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr):
231+
def _attach_renormalize_by_top_p(
232+
bb: relax.BlockBuilder, vocab_size: tir.PrimExpr, target: tvm.target.Target
233+
):
231234
batch_size = tir.Var("batch_size", "int64")
235+
num_pivots = 3
232236
probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32"))
233-
sorted_probs = relax.Var(
234-
"sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")
235-
)
236237
top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32"))
237-
with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]):
238+
init_pivots = relax.Var(
239+
"init_pivots", relax.TensorStructInfo((batch_size, num_pivots), "float32")
240+
)
241+
with bb.function("renormalize_by_top_p", [probs, top_p, init_pivots]):
238242
with bb.dataflow():
239-
probs_tensor = nn.wrap_nested(probs, name="probs")
240-
sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs")
241-
top_p_shape = relax.ShapeExpr([batch_size, 1])
242-
top_p_tensor = nn.wrap_nested(
243-
relax.call_pure_packed(
244-
"vm.builtin.reshape",
245-
top_p,
246-
top_p_shape,
247-
sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"),
248-
),
249-
name="sample_indices",
250-
)
251-
top_k_tensor = nn.tensor_ir_op(
252-
full,
253-
name_hint="full",
254-
args=[vocab_size],
255-
out=nn.Tensor.placeholder(
256-
[batch_size, 1],
257-
"int32",
258-
),
243+
cutoff_output = bb.emit(
244+
relax.call_tir(
245+
bb.add_func(top_p_pivot(num_pivots, target), "top_p_pivot_cutoff"),
246+
args=[probs, top_p, init_pivots],
247+
out_sinfo=[top_p.struct_info, top_p.struct_info], # pylint: disable=no-member
248+
)
259249
)
260-
renormalized_probs = nn.renormalize_top_p_top_k_prob(
261-
probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor
250+
final_pivot = cutoff_output[0]
251+
renorm_sum = cutoff_output[1]
252+
renormalized_probs = bb.emit(
253+
relax.call_tir(
254+
bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"),
255+
args=[probs, final_pivot, renorm_sum],
256+
out_sinfo=probs.struct_info, # pylint: disable=no-member
257+
)
262258
)
263-
bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access
264-
gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access
259+
bb.emit_output(renormalized_probs)
260+
gv = bb.emit_func_output(renormalized_probs)
265261
return gv
266262

267263

python/mlc_llm/compiler_pass/rewrite_softmax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re
7979
def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements
8080
target: tvm.target.Target, chunk_size: int
8181
):
82+
# NOTE: A quick note on the softmax implementation.
83+
# We once tried to multiply every element by log2e which can be computed
84+
# potentially more efficiently on hardware.
85+
# However, when the input values are large, multiplying by the factor of log2e
86+
# causes numerical issue in float32 dtype.
87+
# This leads to the softmax output not summing up to 1.
88+
# For numerical stability, we removed the log2e factor and switched back
89+
# to the standard log/exp computation.
90+
8291
# pylint: disable=invalid-name
8392
@T.prim_func
8493
def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals

python/mlc_llm/op/top_p_pivot.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import tvm
44
from tvm.script import tir as T
55

6+
from mlc_llm.support.max_thread_check import get_max_num_threads_per_block
7+
68
# mypy: disable-error-code="attr-defined,valid-type,name-defined"
79
# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda
810
# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches
911

1012

11-
def top_p_pivot(pN):
13+
def top_p_pivot(pN, target: tvm.target.Target):
1214
"""Top-p pivot function. This function finds the pivot to cut-off top-p percentile.
1315
1416
A valide pivot should satisfy the following conditions:
@@ -23,19 +25,26 @@ def top_p_pivot(pN):
2325
prob:
2426
The probability vector
2527
26-
top_p_global:
28+
top_p_arr:
2729
The top-p threshold
2830
2931
init_pivots:
3032
The initial pivot candidates
3133
3234
final_pivot:
3335
The final pivot to cut-off top-p percentile
36+
37+
final_lsum:
38+
The final sum of the values after top-p filtering.
3439
"""
3540
TX = 1024
3641
K = 32
3742
eps_LR = 1e-7
3843

44+
max_num_threads_per_block = get_max_num_threads_per_block(target)
45+
if max_num_threads_per_block < TX:
46+
TX = max_num_threads_per_block
47+
3948
def _var(dtype="int32"):
4049
return T.alloc_buffer((1,), dtype, scope="local")
4150

@@ -46,7 +55,7 @@ def valid(lsum, lmin, cmin, top_p):
4655
@T.prim_func(private=True)
4756
def _func(
4857
var_prob: T.handle,
49-
top_p_global: T.buffer([1], dtype="float32"),
58+
var_top_p_arr: T.handle,
5059
var_init_pivots: T.handle,
5160
var_final_pivot: T.handle,
5261
var_final_lsum: T.handle,
@@ -55,7 +64,8 @@ def _func(
5564
B = T.int32()
5665
N = T.int32()
5766
prob = T.match_buffer(var_prob, (B, N,), "float32")
58-
init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32")
67+
top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype="float32")
68+
init_pivots = T.match_buffer(var_init_pivots, (B, pN), "float32")
5969
final_pivot = T.match_buffer(var_final_pivot, (B,), "float32")
6070
final_lsum = T.match_buffer(var_final_lsum, (B,), "float32")
6171

@@ -92,7 +102,7 @@ def _func(
92102
with T.block("CTA"):
93103
b, tx = T.axis.remap("SS", [_bx, _tx])
94104

95-
top_p[0] = top_p_global[0]
105+
top_p[0] = top_p_arr[b]
96106

97107
if tx == 0:
98108
# leader thread initializes L, R
@@ -105,8 +115,14 @@ def _func(
105115
R_local[0] = R[0]
106116
for i in T.unroll(0, pN):
107117
# pivots are in descending order
108-
pivot[i] = init_pivots[i]
118+
pivot[i] = init_pivots[b, i]
109119
find_pivot_local[0] = False
120+
if L_local[0] - R_local[0] <= eps_LR:
121+
# When the initial value is too small, set the result directly.
122+
if tx == 0:
123+
final_lsum[b] = 1.0
124+
final_pivot[b] = 0.0
125+
find_pivot_local[0] = True
110126

111127
while T.tvm_thread_invariant(
112128
L_local[0] - R_local[0] > eps_LR
@@ -118,7 +134,7 @@ def _func(
118134
### get lsum, lmin, total_sum
119135
for pidx in T.unroll(0, pN):
120136
lsum[pidx] = 0.0
121-
lmin[pidx] = 1.0
137+
lmin[pidx] = T.max_value("float32")
122138
cmin[pidx] = 0
123139
total_sum[0] = 0.0
124140
it[0] = 0
@@ -226,6 +242,7 @@ def _func(
226242
final_lsum[b] = lsum[pidx]
227243
elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]:
228244
R[0] = pivot[pidx]
245+
final_lsum[b] = lsum[pidx]
229246
elif lsum[pidx] < top_p[0]:
230247
L[0] = pivot[pidx]
231248
it[0] += 1
@@ -243,13 +260,15 @@ def _func(
243260
if tx == 0:
244261
# leader thread writes back the pivot
245262
if T.Not(find_pivot_local[0]):
246-
final_pivot[b] = -1e5
263+
final_pivot[b] = R_local[0]
264+
if R_local[0] == eps_LR:
265+
final_lsum[b] = lsum[pN - 1]
247266
# fmt: on
248267

249268
return _func
250269

251270

252-
def top_p_renorm():
271+
def top_p_renorm(target: tvm.target.Target = None):
253272
"""Top-p renormalization function. This function renormalizes the probability vector.
254273
255274
Given the pivot, the probability vector is renormalized as follows:
@@ -273,6 +292,11 @@ def top_p_renorm():
273292
TX = 1024
274293
CTA_COUNT = 512
275294

295+
if target:
296+
max_num_threads_per_block = get_max_num_threads_per_block(target)
297+
if max_num_threads_per_block < TX:
298+
TX = max_num_threads_per_block
299+
276300
def _var(dtype="int32"):
277301
return T.alloc_buffer((1,), dtype, scope="local")
278302

0 commit comments

Comments
 (0)