Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions cpp/serve/draft_token_workspace_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*!
* Copyright (c) 2024 by Contributors
* \file serve/draft_token_workspace_manager.cc
*/

#include "draft_token_workspace_manager.h"

#include "model.h"

namespace mlc {
namespace llm {
namespace serve {

DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size,
int hidden_size,
DLDataType hidden_states_dtype,
DLDevice device,
const FunctionTable& ft)
: max_num_tokens_(max_num_tokens),
vocab_size_(vocab_size),
hidden_size_(hidden_size),
hidden_states_dtype_(hidden_states_dtype),
device_(device),
ft_(ft) {
free_slots_.resize(max_num_tokens);
std::iota(free_slots_.begin(), free_slots_.end(), 0);
}

void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) {
ICHECK_LE(num_slots, free_slots_.size());
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);
std::vector<int> allocated(free_slots_.begin(), free_slots_.begin() + num_slots);
free_slots_.resize(free_slots_.size() - num_slots);
}

void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) {
std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_));
}

void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
bool require_hidden_states) {
workspace->draft_probs =
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
workspace->draft_probs_storage =
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
if (require_hidden_states) {
workspace->draft_hidden_states_storage =
NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
}
}

} // namespace serve
} // namespace llm
} // namespace mlc
95 changes: 95 additions & 0 deletions cpp/serve/draft_token_workspace_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*!
* Copyright (c) 2024 by Contributors
* \file serve/draft_token_workspace_manager.h
*/

#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
#include <tvm/runtime/device_api.h>

#include <numeric>
#include <optional>
#include <vector>

#include "data.h"
#include "function_table.h"
namespace mlc {
namespace llm {
namespace serve {

using tvm::Device;
using namespace tvm::runtime;

struct ModelWorkspace;

/*!
* \brief Managing the workspace for draft token generation.
*
* The workspace is used to store the associated states for each draft token, including the
* probability distribution of the draft token, the hidden states, etc. The workspace manager
* maintains a pool of slots for the draft tokens to store the states.
*/
class DraftTokenWorkspaceManagerObj : public Object {
public:
/*!
* \brief Constructor
* \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace.
* \param vocab_size The size of the vocabulary.
* \param hidden_size The size of the hidden states.
* \param hidden_states_dtype The data type of the hidden states.
* \param device The device running the model.
* \param ft The function table.
*/
DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size,
DLDataType hidden_states_dtype, DLDevice device,
const FunctionTable& ft);

/*!
* \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure.
* \param workspace The object to stored the allocated draft token workspace.
* \param require_hidden_states Whether to allocate workspace for the hidden states.
*/
void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states);

/*!
* \brief Allocate slots for the draft tokens.
* \param num_slots The number of slots to allocate.
* \param result The vector to store the allocated slots.
*/
void AllocSlots(int num_slots, std::vector<int>* result);

/*!
* \brief Free the slots.
* \param slots The slots to free.
*/
void FreeSlots(const std::vector<int>& slots);

static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager";

private:
std::vector<int> free_slots_;
int max_num_tokens_;
int vocab_size_;
int hidden_size_;
DataType hidden_states_dtype_;
DLDevice device_;
const FunctionTable& ft_;
};

class DraftTokenWorkspaceManager : public ObjectRef {
public:
DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size,
DLDataType hidden_states_dtype, DLDevice device,
const FunctionTable& ft) {
data_ = make_object<DraftTokenWorkspaceManagerObj>(max_num_tokens, vocab_size, hidden_size,
hidden_states_dtype, device, ft);
}
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef,
DraftTokenWorkspaceManagerObj);
};

} // namespace serve
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
55 changes: 33 additions & 22 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,13 @@ class EngineImpl : public Engine {
}

int max_num_tokens = engine_config->max_num_sequence;
DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
max_num_tokens *= engine_config->spec_draft_length + 1;
draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens);
draft_token_workspace_manager->AllocWorkspace(
&model_workspaces_[0],
/*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle);
}
LogitProcessor logit_processor =
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
Expand All @@ -114,30 +119,36 @@ class EngineImpl : public Engine {
ICHECK_GT(this->models_.size(), 1U);
switch (engine_config->speculative_mode) {
case SpeculativeMode::kEagle:
this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
engine_config, //
this->trace_recorder_),
EngineAction::EagleBatchDraft(
this->models_, logit_processor, sampler, this->model_workspaces_,
this->trace_recorder_, engine_config->spec_draft_length),
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
this->model_workspaces_, engine_config,
this->trace_recorder_)};
this->actions_ = {
EngineAction::EagleNewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
draft_token_workspace_manager, //
engine_config, //
this->trace_recorder_),
EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler,
this->model_workspaces_, draft_token_workspace_manager,
this->trace_recorder_,
engine_config->spec_draft_length),
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
this->model_workspaces_, draft_token_workspace_manager,
engine_config, this->trace_recorder_)};
break;
default:
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
engine_config, //
this->trace_recorder_),
EngineAction::BatchDraft(this->models_, logit_processor, sampler,
this->trace_recorder_),
EngineAction::BatchVerify(this->models_, logit_processor, sampler,
engine_config, this->trace_recorder_)};
this->actions_ = {
EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
engine_config, //
this->trace_recorder_),
EngineAction::BatchDraft(this->models_, logit_processor, sampler,
this->model_workspaces_, draft_token_workspace_manager,
this->trace_recorder_),
EngineAction::BatchVerify(this->models_, logit_processor, sampler,
this->model_workspaces_, draft_token_workspace_manager,
engine_config, this->trace_recorder_)};
}
} else {
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
Expand Down
29 changes: 21 additions & 8 deletions cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_

#include "../config.h"
#include "../draft_token_workspace_manager.h"
#include "../engine_state.h"
#include "../event_trace_recorder.h"
#include "../model.h"
Expand Down Expand Up @@ -72,15 +73,16 @@ class EngineAction : public ObjectRef {
* \param logit_processor The logit processor.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction EagleNewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);
static EngineAction EagleNewRequestPrefill(
Array<Model> models, LogitProcessor logit_processor, Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);
/*!
* \brief Create the action that runs one-step decode for requests in the
* `running_queue` of engine state. Preempt low-priority requests
Expand All @@ -104,13 +106,16 @@ class EngineAction : public ObjectRef {
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param trace_recorder The event trace recorder for requests.
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction BatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder, int draft_length = 4);

/*!
* \brief Create the action that runs one-step speculative draft proposal for
Expand All @@ -120,12 +125,14 @@ class EngineAction : public ObjectRef {
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param trace_recorder The event trace recorder for requests.
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);

Expand All @@ -135,13 +142,17 @@ class EngineAction : public ObjectRef {
* accordingly when it is impossible to decode all the running requests.
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param sampler The sampler to sample new tokens.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction BatchVerify(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, EngineConfig engine_config,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);

/*!
Expand All @@ -152,13 +163,15 @@ class EngineAction : public ObjectRef {
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction EagleBatchVerify(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);

Expand Down
13 changes: 9 additions & 4 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
std::move(models), max_single_sequence_length);
}

RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
const Array<Model>& models,
Optional<EventTraceRecorder> trace_recorder) {
RequestStateEntry PreemptLastRunningRequestStateEntry(
EngineState estate, const Array<Model>& models,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder) {
ICHECK(!estate->running_queue.empty());
Request request = estate->running_queue.back();

Expand All @@ -168,8 +169,12 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
// - Update `inputs` for future prefill.
RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt");
rsentry->status = RequestStateStatus::kPending;
std::vector<int> draft_token_slots;
for (RequestModelState mstate : rsentry->mstates) {
mstate->RemoveAllDraftTokens();
if (draft_token_workspace_manager.defined()) {
mstate->RemoveAllDraftTokens(&draft_token_slots);
draft_token_workspace_manager.value()->FreeSlots(draft_token_slots);
}
std::vector<int32_t> committed_token_ids;
committed_token_ids.reserve(mstate->committed_tokens.size());
for (const SampleResult& committed_token : mstate->committed_tokens) {
Expand Down
13 changes: 8 additions & 5 deletions cpp/serve/engine_actions/action_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_

#include "../../tokenizers.h"
#include "../draft_token_workspace_manager.h"
#include "../engine.h"
#include "../engine_state.h"
#include "../event_trace_recorder.h"
Expand Down Expand Up @@ -52,12 +53,14 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
* If it is not in the waiting request queue, add it to the waiting queue.
* \param estate The engine state to update due to preemption.
* \param models The models to remove preempted requests from.
* \param trace_recorder The event trace recorder for requests.
* \return The preempted request state.
* \param draft_token_workspace_manager The draft token workspace manager for requests. Must be
* provided if speculative decoding is enabled. \param trace_recorder The event trace recorder for
* requests. \return The preempted request state.
*/
RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
const Array<Model>& models,
Optional<EventTraceRecorder> trace_recorder);
RequestStateEntry PreemptLastRunningRequestStateEntry(
EngineState estate, const Array<Model>& models,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder);

/*! \brief Get the running request entries from the engine state. */
inline std::vector<RequestStateEntry> GetRunningRequestStateEntries(const EngineState& estate) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BatchDecodeActionObj : public EngineActionObj {
running_rsentries = GetRunningRequestStateEntries(estate);
while (!CanDecode(running_rsentries.size())) {
RequestStateEntry preempted =
PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_);
PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_);
if (preempted.same_as(running_rsentries.back())) {
running_rsentries.pop_back();
}
Expand Down
Loading