Skip to content

Commit b247f8d

Browse files
authored
[Serving] Add Medusa speculative decoding (#2337)
* [Serving] Add Medusa speculative decoding
1 parent 0c03537 commit b247f8d

25 files changed

+558
-226
lines changed

cpp/metadata/model.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,17 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
6363
if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib
6464
result.attention_sink_size = json::Lookup<int64_t>(metadata, "attention_sink_size");
6565
result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, "tensor_parallel_shards");
66-
result.kv_cache_metadata =
67-
KVCacheMetadata::FromJSON(json::Lookup<picojson::object>(metadata, "kv_cache"));
66+
result.kv_state_kind = KVStateKindFromString(
67+
json::LookupOrDefault<std::string>(metadata, "kv_state_kind", "kv_cache"));
68+
if (result.kv_state_kind != KVStateKind::kNone) {
69+
result.kv_cache_metadata =
70+
KVCacheMetadata::FromJSON(json::Lookup<picojson::object>(metadata, "kv_cache"));
71+
} else {
72+
result.kv_cache_metadata = {/*num_hidden_layers=*/0,
73+
/*head_dim=*/0,
74+
/*num_attention_heads=*/0,
75+
/*num_key_value_heads=*/0};
76+
}
6877
{
6978
std::vector<ModelMetadata::Param>& params = result.params;
7079
picojson::array json_params = json::Lookup<picojson::array>(metadata, "params");
@@ -94,7 +103,7 @@ ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module,
94103
try {
95104
return ModelMetadata::FromJSON(json, model_config);
96105
} catch (const std::exception& e) {
97-
LOG(WARNING) << "Failed to parse metadata:\n" << json_str;
106+
LOG(WARNING) << "Failed to parse metadata:\n" << json_str << "\nerror: " << e.what();
98107
throw e;
99108
}
100109
}

cpp/metadata/model.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,36 @@
1616
namespace mlc {
1717
namespace llm {
1818

19+
/*! \brief The kind of cache. */
20+
enum class KVStateKind : int {
21+
kKVCache = 0,
22+
kRNNState = 1,
23+
kNone = 2,
24+
};
25+
26+
inline std::string KVStateKindToString(KVStateKind kv_state_kind) {
27+
if (kv_state_kind == KVStateKind::kKVCache) {
28+
return "kv_cache";
29+
} else if (kv_state_kind == KVStateKind::kRNNState) {
30+
return "rnn_state";
31+
} else if (kv_state_kind == KVStateKind::kNone) {
32+
return "none";
33+
} else {
34+
LOG(FATAL) << "Invalid kv state kind: " << static_cast<int>(kv_state_kind);
35+
}
36+
}
37+
38+
inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) {
39+
if (kv_state_kind == "kv_cache") {
40+
return KVStateKind::kKVCache;
41+
} else if (kv_state_kind == "rnn_state") {
42+
return KVStateKind::kRNNState;
43+
} else if (kv_state_kind == "none") {
44+
return KVStateKind::kNone;
45+
} else {
46+
LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind;
47+
}
48+
}
1949
struct ModelMetadata {
2050
struct Param {
2151
struct Preproc {
@@ -49,6 +79,7 @@ struct ModelMetadata {
4979
int64_t attention_sink_size;
5080
std::vector<Param> params;
5181
std::unordered_map<std::string, int64_t> memory_usage;
82+
KVStateKind kv_state_kind;
5283
KVCacheMetadata kv_cache_metadata;
5384

5485
static ModelMetadata FromJSON(const picojson::object& json_str,

cpp/serve/config.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig(
248248
CHECK(inferred_config.max_single_sequence_length.has_value());
249249
CHECK(inferred_config.prefill_chunk_size.has_value());
250250
CHECK(inferred_config.max_history_size.has_value());
251-
CHECK(inferred_config.kv_state_kind.has_value());
252251
ObjectPtr<EngineConfigNode> n = make_object<EngineConfigNode>();
253252

254253
// - Get models and model libs.
@@ -290,7 +289,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig(
290289
n->max_single_sequence_length = inferred_config.max_single_sequence_length.value();
291290
n->prefill_chunk_size = inferred_config.prefill_chunk_size.value();
292291
n->max_history_size = inferred_config.max_history_size.value();
293-
n->kv_state_kind = inferred_config.kv_state_kind.value();
294292

295293
return EngineConfig(n);
296294
}
@@ -356,7 +354,6 @@ String EngineConfigNode::AsJSONString() const {
356354
picojson::value(static_cast<int64_t>(this->max_single_sequence_length));
357355
config["prefill_chunk_size"] = picojson::value(static_cast<int64_t>(this->prefill_chunk_size));
358356
config["max_history_size"] = picojson::value(static_cast<int64_t>(this->max_history_size));
359-
config["kv_state_kind"] = picojson::value(KVStateKindToString(this->kv_state_kind));
360357
config["speculative_mode"] = picojson::value(SpeculativeModeToString(this->speculative_mode));
361358
config["spec_draft_length"] = picojson::value(static_cast<int64_t>(this->spec_draft_length));
362359
config["verbose"] = picojson::value(static_cast<bool>(this->verbose));
@@ -428,14 +425,18 @@ Result<ModelConfigLimits> GetModelConfigLimits(const std::vector<picojson::objec
428425
") is larger than the prefill chunk size used at compile time (" +
429426
std::to_string(compile_time_prefill_chunk_size) + ").");
430427
}
431-
model_max_prefill_chunk_size =
432-
std::min(model_max_prefill_chunk_size, runtime_prefill_chunk_size);
428+
if (runtime_prefill_chunk_size != -1) {
429+
model_max_prefill_chunk_size =
430+
std::min(model_max_prefill_chunk_size, runtime_prefill_chunk_size);
431+
}
433432
// - The maximum batch size is the minimum max batch size among all models.
434433
model_max_batch_size = std::min(
435434
model_max_batch_size, json::Lookup<int64_t>(compile_time_model_config, "max_batch_size"));
436435
}
437436
ICHECK_NE(model_max_prefill_chunk_size, std::numeric_limits<int64_t>::max());
438437
ICHECK_NE(model_max_batch_size, std::numeric_limits<int64_t>::max());
438+
ICHECK_GT(model_max_prefill_chunk_size, 0);
439+
ICHECK_GT(model_max_batch_size, 0);
439440
return Result<ModelConfigLimits>::Ok(
440441
{model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size});
441442
}
@@ -689,7 +690,6 @@ Result<InferrableEngineConfig> InferrableEngineConfig::InferForKVCache(
689690
<< " MB). The actual usage might be slightly larger than the estimated number.";
690691
}
691692

692-
inferred_config.kv_state_kind = KVStateKind::kKVCache;
693693
inferred_config.max_history_size = 0;
694694
return Result<InferrableEngineConfig>::Ok(inferred_config);
695695
}
@@ -853,7 +853,6 @@ Result<InferrableEngineConfig> InferrableEngineConfig::InferForRNNState(
853853
<< " MB). The actual usage might be slightly larger than the estimated number.";
854854
}
855855

856-
inferred_config.kv_state_kind = KVStateKind::kRNNState;
857856
return Result<InferrableEngineConfig>::Ok(inferred_config);
858857
}
859858

cpp/serve/config.h

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,8 @@ enum class SpeculativeMode : int {
114114
kSmallDraft = 1,
115115
/*! \brief The eagle-style speculative decoding. */
116116
kEagle = 2,
117-
};
118-
119-
/*! \brief The kind of cache. */
120-
enum class KVStateKind : int {
121-
kKVCache = 0,
122-
kRNNState = 1,
117+
/*! \brief The Medusa-style speculative decoding. */
118+
kMedusa = 3,
123119
};
124120

125121
class InferrableEngineConfig;
@@ -172,8 +168,6 @@ class EngineConfigNode : public Object {
172168
int prefill_chunk_size = 1024;
173169
/*! \brief The maximum history size for RNN state. KV cache does not need this. */
174170
int max_history_size = 0;
175-
/*! \brief The kind of cache. Whether it's KV cache or RNN state. */
176-
KVStateKind kv_state_kind = KVStateKind::kKVCache;
177171

178172
/*************** Speculative decoding ***************/
179173

@@ -216,7 +210,6 @@ struct InferrableEngineConfig {
216210
std::optional<int64_t> max_single_sequence_length;
217211
std::optional<int64_t> prefill_chunk_size;
218212
std::optional<int64_t> max_history_size;
219-
std::optional<KVStateKind> kv_state_kind;
220213

221214
/*! \brief Infer the config for KV cache from a given initial config. */
222215
TVM_DLL static Result<InferrableEngineConfig> InferForKVCache(
@@ -238,9 +231,16 @@ struct InferrableEngineConfig {
238231
Result<bool> ModelsUseKVCache(const std::vector<picojson::object>& model_configs);
239232

240233
inline std::string EngineModeToString(EngineMode mode) {
241-
return mode == EngineMode::kLocal ? "local"
242-
: mode == EngineMode::kInteractive ? "interactive"
243-
: "server";
234+
if (mode == EngineMode::kLocal) {
235+
return "local";
236+
} else if (mode == EngineMode::kInteractive) {
237+
return "interactive";
238+
} else if (mode == EngineMode::kServer) {
239+
return "server";
240+
} else {
241+
LOG(FATAL) << "Invalid engine mode: " << static_cast<int>(mode);
242+
throw;
243+
}
244244
}
245245

246246
inline EngineMode EngineModeFromString(const std::string& mode) {
@@ -252,13 +252,22 @@ inline EngineMode EngineModeFromString(const std::string& mode) {
252252
return EngineMode::kServer;
253253
} else {
254254
LOG(FATAL) << "Invalid engine mode string: " << mode;
255+
throw;
255256
}
256257
}
257258

258259
inline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) {
259-
return speculative_mode == SpeculativeMode::kDisable ? "disable"
260-
: speculative_mode == SpeculativeMode::kSmallDraft ? "small_draft"
261-
: "eagle";
260+
if (speculative_mode == SpeculativeMode::kDisable) {
261+
return "disable";
262+
} else if (speculative_mode == SpeculativeMode::kSmallDraft) {
263+
return "small_draft";
264+
} else if (speculative_mode == SpeculativeMode::kEagle) {
265+
return "eagle";
266+
} else if (speculative_mode == SpeculativeMode::kMedusa) {
267+
return "medusa";
268+
} else {
269+
LOG(FATAL) << "Invalid speculative mode: " << static_cast<int>(speculative_mode);
270+
}
262271
}
263272

264273
inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) {
@@ -268,22 +277,11 @@ inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_
268277
return SpeculativeMode::kSmallDraft;
269278
} else if (speculative_mode == "eagle") {
270279
return SpeculativeMode::kEagle;
280+
} else if (speculative_mode == "medusa") {
281+
return SpeculativeMode::kMedusa;
271282
} else {
272283
LOG(FATAL) << "Invalid speculative mode string: " << speculative_mode;
273-
}
274-
}
275-
276-
inline std::string KVStateKindToString(KVStateKind kv_state_kind) {
277-
return kv_state_kind == KVStateKind::kKVCache ? "kv_cache" : "rnn_State";
278-
}
279-
280-
inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) {
281-
if (kv_state_kind == "kv_cache") {
282-
return KVStateKind::kKVCache;
283-
} else if (kv_state_kind == "rnn_state") {
284-
return KVStateKind::kRNNState;
285-
} else {
286-
LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind;
284+
throw;
287285
}
288286
}
289287

cpp/serve/engine.cc

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ class EngineImpl : public Engine {
105105
model->SetPrefillChunkSize(engine_config->prefill_chunk_size);
106106
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
107107
engine_config->max_total_sequence_length,
108-
engine_config->prefill_chunk_size, engine_config->max_history_size,
109-
engine_config->kv_state_kind);
108+
engine_config->prefill_chunk_size, engine_config->max_history_size);
110109
n->model_workspaces_.push_back(
111110
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
112111
}
@@ -161,6 +160,18 @@ class EngineImpl : public Engine {
161160
n->model_workspaces_, draft_token_workspace_manager,
162161
engine_config, n->trace_recorder_)};
163162
break;
163+
case SpeculativeMode::kMedusa:
164+
n->actions_ = {EngineAction::EagleNewRequestPrefill(n->models_, //
165+
logit_processor, //
166+
sampler, //
167+
n->model_workspaces_, //
168+
draft_token_workspace_manager, //
169+
engine_config, //
170+
n->trace_recorder_),
171+
EngineAction::EagleBatchVerify(
172+
n->models_, logit_processor, sampler, n->model_workspaces_,
173+
draft_token_workspace_manager, engine_config, n->trace_recorder_)};
174+
break;
164175
default:
165176
n->actions_ = {
166177
EngineAction::NewRequestPrefill(n->models_, //
@@ -422,13 +433,9 @@ class EngineImpl : public Engine {
422433
json::LookupOptional<int64_t>(config, "max_history_size");
423434
std::optional<std::string> kv_state_kind_str =
424435
json::LookupOptional<std::string>(config, "kv_state_kind");
425-
std::optional<KVStateKind> kv_state_kind;
426-
if (kv_state_kind_str.has_value()) {
427-
kv_state_kind = KVStateKindFromString(kv_state_kind_str.value());
428-
}
429-
InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length,
436+
InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length,
430437
max_single_sequence_length, prefill_chunk_size,
431-
max_history_size, kv_state_kind};
438+
max_history_size};
432439

433440
// - Get the model metadata.
434441
std::vector<ModelMetadata> model_metadata;
@@ -440,28 +447,13 @@ class EngineImpl : public Engine {
440447
if (use_kv_cache.IsErr()) {
441448
return TResult::Error(use_kv_cache.UnwrapErr());
442449
}
443-
KVStateKind inferred_kv_state_kind;
444450
Result<InferrableEngineConfig> inferrable_cfg_res;
445451
if (use_kv_cache.Unwrap()) {
446-
inferred_kv_state_kind = KVStateKind::kKVCache;
447-
// - Check if the kv state kind from config is valid.
448-
if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) {
449-
return TResult::Error(
450-
"Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is "
451-
"specified in EngineConfig.");
452-
}
453452
// - Infer configuration.
454453
inferrable_cfg_res = InferrableEngineConfig::InferForKVCache(
455454
mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,
456455
verbose);
457456
} else {
458-
inferred_kv_state_kind = KVStateKind::kRNNState;
459-
// - Check if the kv state kind from config is valid.
460-
if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) {
461-
return TResult::Error(
462-
"Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is "
463-
"specified in EngineConfig.");
464-
}
465457
// - Infer configuration.
466458
inferrable_cfg_res = InferrableEngineConfig::InferForRNNState(
467459
mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,
@@ -477,7 +469,6 @@ class EngineImpl : public Engine {
477469
ICHECK(inferrable_cfg.max_single_sequence_length.has_value());
478470
ICHECK(inferrable_cfg.prefill_chunk_size.has_value());
479471
ICHECK(inferrable_cfg.max_history_size.has_value());
480-
ICHECK(inferrable_cfg.kv_state_kind.has_value());
481472
return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg));
482473
}
483474

cpp/serve/engine_actions/action_commons.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,26 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(
211211
return rsentry;
212212
}
213213

214+
std::pair<NDArray, std::vector<SampleResult>> ApplyLogitProcessorAndSample(
215+
const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits,
216+
const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,
217+
const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,
218+
const std::vector<int>& sample_indices) {
219+
// - Update logits.
220+
logit_processor->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);
221+
222+
// - Compute probability distributions.
223+
NDArray probs_on_device =
224+
logit_processor->ComputeProbsFromLogits(logits, generation_cfg, request_ids);
225+
226+
// - Sample tokens.
227+
NDArray renormalized_probs = sampler->BatchRenormalizeProbsByTopP(probs_on_device, sample_indices,
228+
request_ids, generation_cfg);
229+
std::vector<SampleResult> sample_results = sampler->BatchSampleTokensWithProbAfterTopP(
230+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
231+
return {std::move(probs_on_device), std::move(sample_results)};
232+
}
233+
214234
} // namespace serve
215235
} // namespace llm
216236
} // namespace mlc

cpp/serve/engine_actions/action_commons.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ inline std::vector<RequestStateEntry> GetRunningRequestStateEntries(const Engine
7575
return rsentries;
7676
}
7777

78+
/*!
79+
* \brief Apply the logit processor to the logits and sample one token for each request.
80+
* \param logit_processor The logit processor to apply.
81+
* \param sampler The sampler to sample tokens.
82+
* \param logits The logits to process.
83+
* \param generation_cfg The generation configurations of the requests.
84+
* \param request_ids The request ids.
85+
* \param mstates The model states of the requests.
86+
* \param rngs The random generators of the requests.
87+
* \param sample_indices The indices of the requests to sample.
88+
* \return The processed logits and the sampled results.
89+
*/
90+
std::pair<NDArray, std::vector<SampleResult>> ApplyLogitProcessorAndSample(
91+
const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits,
92+
const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,
93+
const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,
94+
const std::vector<int>& sample_indices);
95+
7896
} // namespace serve
7997
} // namespace llm
8098
} // namespace mlc

0 commit comments

Comments
 (0)