Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/serve/draft_token_workspace_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
if (require_hidden_states) {
workspace->draft_hidden_states_storage =
NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
}
}

Expand Down
22 changes: 8 additions & 14 deletions cpp/serve/engine_actions/eagle_batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,15 @@ class EagleBatchDraftActionObj : public EngineActionObj {
mstates.push_back(rsentry->mstates[model_id]);
}
// draft_length_ rounds of draft proposal.
ObjectRef last_hidden_states{nullptr};
NDArray hidden_states = Downcast<NDArray>(model_workspaces_[model_id].hidden_states);
ObjectRef hidden_states = model_workspaces_[model_id].hidden_states;
// Concat last hidden_states
draft_token_slots_.clear();
if (draft_length_ > 1) {
for (int i = 0; i < num_rsentries; ++i) {
draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());
}
hidden_states = Downcast<NDArray>(models_[model_id]->GatherHiddenStates(
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states));
ICHECK(hidden_states->ndim == 2);
last_hidden_states = hidden_states.CreateView(
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
hidden_states = models_[model_id]->GatherHiddenStates(
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);
}
// The first draft token has been generated in prefill/verify stage
for (int draft_id = 1; draft_id < draft_length_; ++draft_id) {
Expand All @@ -114,11 +110,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {

// - Invoke model decode.
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states =
models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids);
last_hidden_states = hidden_states;
ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states,
request_internal_ids);
NDArray logits;
if (models_[model_id]->CanGetLogits()) {
logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
Expand All @@ -145,11 +140,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
std::vector<NDArray> prob_dist;
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Add draft token to the state.
Expand Down
55 changes: 10 additions & 45 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
Array<GenerationConfig> generation_cfg;
std::vector<RandomGenerator*> rngs;
std::vector<std::vector<SampleResult>> draft_output_tokens;
std::vector<std::vector<NDArray>> draft_output_prob_dist;
request_internal_ids.reserve(num_rsentries);
all_tokens_to_verify.reserve(total_draft_length);
verify_request_mstates.reserve(num_rsentries);
Expand Down Expand Up @@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden(
embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]);
NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
fused_hidden_states, request_internal_ids, verify_lengths);
ICHECK_EQ(hidden_states->ndim, 3);
ICHECK_EQ(hidden_states->shape[0], 1);
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
embeddings, request_internal_ids, verify_lengths);
NDArray logits =
models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
Expand Down Expand Up @@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {

{
// One step draft for the following steps
NDArray last_hidden_states_nd = hidden_states.CreateView(
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
hidden_states->dtype);

hidden_states = Downcast<NDArray>(models_[draft_model_id_]->GatherHiddenStates(
last_hidden_states_nd, last_accepted_hidden_positions,
&model_workspaces_[draft_model_id_].hidden_states));
ICHECK(hidden_states->ndim == 2);
hidden_states = hidden_states.CreateView(
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
// Gather hidden states for the last accepted tokens.
hidden_states = models_[draft_model_id_]->GatherHiddenStates(
hidden_states, last_accepted_hidden_positions,
&model_workspaces_[draft_model_id_].hidden_states);

std::vector<int> input_tokens;
Array<RequestModelState> mstates;
Expand All @@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {

// - Invoke model decode.
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states,
request_internal_ids);
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(
fused_embedding_hidden_states, request_internal_ids);

if (models_[draft_model_id_]->CanGetLogits()) {
logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
Expand All @@ -239,22 +229,17 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
std::vector<NDArray> prob_dist;
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Slice and save hidden_states_for_sample
draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_);
models_[draft_model_id_]->ScatterDraftProbs(
renormalized_probs, draft_token_slots_,
&model_workspaces_[verify_model_id_].draft_probs_storage);
ICHECK(hidden_states->ndim == 3);
hidden_states = hidden_states.CreateView(
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
hidden_states->dtype);
models_[draft_model_id_]->ScatterHiddenStates(
hidden_states, draft_token_slots_,
&model_workspaces_[verify_model_id_].draft_hidden_states_storage);
Expand Down Expand Up @@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
return num_required_pages <= num_available_pages;
}

/*!
* \brief Get one item from a hidden_states array, which corresponds to the last token.
* \param hidden_states The hidden_states of all the tokens.
* \param token_pos The desired token position in the sequence.
* \return The desired token's hidden_states
*/
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
ICHECK_EQ(hidden_states->ndim, 3);
NDArray last_hidden_on_device =
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);

int64_t ndata = hidden_states->shape[2];
const int16_t* __restrict p_hidden =
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
(token_pos * ndata);

last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
return last_hidden_on_device;
}

/*!
* \brief The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
Expand Down
70 changes: 29 additions & 41 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
// - Get embedding and run prefill for each model.
std::vector<int> prefill_lengths;
prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);
NDArray hidden_states_for_input{nullptr};
NDArray hidden_states_for_sample{nullptr};
ObjectRef hidden_states_for_input{nullptr};
ObjectRef hidden_states_for_sample{nullptr};
NDArray logits_for_sample{nullptr};
// A map used to record the entry and child_idx pair needed to fork sequence.
// The base model (id 0) should record all the pairs and all the small models
Expand Down Expand Up @@ -167,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
}

RECORD_EVENT(trace_recorder_, request_ids, "start prefill");
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden(
fused_hidden_states, request_internal_ids, prefill_lengths);
ObjectRef embedding_or_hidden_states{nullptr};
if (model_id == 0) {
embedding_or_hidden_states = embeddings;
} else {
embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
}
// hidden_states: (b * s, h)
ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden(
embedding_or_hidden_states, request_internal_ids, prefill_lengths);
RECORD_EVENT(trace_recorder_, request_ids, "finish prefill");
ICHECK_EQ(hidden_states->ndim, 3);
ICHECK_EQ(hidden_states->shape[0], 1);
ICHECK_EQ(hidden_states->shape[1], cum_prefill_length);

if (model_id == 0) {
// We only need to sample for model 0 in prefill.
Expand All @@ -183,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {

// Whether to use base model to get logits.
int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id;
hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden(
hidden_states, request_internal_ids, prefill_lengths);

std::vector<int> logit_positions;
{
// Prepare the logit positions
logit_positions.reserve(prefill_lengths.size());
int total_len = 0;
for (int i = 0; i < prefill_lengths.size(); ++i) {
total_len += prefill_lengths[i];
logit_positions.push_back(total_len - 1);
}
}
// hidden_states_for_sample: (b * s, h)
hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates(
hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states);
// logits_for_sample: (b * s, v)
logits_for_sample =
models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries);
ICHECK_EQ(hidden_states_for_sample->ndim, 3);
ICHECK_EQ(hidden_states_for_sample->shape[0], 1);
ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries);

// - Update logits.
ICHECK(logits_for_sample.defined());
Array<GenerationConfig> generation_cfg;
Expand Down Expand Up @@ -278,11 +290,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
rsentry_activated.push_back(true);
}
}
std::vector<NDArray> prob_dist;

NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());

// - Update the committed tokens of states.
Expand Down Expand Up @@ -311,10 +323,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,
&model_workspaces_[0].draft_probs_storage);
if (engine_config_->spec_draft_length > 1) {
hidden_states_for_sample = hidden_states_for_sample.CreateView(
{hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1],
hidden_states_for_sample->shape[2]},
hidden_states_for_sample->dtype);
models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_,
&model_workspaces_[0].draft_hidden_states_storage);
}
Expand Down Expand Up @@ -567,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
ICHECK(false) << "Cannot reach here";
}

/*!
* \brief Get one item from a hidden_states array, which corresponds to the last token.
* \param hidden_states The hidden_states of all the tokens.
* \param token_pos The desired token position in the sequence.
* \return The desired token's hidden_states
*/
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
ICHECK_EQ(hidden_states->ndim, 3);
NDArray last_hidden_on_device =
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);

int64_t ndata = hidden_states->shape[2];
const int16_t* __restrict p_hidden =
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
(token_pos * ndata);

last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
return last_hidden_on_device;
}

/*! \brief The models to run prefill in. */
Array<Model> models_;
/*! \brief The logit processor. */
Expand Down
7 changes: 4 additions & 3 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void FunctionTable::_InitFunctions() {
Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm;
this->get_logits_func_ = mod_get_func("get_logits");
this->batch_get_logits_func_ = mod_get_func("batch_get_logits");
this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true);
this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states");
this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true);
this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true);
this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true);
Expand Down Expand Up @@ -259,11 +259,12 @@ void FunctionTable::_InitFunctions() {
this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of");
this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset");
support_backtracking_kv_ = true;
this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem");

this->gather_probs_func_ = mod->GetFunction("gather_probs", true);
this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true);
this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true);
this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true);
this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states");
this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states");
}

ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const {
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct FunctionTable {
PackedFunc nd_view_func_;
PackedFunc nd_get_shape_func_;
PackedFunc nd_copy_embedding_to_offset_func_;
PackedFunc tuple_getitem_func_;
// Auxiliary functions for speculative decoding.
PackedFunc gather_probs_func_;
PackedFunc scatter_probs_func_;
Expand Down
Loading