Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class BatchDraftActionObj : public EngineActionObj {
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Add draft token to the state.
Expand Down
61 changes: 14 additions & 47 deletions cpp/serve/sampler/cpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ namespace serve {
* \param input_prob_offset The offset specifying which distribution to sample from.
* \param top_p The top-p value of sampling.
* \param uniform_sample The random number in [0, 1] for sampling.
* \param output_prob_dist Optional pointer to store the corresponding probability distribution of
* each token, offset by unit_offset. If nullptr provided, nothing will be stored out.
* \return The sampled value and probability.
* \note This function is an enhancement of SampleTopPFromProb in TVM Unity.
* We will upstream the enhancement after it gets stable.
*/
TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_offset, double top_p,
double uniform_sample,
std::vector<NDArray>* output_prob_dist = nullptr) {
double uniform_sample) {
// prob: (*, v)
// The prob array may have arbitrary ndim and shape.
// The last dimension corresponds to the prob distribution size.
Expand All @@ -51,13 +48,6 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o
static_cast<float*>(__builtin_assume_aligned(prob->data, 4)) + (input_prob_offset * ndata);
constexpr double one = 1.0f - 1e-5f;

if (output_prob_dist) {
ICHECK_LT(unit_offset, static_cast<int>(output_prob_dist->size()));
if (!(*output_prob_dist)[unit_offset].defined()) {
(*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0});
}
}

if (top_p == 0) {
// Specially handle case where top_p == 0.
// This case is equivalent to doing argmax.
Expand All @@ -75,20 +65,9 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o
break;
}
}
if (output_prob_dist) {
float* __restrict p_output_prob =
static_cast<float*>(__builtin_assume_aligned((*output_prob_dist)[unit_offset]->data, 4));
for (int i = 0; i < ndata; ++i) {
p_output_prob[i] = i == argmax_pos ? 1.0 : 0.0;
}
}
return {argmax_pos, 1.0};
}

if (output_prob_dist) {
(*output_prob_dist)[unit_offset].CopyFromBytes(p_prob, ndata * sizeof(float));
}

if (top_p >= one) {
// Specially handle case where top_p == 1.
double prob_sum = 0.0f;
Expand Down Expand Up @@ -419,10 +398,9 @@ class CPUSampler : public SamplerObj {
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
std::vector<NDArray>* output_prob_dist) final {
const std::vector<RandomGenerator*>& rngs) final {
return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs,
/*top_p_applied=*/true, output_prob_dist);
/*top_p_applied=*/true);
}

std::vector<std::vector<SampleResult>> BatchVerifyDraftTokensWithProbAfterTopP(
Expand Down Expand Up @@ -520,14 +498,12 @@ class CPUSampler : public SamplerObj {
}

private:
std::vector<SampleResult> BatchSampleTokensImpl(
NDArray probs_on_host, //
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
bool top_p_applied, //
std::vector<NDArray>* output_prob_dist = nullptr) {
std::vector<SampleResult> BatchSampleTokensImpl(NDArray probs_on_host, //
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
bool top_p_applied) {
// probs_on_host: (n, v)
RECORD_EVENT(trace_recorder_, request_ids, "start sampling");
ICHECK_EQ(probs_on_host->ndim, 2);
Expand All @@ -540,29 +516,20 @@ class CPUSampler : public SamplerObj {

std::vector<SampleResult> sample_results;
sample_results.resize(n);
if (output_prob_dist) {
output_prob_dist->resize(n);
}

tvm::runtime::parallel_for_with_threading_backend(
[this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied,
sample_indices, output_prob_dist](int i) {
sample_indices](int i) {
RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token");
// Sample top p from probability.
double top_p =
top_p_applied
? 1.0f
: (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p);
sample_results[i].sampled_token_id =
SampleTopPFromProb(probs_on_host, i, sample_indices[i], top_p,
rngs[i]->GetRandomNumber(), output_prob_dist);
if (output_prob_dist == nullptr) {
// When `output_prob_dist` is not nullptr, it means right now
// we are sampling for a small model in speculation, in which
// case we do not need to get the top probs.
sample_results[i].top_prob_tokens =
ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs);
}
sample_results[i].sampled_token_id = SampleTopPFromProb(
probs_on_host, i, sample_indices[i], top_p, rngs[i]->GetRandomNumber());
sample_results[i].top_prob_tokens =
ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs);
RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token");
},
0, n);
Expand Down
29 changes: 8 additions & 21 deletions cpp/serve/sampler/gpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,10 @@ class GPUSampler : public SamplerObj {
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
std::vector<NDArray>* output_prob_dist = nullptr) final {
const std::vector<RandomGenerator*>& rngs) final {
NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbAfterTopP");
return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids,
generation_cfg, rngs, /*top_p_applied=*/true, output_prob_dist);
generation_cfg, rngs, /*top_p_applied=*/true);
}

std::vector<std::vector<SampleResult>> BatchVerifyDraftTokensWithProbAfterTopP(
Expand Down Expand Up @@ -326,14 +325,12 @@ class GPUSampler : public SamplerObj {
}

private:
std::vector<SampleResult> BatchSampleTokensImpl(
NDArray probs_on_device, //
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
bool top_p_applied, //
std::vector<NDArray>* output_prob_dist = nullptr) {
std::vector<SampleResult> BatchSampleTokensImpl(NDArray probs_on_device, //
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
bool top_p_applied) {
// probs_on_device: (n, v)
RECORD_EVENT(trace_recorder_, request_ids, "start sampling");
CHECK_EQ(probs_on_device->ndim, 2);
Expand All @@ -342,16 +339,6 @@ class GPUSampler : public SamplerObj {
int num_samples = sample_indices.size();
int num_probs = probs_on_device->shape[0];
int vocab_size = probs_on_device->shape[1];
if (output_prob_dist != nullptr) {
ICHECK(output_prob_dist->empty());
output_prob_dist->reserve(num_samples);
for (int i = 0; i < num_samples; ++i) {
NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_);
float* p_prob = static_cast<float*>(probs_on_device->data) + sample_indices[i] * vocab_size;
prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float));
output_prob_dist->push_back(std::move(prob_dist));
}
}
if (num_samples == 0) {
// This synchronization is necessary for making sure that this round
// of model forward is finished.
Expand Down
4 changes: 1 addition & 3 deletions cpp/serve/sampler/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class SamplerObj : public Object {
* \param generation_cfg The generation config of each request
* in the input batch.
* \param rngs The random number generator of each sequence.
* \param output_prob_dist The output probability distribution
* \return The batch of sampling results, which contain the sampled token id
* and other probability info.
*/
Expand All @@ -92,8 +91,7 @@ class SamplerObj : public Object {
const std::vector<int>& sample_indices, //
const Array<String>& request_ids, //
const Array<GenerationConfig>& generation_cfg, //
const std::vector<RandomGenerator*>& rngs, //
std::vector<NDArray>* output_prob_dist = nullptr) = 0;
const std::vector<RandomGenerator*>& rngs) = 0;

/*!
* \brief Verify draft tokens generated by small models in the large model
Expand Down