@@ -105,8 +105,7 @@ class EngineImpl : public Engine {
105105 model->SetPrefillChunkSize (engine_config->prefill_chunk_size );
106106 model->CreateKVCache (engine_config->kv_cache_page_size , engine_config->max_num_sequence ,
107107 engine_config->max_total_sequence_length ,
108- engine_config->prefill_chunk_size , engine_config->max_history_size ,
109- engine_config->kv_state_kind );
108+ engine_config->prefill_chunk_size , engine_config->max_history_size );
110109 n->model_workspaces_ .push_back (
111110 ModelWorkspace{model->AllocEmbeddingTensor (), model->AllocHiddenStatesTensor ()});
112111 }
@@ -161,6 +160,18 @@ class EngineImpl : public Engine {
161160 n->model_workspaces_ , draft_token_workspace_manager,
162161 engine_config, n->trace_recorder_ )};
163162 break ;
163+ case SpeculativeMode::kMedusa :
164+ n->actions_ = {EngineAction::EagleNewRequestPrefill (n->models_ , //
165+ logit_processor, //
166+ sampler, //
167+ n->model_workspaces_ , //
168+ draft_token_workspace_manager, //
169+ engine_config, //
170+ n->trace_recorder_ ),
171+ EngineAction::EagleBatchVerify (
172+ n->models_ , logit_processor, sampler, n->model_workspaces_ ,
173+ draft_token_workspace_manager, engine_config, n->trace_recorder_ )};
174+ break ;
164175 default :
165176 n->actions_ = {
166177 EngineAction::NewRequestPrefill (n->models_ , //
@@ -422,13 +433,9 @@ class EngineImpl : public Engine {
422433 json::LookupOptional<int64_t >(config, " max_history_size" );
423434 std::optional<std::string> kv_state_kind_str =
424435 json::LookupOptional<std::string>(config, " kv_state_kind" );
425- std::optional<KVStateKind> kv_state_kind;
426- if (kv_state_kind_str.has_value ()) {
427- kv_state_kind = KVStateKindFromString (kv_state_kind_str.value ());
428- }
429- InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length,
436+ InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length,
430437 max_single_sequence_length, prefill_chunk_size,
431- max_history_size, kv_state_kind };
438+ max_history_size};
432439
433440 // - Get the model metadata.
434441 std::vector<ModelMetadata> model_metadata;
@@ -440,28 +447,13 @@ class EngineImpl : public Engine {
440447 if (use_kv_cache.IsErr ()) {
441448 return TResult::Error (use_kv_cache.UnwrapErr ());
442449 }
443- KVStateKind inferred_kv_state_kind;
444450 Result<InferrableEngineConfig> inferrable_cfg_res;
445451 if (use_kv_cache.Unwrap ()) {
446- inferred_kv_state_kind = KVStateKind::kKVCache ;
447- // - Check if the kv state kind from config is valid.
448- if (kv_state_kind.has_value () && kv_state_kind.value () != inferred_kv_state_kind) {
449- return TResult::Error (
450- " Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is "
451- " specified in EngineConfig." );
452- }
453452 // - Infer configuration.
454453 inferrable_cfg_res = InferrableEngineConfig::InferForKVCache (
455454 mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,
456455 verbose);
457456 } else {
458- inferred_kv_state_kind = KVStateKind::kRNNState ;
459- // - Check if the kv state kind from config is valid.
460- if (kv_state_kind.has_value () && kv_state_kind.value () != inferred_kv_state_kind) {
461- return TResult::Error (
462- " Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is "
463- " specified in EngineConfig." );
464- }
465457 // - Infer configuration.
466458 inferrable_cfg_res = InferrableEngineConfig::InferForRNNState (
467459 mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,
@@ -477,7 +469,6 @@ class EngineImpl : public Engine {
477469 ICHECK (inferrable_cfg.max_single_sequence_length .has_value ());
478470 ICHECK (inferrable_cfg.prefill_chunk_size .has_value ());
479471 ICHECK (inferrable_cfg.max_history_size .has_value ());
480- ICHECK (inferrable_cfg.kv_state_kind .has_value ());
481472 return TResult::Ok (EngineConfig::FromJSONAndInferredConfig (config, inferrable_cfg));
482473 }
483474
0 commit comments