1818                      mpi_comm , mpi_rank , nvtx_range_debug )
1919from  ..bindings  import  executor  as  tllm 
2020from  ..builder  import  ConfigEncoder , Engine , EngineConfig 
21- from  ..llmapi .llm_args  import  PybindMirror , TorchLlmArgs 
21+ from  ..llmapi .llm_args  import  BaseLlmArgs ,  PybindMirror , TorchLlmArgs 
2222from  ..llmapi .mpi_session  import  set_mpi_session_cpp 
2323from  ..llmapi .tokenizer  import  TokenizerBase 
2424from  ..llmapi .tracer  import  VizTracer , global_tracer , set_global_tracer 
@@ -63,7 +63,7 @@ def __init__(
6363        lora_config : Optional [LoraConfig ] =  None ,
6464        hf_model_dir : Optional [Path ] =  None ,
6565        tokenizer : Optional [TokenizerBase ] =  None ,
66-         llm_args : Optional [TorchLlmArgs ] =  None ,
66+         llm_args : Optional [BaseLlmArgs ] =  None ,
6767    ) ->  None :
6868        postproc_config  =  postproc_worker_config  or  PostprocWorkerConfig ()
6969        super ().__init__ (
@@ -102,39 +102,54 @@ def _get_comm_ranks_device_id():
102102            device_ids  =  mpi_comm ().allgather (device_id )
103103            return  comm_ranks , device_ids 
104104
105-         def  _create_py_executor (executor_config ):
106-             assert  executor_config  is  None , "expect an empty executor_config is _create_py_executor" 
107-             executor_config  =  llm_args .get_executor_config (
108-                 hf_model_dir , tokenizer )
109-             # Persist so downstream code (e.g., default max_tokens deduction) has access 
110-             self ._executor_config  =  executor_config 
111-             executor_config .logits_post_processor_config  =  tllm .LogitsPostProcessorConfig (
112-                 processor_batched = batched_logits_processor , replicate = False )
113-             comm_ranks , device_ids  =  _get_comm_ranks_device_id ()
114-             executor_config .parallel_config  =  tllm .ParallelConfig (
115-                 participant_ids = comm_ranks , device_ids = device_ids )
116-             args  =  {
117-                 "executor_config" : executor_config ,
118-                 "checkpoint_dir" : executor_config .hf_model_dir ,
119-             }
105+         def  _create_py_executor ():
106+             args  =  {}
120107            assert  hasattr (
121-                 executor_config , "backend" 
122-             ), "executor_config  should be with backend in _create_py_executor" 
123-             if  executor_config .backend  ==  "pytorch" :
108+                 self . llm_args , "backend" 
109+             ), "llm_args  should be with backend in _create_py_executor" 
110+             if  self . llm_args .backend  ==  "pytorch" :
124111                from  tensorrt_llm ._torch .pyexecutor .py_executor_creator  import  \
125112                    create_py_executor 
126113                create_executor  =  create_py_executor 
114+                 args ["llm_args" ] =  self .llm_args 
115+                 args ["checkpoint_dir" ] =  hf_model_dir 
116+                 args ["tokenizer" ] =  tokenizer 
127117                args ["lora_config" ] =  lora_config 
128118                args [
129-                     "garbage_collection_gen0_threshold" ] =  llm_args .garbage_collection_gen0_threshold 
130-             elif  executor_config .backend  ==  "_autodeploy" :
119+                     "logits_post_processor_config" ] =  tllm .LogitsPostProcessorConfig (
120+                         processor_batched = batched_logits_processor ,
121+                         replicate = False )
122+                 comm_ranks , device_ids  =  _get_comm_ranks_device_id ()
123+                 args ["parallel_config" ] =  tllm .ParallelConfig (
124+                     participant_ids = comm_ranks , device_ids = device_ids )
125+             elif  self .llm_args .backend  ==  "_autodeploy" :
126+                 from  tensorrt_llm ._torch .auto_deploy .llm_args  import  \
127+                     LlmArgs  as  ADLlmArgs 
131128                from  tensorrt_llm ._torch .auto_deploy .shim .ad_executor  import  \
132129                    create_autodeploy_executor 
133130                create_executor  =  create_autodeploy_executor 
131+                 assert  isinstance (self .llm_args , ADLlmArgs )
132+                 args ["ad_config" ] =  self .llm_args .get_pytorch_backend_config ()
134133            else :
135134                raise  ValueError (
136-                     f"Unsupported backend config: { executor_config .backend }  )
137-             return  create_executor (** args )
135+                     f"Unsupported backend config: { self .llm_args .backend }  )
136+ 
137+             # Define additional attributes that can be used later, such as in _deduce_max_tokens 
138+             self .mapping  =  self .llm_args .parallel_config .to_mapping ()
139+             self .checkpoint_loader  =  None 
140+             if  self .llm_args .backend  ==  "pytorch" :
141+                 from  tensorrt_llm ._torch .pyexecutor .config  import  \
142+                     _construct_checkpoint_loader 
143+                 self .checkpoint_loader  =  _construct_checkpoint_loader (
144+                     self .llm_args .backend , self .llm_args .checkpoint_loader ,
145+                     self .llm_args .checkpoint_format )
146+ 
147+             _executor  =  create_executor (** args )
148+             self .max_seq_len  =  self .llm_args .max_seq_len 
149+             if  _executor .max_seq_len  is  not None :
150+                 # max_seq_len might be updated by model engine as in create_py_executor 
151+                 self .max_seq_len  =  _executor .max_seq_len 
152+             return  _executor 
138153
139154        def  _create_engine (executor_config ):
140155            if  executor_config  is  None :
@@ -158,8 +173,7 @@ def _create_engine(executor_config):
158173                                 executor_config )
159174
160175        self .engine  =  _create_py_executor (
161-             executor_config ) if  llm_args  is  not None  else  _create_engine (
162-                 executor_config )
176+         ) if  self .llm_args  is  not None  else  _create_engine (executor_config )
163177
164178        self ._lora_manager : Optional [LoraManager ] =  None 
165179        self ._prompt_adapter_manager : Optional [PromptAdapterManager ] =  None 
@@ -182,8 +196,9 @@ def _create_engine(executor_config):
182196            if  engine_config .build_config .max_prompt_embedding_table_size  >  0 :
183197                self ._prompt_adapter_manager  =  PromptAdapterManager ()
184198
185-         if  getattr (self ._executor_config , "backend" ,
186-                    "" ) ==  "pytorch"  and  lora_config  is  not None :
199+         if  self .llm_args  and  getattr (
200+                 self .llm_args , "backend" ,
201+                 "" ) ==  "pytorch"  and  lora_config  is  not None :
187202            from  tensorrt_llm ._torch .pyexecutor .resource_manager  import  \
188203                ResourceManagerType 
189204            peft_cache_manager  =  self .engine .resource_manager .resource_managers .get (
@@ -465,26 +480,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
465480        assert  request .id  is  not None 
466481
467482        def  _deduce_max_tokens (request : GenerationRequest ,
468-                                executor_config : tllm .ExecutorConfig ) ->  int :
483+                                executor_config : tllm .ExecutorConfig ,
484+                                llm_args : Optional [BaseLlmArgs ] =  None ) ->  int :
469485            # deduce max_tokens when it's not set by user 
470486            max_tokens  =  request .sampling_params .max_tokens 
471487            query_token_len  =  len (
472488                request .query_token_ids ) if  request .query_token_ids  else  0 
473-             cp_size  =  1  if  (not  hasattr (executor_config , "mapping" )
474-                             or  executor_config .mapping .cp_size 
475-                             is  None ) else  executor_config .mapping .cp_size 
476-             if  not  hasattr (executor_config , "max_seq_len" ):
489+ 
490+             cp_size  =  1 
491+             max_seq_len  =  None 
492+             if  llm_args  is  not None :
493+                 # deduce max_tokens by llm args 
494+                 assert  executor_config  is  None , "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." 
495+                 if  hasattr (self ,
496+                            "mapping" ) and  self .mapping .cp_size  is  not None :
497+                     cp_size  =  self .mapping .cp_size 
498+                 max_seq_len  =  getattr (self , "max_seq_len" , None )
499+             else :
500+                 # deduce max_tokens by executor config 
501+                 if  hasattr (executor_config , "mapping" 
502+                            ) and  executor_config .mapping .cp_size  is  not None :
503+                     cp_size  =  executor_config .mapping .cp_size 
504+                 max_seq_len  =  getattr (executor_config , "max_seq_len" , None )
505+             if  max_seq_len  is  None :
477506                logger .warning ("`default_max_tokens` cannot be deduced" )
478507                if  max_tokens  is  None :
479508                    raise  ValueError (
480509                        "`max_tokens` must be set when `default_max_tokens` cannot be deduced" 
481510                    )
511+                 else :
512+                     # use max_tokens if can't deduce default_max_tokens 
513+                     return  max_tokens 
482514            splited_prompt_len  =  int (len (prompt_token_ids ) /  cp_size )
483-             default_max_tokens  =  executor_config . max_seq_len  -  splited_prompt_len  -  query_token_len 
515+             default_max_tokens  =  max_seq_len  -  splited_prompt_len  -  query_token_len 
484516            if  default_max_tokens  <=  0 :
485517                logger .warning (
486518                    f"`default_max_tokens` ({ default_max_tokens }  
487-                     f"`default_max_tokens` ({ default_max_tokens } { executor_config . max_seq_len }  
519+                     f"`default_max_tokens` ({ default_max_tokens } { max_seq_len }  
488520                    f" - `splited_prompt_len` ({ splited_prompt_len } { query_token_len }  
489521                )
490522                if  max_tokens  is  None :
@@ -506,7 +538,8 @@ def _deduce_max_tokens(request: GenerationRequest,
506538            executor_request  =  tllm .Request (
507539                client_id = request .id ,
508540                input_token_ids = prompt_token_ids ,
509-                 max_tokens = _deduce_max_tokens (request , self ._executor_config ),
541+                 max_tokens = _deduce_max_tokens (request , self ._executor_config ,
542+                                               self .llm_args ),
510543                streaming = request .streaming ,
511544                sampling_config = request .sampling_params ._get_sampling_config (),
512545                end_id = - 1  if  request .sampling_params .ignore_eos  else 
@@ -632,11 +665,19 @@ def shutdown(self):
632665            self .engine .shutdown ()
633666            self .engine  =  None 
634667
635-             if  hasattr (
636-                     self ._executor_config , "checkpoint_loader" 
637-             ) and  self ._executor_config .checkpoint_loader  is  not None :
638-                 self ._executor_config .checkpoint_loader .cleanup ()
639-                 self ._executor_config .checkpoint_loader  =  None 
668+             if  self .llm_args  is  not None :
669+                 assert  self ._executor_config  is  None , "An empty executor_config is expected in shutdown when LLM arguments are defined." 
670+                 if  (self .llm_args .backend  ==  "pytorch" 
671+                         and  hasattr (self , "checkpoint_loader" )
672+                         and  self .checkpoint_loader  is  not None ):
673+                     self .checkpoint_loader .cleanup ()
674+                     self .checkpoint_loader  =  None 
675+             else :
676+                 if  hasattr (
677+                         self ._executor_config , "checkpoint_loader" 
678+                 ) and  self ._executor_config .checkpoint_loader  is  not None :
679+                     self ._executor_config .checkpoint_loader .cleanup ()
680+                     self ._executor_config .checkpoint_loader  =  None 
640681
641682        # Check if there are any errors from the threads before shutdown. 
642683        self ._handle_background_error ()
@@ -682,7 +723,7 @@ def worker_main(
682723    lora_config : Optional [LoraConfig ] =  None ,
683724    hf_model_dir : Optional [Path ] =  None ,
684725    tokenizer : Optional [TokenizerBase ] =  None ,
685-     llm_args : Optional [TorchLlmArgs ] =  None ,
726+     llm_args : Optional [BaseLlmArgs ] =  None ,
686727) ->  None :
687728    mpi_comm ().barrier ()
688729    print_colored_debug (f"Worker { mpi_rank ()} \n " ,
0 commit comments