Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MaxRequestsScheduler : public BaseCapacityScheduler
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);

Expand All @@ -98,8 +98,8 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler

private:
SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
/// @brief Boolean that indicates if two step lookahead is enabled
bool mTwoStepsLookAhead;
};

/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
Expand Down Expand Up @@ -146,7 +146,7 @@ class CapacityScheduler : public Algorithm
constexpr static auto name{"CapacityScheduler"};

explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, std::optional<bool> manyMicroBatches = std::nullopt,
bool hasKvCacheManager, bool twoStepsLookAhead = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);

Expand Down
14 changes: 7 additions & 7 deletions cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ MaxRequestsScheduler::MaxRequestsScheduler(
{
}

MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState)
, mMaxNumRequests(maxNumRequests)
, mManyMicroBatches{manyMicroBatches}
, mTwoStepsLookAhead{twoStepsLookAhead}
{
}

Expand Down Expand Up @@ -346,7 +346,7 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(

// Keep track of number of requests and block needed for the scheduled requests
auto scheduledBlocksManager
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mManyMicroBatches);
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mTwoStepsLookAhead);
SizeType32 numScheduledPeftPages{0};
std::unordered_set<uint64_t> seenTaskIds;

Expand Down Expand Up @@ -456,17 +456,17 @@ bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
}

CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager,
std::optional<bool> manyMicroBatches, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
{
if (!hasKvCacheManager)
{
mScheduler = MaxRequestsScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION)
{
mScheduler = MaxUtilizationScheduler{
maxNumRequests, manyMicroBatches ? *manyMicroBatches : false, noScheduleUntilState, noScheduleAfterState};
mScheduler
= MaxUtilizationScheduler{maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
// handling of maximizing utilization or pause/evict
// TODO: finer control on encoder requests scheduling
mCapacityScheduler = std::make_unique<tensorrt_llm::batch_manager::CapacityScheduler>(
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false,
std::nullopt, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false,
LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);

mMicroBatchScheduler = std::make_unique<tensorrt_llm::batch_manager::MicroBatchScheduler>(
std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
Expand Down
4 changes: 1 addition & 3 deletions cpp/tensorrt_llm/batch_manager/trtGptModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ class TrtGptModel : public executor::Model
? optionalParams.kvCacheConfig.sinkTokenLength.value()
: 0;

auto const numBatches
= worldConfig.isPipelineParallel() ? worldConfig.getPipelineParallelism() : (mEnableTrtOverlap ? 2 : 1);
mMaxNumSequences = numBatches * mMaxBatchSize;
mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism();

auto const numTotalAttenLayers = modelConfig.getNbAttentionLayers();
auto const numRepeatsAttenWindow = numTotalAttenLayers / mMaxAttentionWindowVec.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
ctxChunkConfig.value().chunkUnitSize, mKvCacheManager->getTokensPerBlock());
}

mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxBatchSize() * mNumMicroBatches,
mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxNumSequences(),
optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), mKvCacheManager != nullptr, mNumMicroBatches > 1);

mMicroBatchScheduler = std::make_unique<MicroBatchScheduler>(ctxChunkConfig, maxContextLength);
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
.def(py::init<SizeType32, executor::CapacitySchedulerPolicy, bool, bool, LlmRequestState, LlmRequestState>(),
py::arg("max_num_requests"), py::arg("capacity_scheduler_policy"), py::arg("has_kv_cache_manager"),
py::arg("many_micro_batches") = false,
py::arg("two_step_lookahead") = false,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))
Expand Down
13 changes: 5 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,9 @@ def create_py_executor_instance(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)

if mapping.has_pp():
num_micro_batches = mapping.pp_size
else:
num_micro_batches = 1 if pytorch_backend_config.disable_overlap_scheduler else 2
max_num_sequences = executor_config.max_batch_size * mapping.pp_size

resources["seq_slot_manager"] = SeqSlotManager(
executor_config.max_batch_size * num_micro_batches)
resources["seq_slot_manager"] = SeqSlotManager(max_num_sequences)

resource_manager = ResourceManager(resources)

Expand All @@ -428,10 +424,11 @@ def create_py_executor_instance(
last=True)

capacity_scheduler = BindCapacityScheduler(
executor_config.max_batch_size,
max_num_sequences,
kv_cache_manager.impl if kv_cache_manager is not None else None,
executor_config.scheduler_config.capacity_scheduler_policy,
num_micro_batches=num_micro_batches)
two_step_lookahead=mapping.has_pp()
or not pytorch_backend_config.disable_overlap_scheduler)
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
executor_config.max_num_tokens,
ctx_chunk_config)
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def __init__(
kv_cache_manager,
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
num_micro_batches: int = 1,
two_step_lookahead: bool = False,
):
super(BindCapacityScheduler, self).__init__()
self.kv_cache_manager = kv_cache_manager

self.impl = tb_internal.algorithms.CapacityScheduler(
max_num_requests=max_num_requests * num_micro_batches,
max_num_requests=max_num_requests,
capacity_scheduler_policy=scheduler_policy,
has_kv_cache_manager=kv_cache_manager is not None,
many_micro_batches=num_micro_batches > 1,
two_step_lookahead=two_step_lookahead,
no_schedule_until_state=LlmRequestState.CONTEXT_INIT,
no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE)

Expand Down
3 changes: 1 addition & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
@skip_pre_hopper
def test_fp8_llm_sampler(self):
model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
pytorch_config = dict(enable_trtllm_sampler=True)
llm = LLM(model_path, **pytorch_config)
llm = LLM(model_path, enable_trtllm_sampler=True, max_batch_size=256)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8

sampling_params = SamplingParams(
Expand Down