2525from ..attention_backend import get_sparse_attn_kv_cache_manager
2626from ..model_config import ModelConfig
2727from ..speculative import get_num_extra_kv_tokens , get_spec_decoder
28- from .config import PyTorchConfig
2928from .config_utils import is_mla , is_nemotron_hybrid , is_qwen3_next
3029from .guided_decoder import GuidedDecoder
3130from .kv_cache_connector import KvCacheConnectorManager
@@ -73,7 +72,7 @@ def __init__(
7372 max_seq_len : int ,
7473 max_batch_size : int ,
7574 kv_cache_config : KvCacheConfig ,
76- pytorch_backend_config : PyTorchConfig ,
75+ llm_args : TorchLlmArgs ,
7776 speculative_config : SpeculativeConfig ,
7877 sparse_attention_config : SparseAttentionConfig ,
7978 profiling_stage_data : Optional [dict ],
@@ -86,7 +85,7 @@ def __init__(
8685 self ._max_num_tokens = max_num_tokens
8786 self ._max_beam_width = max_beam_width
8887 self ._kv_connector_manager = kv_connector_manager
89- self ._pytorch_backend_config = pytorch_backend_config
88+ self ._llm_args = llm_args
9089 self ._speculative_config = speculative_config
9190 self ._sparse_attention_config = sparse_attention_config
9291 self ._tokens_per_block = tokens_per_block
@@ -248,9 +247,8 @@ def _get_token_num_for_estimation(self) -> int:
248247 # estimate_max_kv_cache_tokens submits self._dummy_reqs
249248 num_cache_blocks = 0
250249 num_extra_tokens_per_seq = 1 # account for generated tokens
251- pytorch_backend_config = self ._pytorch_backend_config
252250 spec_cfg = self ._speculative_config
253- if not pytorch_backend_config .disable_overlap_scheduler :
251+ if not self . _llm_args .disable_overlap_scheduler :
254252 num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
255253 if spec_cfg is not None :
256254 num_extra_tokens_per_seq += spec_cfg .max_total_draft_tokens
@@ -653,7 +651,7 @@ def create_py_executor_instance(
653651 dist ,
654652 resources ,
655653 mapping ,
656- pytorch_backend_config ,
654+ llm_args ,
657655 ctx_chunk_config ,
658656 model_engine ,
659657 start_worker ,
@@ -679,7 +677,7 @@ def create_py_executor_instance(
679677 f"max_seq_len={ max_seq_len } , max_num_requests={ max_batch_size } , max_num_tokens={ max_num_tokens } , max_batch_size={ max_batch_size } "
680678 )
681679
682- for key , value in pytorch_backend_config .extra_resource_managers .items ():
680+ for key , value in llm_args .extra_resource_managers .items ():
683681 if key in resources :
684682 raise ValueError (
685683 f"Cannot overwrite existing resource manager { key } ." )
@@ -804,8 +802,7 @@ def create_py_executor_instance(
804802 drafter = drafter ,
805803 dist = dist ,
806804 max_num_sequences = max_num_sequences ,
807- disable_overlap_scheduler = pytorch_backend_config .
808- disable_overlap_scheduler ,
805+ disable_overlap_scheduler = llm_args .disable_overlap_scheduler ,
809806 max_batch_size = max_batch_size ,
810807 max_beam_width = max_beam_width ,
811808 max_draft_len = spec_config .max_draft_len
@@ -840,13 +837,11 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
840837 )
841838
842839
843- def instantiate_sampler (engine : PyTorchModelEngine ,
844- pytorch_backend_config : PyTorchConfig , mapping : Mapping ,
845- max_batch_size : int , max_beam_width : int ,
846- max_seq_len : int , mm_encoder_only : bool ,
847- speculative_config : SpeculativeConfig ,
848- decoding_config : trtllm .DecodingConfig ,
849- kv_cache_config : KvCacheConfig ):
840+ def instantiate_sampler (
841+ engine : PyTorchModelEngine , llm_args : TorchLlmArgs , mapping : Mapping ,
842+ max_batch_size : int , max_beam_width : int , max_seq_len : int ,
843+ mm_encoder_only : bool , speculative_config : SpeculativeConfig ,
844+ decoding_config : trtllm .DecodingConfig , kv_cache_config : KvCacheConfig ):
850845 sampler_args = create_torch_sampler_args (
851846 mapping ,
852847 max_seq_len = engine .max_seq_len ,
@@ -856,7 +851,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
856851 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
857852 max_beam_width = max_beam_width )
858853 if mapping .cp_config .get ('cp_type' ) == CpType .STAR :
859- assert pytorch_backend_config .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
854+ assert llm_args .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
860855 return TorchSampler (sampler_args )
861856 if engine .spec_config is not None and engine .spec_config .spec_dec_mode .has_spec_decoder (
862857 ):
@@ -865,15 +860,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
865860 if mm_encoder_only :
866861 # NOTE: handle model outputs specially for mm encoder executor/engine
867862 return EarlyStopWithMMResult ()
868- if pytorch_backend_config .sampler_type == SamplerType .TRTLLMSampler or (
869- pytorch_backend_config .sampler_type == SamplerType .auto
863+ if llm_args .sampler_type == SamplerType .TRTLLMSampler or (
864+ llm_args .sampler_type == SamplerType .auto
870865 and decoding_mode .isBeamSearch ()):
871866 logger .debug (f"DecodingMode: { decoding_mode .name } " )
872867 return TRTLLMSampler (engine .model ,
873868 engine .dtype ,
874869 mapping ,
875870 decoding_mode ,
876- pytorch_backend_config .disable_overlap_scheduler ,
871+ llm_args .disable_overlap_scheduler ,
877872 max_seq_len = max_seq_len ,
878873 max_batch_size = max_batch_size ,
879874 max_beam_width = max_beam_width ,
@@ -935,7 +930,12 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
935930 return num_experts
936931
937932
938- def _adjust_torch_mem_fraction (pytorch_backend_config : PyTorchConfig ):
933+ def _adjust_torch_mem_fraction ():
934+ # If true, adjust PyTorch CUDA memory fraction to correspond to the
935+ # total GPU memory minus the statically allocated engine memory.
936+ # If false, set the PyTorch CUDA memory fraction to 1.0.
937+ _limit_torch_cuda_mem_fraction : bool = True
938+
939939 # FIXME: PyTorch only uses the garbage_collection_threshold setting
940940 # if a memory fraction is set, cf.
941941 # https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
@@ -964,7 +964,7 @@ def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
964964 # lead PyTorch to release all unused memory before hitting the set fraction. This
965965 # still mitigates OOM, although at a higher performance impact, because it
966966 # effectively resets the allocator cache.
967- if not pytorch_backend_config . _limit_torch_cuda_mem_fraction :
967+ if not _limit_torch_cuda_mem_fraction :
968968 return
969969 mem_reserved = torch .cuda .memory_reserved ()
970970 mem_free , mem_total = torch .cuda .mem_get_info ()
0 commit comments