@@ -227,8 +227,8 @@ def create_py_executor(
227227 with mem_monitor .observe_creation_stage (
228228 _ExecutorCreationStage .MODEL_ENGINE_MAIN ):
229229 model_engine = PyTorchModelEngine (
230- checkpoint_dir ,
231- pytorch_backend_config ,
230+ model_path = checkpoint_dir ,
231+ pytorch_backend_config = pytorch_backend_config ,
232232 batch_size = executor_config .max_batch_size ,
233233 max_beam_width = executor_config .max_beam_width ,
234234 max_num_tokens = executor_config .max_num_tokens ,
@@ -250,8 +250,8 @@ def create_py_executor(
250250 draft_spec_config .max_draft_tokens = 0
251251
252252 draft_model_engine = PyTorchModelEngine (
253- spec_config .draft_model_path ,
254- pytorch_backend_config ,
253+ model_path = spec_config .draft_model_path ,
254+ pytorch_backend_config = pytorch_backend_config ,
255255 batch_size = executor_config .max_batch_size ,
256256 max_beam_width = executor_config .max_beam_width ,
257257 max_num_tokens = executor_config .max_num_tokens ,
@@ -358,24 +358,36 @@ def create_py_executor(
358358 if estimating_kv_cache else _ExecutorCreationStage .KV_CACHE ):
359359 kv_cache_creator .build_managers (resources )
360360
361+ # Drafter for speculative decoding
362+ with mem_monitor .observe_creation_stage (_ExecutorCreationStage .DRAFTER ):
363+ drafter = get_spec_drafter (model_engine )
364+
361365 # Resource managers for speculative decoding
362366 spec_resource_manager = get_spec_resource_manager (model_engine ,
363- draft_model_engine )
367+ draft_model_engine ,
368+ drafter )
364369 if spec_resource_manager is not None :
365370 resources [
366371 ResourceManagerType .SPEC_RESOURCE_MANAGER ] = spec_resource_manager
367372
368- # Drafter for speculative decoding
369- with mem_monitor .observe_creation_stage (_ExecutorCreationStage .DRAFTER ):
370- drafter = get_spec_drafter (model_engine , spec_resource_manager )
371-
372373 with mem_monitor .observe_creation_stage (
373374 _ExecutorCreationStage .INIT_EXTRA_RESOURCES
374375 if estimating_kv_cache else _ExecutorCreationStage .EXTRA_RESOURCES ):
375376 py_executor = create_py_executor_instance (
376- dist , resources , mapping , pytorch_backend_config , executor_config ,
377- ctx_chunk_config , model_engine , draft_model_engine , False , sampler ,
378- drafter , lora_config , garbage_collection_gen0_threshold )
377+ dist = dist ,
378+ resources = resources ,
379+ mapping = mapping ,
380+ pytorch_backend_config = pytorch_backend_config ,
381+ executor_config = executor_config ,
382+ ctx_chunk_config = ctx_chunk_config ,
383+ model_engine = model_engine ,
384+ draft_model_engine = draft_model_engine ,
385+ start_worker = False ,
386+ sampler = sampler ,
387+ drafter = drafter ,
388+ lora_config = lora_config ,
389+ garbage_collection_gen0_threshold = garbage_collection_gen0_threshold ,
390+ )
379391
380392 if estimating_kv_cache :
381393 assert kv_cache_creator is not None
@@ -404,10 +416,21 @@ def create_py_executor(
404416 with mem_monitor .observe_creation_stage (
405417 _ExecutorCreationStage .EXTRA_RESOURCES ):
406418 py_executor = create_py_executor_instance (
407- dist , resources , mapping , pytorch_backend_config ,
408- executor_config , ctx_chunk_config , model_engine ,
409- draft_model_engine , False , sampler , drafter , lora_config ,
410- garbage_collection_gen0_threshold )
419+ dist = dist ,
420+ resources = resources ,
421+ mapping = mapping ,
422+ pytorch_backend_config = pytorch_backend_config ,
423+ executor_config = executor_config ,
424+ ctx_chunk_config = ctx_chunk_config ,
425+ model_engine = model_engine ,
426+ draft_model_engine = draft_model_engine ,
427+ start_worker = False ,
428+ sampler = sampler ,
429+ drafter = drafter ,
430+ lora_config = lora_config ,
431+ garbage_collection_gen0_threshold =
432+ garbage_collection_gen0_threshold ,
433+ )
411434
412435 py_executor .start_worker ()
413436 return py_executor
0 commit comments