diff --git a/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h index 272758936ed..70690411797 100644 --- a/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h +++ b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h @@ -88,7 +88,7 @@ class MaxRequestsScheduler : public BaseCapacityScheduler class MaxUtilizationScheduler : public BaseCapacityScheduler { public: - MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches, + MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead, LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); @@ -98,8 +98,8 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler private: SizeType32 mMaxNumRequests; - /// @brief Boolean that indicates if multiple micro batches might be in flight - bool mManyMicroBatches; + /// @brief Boolean that indicates if two step lookahead is enabled + bool mTwoStepsLookAhead; }; /// @brief Schedule requests using the GUARANTEED_NO_EVICT policy @@ -146,7 +146,7 @@ class CapacityScheduler : public Algorithm constexpr static auto name{"CapacityScheduler"}; explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy, - bool hasKvCacheManager, std::optional manyMicroBatches = std::nullopt, + bool hasKvCacheManager, bool twoStepsLookAhead = false, LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); diff --git a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp index 6805be14f2a..4608a94b891 100644 --- a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp @@ -129,11 +129,11 @@ MaxRequestsScheduler::MaxRequestsScheduler( { } -MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches, +MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) : BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState) , mMaxNumRequests(maxNumRequests) - , mManyMicroBatches{manyMicroBatches} + , mTwoStepsLookAhead{twoStepsLookAhead} { } @@ -346,7 +346,7 @@ std::tuple MaxUtilizationScheduler::operator()( // Keep track of number of requests and block needed for the scheduled requests auto scheduledBlocksManager - = kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mManyMicroBatches); + = kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mTwoStepsLookAhead); SizeType32 numScheduledPeftPages{0}; std::unordered_set seenTaskIds; @@ -456,8 +456,8 @@ bool trySchedulingRequestMaxUtilization(std::shared_ptr const& req, } CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests, - executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, - std::optional manyMicroBatches, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) + executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, bool twoStepsLookAhead, + LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) { if (!hasKvCacheManager) { @@ -465,8 +465,8 @@ CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests, } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION) { - mScheduler = MaxUtilizationScheduler{ - maxNumRequests, manyMicroBatches ? *manyMicroBatches : false, noScheduleUntilState, noScheduleAfterState}; + mScheduler + = MaxUtilizationScheduler{maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState}; } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) { diff --git a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp index 6f09753ea64..980423d7d8e 100644 --- a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp @@ -75,8 +75,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC // handling of maximizing utilization or pause/evict // TODO: finer control on encoder requests scheduling mCapacityScheduler = std::make_unique( - getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, - std::nullopt, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); + getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false, + LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); mMicroBatchScheduler = std::make_unique( std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModel.h b/cpp/tensorrt_llm/batch_manager/trtGptModel.h index 65075bc84be..25b6ef356d8 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModel.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModel.h @@ -116,9 +116,7 @@ class TrtGptModel : public executor::Model ? optionalParams.kvCacheConfig.sinkTokenLength.value() : 0; - auto const numBatches - = worldConfig.isPipelineParallel() ? worldConfig.getPipelineParallelism() : (mEnableTrtOverlap ? 2 : 1); - mMaxNumSequences = numBatches * mMaxBatchSize; + mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism(); auto const numTotalAttenLayers = modelConfig.getNbAttentionLayers(); auto const numRepeatsAttenWindow = numTotalAttenLayers / mMaxAttentionWindowVec.size(); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index d5b5db4f097..491a104240a 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -412,7 +412,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrgetTokensPerBlock()); } - mCapacityScheduler = std::make_unique(getMaxBatchSize() * mNumMicroBatches, + mCapacityScheduler = std::make_unique(getMaxNumSequences(), optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), mKvCacheManager != nullptr, mNumMicroBatches > 1); mMicroBatchScheduler = std::make_unique(ctxChunkConfig, maxContextLength); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index ab2576941f0..b7a49cc4945 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -56,7 +56,7 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod py::class_(m, CapacityScheduler::name) .def(py::init(), py::arg("max_num_requests"), py::arg("capacity_scheduler_policy"), py::arg("has_kv_cache_manager"), - py::arg("many_micro_batches") = false, + py::arg("two_step_lookahead") = false, py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"), py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE, "LlmRequestState.GENERATION_COMPLETE")) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 92b155df639..fa65e7dd7de 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -411,13 +411,9 @@ def create_py_executor_instance( lora_config.lora_target_modules, lora_config.trtllm_modules_to_hf_modules) - if mapping.has_pp(): - num_micro_batches = mapping.pp_size - else: - num_micro_batches = 1 if pytorch_backend_config.disable_overlap_scheduler else 2 + max_num_sequences = executor_config.max_batch_size * mapping.pp_size - resources["seq_slot_manager"] = SeqSlotManager( - executor_config.max_batch_size * num_micro_batches) + resources["seq_slot_manager"] = SeqSlotManager(max_num_sequences) resource_manager = ResourceManager(resources) @@ -428,10 +424,11 @@ def create_py_executor_instance( last=True) capacity_scheduler = BindCapacityScheduler( - executor_config.max_batch_size, + max_num_sequences, kv_cache_manager.impl if kv_cache_manager is not None else None, executor_config.scheduler_config.capacity_scheduler_policy, - num_micro_batches=num_micro_batches) + two_step_lookahead=mapping.has_pp() + or not pytorch_backend_config.disable_overlap_scheduler) mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size, executor_config.max_num_tokens, ctx_chunk_config) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 80f36c4dfe9..15259835da5 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -77,16 +77,16 @@ def __init__( kv_cache_manager, scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor. CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, - num_micro_batches: int = 1, + two_step_lookahead: bool = False, ): super(BindCapacityScheduler, self).__init__() self.kv_cache_manager = kv_cache_manager self.impl = tb_internal.algorithms.CapacityScheduler( - max_num_requests=max_num_requests * num_micro_batches, + max_num_requests=max_num_requests, capacity_scheduler_policy=scheduler_policy, has_kv_cache_manager=kv_cache_manager is not None, - many_micro_batches=num_micro_batches > 1, + two_step_lookahead=two_step_lookahead, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 6aa1b9a4bf1..1dd4a69b70a 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -191,8 +191,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, @skip_pre_hopper def test_fp8_llm_sampler(self): model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" - pytorch_config = dict(enable_trtllm_sampler=True) - llm = LLM(model_path, **pytorch_config) + llm = LLM(model_path, enable_trtllm_sampler=True, max_batch_size=256) assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 sampling_params = SamplingParams(